You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
FactorSimulator/main.py

530 lines
20 KiB

# -*- coding: utf-8 -*-
import os
import time
import random
import psycopg2
import psycopg2.extras
import httpx
import threading
from datetime import datetime
from dataclasses import dataclass
from typing import List, Tuple, Dict, Any, Optional
from httpx import BasicAuth
from configparser import ConfigParser
import logging
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class Settings:
"""应用配置 - 完全按照原代码"""
def __init__(self):
# 数据库配置
self.config_file = "db.conf"
self.config = ConfigParser()
if not os.path.exists(self.config_file):
raise FileNotFoundError(f"数据库配置文件 {self.config_file} 不存在")
self.config.read(self.config_file, encoding='utf-8')
self.DATABASE_CONFIG = {
"host": self.config.get('database', 'host'),
"port": self.config.get('database', 'port'),
"user": self.config.get('database', 'user'),
"password": self.config.get('database', 'password'),
"database": self.config.get('database', 'database')
}
# API配置
self.BRAIN_API_URL = "https://api.worldquantbrain.com"
# 模拟配置 - 完全按照原代码
self.SIMULATION_SETTINGS = {
'instrumentType': 'EQUITY',
'region': 'USA',
'universe': 'TOP3000',
'delay': 1,
'decay': 0,
'neutralization': 'INDUSTRY',
'truncation': 0.08,
'pasteurization': 'ON',
'unitHandling': 'VERIFY',
'nanHandling': 'OFF',
'language': 'FASTEXPR',
'visualization': False,
}
# 批次配置 - 每批8个
self.BATCH_SIZE = 8
# 模拟间隔配置
self.SIMULATION_INTERVAL_MIN = 5 # 最小间隔秒数
self.SIMULATION_INTERVAL_MAX = 10 # 最大间隔秒数
# 通知配置
self.GOTIFY_URL = os.getenv("GOTIFY_URL", "https://gotify.erhe.top/message?token=AvKJCJwQKU6yLP8")
# Session有效期(秒) - 4小时
self.SESSION_EXPIRY_SECONDS = 4 * 3600 # 4小时
@property
def credentials_file(self) -> str:
"""获取凭证文件路径"""
return "account.txt"
@dataclass
class SimulationResult:
"""模拟结果实体"""
alpha: str
time_consuming: float
status: str
timestamp: str
alpha_id: Optional[str] = None
message: Optional[str] = None
class AuthService:
"""认证服务类 - 按照原代码逻辑"""
def __init__(self, settings: Settings):
self.settings = settings
self.credentials_file = settings.credentials_file
self.last_login_time: Optional[float] = None
self.login_expiry: Optional[float] = None
def load_credentials(self) -> Tuple[str, str]:
"""加载凭证"""
if not os.path.exists(self.credentials_file):
self._create_credentials_file()
with open(self.credentials_file, 'r', encoding='utf-8') as f:
credentials = eval(f.read())
return credentials[0], credentials[1]
def _create_credentials_file(self) -> None:
"""创建凭证文件"""
logger.error(f"未找到 {self.credentials_file} 文件")
with open(self.credentials_file, 'w', encoding='utf-8') as f:
f.write("['your_username', 'your_password']")
logger.error(f"请编辑 {self.credentials_file} 文件,填写账号密码")
raise FileNotFoundError(f"请创建并配置 {self.credentials_file} 文件")
def login(self, client: httpx.Client) -> Dict[str, Any]:
"""登录并设置client的认证信息 - 按照原代码"""
username, password = self.load_credentials()
client.auth = BasicAuth(username, password)
response = client.post(f'{self.settings.BRAIN_API_URL}/authentication')
logger.info(f"登录状态: {response.status_code}")
if response.status_code == 201:
login_data = response.json()
self.last_login_time = time.time()
self.login_expiry = self.last_login_time + self.settings.SESSION_EXPIRY_SECONDS
logger.info(f"登录成功!Session有效期: {self.settings.SESSION_EXPIRY_SECONDS}秒 (4.0小时)")
return login_data
elif response.status_code == 429:
logger.error("API rate limit exceeded")
raise Exception("API rate limit exceeded")
else:
logger.error(f"登录失败: {response.status_code} - {response.text}")
raise Exception(f"登录失败: {response.status_code} - {response.text}")
def is_session_expired(self) -> bool:
"""检查session是否已过期"""
if not self.last_login_time or not self.login_expiry:
return True
current_time = time.time()
return current_time >= self.login_expiry
def needs_renewal(self) -> bool:
"""检查session是否需要续期(小于30分钟)"""
if not self.login_expiry:
return True
current_time = time.time()
time_remaining = self.login_expiry - current_time
return time_remaining < 1800 # 30分钟
class AlphaService:
"""Alpha表达式服务类 - 完全按照原代码逻辑"""
def __init__(self, client: httpx.Client, settings: Settings):
self.client = client
self.settings = settings
def simulate_alpha(self, alpha: str) -> Dict[str, Any]:
"""模拟单个Alpha表达式 - 完全按照原代码"""
simulation_data = {
'type': 'REGULAR',
'settings': self.settings.SIMULATION_SETTINGS, # 直接使用settings
'regular': alpha
}
logger.info(f"发送模拟请求,alpha长度: {len(alpha)}")
sim_resp = self.client.post(f'{self.settings.BRAIN_API_URL}/simulations', json=simulation_data)
logger.info(f"模拟提交状态: {sim_resp.status_code}")
if 'location' not in sim_resp.headers:
# 打印更多调试信息
logger.error(f"缺少location header,状态码: {sim_resp.status_code}")
logger.error(f"响应头: {dict(sim_resp.headers)}")
logger.error(f"响应内容: {sim_resp.text[:500] if sim_resp.text else ''}")
return {"status": "err", "message": "No location header in response"}
sim_progress_url = sim_resp.headers['location']
logger.info(f"模拟任务已创建,进度URL: {sim_progress_url}")
return self._wait_for_simulation_result(sim_progress_url)
def _wait_for_simulation_result(self, progress_url: str) -> Dict[str, Any]:
"""等待模拟结果 - 按照原代码逻辑"""
while True:
sim_progress_resp = self.client.get(progress_url)
retry_after_sec = float(sim_progress_resp.headers.get("Retry-After", 0))
if retry_after_sec == 0:
break
if sim_progress_resp.json():
result = sim_progress_resp.json()
progress = result.get('progress', 0)
logger.info(f"模拟进度: {progress}%")
time.sleep(retry_after_sec)
result_data = sim_progress_resp.json()
if result_data.get("status") == "ERROR":
error_message = result_data.get("message", "未知错误")
logger.error(f"因子模拟失败: {error_message}")
return {"status": "err", "message": error_message}
alpha_id = result_data.get("alpha")
logger.info(f"生成的Alpha ID: {alpha_id}")
return {"status": "ok", "alpha_id": alpha_id}
class NotificationService:
"""通知服务类"""
def __init__(self, settings: Settings):
self.settings = settings
def send_to_gotify(self, success_count: int, fail_count: int) -> None:
"""发送结果到Gotify"""
now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
text = f"总计: 成功 {success_count} 个, 失败 {fail_count}\n\n完成时间: {now}"
title = f"alpha模拟结果 时间: {now}"
try:
resp = httpx.post(
self.settings.GOTIFY_URL,
json={'title': title, 'message': text},
timeout=10
)
logger.info("通知发送成功")
except Exception as e:
logger.error(f"通知发送失败: {e}")
class DatabaseManager:
"""数据库管理类"""
def __init__(self, settings: Settings):
self.settings = settings
self.connection = None
def get_connection(self) -> psycopg2.extensions.connection:
"""获取数据库连接"""
if self.connection is None or self.connection.closed:
self.connection = psycopg2.connect(
host=self.settings.DATABASE_CONFIG["host"],
port=self.settings.DATABASE_CONFIG["port"],
database=self.settings.DATABASE_CONFIG["database"],
user=self.settings.DATABASE_CONFIG["user"],
password=self.settings.DATABASE_CONFIG["password"]
)
return self.connection
def get_unused_alpha(self) -> List[Tuple[int, str]]:
"""获取所有未使用的alpha表达式(返回ID和alpha值)"""
conn = self.get_connection()
cursor = conn.cursor()
cursor.execute("SELECT id, alpha FROM alpha_simulation WHERE unused = TRUE AND alpha IS NOT NULL")
return cursor.fetchall()
def update_alpha_simulation(self, record_id: int, result: SimulationResult) -> None:
"""更新alpha_simulation表的记录"""
conn = self.get_connection()
cursor = conn.cursor()
cursor.execute('''
UPDATE alpha_simulation
SET unused = FALSE,
time_consuming = %s,
status = %s,
timestamp = %s,
alpha_id = %s,
message = %s
WHERE id = %s
''', (
result.time_consuming,
result.status,
result.timestamp,
result.alpha_id,
result.message or "",
record_id
))
conn.commit()
def close_connection(self) -> None:
"""关闭数据库连接"""
if self.connection and not self.connection.closed:
self.connection.close()
class AlphaSimulator:
"""Alpha模拟器主类 - 按照原代码逻辑"""
def __init__(self):
self.settings = Settings()
self.db_manager = DatabaseManager(self.settings)
self.auth_service = AuthService(self.settings)
self.notification_service = NotificationService(self.settings)
self.client: Optional[httpx.Client] = None
self.alpha_service: Optional[AlphaService] = None
self.success_count = 0
self.fail_count = 0
self.total_processed = 0
def __del__(self):
"""析构函数,确保数据库连接被关闭"""
if hasattr(self, 'db_manager'):
self.db_manager.close_connection()
if self.client:
self.client.close()
def initialize(self) -> None:
"""初始化模拟器"""
try:
self.client = httpx.Client()
self.login()
logger.info("模拟器初始化成功")
except Exception as e:
logger.error(f"模拟器初始化失败: {e}")
raise
def login(self) -> Dict[str, Any]:
"""登录并初始化alpha服务"""
logger.info("正在登录...")
login_result = self.auth_service.login(self.client)
self.alpha_service = AlphaService(self.client, self.settings)
expires_in = self.settings.SESSION_EXPIRY_SECONDS
hours = int(expires_in // 3600)
minutes = int((expires_in % 3600) // 60)
logger.info(f"登录成功!Session将在 {hours}小时{minutes}分钟后过期")
return login_result
def check_and_refresh_session_if_needed(self) -> None:
"""
检查并刷新session(如果需要)
如果session即将过期(小于30分钟),则重新登录
"""
if not self.auth_service.last_login_time:
logger.info("未找到登录信息,需要登录")
self.login()
return
if self.auth_service.is_session_expired():
current_time = time.time()
expired_seconds = current_time - self.auth_service.login_expiry
logger.info(f"Session已过期(已过期 {expired_seconds:.0f} 秒),需要重新登录")
self.login()
return
if self.auth_service.needs_renewal():
time_remaining = self.auth_service.login_expiry - time.time()
minutes = int(time_remaining // 60)
seconds = int(time_remaining % 60)
logger.info(f"Session将在{minutes}{seconds}秒后过期(小于30分钟),需要重新登录")
# 创建新的client对象
if self.client:
self.client.close()
self.client = httpx.Client()
self.login()
return
time_remaining = self.auth_service.login_expiry - time.time()
hours = int(time_remaining // 3600)
minutes = int((time_remaining % 3600) // 60)
seconds = int(time_remaining % 60)
logger.info(f"Session仍然有效,剩余时间: {hours}小时{minutes}{seconds}")
def load_alpha_list(self) -> List[Tuple[int, str]]:
"""从数据库加载未使用的alpha表达式(包含ID)"""
return self.db_manager.get_unused_alpha()
def run_batch_simulation(self, alpha_records: List[Tuple[int, str]]) -> Tuple[int, int]:
"""运行批次模拟,每批8个,等一批全部完成再开始下一批"""
total_count = len(alpha_records)
batch_count = (total_count + self.settings.BATCH_SIZE - 1) // self.settings.BATCH_SIZE
logger.info(f"开始批次模拟,总共 {total_count} 个因子,每批 {self.settings.BATCH_SIZE} 个, 当前时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
# 重置计数器
self.success_count = 0
self.fail_count = 0
self.total_processed = total_count
for batch_idx in range(batch_count):
# 在每批开始前检查session状态,如果需要则重新登录
logger.info("检查session状态...")
self.check_and_refresh_session_if_needed()
start_idx = batch_idx * self.settings.BATCH_SIZE
end_idx = min(start_idx + self.settings.BATCH_SIZE, total_count)
batch_records = alpha_records[start_idx:end_idx]
logger.info(f"\n=== 开始第 {batch_idx + 1}/{batch_count} 批,本批 {len(batch_records)} 个因子 ===\t 当前时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
# 创建并启动本批的所有线程
threads = []
thread_results = [None] * len(batch_records)
for i, (record_id, alpha) in enumerate(batch_records):
thread = threading.Thread(
target=self._simulate_single_in_batch,
args=(record_id, alpha, i, thread_results)
)
thread.daemon = True
threads.append(thread)
thread.start()
logger.info(f"启动第 {batch_idx + 1} 批第 {i + 1} 个因子: {alpha[:50]}... (ID: {record_id})")
# 等待本批所有线程完成
for thread in threads:
thread.join()
logger.info(f"=== 第 {batch_idx + 1}/{batch_count} 批完成 ===, 当前时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
# 每批之间随机sleep
if batch_idx < batch_count - 1:
sleep_time = random.uniform(self.settings.SIMULATION_INTERVAL_MIN, self.settings.SIMULATION_INTERVAL_MAX)
logger.info(f"等待 {sleep_time:.1f} 秒后开始下一批...")
time.sleep(sleep_time)
# 打印总结并发送通知
self._print_summary()
# 返回最终结果
return self.success_count, self.fail_count
def _simulate_single_in_batch(self, record_id: int, alpha: str, index: int, results_list: list) -> None:
"""批次中模拟单个Alpha表达式"""
start_time = time.time()
try:
# 模拟alpha - 直接调用,不传递额外参数
result = self.alpha_service.simulate_alpha(alpha)
end_time = time.time()
time_consuming = round(end_time - start_time, 2)
simulation_result = SimulationResult(
alpha=alpha,
time_consuming=time_consuming,
status=result["status"],
timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
alpha_id=result.get("alpha_id"),
message=result.get("message", "")
)
# 立即保存到数据库
self.db_manager.update_alpha_simulation(record_id, simulation_result)
if result["status"] == "ok":
self.success_count += 1
logger.info(f"✅ 批次内第 {index + 1} 个因子模拟成功 - Alpha ID: {result['alpha_id']} - 耗时: {time_consuming}")
else:
self.fail_count += 1
logger.error(f"❌ 批次内第 {index + 1} 个因子模拟失败 - {result.get('message', '未知错误')}")
results_list[index] = simulation_result
except Exception as e:
end_time = time.time()
time_consuming = round(end_time - start_time, 2)
simulation_result = SimulationResult(
alpha=alpha,
time_consuming=time_consuming,
status="err",
message=str(e),
timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S")
)
self.db_manager.update_alpha_simulation(record_id, simulation_result)
self.fail_count += 1
logger.error(f"❌ 批次内第 {index + 1} 个因子模拟异常 - {str(e)}")
results_list[index] = simulation_result
def _print_summary(self) -> None:
"""打印总结信息"""
now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
logger.info(f"\n总计: 成功 {self.success_count} 个, 失败 {self.fail_count} 个,共处理 {self.total_processed} 个因子")
logger.info(f"完成时间: {now}")
logger.info(f"所有结果已保存到 PostgreSQL 数据库 {self.settings.DATABASE_CONFIG['database']} 的 alpha_simulation 表中")
# 发送通知
self.notification_service.send_to_gotify(self.success_count, self.fail_count)
def main():
"""主函数"""
logger.info("开始运行Alpha因子模拟器...")
try:
# 创建模拟器实例
simulator = AlphaSimulator()
# 初始化(包括登录)
simulator.initialize()
# 加载待处理的alpha表达式
alpha_records = simulator.load_alpha_list()
if not alpha_records:
logger.info("暂无待处理的alpha表达式")
return
logger.info(f"共加载 {len(alpha_records)} 个需要模拟的因子表达式")
# 运行批次模拟
success_count, fail_count = simulator.run_batch_simulation(alpha_records)
logger.info(f"模拟任务完成!成功: {success_count} 个, 失败: {fail_count} 个,共处理 {len(alpha_records)} 个因子")
except KeyboardInterrupt:
logger.info("用户中断执行")
except Exception as e:
logger.error(f"模拟过程发生错误: {e}", exc_info=True)
finally:
logger.info("Alpha因子模拟器运行结束")
if __name__ == "__main__":
main()