You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
214 lines
7.7 KiB
214 lines
7.7 KiB
# -*- coding: utf-8 -*-
|
|
import os.path
|
|
import httpx
|
|
import time
|
|
from httpx import BasicAuth
|
|
from typing import Dict, Any, Optional, Tuple
|
|
|
|
from .models import AlphaMetrics, TrainMetrics, TestMetrics, AlphaInfo
|
|
|
|
|
|
class WorldQuantBrainSimulate:
|
|
def __init__(self, credentials_file='account.txt'):
|
|
self.credentials_file = credentials_file
|
|
self.client = None
|
|
self.brain_api_url = 'https://api.worldquantbrain.com'
|
|
|
|
"""读取本地账号密码"""
|
|
def load_credentials(self) -> Tuple[str, str]:
|
|
if not os.path.exists(self.credentials_file):
|
|
print("未找到 account.txt 文件")
|
|
with open(self.credentials_file, 'w') as f:
|
|
f.write("")
|
|
print("account.txt 文件已创建,请填写账号密码, 格式: ['username', 'password]")
|
|
exit(1)
|
|
|
|
with open(self.credentials_file) as f:
|
|
credentials = eval(f.read())
|
|
return credentials[0], credentials[1]
|
|
|
|
"""登录认证"""
|
|
def login(self) -> bool:
|
|
username, password = self.load_credentials()
|
|
self.client = httpx.Client(auth=BasicAuth(username, password))
|
|
|
|
response = self.client.post(f'{self.brain_api_url}/authentication')
|
|
print(f"登录状态: {response.status_code}")
|
|
|
|
if response.status_code == 201:
|
|
print("登录成功!")
|
|
return True
|
|
else:
|
|
print(f"登录失败: {response.json()}")
|
|
return False
|
|
|
|
"""模拟Alpha因子"""
|
|
def simulate_alpha(self, expression: str, settings: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
|
if self.client is None:
|
|
raise Exception("请先登录")
|
|
|
|
default_settings = {
|
|
'instrumentType': 'EQUITY',
|
|
'region': 'USA',
|
|
'universe': 'TOP3000',
|
|
'delay': 1,
|
|
'decay': 0,
|
|
'neutralization': 'INDUSTRY',
|
|
'truncation': 0.08,
|
|
'pasteurization': 'ON',
|
|
'unitHandling': 'VERIFY',
|
|
'nanHandling': 'OFF',
|
|
'language': 'FASTEXPR',
|
|
'visualization': False,
|
|
}
|
|
|
|
if settings:
|
|
default_settings.update(settings)
|
|
|
|
simulation_data = {
|
|
'type': 'REGULAR',
|
|
'settings': default_settings,
|
|
'regular': expression
|
|
}
|
|
|
|
sim_resp = self.client.post(f'{self.brain_api_url}/simulations', json=simulation_data)
|
|
print(f"模拟提交状态: {sim_resp.status_code}")
|
|
|
|
sim_progress_url = sim_resp.headers['location']
|
|
print(f"进度URL: {sim_progress_url}")
|
|
|
|
while True:
|
|
sim_progress_resp = self.client.get(sim_progress_url)
|
|
retry_after_sec = float(sim_progress_resp.headers.get("Retry-After", 0))
|
|
|
|
if retry_after_sec == 0:
|
|
break
|
|
print(sim_progress_resp.json())
|
|
print(f"等待 {retry_after_sec} 秒...")
|
|
time.sleep(retry_after_sec)
|
|
|
|
# 如果因子模拟不通过, 获取一下失败信息
|
|
if sim_progress_resp.json()["status"] == "ERROR":
|
|
result = sim_progress_resp.json()["message"]
|
|
print(f"因子模拟失败: {result}")
|
|
# 返回一个特殊标识,表示模拟失败
|
|
return {"status": "error", "message": result}
|
|
|
|
alpha_id = sim_progress_resp.json()["alpha"]
|
|
print(f"生成的Alpha ID: {alpha_id}")
|
|
|
|
# 获取详细的性能指标
|
|
metrics = self.get_alpha_metrics(alpha_id)
|
|
|
|
return {"status": "success", "alpha_id": alpha_id, "metrics": metrics}
|
|
|
|
"""获取Alpha因子的详细指标"""
|
|
def get_alpha_metrics(self, alpha_id: str) -> AlphaMetrics:
|
|
if self.client is None:
|
|
raise Exception("请先登录")
|
|
|
|
try:
|
|
# 获取Alpha的基本信息和指标
|
|
alpha_url = f'{self.brain_api_url}/alphas/{alpha_id}'
|
|
alpha_resp = self.client.get(alpha_url)
|
|
|
|
if alpha_resp.status_code in [200, 201]:
|
|
alpha_data = alpha_resp.json()
|
|
return self._parse_alpha_metrics(alpha_data)
|
|
else:
|
|
return AlphaMetrics(
|
|
train_metrics=TrainMetrics(),
|
|
is_metrics=TestMetrics(),
|
|
test_metrics=TestMetrics(),
|
|
alpha_info=AlphaInfo()
|
|
)
|
|
|
|
except Exception as e:
|
|
print(f"获取指标时出错: {str(e)}")
|
|
return AlphaMetrics(
|
|
train_metrics=TrainMetrics(),
|
|
is_metrics=TestMetrics(),
|
|
test_metrics=TestMetrics(),
|
|
alpha_info=AlphaInfo()
|
|
)
|
|
|
|
"""解析Alpha数据,提取关键指标"""
|
|
def _parse_alpha_metrics(self, alpha_data: Dict[str, Any]) -> AlphaMetrics:
|
|
# 解析训练集数据
|
|
train_metrics = TrainMetrics()
|
|
if 'train' in alpha_data and alpha_data['train']:
|
|
train_data = alpha_data['train']
|
|
train_metrics = TrainMetrics(
|
|
sharpe_ratio=train_data.get('sharpe'),
|
|
annual_return=train_data.get('returns'),
|
|
max_drawdown=train_data.get('drawdown'),
|
|
turnover=train_data.get('turnover'),
|
|
fitness=train_data.get('fitness'),
|
|
pnl=train_data.get('pnl'),
|
|
book_size=train_data.get('bookSize'),
|
|
long_count=train_data.get('longCount'),
|
|
short_count=train_data.get('shortCount'),
|
|
margin=train_data.get('margin'),
|
|
)
|
|
|
|
# 解析样本内测试数据
|
|
is_metrics = TestMetrics()
|
|
if 'is' in alpha_data and alpha_data['is']:
|
|
is_data = alpha_data['is']
|
|
is_metrics = TestMetrics(
|
|
sharpe_ratio=is_data.get('sharpe'),
|
|
annual_return=is_data.get('returns'),
|
|
max_drawdown=is_data.get('drawdown'),
|
|
turnover=is_data.get('turnover'),
|
|
fitness=is_data.get('fitness'),
|
|
pnl=is_data.get('pnl'),
|
|
)
|
|
|
|
# 解析样本外测试数据
|
|
test_metrics = TestMetrics()
|
|
if 'test' in alpha_data and alpha_data['test']:
|
|
test_data = alpha_data['test']
|
|
test_metrics = TestMetrics(
|
|
sharpe_ratio=test_data.get('sharpe'),
|
|
annual_return=test_data.get('returns'),
|
|
max_drawdown=test_data.get('drawdown'),
|
|
turnover=test_data.get('turnover'),
|
|
fitness=test_data.get('fitness'),
|
|
pnl=test_data.get('pnl'),
|
|
)
|
|
|
|
# 解析Alpha基本信息
|
|
alpha_info = AlphaInfo(
|
|
grade=alpha_data.get('grade'),
|
|
stage=alpha_data.get('stage'),
|
|
status=alpha_data.get('status'),
|
|
date_created=alpha_data.get('dateCreated'),
|
|
)
|
|
|
|
# 解析检查结果
|
|
if 'is' in alpha_data and 'checks' in alpha_data['is']:
|
|
checks = alpha_data['is']['checks']
|
|
check_results = {}
|
|
for check in checks:
|
|
check_name = check.get('name', '')
|
|
result = check.get('result', '')
|
|
value = check.get('value', None)
|
|
check_results[check_name.lower()] = {
|
|
'result': result,
|
|
'value': value,
|
|
'limit': check.get('limit', None)
|
|
}
|
|
alpha_info.checks = check_results
|
|
|
|
return AlphaMetrics(
|
|
train_metrics=train_metrics,
|
|
is_metrics=is_metrics,
|
|
test_metrics=test_metrics,
|
|
alpha_info=alpha_info,
|
|
alpha_id=alpha_data.get('id')
|
|
)
|
|
|
|
def close(self):
|
|
"""关闭连接"""
|
|
if self.client:
|
|
self.client.close() |