# -*- coding: utf-8 -*- import os import time import random import psycopg2 import psycopg2.extras import httpx import threading from datetime import datetime from dataclasses import dataclass from typing import List, Tuple, Dict, Any, Optional from httpx import BasicAuth from configparser import ConfigParser import logging # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) class Settings: """应用配置 - 完全按照原代码""" def __init__(self): # 数据库配置 self.config_file = "db.conf" self.config = ConfigParser() if not os.path.exists(self.config_file): raise FileNotFoundError(f"数据库配置文件 {self.config_file} 不存在") self.config.read(self.config_file, encoding='utf-8') self.DATABASE_CONFIG = { "host": self.config.get('database', 'host'), "port": self.config.get('database', 'port'), "user": self.config.get('database', 'user'), "password": self.config.get('database', 'password'), "database": self.config.get('database', 'database') } # API配置 self.BRAIN_API_URL = "https://api.worldquantbrain.com" # 模拟配置 - 完全按照原代码 self.SIMULATION_SETTINGS = { 'instrumentType': 'EQUITY', 'region': 'USA', 'universe': 'TOP3000', 'delay': 1, 'decay': 0, 'neutralization': 'INDUSTRY', 'truncation': 0.08, 'pasteurization': 'ON', 'unitHandling': 'VERIFY', 'nanHandling': 'OFF', 'language': 'FASTEXPR', 'visualization': False, } # 批次配置 - 每批8个 self.BATCH_SIZE = 8 # 模拟间隔配置 self.SIMULATION_INTERVAL_MIN = 5 # 最小间隔秒数 self.SIMULATION_INTERVAL_MAX = 10 # 最大间隔秒数 # 通知配置 self.GOTIFY_URL = os.getenv("GOTIFY_URL", "https://gotify.erhe.top/message?token=AvKJCJwQKU6yLP8") # Session有效期(秒) - 4小时 self.SESSION_EXPIRY_SECONDS = 4 * 3600 # 4小时 @property def credentials_file(self) -> str: """获取凭证文件路径""" return "account.txt" @dataclass class SimulationResult: """模拟结果实体""" alpha: str time_consuming: float status: str timestamp: str alpha_id: Optional[str] = None message: Optional[str] = None class AuthService: """认证服务类 - 按照原代码逻辑""" def __init__(self, settings: Settings): self.settings = settings self.credentials_file = settings.credentials_file self.last_login_time: Optional[float] = None self.login_expiry: Optional[float] = None def load_credentials(self) -> Tuple[str, str]: """加载凭证""" if not os.path.exists(self.credentials_file): self._create_credentials_file() with open(self.credentials_file, 'r', encoding='utf-8') as f: credentials = eval(f.read()) return credentials[0], credentials[1] def _create_credentials_file(self) -> None: """创建凭证文件""" logger.error(f"未找到 {self.credentials_file} 文件") with open(self.credentials_file, 'w', encoding='utf-8') as f: f.write("['your_username', 'your_password']") logger.error(f"请编辑 {self.credentials_file} 文件,填写账号密码") raise FileNotFoundError(f"请创建并配置 {self.credentials_file} 文件") def login(self, client: httpx.Client) -> Dict[str, Any]: """登录并设置client的认证信息 - 按照原代码""" username, password = self.load_credentials() client.auth = BasicAuth(username, password) response = client.post(f'{self.settings.BRAIN_API_URL}/authentication') logger.info(f"登录状态: {response.status_code}") if response.status_code == 201: login_data = response.json() self.last_login_time = time.time() self.login_expiry = self.last_login_time + self.settings.SESSION_EXPIRY_SECONDS logger.info(f"登录成功!Session有效期: {self.settings.SESSION_EXPIRY_SECONDS}秒 (4.0小时)") return login_data elif response.status_code == 429: logger.error("API rate limit exceeded") raise Exception("API rate limit exceeded") else: logger.error(f"登录失败: {response.status_code} - {response.text}") raise Exception(f"登录失败: {response.status_code} - {response.text}") def is_session_expired(self) -> bool: """检查session是否已过期""" if not self.last_login_time or not self.login_expiry: return True current_time = time.time() return current_time >= self.login_expiry def needs_renewal(self) -> bool: """检查session是否需要续期(小于30分钟)""" if not self.login_expiry: return True current_time = time.time() time_remaining = self.login_expiry - current_time return time_remaining < 1800 # 30分钟 class AlphaService: """Alpha表达式服务类 - 完全按照原代码逻辑""" def __init__(self, client: httpx.Client, settings: Settings): self.client = client self.settings = settings def simulate_alpha(self, alpha: str) -> Dict[str, Any]: """模拟单个Alpha表达式 - 完全按照原代码""" simulation_data = { 'type': 'REGULAR', 'settings': self.settings.SIMULATION_SETTINGS, # 直接使用settings 'regular': alpha } logger.info(f"发送模拟请求,alpha长度: {len(alpha)}") sim_resp = self.client.post(f'{self.settings.BRAIN_API_URL}/simulations', json=simulation_data) logger.info(f"模拟提交状态: {sim_resp.status_code}") if 'location' not in sim_resp.headers: # 打印更多调试信息 logger.error(f"缺少location header,状态码: {sim_resp.status_code}") logger.error(f"响应头: {dict(sim_resp.headers)}") logger.error(f"响应内容: {sim_resp.text[:500] if sim_resp.text else '空'}") return {"status": "err", "message": "No location header in response"} sim_progress_url = sim_resp.headers['location'] logger.info(f"模拟任务已创建,进度URL: {sim_progress_url}") return self._wait_for_simulation_result(sim_progress_url) def _wait_for_simulation_result(self, progress_url: str) -> Dict[str, Any]: """等待模拟结果 - 按照原代码逻辑""" while True: sim_progress_resp = self.client.get(progress_url) retry_after_sec = float(sim_progress_resp.headers.get("Retry-After", 0)) if retry_after_sec == 0: break if sim_progress_resp.json(): result = sim_progress_resp.json() progress = result.get('progress', 0) logger.info(f"模拟进度: {progress}%") time.sleep(retry_after_sec) result_data = sim_progress_resp.json() if result_data.get("status") == "ERROR": error_message = result_data.get("message", "未知错误") logger.error(f"因子模拟失败: {error_message}") return {"status": "err", "message": error_message} alpha_id = result_data.get("alpha") logger.info(f"生成的Alpha ID: {alpha_id}") return {"status": "ok", "alpha_id": alpha_id} class NotificationService: """通知服务类""" def __init__(self, settings: Settings): self.settings = settings def send_to_gotify(self, success_count: int, fail_count: int) -> None: """发送结果到Gotify""" now = datetime.now().strftime('%Y-%m-%d %H:%M:%S') text = f"总计: 成功 {success_count} 个, 失败 {fail_count} 个\n\n完成时间: {now}" title = f"alpha模拟结果 时间: {now}" try: resp = httpx.post( self.settings.GOTIFY_URL, json={'title': title, 'message': text}, timeout=10 ) logger.info("通知发送成功") except Exception as e: logger.error(f"通知发送失败: {e}") class DatabaseManager: """数据库管理类""" def __init__(self, settings: Settings): self.settings = settings self.connection = None def get_connection(self) -> psycopg2.extensions.connection: """获取数据库连接""" if self.connection is None or self.connection.closed: self.connection = psycopg2.connect( host=self.settings.DATABASE_CONFIG["host"], port=self.settings.DATABASE_CONFIG["port"], database=self.settings.DATABASE_CONFIG["database"], user=self.settings.DATABASE_CONFIG["user"], password=self.settings.DATABASE_CONFIG["password"] ) return self.connection def get_unused_alpha(self) -> List[Tuple[int, str]]: """获取所有未使用的alpha表达式(返回ID和alpha值)""" conn = self.get_connection() cursor = conn.cursor() cursor.execute("SELECT id, alpha FROM alpha_simulation WHERE unused = TRUE AND alpha IS NOT NULL") return cursor.fetchall() def update_alpha_simulation(self, record_id: int, result: SimulationResult) -> None: """更新alpha_simulation表的记录""" conn = self.get_connection() cursor = conn.cursor() cursor.execute(''' UPDATE alpha_simulation SET unused = FALSE, time_consuming = %s, status = %s, timestamp = %s, alpha_id = %s, message = %s WHERE id = %s ''', ( result.time_consuming, result.status, result.timestamp, result.alpha_id, result.message or "", record_id )) conn.commit() def close_connection(self) -> None: """关闭数据库连接""" if self.connection and not self.connection.closed: self.connection.close() class AlphaSimulator: """Alpha模拟器主类 - 按照原代码逻辑""" def __init__(self): self.settings = Settings() self.db_manager = DatabaseManager(self.settings) self.auth_service = AuthService(self.settings) self.notification_service = NotificationService(self.settings) self.client: Optional[httpx.Client] = None self.alpha_service: Optional[AlphaService] = None self.success_count = 0 self.fail_count = 0 self.total_processed = 0 def __del__(self): """析构函数,确保数据库连接被关闭""" if hasattr(self, 'db_manager'): self.db_manager.close_connection() if self.client: self.client.close() def initialize(self) -> None: """初始化模拟器""" try: self.client = httpx.Client() self.login() logger.info("模拟器初始化成功") except Exception as e: logger.error(f"模拟器初始化失败: {e}") raise def login(self) -> Dict[str, Any]: """登录并初始化alpha服务""" logger.info("正在登录...") login_result = self.auth_service.login(self.client) self.alpha_service = AlphaService(self.client, self.settings) expires_in = self.settings.SESSION_EXPIRY_SECONDS hours = int(expires_in // 3600) minutes = int((expires_in % 3600) // 60) logger.info(f"登录成功!Session将在 {hours}小时{minutes}分钟后过期") return login_result def check_and_refresh_session_if_needed(self) -> None: """ 检查并刷新session(如果需要) 如果session即将过期(小于30分钟),则重新登录 """ if not self.auth_service.last_login_time: logger.info("未找到登录信息,需要登录") self.login() return if self.auth_service.is_session_expired(): current_time = time.time() expired_seconds = current_time - self.auth_service.login_expiry logger.info(f"Session已过期(已过期 {expired_seconds:.0f} 秒),需要重新登录") self.login() return if self.auth_service.needs_renewal(): time_remaining = self.auth_service.login_expiry - time.time() minutes = int(time_remaining // 60) seconds = int(time_remaining % 60) logger.info(f"Session将在{minutes}分{seconds}秒后过期(小于30分钟),需要重新登录") # 创建新的client对象 if self.client: self.client.close() self.client = httpx.Client() self.login() return time_remaining = self.auth_service.login_expiry - time.time() hours = int(time_remaining // 3600) minutes = int((time_remaining % 3600) // 60) seconds = int(time_remaining % 60) logger.info(f"Session仍然有效,剩余时间: {hours}小时{minutes}分{seconds}秒") def load_alpha_list(self) -> List[Tuple[int, str]]: """从数据库加载未使用的alpha表达式(包含ID)""" return self.db_manager.get_unused_alpha() def run_batch_simulation(self, alpha_records: List[Tuple[int, str]]) -> Tuple[int, int]: """运行批次模拟,每批8个,等一批全部完成再开始下一批""" total_count = len(alpha_records) batch_count = (total_count + self.settings.BATCH_SIZE - 1) // self.settings.BATCH_SIZE logger.info(f"开始批次模拟,总共 {total_count} 个因子,每批 {self.settings.BATCH_SIZE} 个, 当前时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") # 重置计数器 self.success_count = 0 self.fail_count = 0 self.total_processed = total_count for batch_idx in range(batch_count): # 在每批开始前检查session状态,如果需要则重新登录 logger.info("检查session状态...") self.check_and_refresh_session_if_needed() start_idx = batch_idx * self.settings.BATCH_SIZE end_idx = min(start_idx + self.settings.BATCH_SIZE, total_count) batch_records = alpha_records[start_idx:end_idx] logger.info(f"\n=== 开始第 {batch_idx + 1}/{batch_count} 批,本批 {len(batch_records)} 个因子 ===\t 当前时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") # 创建并启动本批的所有线程 threads = [] thread_results = [None] * len(batch_records) for i, (record_id, alpha) in enumerate(batch_records): thread = threading.Thread( target=self._simulate_single_in_batch, args=(record_id, alpha, i, thread_results) ) thread.daemon = True threads.append(thread) thread.start() logger.info(f"启动第 {batch_idx + 1} 批第 {i + 1} 个因子: {alpha[:50]}... (ID: {record_id})") # 等待本批所有线程完成 for thread in threads: thread.join() logger.info(f"=== 第 {batch_idx + 1}/{batch_count} 批完成 ===, 当前时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") # 每批之间随机sleep if batch_idx < batch_count - 1: sleep_time = random.uniform(self.settings.SIMULATION_INTERVAL_MIN, self.settings.SIMULATION_INTERVAL_MAX) logger.info(f"等待 {sleep_time:.1f} 秒后开始下一批...") time.sleep(sleep_time) # 打印总结并发送通知 self._print_summary() # 返回最终结果 return self.success_count, self.fail_count def _simulate_single_in_batch(self, record_id: int, alpha: str, index: int, results_list: list) -> None: """批次中模拟单个Alpha表达式""" start_time = time.time() try: # 模拟alpha - 直接调用,不传递额外参数 result = self.alpha_service.simulate_alpha(alpha) end_time = time.time() time_consuming = round(end_time - start_time, 2) simulation_result = SimulationResult( alpha=alpha, time_consuming=time_consuming, status=result["status"], timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), alpha_id=result.get("alpha_id"), message=result.get("message", "") ) # 立即保存到数据库 self.db_manager.update_alpha_simulation(record_id, simulation_result) if result["status"] == "ok": self.success_count += 1 logger.info(f"✅ 批次内第 {index + 1} 个因子模拟成功 - Alpha ID: {result['alpha_id']} - 耗时: {time_consuming}秒") else: self.fail_count += 1 logger.error(f"❌ 批次内第 {index + 1} 个因子模拟失败 - {result.get('message', '未知错误')}") results_list[index] = simulation_result except Exception as e: end_time = time.time() time_consuming = round(end_time - start_time, 2) simulation_result = SimulationResult( alpha=alpha, time_consuming=time_consuming, status="err", message=str(e), timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S") ) self.db_manager.update_alpha_simulation(record_id, simulation_result) self.fail_count += 1 logger.error(f"❌ 批次内第 {index + 1} 个因子模拟异常 - {str(e)}") results_list[index] = simulation_result def _print_summary(self) -> None: """打印总结信息""" now = datetime.now().strftime('%Y-%m-%d %H:%M:%S') logger.info(f"\n总计: 成功 {self.success_count} 个, 失败 {self.fail_count} 个,共处理 {self.total_processed} 个因子") logger.info(f"完成时间: {now}") logger.info(f"所有结果已保存到 PostgreSQL 数据库 {self.settings.DATABASE_CONFIG['database']} 的 alpha_simulation 表中") # 发送通知 self.notification_service.send_to_gotify(self.success_count, self.fail_count) def main(): """主函数""" logger.info("开始运行Alpha因子模拟器...") try: # 创建模拟器实例 simulator = AlphaSimulator() # 初始化(包括登录) simulator.initialize() # 加载待处理的alpha表达式 alpha_records = simulator.load_alpha_list() if not alpha_records: logger.info("暂无待处理的alpha表达式") return logger.info(f"共加载 {len(alpha_records)} 个需要模拟的因子表达式") # 运行批次模拟 success_count, fail_count = simulator.run_batch_simulation(alpha_records) logger.info(f"模拟任务完成!成功: {success_count} 个, 失败: {fail_count} 个,共处理 {len(alpha_records)} 个因子") except KeyboardInterrupt: logger.info("用户中断执行") except Exception as e: logger.error(f"模拟过程发生错误: {e}", exc_info=True) finally: logger.info("Alpha因子模拟器运行结束") if __name__ == "__main__": main()