You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
396 lines
14 KiB
396 lines
14 KiB
# -*- coding: utf-8 -*-
|
|
import os
|
|
import time
|
|
import json
|
|
import httpx
|
|
import psycopg2
|
|
from httpx import BasicAuth
|
|
|
|
|
|
class DatabaseManager:
|
|
def __init__(self):
|
|
self.connection = None
|
|
self.database_name = "alpha"
|
|
self.init_database()
|
|
|
|
def create_database(self):
|
|
"""创建数据库(如果不存在)"""
|
|
try:
|
|
# 先连接到默认的postgres数据库来创建alpha数据库
|
|
conn = psycopg2.connect(
|
|
host="localhost",
|
|
port="5432",
|
|
database="postgres",
|
|
user="jack",
|
|
password="aaaAAA111"
|
|
)
|
|
conn.autocommit = True
|
|
cursor = conn.cursor()
|
|
|
|
# 检查数据库是否存在
|
|
cursor.execute("SELECT 1 FROM pg_catalog.pg_database WHERE datname = %s", (self.database_name,))
|
|
exists = cursor.fetchone()
|
|
|
|
if not exists:
|
|
cursor.execute(f"CREATE DATABASE {self.database_name}")
|
|
print(f"数据库 {self.database_name} 创建成功")
|
|
else:
|
|
print(f"数据库 {self.database_name} 已存在")
|
|
|
|
cursor.close()
|
|
conn.close()
|
|
|
|
except Exception as e:
|
|
print(f"创建数据库时出错: {e}")
|
|
raise
|
|
|
|
def get_connection(self):
|
|
"""获取数据库连接"""
|
|
if self.connection is None or self.connection.closed:
|
|
self.connection = psycopg2.connect(
|
|
host="localhost",
|
|
port="5432",
|
|
database=self.database_name,
|
|
user="jack",
|
|
password="aaaAAA111"
|
|
)
|
|
return self.connection
|
|
|
|
def init_database(self):
|
|
"""初始化数据库和表结构"""
|
|
# 先创建数据库
|
|
self.create_database()
|
|
|
|
# 然后连接到此数据库创建表
|
|
conn = self.get_connection()
|
|
cursor = conn.cursor()
|
|
|
|
# 创建 alpha_prepare 表
|
|
cursor.execute('''
|
|
CREATE TABLE IF NOT EXISTS alpha_prepare
|
|
(
|
|
id
|
|
SERIAL
|
|
PRIMARY
|
|
KEY,
|
|
alpha
|
|
TEXT
|
|
NOT
|
|
NULL
|
|
UNIQUE,
|
|
unused
|
|
BOOLEAN
|
|
NOT
|
|
NULL
|
|
DEFAULT
|
|
TRUE,
|
|
created_time
|
|
TIMESTAMP
|
|
DEFAULT
|
|
CURRENT_TIMESTAMP
|
|
)
|
|
''')
|
|
|
|
# 创建 simulation 表
|
|
cursor.execute('''
|
|
CREATE TABLE IF NOT EXISTS simulation
|
|
(
|
|
id
|
|
SERIAL
|
|
PRIMARY
|
|
KEY,
|
|
expression
|
|
TEXT
|
|
NOT
|
|
NULL,
|
|
time_consuming
|
|
REAL
|
|
NOT
|
|
NULL,
|
|
status
|
|
TEXT
|
|
NOT
|
|
NULL,
|
|
timestamp
|
|
TEXT
|
|
NOT
|
|
NULL,
|
|
alpha_id
|
|
TEXT,
|
|
message
|
|
TEXT,
|
|
created_time
|
|
TIMESTAMP
|
|
DEFAULT
|
|
CURRENT_TIMESTAMP
|
|
)
|
|
''')
|
|
|
|
conn.commit()
|
|
print(f"数据库 {self.database_name} 表结构初始化完成")
|
|
|
|
def get_unused_alpha(self):
|
|
"""获取所有未使用的alpha表达式"""
|
|
conn = self.get_connection()
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("SELECT alpha FROM alpha_prepare WHERE unused = TRUE")
|
|
results = cursor.fetchall()
|
|
|
|
alpha_list = [result[0] for result in results]
|
|
return alpha_list
|
|
|
|
def mark_alpha_used(self, alpha):
|
|
"""将alpha标记为已使用"""
|
|
conn = self.get_connection()
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("UPDATE alpha_prepare SET unused = FALSE WHERE alpha = %s", (alpha,))
|
|
conn.commit()
|
|
|
|
def insert_simulation_result(self, result_data):
|
|
"""插入模拟结果到simulation表"""
|
|
conn = self.get_connection()
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute('''
|
|
INSERT INTO simulation
|
|
(expression, time_consuming, status, timestamp, alpha_id, message)
|
|
VALUES (%s, %s, %s, %s, %s, %s)
|
|
''', (
|
|
result_data["expression"],
|
|
result_data["time_consuming"],
|
|
result_data["status"],
|
|
result_data["timestamp"],
|
|
result_data.get("alpha_id"),
|
|
result_data.get("message", "")
|
|
))
|
|
|
|
conn.commit()
|
|
|
|
def close_connection(self):
|
|
"""关闭数据库连接"""
|
|
if self.connection and not self.connection.closed:
|
|
self.connection.close()
|
|
|
|
|
|
class AlphaSimulator:
|
|
def __init__(self, credentials_file='account.txt'):
|
|
self.credentials_file = credentials_file
|
|
self.client = None
|
|
self.brain_api_url = 'https://api.worldquantbrain.com'
|
|
self.db_manager = DatabaseManager()
|
|
|
|
def __del__(self):
|
|
"""析构函数,确保数据库连接被关闭"""
|
|
if hasattr(self, 'db_manager'):
|
|
self.db_manager.close_connection()
|
|
|
|
def load_credentials(self):
|
|
if not os.path.exists(self.credentials_file):
|
|
print("未找到 account.txt 文件")
|
|
with open(self.credentials_file, 'w') as f:
|
|
f.write("")
|
|
print("account.txt 文件已创建,请填写账号密码, 格式: ['username', 'password]")
|
|
exit(1)
|
|
|
|
with open(self.credentials_file) as f:
|
|
credentials = eval(f.read())
|
|
return credentials[0], credentials[1]
|
|
|
|
def login(self):
|
|
username, password = self.load_credentials()
|
|
self.client = httpx.Client(auth=BasicAuth(username, password))
|
|
|
|
response = self.client.post(f'{self.brain_api_url}/authentication')
|
|
print(f"登录状态: {response.status_code}")
|
|
|
|
if response.status_code == 201:
|
|
print(f"登录成功!:{response.json()}")
|
|
return response.json()
|
|
else:
|
|
print(f"登录失败: {response.json()}")
|
|
return {}
|
|
|
|
def load_alpha_list(self):
|
|
"""从数据库加载未使用的alpha表达式"""
|
|
alpha_list = self.db_manager.get_unused_alpha()
|
|
return alpha_list
|
|
|
|
def simulate_alpha(self, expression):
|
|
if self.client is None:
|
|
raise Exception("请先登录")
|
|
|
|
settings = {
|
|
'instrumentType': 'EQUITY',
|
|
'region': 'USA',
|
|
'universe': 'TOP3000',
|
|
'delay': 1,
|
|
'decay': 0,
|
|
'neutralization': 'INDUSTRY',
|
|
'truncation': 0.08,
|
|
'pasteurization': 'ON',
|
|
'unitHandling': 'VERIFY',
|
|
'nanHandling': 'OFF',
|
|
'language': 'FASTEXPR',
|
|
'visualization': False,
|
|
}
|
|
|
|
simulation_data = {
|
|
'type': 'REGULAR',
|
|
'settings': settings,
|
|
'regular': expression
|
|
}
|
|
|
|
sim_resp = self.client.post(f'{self.brain_api_url}/simulations', json=simulation_data)
|
|
print(f"模拟提交状态: {sim_resp.status_code}")
|
|
|
|
sim_progress_url = sim_resp.headers['location']
|
|
|
|
while True:
|
|
sim_progress_resp = self.client.get(sim_progress_url)
|
|
retry_after_sec = float(sim_progress_resp.headers.get("Retry-After", 0))
|
|
|
|
if retry_after_sec == 0:
|
|
break
|
|
|
|
if sim_progress_resp.json():
|
|
result = sim_progress_resp.json()
|
|
progress = result.get('progress', 0)
|
|
print(f"模拟进度: {float(progress) * 100}%")
|
|
|
|
print(f"等待 {retry_after_sec} 秒...")
|
|
time.sleep(retry_after_sec)
|
|
|
|
if sim_progress_resp.json().get("status") == "ERROR":
|
|
result = sim_progress_resp.json().get("message", "未知错误")
|
|
print(f"因子模拟失败: {result}")
|
|
return {"status": "err", "message": result}
|
|
|
|
alpha_id = sim_progress_resp.json().get("alpha")
|
|
print(f"生成的Alpha ID: {alpha_id}")
|
|
|
|
return {"status": "ok", "alpha_id": alpha_id}
|
|
|
|
def run_batch_simulation(self, alpha_list, batch_size=3):
|
|
success_count = 0
|
|
fail_count = 0
|
|
|
|
for i in range(0, len(alpha_list), batch_size):
|
|
batch = alpha_list[i:i + batch_size]
|
|
print(f"\n开始处理第 {i // batch_size + 1} 批因子,共 {len(batch)} 个")
|
|
|
|
for expression in batch:
|
|
print(f"\n模拟因子: {expression}")
|
|
start_time = time.time()
|
|
|
|
try:
|
|
result = self.simulate_alpha(expression)
|
|
end_time = time.time()
|
|
time_consuming = round(end_time - start_time, 2)
|
|
|
|
result_data = {
|
|
"expression": expression,
|
|
"time_consuming": time_consuming,
|
|
"status": result["status"],
|
|
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
|
|
}
|
|
|
|
if result["status"] == "ok":
|
|
result_data["alpha_id"] = result["alpha_id"]
|
|
success_count += 1
|
|
print(f"✅ 模拟成功 - Alpha ID: {result['alpha_id']} - 耗时: {time_consuming}秒")
|
|
else:
|
|
result_data["message"] = result.get("message", "")
|
|
fail_count += 1
|
|
print(f"❌ 模拟失败 - {result.get('message', '未知错误')}")
|
|
|
|
# 无论成功失败,都标记为已使用并保存结果到数据库
|
|
self.db_manager.mark_alpha_used(expression)
|
|
self.db_manager.insert_simulation_result(result_data)
|
|
|
|
except Exception as e:
|
|
end_time = time.time()
|
|
time_consuming = round(end_time - start_time, 2)
|
|
error_result = {
|
|
"expression": expression,
|
|
"time_consuming": time_consuming,
|
|
"status": "err",
|
|
"message": str(e),
|
|
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
|
|
}
|
|
fail_count += 1
|
|
|
|
# 异常情况也标记为已使用并保存结果到数据库
|
|
self.db_manager.mark_alpha_used(expression)
|
|
self.db_manager.insert_simulation_result(error_result)
|
|
|
|
print(f"❌ 模拟异常 - {str(e)}")
|
|
|
|
print(f"第 {i // batch_size + 1} 批处理完成")
|
|
|
|
self._print_summary(success_count, fail_count)
|
|
return success_count, fail_count
|
|
|
|
def _print_summary(self, success_count, fail_count):
|
|
"""打印总结信息"""
|
|
now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
|
print_result = f"\n总计: 成功 {success_count} 个, 失败 {fail_count} 个\n\n 完成时间: {now}"
|
|
print(print_result)
|
|
print(f"所有结果已保存到 PostgreSQL 数据库 {self.db_manager.database_name} 的 simulation 表中")
|
|
_send_to_gotify(print_result, now)
|
|
|
|
def _send_to_gotify(self, text, now):
|
|
"""将结果推送到 gotify """
|
|
title = f"alpha模拟结果 时间: {now}"
|
|
try:
|
|
resp = httpx.post("https://gotify.erhe.top/message?token=AvKJCJwQKU6yLP8", json={'title': title, 'message': text})
|
|
except Exception as e:
|
|
print(e)
|
|
|
|
def main():
|
|
simulator = AlphaSimulator()
|
|
|
|
def refresh_login():
|
|
"""刷新登录token"""
|
|
login_result = simulator.login()
|
|
if not login_result:
|
|
print("登录失败,程序退出")
|
|
exit(1)
|
|
return login_result
|
|
|
|
# 初始登录
|
|
login_result = refresh_login()
|
|
|
|
while True:
|
|
try:
|
|
expiry = int(login_result['token']['expiry'])
|
|
if expiry < 1800:
|
|
print(f"Token剩余 {expiry}秒,重新登录...")
|
|
login_result = refresh_login()
|
|
|
|
alpha_list = simulator.load_alpha_list()
|
|
if not alpha_list:
|
|
print("暂无待处理的alpha表达式,10分钟后重新检查...")
|
|
time.sleep(600)
|
|
continue
|
|
|
|
print(f"共加载 {len(alpha_list)} 个需要模拟的因子表达式")
|
|
|
|
success_count, fail_count = simulator.run_batch_simulation(alpha_list, batch_size=3)
|
|
print(f"本轮处理完成: 成功 {success_count} 个, 失败 {fail_count} 个")
|
|
|
|
print("等待5分钟后继续检查...")
|
|
time.sleep(300)
|
|
|
|
except KeyboardInterrupt:
|
|
print("\n程序被用户中断")
|
|
break
|
|
except Exception as e:
|
|
print(f"程序执行异常: {e}")
|
|
print("5秒后重试...")
|
|
time.sleep(5)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |