增加获取数据集名称, 指定地区和股票池

main
Jack 4 weeks ago
parent faaa8e3802
commit 22ef11dfcc
  1. 133
      rpc_batch_fetch_dataset/rpc_create_dataset_record.py

@ -0,0 +1,133 @@
# -*- coding: utf-8 -*-
import httpx
import json
import time
import random
import csv
params = {
'delay': 1,
'limit': 50,
'order': '-coverage',
'region': 'EUR',
'universe': 'TOP2500'
}
need_save = 0
def login():
"""登录WorldQuant Brain API"""
nacos_resp = httpx.get('http://192.168.31.41:30848/nacos/v1/cs/configs?dataId=wq_account&group=quantify', timeout=10)
if nacos_resp.status_code != 200:
print('获取账号密码失败')
return None
config = nacos_resp.json()
username = config['user_name']
password = config['password']
print(f"正在登录账户: {username}")
client = httpx.Client(auth=httpx.BasicAuth(username, password), timeout=10)
response = client.post('https://api.worldquantbrain.com/authentication')
print(f"登录状态: {response.status_code}")
if response.status_code == 201:
print("登录成功!")
return client
else:
print(f"登录失败: {response.json()}")
client.close()
return None
def request_with_retry(client, url, max_retries=3, retry_delay=10):
for attempt in range(max_retries):
try:
response = client.get(url, timeout=10)
if response.status_code == 200 and response.content:
return response
else:
print(f"请求失败,状态码: {response.status_code},内容长度: {len(response.content)},重试 {attempt + 1}/{max_retries}")
except Exception as e:
print(f"请求异常: {e},重试 {attempt + 1}/{max_retries}")
if attempt < max_retries - 1:
time.sleep(retry_delay)
return None
client = login()
if client:
all_results = []
offset = 0
while True:
params['offset'] = offset
param_str = '&'.join([f"{k}={v}" for k, v in params.items()])
url = f'https://api.worldquantbrain.com/data-sets?instrumentType=EQUITY&{param_str}'
print(f"请求URL: {url}")
response = request_with_retry(client, url)
if response is None:
print(f"offset {offset} 请求失败,停止获取")
break
try:
data = response.json()
except Exception as e:
print(f"解析JSON失败: {e},响应内容前500字符: {response.text[:500]}")
break
if not data or 'results' not in data or not data['results']:
break
all_results.extend(data['results'])
print(f"已获取 offset {offset},共 {len(data['results'])} 条,总计 {len(all_results)} 条记录")
if len(data['results']) < params['limit']:
break
offset += params['limit']
time.sleep(random.uniform(3, 5))
# 保存到CSV文件
if all_results and need_save:
filename = f"{params['region']}_{params['universe']}_D{params['delay']}.csv"
# 获取所有字段名(所有字典的键的并集)
fieldnames = set()
for item in all_results:
fieldnames.update(item.keys())
# 处理嵌套字段
for item in all_results:
for key, value in item.items():
if isinstance(value, (dict, list)):
# 将嵌套结构转为JSON字符串
item[key] = json.dumps(value, ensure_ascii=False)
with open(filename, 'w', newline='', encoding='utf-8') as f:
writer = csv.DictWriter(f, fieldnames=sorted(fieldnames))
writer.writeheader()
writer.writerows(all_results)
print(f"已保存 {len(all_results)} 条记录到 {filename}")
else:
print("没有获取到任何数据")
if all_results:
# 输出 coverage 大于 0.5 的数据的 id
print("\ncoverage > 0.5 的数据 ID:")
coverage_ids = []
for item in all_results:
if 'coverage' in item and item['coverage'] > 0.5:
coverage_ids.append(item['id'])
print(item['id'])
print(f"\n{len(coverage_ids)} 个数据集 coverage > 0.5")
client.close()
Loading…
Cancel
Save