You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
530 lines
20 KiB
530 lines
20 KiB
# -*- coding: utf-8 -*-
|
|
import os
|
|
import time
|
|
import random
|
|
import psycopg2
|
|
import psycopg2.extras
|
|
import httpx
|
|
import threading
|
|
from datetime import datetime
|
|
from dataclasses import dataclass
|
|
from typing import List, Tuple, Dict, Any, Optional
|
|
from httpx import BasicAuth
|
|
from configparser import ConfigParser
|
|
import logging
|
|
|
|
# 配置日志
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Settings:
|
|
"""应用配置 - 完全按照原代码"""
|
|
|
|
def __init__(self):
|
|
# 数据库配置
|
|
self.config_file = "db.conf"
|
|
self.config = ConfigParser()
|
|
|
|
if not os.path.exists(self.config_file):
|
|
raise FileNotFoundError(f"数据库配置文件 {self.config_file} 不存在")
|
|
|
|
self.config.read(self.config_file, encoding='utf-8')
|
|
|
|
self.DATABASE_CONFIG = {
|
|
"host": self.config.get('database', 'host'),
|
|
"port": self.config.get('database', 'port'),
|
|
"user": self.config.get('database', 'user'),
|
|
"password": self.config.get('database', 'password'),
|
|
"database": self.config.get('database', 'database')
|
|
}
|
|
|
|
# API配置
|
|
self.BRAIN_API_URL = "https://api.worldquantbrain.com"
|
|
|
|
# 模拟配置 - 完全按照原代码
|
|
self.SIMULATION_SETTINGS = {
|
|
'instrumentType': 'EQUITY',
|
|
'region': 'USA',
|
|
'universe': 'TOP3000',
|
|
'delay': 1,
|
|
'decay': 0,
|
|
'neutralization': 'INDUSTRY',
|
|
'truncation': 0.08,
|
|
'pasteurization': 'ON',
|
|
'unitHandling': 'VERIFY',
|
|
'nanHandling': 'OFF',
|
|
'language': 'FASTEXPR',
|
|
'visualization': False,
|
|
}
|
|
|
|
# 批次配置 - 每批8个
|
|
self.BATCH_SIZE = 8
|
|
|
|
# 模拟间隔配置
|
|
self.SIMULATION_INTERVAL_MIN = 5 # 最小间隔秒数
|
|
self.SIMULATION_INTERVAL_MAX = 10 # 最大间隔秒数
|
|
|
|
# 通知配置
|
|
self.GOTIFY_URL = os.getenv("GOTIFY_URL", "https://gotify.erhe.top/message?token=AvKJCJwQKU6yLP8")
|
|
|
|
# Session有效期(秒) - 4小时
|
|
self.SESSION_EXPIRY_SECONDS = 4 * 3600 # 4小时
|
|
|
|
@property
|
|
def credentials_file(self) -> str:
|
|
"""获取凭证文件路径"""
|
|
return "account.txt"
|
|
|
|
|
|
@dataclass
|
|
class SimulationResult:
|
|
"""模拟结果实体"""
|
|
alpha: str
|
|
time_consuming: float
|
|
status: str
|
|
timestamp: str
|
|
alpha_id: Optional[str] = None
|
|
message: Optional[str] = None
|
|
|
|
|
|
class AuthService:
|
|
"""认证服务类 - 按照原代码逻辑"""
|
|
|
|
def __init__(self, settings: Settings):
|
|
self.settings = settings
|
|
self.credentials_file = settings.credentials_file
|
|
self.last_login_time: Optional[float] = None
|
|
self.login_expiry: Optional[float] = None
|
|
|
|
def load_credentials(self) -> Tuple[str, str]:
|
|
"""加载凭证"""
|
|
if not os.path.exists(self.credentials_file):
|
|
self._create_credentials_file()
|
|
|
|
with open(self.credentials_file, 'r', encoding='utf-8') as f:
|
|
credentials = eval(f.read())
|
|
return credentials[0], credentials[1]
|
|
|
|
def _create_credentials_file(self) -> None:
|
|
"""创建凭证文件"""
|
|
logger.error(f"未找到 {self.credentials_file} 文件")
|
|
with open(self.credentials_file, 'w', encoding='utf-8') as f:
|
|
f.write("['your_username', 'your_password']")
|
|
logger.error(f"请编辑 {self.credentials_file} 文件,填写账号密码")
|
|
raise FileNotFoundError(f"请创建并配置 {self.credentials_file} 文件")
|
|
|
|
def login(self, client: httpx.Client) -> Dict[str, Any]:
|
|
"""登录并设置client的认证信息 - 按照原代码"""
|
|
username, password = self.load_credentials()
|
|
client.auth = BasicAuth(username, password)
|
|
|
|
response = client.post(f'{self.settings.BRAIN_API_URL}/authentication')
|
|
logger.info(f"登录状态: {response.status_code}")
|
|
|
|
if response.status_code == 201:
|
|
login_data = response.json()
|
|
self.last_login_time = time.time()
|
|
self.login_expiry = self.last_login_time + self.settings.SESSION_EXPIRY_SECONDS
|
|
|
|
logger.info(f"登录成功!Session有效期: {self.settings.SESSION_EXPIRY_SECONDS}秒 (4.0小时)")
|
|
|
|
return login_data
|
|
elif response.status_code == 429:
|
|
logger.error("API rate limit exceeded")
|
|
raise Exception("API rate limit exceeded")
|
|
else:
|
|
logger.error(f"登录失败: {response.status_code} - {response.text}")
|
|
raise Exception(f"登录失败: {response.status_code} - {response.text}")
|
|
|
|
def is_session_expired(self) -> bool:
|
|
"""检查session是否已过期"""
|
|
if not self.last_login_time or not self.login_expiry:
|
|
return True
|
|
|
|
current_time = time.time()
|
|
return current_time >= self.login_expiry
|
|
|
|
def needs_renewal(self) -> bool:
|
|
"""检查session是否需要续期(小于30分钟)"""
|
|
if not self.login_expiry:
|
|
return True
|
|
|
|
current_time = time.time()
|
|
time_remaining = self.login_expiry - current_time
|
|
return time_remaining < 1800 # 30分钟
|
|
|
|
|
|
class AlphaService:
|
|
"""Alpha表达式服务类 - 完全按照原代码逻辑"""
|
|
|
|
def __init__(self, client: httpx.Client, settings: Settings):
|
|
self.client = client
|
|
self.settings = settings
|
|
|
|
def simulate_alpha(self, alpha: str) -> Dict[str, Any]:
|
|
"""模拟单个Alpha表达式 - 完全按照原代码"""
|
|
simulation_data = {
|
|
'type': 'REGULAR',
|
|
'settings': self.settings.SIMULATION_SETTINGS, # 直接使用settings
|
|
'regular': alpha
|
|
}
|
|
|
|
logger.info(f"发送模拟请求,alpha长度: {len(alpha)}")
|
|
sim_resp = self.client.post(f'{self.settings.BRAIN_API_URL}/simulations', json=simulation_data)
|
|
logger.info(f"模拟提交状态: {sim_resp.status_code}")
|
|
|
|
if 'location' not in sim_resp.headers:
|
|
# 打印更多调试信息
|
|
logger.error(f"缺少location header,状态码: {sim_resp.status_code}")
|
|
logger.error(f"响应头: {dict(sim_resp.headers)}")
|
|
logger.error(f"响应内容: {sim_resp.text[:500] if sim_resp.text else '空'}")
|
|
return {"status": "err", "message": "No location header in response"}
|
|
|
|
sim_progress_url = sim_resp.headers['location']
|
|
logger.info(f"模拟任务已创建,进度URL: {sim_progress_url}")
|
|
return self._wait_for_simulation_result(sim_progress_url)
|
|
|
|
def _wait_for_simulation_result(self, progress_url: str) -> Dict[str, Any]:
|
|
"""等待模拟结果 - 按照原代码逻辑"""
|
|
while True:
|
|
sim_progress_resp = self.client.get(progress_url)
|
|
retry_after_sec = float(sim_progress_resp.headers.get("Retry-After", 0))
|
|
|
|
if retry_after_sec == 0:
|
|
break
|
|
|
|
if sim_progress_resp.json():
|
|
result = sim_progress_resp.json()
|
|
progress = result.get('progress', 0)
|
|
logger.info(f"模拟进度: {progress}%")
|
|
|
|
time.sleep(retry_after_sec)
|
|
|
|
result_data = sim_progress_resp.json()
|
|
|
|
if result_data.get("status") == "ERROR":
|
|
error_message = result_data.get("message", "未知错误")
|
|
logger.error(f"因子模拟失败: {error_message}")
|
|
return {"status": "err", "message": error_message}
|
|
|
|
alpha_id = result_data.get("alpha")
|
|
logger.info(f"生成的Alpha ID: {alpha_id}")
|
|
|
|
return {"status": "ok", "alpha_id": alpha_id}
|
|
|
|
|
|
class NotificationService:
|
|
"""通知服务类"""
|
|
|
|
def __init__(self, settings: Settings):
|
|
self.settings = settings
|
|
|
|
def send_to_gotify(self, success_count: int, fail_count: int) -> None:
|
|
"""发送结果到Gotify"""
|
|
now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
|
text = f"总计: 成功 {success_count} 个, 失败 {fail_count} 个\n\n完成时间: {now}"
|
|
title = f"alpha模拟结果 时间: {now}"
|
|
|
|
try:
|
|
resp = httpx.post(
|
|
self.settings.GOTIFY_URL,
|
|
json={'title': title, 'message': text},
|
|
timeout=10
|
|
)
|
|
logger.info("通知发送成功")
|
|
except Exception as e:
|
|
logger.error(f"通知发送失败: {e}")
|
|
|
|
|
|
class DatabaseManager:
|
|
"""数据库管理类"""
|
|
|
|
def __init__(self, settings: Settings):
|
|
self.settings = settings
|
|
self.connection = None
|
|
|
|
def get_connection(self) -> psycopg2.extensions.connection:
|
|
"""获取数据库连接"""
|
|
if self.connection is None or self.connection.closed:
|
|
self.connection = psycopg2.connect(
|
|
host=self.settings.DATABASE_CONFIG["host"],
|
|
port=self.settings.DATABASE_CONFIG["port"],
|
|
database=self.settings.DATABASE_CONFIG["database"],
|
|
user=self.settings.DATABASE_CONFIG["user"],
|
|
password=self.settings.DATABASE_CONFIG["password"]
|
|
)
|
|
return self.connection
|
|
|
|
def get_unused_alpha(self) -> List[Tuple[int, str]]:
|
|
"""获取所有未使用的alpha表达式(返回ID和alpha值)"""
|
|
conn = self.get_connection()
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT id, alpha FROM alpha_simulation WHERE unused = TRUE AND alpha IS NOT NULL")
|
|
return cursor.fetchall()
|
|
|
|
def update_alpha_simulation(self, record_id: int, result: SimulationResult) -> None:
|
|
"""更新alpha_simulation表的记录"""
|
|
conn = self.get_connection()
|
|
cursor = conn.cursor()
|
|
cursor.execute('''
|
|
UPDATE alpha_simulation
|
|
SET unused = FALSE,
|
|
time_consuming = %s,
|
|
status = %s,
|
|
timestamp = %s,
|
|
alpha_id = %s,
|
|
message = %s
|
|
WHERE id = %s
|
|
''', (
|
|
result.time_consuming,
|
|
result.status,
|
|
result.timestamp,
|
|
result.alpha_id,
|
|
result.message or "",
|
|
record_id
|
|
))
|
|
conn.commit()
|
|
|
|
def close_connection(self) -> None:
|
|
"""关闭数据库连接"""
|
|
if self.connection and not self.connection.closed:
|
|
self.connection.close()
|
|
|
|
|
|
class AlphaSimulator:
|
|
"""Alpha模拟器主类 - 按照原代码逻辑"""
|
|
|
|
def __init__(self):
|
|
self.settings = Settings()
|
|
self.db_manager = DatabaseManager(self.settings)
|
|
self.auth_service = AuthService(self.settings)
|
|
self.notification_service = NotificationService(self.settings)
|
|
self.client: Optional[httpx.Client] = None
|
|
self.alpha_service: Optional[AlphaService] = None
|
|
self.success_count = 0
|
|
self.fail_count = 0
|
|
self.total_processed = 0
|
|
|
|
def __del__(self):
|
|
"""析构函数,确保数据库连接被关闭"""
|
|
if hasattr(self, 'db_manager'):
|
|
self.db_manager.close_connection()
|
|
if self.client:
|
|
self.client.close()
|
|
|
|
def initialize(self) -> None:
|
|
"""初始化模拟器"""
|
|
try:
|
|
self.client = httpx.Client()
|
|
self.login()
|
|
logger.info("模拟器初始化成功")
|
|
except Exception as e:
|
|
logger.error(f"模拟器初始化失败: {e}")
|
|
raise
|
|
|
|
def login(self) -> Dict[str, Any]:
|
|
"""登录并初始化alpha服务"""
|
|
logger.info("正在登录...")
|
|
login_result = self.auth_service.login(self.client)
|
|
self.alpha_service = AlphaService(self.client, self.settings)
|
|
|
|
expires_in = self.settings.SESSION_EXPIRY_SECONDS
|
|
hours = int(expires_in // 3600)
|
|
minutes = int((expires_in % 3600) // 60)
|
|
logger.info(f"登录成功!Session将在 {hours}小时{minutes}分钟后过期")
|
|
|
|
return login_result
|
|
|
|
def check_and_refresh_session_if_needed(self) -> None:
|
|
"""
|
|
检查并刷新session(如果需要)
|
|
如果session即将过期(小于30分钟),则重新登录
|
|
"""
|
|
if not self.auth_service.last_login_time:
|
|
logger.info("未找到登录信息,需要登录")
|
|
self.login()
|
|
return
|
|
|
|
if self.auth_service.is_session_expired():
|
|
current_time = time.time()
|
|
expired_seconds = current_time - self.auth_service.login_expiry
|
|
logger.info(f"Session已过期(已过期 {expired_seconds:.0f} 秒),需要重新登录")
|
|
self.login()
|
|
return
|
|
|
|
if self.auth_service.needs_renewal():
|
|
time_remaining = self.auth_service.login_expiry - time.time()
|
|
minutes = int(time_remaining // 60)
|
|
seconds = int(time_remaining % 60)
|
|
logger.info(f"Session将在{minutes}分{seconds}秒后过期(小于30分钟),需要重新登录")
|
|
|
|
# 创建新的client对象
|
|
if self.client:
|
|
self.client.close()
|
|
self.client = httpx.Client()
|
|
self.login()
|
|
return
|
|
|
|
time_remaining = self.auth_service.login_expiry - time.time()
|
|
hours = int(time_remaining // 3600)
|
|
minutes = int((time_remaining % 3600) // 60)
|
|
seconds = int(time_remaining % 60)
|
|
|
|
logger.info(f"Session仍然有效,剩余时间: {hours}小时{minutes}分{seconds}秒")
|
|
|
|
def load_alpha_list(self) -> List[Tuple[int, str]]:
|
|
"""从数据库加载未使用的alpha表达式(包含ID)"""
|
|
return self.db_manager.get_unused_alpha()
|
|
|
|
def run_batch_simulation(self, alpha_records: List[Tuple[int, str]]) -> Tuple[int, int]:
|
|
"""运行批次模拟,每批8个,等一批全部完成再开始下一批"""
|
|
total_count = len(alpha_records)
|
|
batch_count = (total_count + self.settings.BATCH_SIZE - 1) // self.settings.BATCH_SIZE
|
|
|
|
logger.info(f"开始批次模拟,总共 {total_count} 个因子,每批 {self.settings.BATCH_SIZE} 个, 当前时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
|
|
|
# 重置计数器
|
|
self.success_count = 0
|
|
self.fail_count = 0
|
|
self.total_processed = total_count
|
|
|
|
for batch_idx in range(batch_count):
|
|
# 在每批开始前检查session状态,如果需要则重新登录
|
|
logger.info("检查session状态...")
|
|
self.check_and_refresh_session_if_needed()
|
|
|
|
start_idx = batch_idx * self.settings.BATCH_SIZE
|
|
end_idx = min(start_idx + self.settings.BATCH_SIZE, total_count)
|
|
batch_records = alpha_records[start_idx:end_idx]
|
|
|
|
logger.info(f"\n=== 开始第 {batch_idx + 1}/{batch_count} 批,本批 {len(batch_records)} 个因子 ===\t 当前时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
|
|
|
# 创建并启动本批的所有线程
|
|
threads = []
|
|
thread_results = [None] * len(batch_records)
|
|
|
|
for i, (record_id, alpha) in enumerate(batch_records):
|
|
thread = threading.Thread(
|
|
target=self._simulate_single_in_batch,
|
|
args=(record_id, alpha, i, thread_results)
|
|
)
|
|
thread.daemon = True
|
|
threads.append(thread)
|
|
thread.start()
|
|
logger.info(f"启动第 {batch_idx + 1} 批第 {i + 1} 个因子: {alpha[:50]}... (ID: {record_id})")
|
|
|
|
# 等待本批所有线程完成
|
|
for thread in threads:
|
|
thread.join()
|
|
|
|
logger.info(f"=== 第 {batch_idx + 1}/{batch_count} 批完成 ===, 当前时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
|
|
|
# 每批之间随机sleep
|
|
if batch_idx < batch_count - 1:
|
|
sleep_time = random.uniform(self.settings.SIMULATION_INTERVAL_MIN, self.settings.SIMULATION_INTERVAL_MAX)
|
|
logger.info(f"等待 {sleep_time:.1f} 秒后开始下一批...")
|
|
time.sleep(sleep_time)
|
|
|
|
# 打印总结并发送通知
|
|
self._print_summary()
|
|
|
|
# 返回最终结果
|
|
return self.success_count, self.fail_count
|
|
|
|
def _simulate_single_in_batch(self, record_id: int, alpha: str, index: int, results_list: list) -> None:
|
|
"""批次中模拟单个Alpha表达式"""
|
|
start_time = time.time()
|
|
|
|
try:
|
|
# 模拟alpha - 直接调用,不传递额外参数
|
|
result = self.alpha_service.simulate_alpha(alpha)
|
|
end_time = time.time()
|
|
time_consuming = round(end_time - start_time, 2)
|
|
|
|
simulation_result = SimulationResult(
|
|
alpha=alpha,
|
|
time_consuming=time_consuming,
|
|
status=result["status"],
|
|
timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
|
alpha_id=result.get("alpha_id"),
|
|
message=result.get("message", "")
|
|
)
|
|
|
|
# 立即保存到数据库
|
|
self.db_manager.update_alpha_simulation(record_id, simulation_result)
|
|
|
|
if result["status"] == "ok":
|
|
self.success_count += 1
|
|
logger.info(f"✅ 批次内第 {index + 1} 个因子模拟成功 - Alpha ID: {result['alpha_id']} - 耗时: {time_consuming}秒")
|
|
else:
|
|
self.fail_count += 1
|
|
logger.error(f"❌ 批次内第 {index + 1} 个因子模拟失败 - {result.get('message', '未知错误')}")
|
|
|
|
results_list[index] = simulation_result
|
|
|
|
except Exception as e:
|
|
end_time = time.time()
|
|
time_consuming = round(end_time - start_time, 2)
|
|
|
|
simulation_result = SimulationResult(
|
|
alpha=alpha,
|
|
time_consuming=time_consuming,
|
|
status="err",
|
|
message=str(e),
|
|
timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
)
|
|
|
|
self.db_manager.update_alpha_simulation(record_id, simulation_result)
|
|
self.fail_count += 1
|
|
logger.error(f"❌ 批次内第 {index + 1} 个因子模拟异常 - {str(e)}")
|
|
results_list[index] = simulation_result
|
|
|
|
def _print_summary(self) -> None:
|
|
"""打印总结信息"""
|
|
now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
|
logger.info(f"\n总计: 成功 {self.success_count} 个, 失败 {self.fail_count} 个,共处理 {self.total_processed} 个因子")
|
|
logger.info(f"完成时间: {now}")
|
|
logger.info(f"所有结果已保存到 PostgreSQL 数据库 {self.settings.DATABASE_CONFIG['database']} 的 alpha_simulation 表中")
|
|
|
|
# 发送通知
|
|
self.notification_service.send_to_gotify(self.success_count, self.fail_count)
|
|
|
|
|
|
def main():
|
|
"""主函数"""
|
|
logger.info("开始运行Alpha因子模拟器...")
|
|
|
|
try:
|
|
# 创建模拟器实例
|
|
simulator = AlphaSimulator()
|
|
|
|
# 初始化(包括登录)
|
|
simulator.initialize()
|
|
|
|
# 加载待处理的alpha表达式
|
|
alpha_records = simulator.load_alpha_list()
|
|
if not alpha_records:
|
|
logger.info("暂无待处理的alpha表达式")
|
|
return
|
|
|
|
logger.info(f"共加载 {len(alpha_records)} 个需要模拟的因子表达式")
|
|
|
|
# 运行批次模拟
|
|
success_count, fail_count = simulator.run_batch_simulation(alpha_records)
|
|
|
|
logger.info(f"模拟任务完成!成功: {success_count} 个, 失败: {fail_count} 个,共处理 {len(alpha_records)} 个因子")
|
|
|
|
except KeyboardInterrupt:
|
|
logger.info("用户中断执行")
|
|
except Exception as e:
|
|
logger.error(f"模拟过程发生错误: {e}", exc_info=True)
|
|
finally:
|
|
logger.info("Alpha因子模拟器运行结束")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |