You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

404 lines
14 KiB

# FastAPI 应用主入口
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
import sys
import os
import json
import subprocess
import uuid
import asyncio
import zipfile
from datetime import datetime
# 创建 FastAPI 应用实例
app = FastAPI(title="Alpha Transformer", version="1.0.0")
# Chrome DevTools 健康检查端点(可选,阻止 404 日志)
@app.get("/.well-known/appspecific/com.chrome.devtools.json")
async def chrome_devtools_check():
"""Chrome DevTools 健康检查"""
return {"status": "ok"}
# 提供模板目录中的 CSS 和 JS 文件
@app.get("/styles.css")
async def get_styles():
"""提供 styles.css 文件"""
from fastapi.responses import FileResponse
return FileResponse("templates/styles.css")
@app.get("/app.js")
async def get_app_js():
"""提供 app.js 文件"""
from fastapi.responses import FileResponse
return FileResponse("templates/app.js")
# 配置 Jinja2 模板引擎,用于渲染 HTML 页面
templates = Jinja2Templates(directory="templates")
# 存储正在运行的任务信息(task_id -> 任务状态)
transformer_tasks = {}
# 全局配置变量
app_config = {}
def load_config():
"""
加载配置文件
读取 config.json 文件,如果有则加载,否则返回空字典
"""
config_path = os.path.join(os.path.dirname(__file__), 'config.json')
if os.path.exists(config_path):
try:
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
print(f"✓ 已加载配置文件: {config_path}")
return config
except Exception as e:
print(f"⚠ 加载配置文件失败: {e}")
return {}
else:
print(f"⚠ 配置文件不存在: {config_path}")
return {}
# 启动时加载配置
app_config = load_config()
@app.get("/", response_class=HTMLResponse)
async def home():
# 读取并返回前端首页 HTML
with open("templates/index.html", "r", encoding="utf-8") as f:
return f.read()
@app.get("/api/config/defaults")
async def get_config_defaults():
"""
获取默认配置
返回 config.json 中的配置作为表单默认值
"""
return JSONResponse(content={
"success": True,
"config": app_config
})
@app.post("/api/config/save")
async def save_config(request: Request):
"""
保存配置到 config.json
用于在页面上修改配置后保存
"""
try:
data = await request.json()
# 更新全局配置
global app_config
app_config = data
# 写入文件
config_path = os.path.join(os.path.dirname(__file__), 'config.json')
with open(config_path, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=4, ensure_ascii=False)
return JSONResponse(content={
"success": True,
"message": "配置已保存"
})
except Exception as e:
return JSONResponse(
status_code=500,
content={"success": False, "error": str(e)}
)
@app.post("/api/generate")
async def generate_alpha(request: Request):
"""
生成 Alpha 变种的 API 端点
接收前端表单数据,启动 Transformer 脚本执行 Alpha 生成任务
"""
print("=" * 50)
print("收到生成变种请求")
try:
# 解析请求数据
data = await request.json()
print(f"请求数据: alpha_id={data.get('alpha_id')}, llm_model={data.get('llm_model')}")
# 生成唯一任务 ID
task_id = str(uuid.uuid4())
print(f"生成任务 ID: {task_id}")
# 定义必须提交的字段
required_fields = [
"alpha_id",
"llm_api_key",
"llm_base_url",
"llm_model",
"brain_username",
"brain_password"
]
# 检查必填字段是否完整
for field in required_fields:
if not data.get(field):
print(f"缺少必填字段: {field}")
return JSONResponse(
status_code=400,
content={"success": False, "error": f"Missing required field: {field}"}
)
# 获取脚本所在目录和 Transformer 子目录
script_dir = os.path.dirname(os.path.abspath(__file__))
transformer_dir = os.path.join(script_dir, 'Tranformer')
print(f"Transformer 目录: {transformer_dir}")
# 构建传递给 Transformer 脚本的配置
config = {
"LLM_model_name": data.get('llm_model'),
"LLM_API_KEY": data.get('llm_api_key'),
"llm_base_url": data.get('llm_base_url'),
"username": data.get('brain_username'),
"password": data.get('brain_password'),
"alpha_id": data.get('alpha_id'),
"top_n_datafield": int(data.get('top_n_datafield', 50)),
"user_region": data.get('user_region'),
"user_universe": data.get('user_universe'),
"user_delay": int(data.get('user_delay')) if data.get('user_delay') else None,
"user_category": data.get('user_category'),
"user_data_type": data.get('user_data_type', 'MATRIX')
}
print(f"配置已构建: LLM_model={config['LLM_model_name']}, alpha_id={config['alpha_id']}")
# 将配置写入临时 JSON 文件,供 Transformer 脚本读取
config_path = os.path.join(transformer_dir, f'config_{task_id}.json')
with open(config_path, 'w', encoding='utf-8') as f:
json.dump(config, f, indent=4)
print(f"配置文件已写入: {config_path}")
try:
# 启动 Transformer.py 子进程执行 Alpha 生成
print(f"启动 Transformer 脚本...")
process = subprocess.run(
[sys.executable, '-u', os.path.join(transformer_dir, 'Transformer.py'), config_path],
cwd=transformer_dir,
capture_output=True,
text=True,
timeout=600,
env={**os.environ, "PYTHONIOENCODING": "utf-8"}
)
print(f"Transformer 脚本执行完成,返回码: {process.returncode}")
# 定义输出文件路径
output_file = os.path.join(transformer_dir, 'output', 'Alpha_generated_expressions_success.json')
candidates_file = os.path.join(transformer_dir, 'output', 'Alpha_candidates.json')
error_file = os.path.join(transformer_dir, 'output', 'Alpha_generated_expressions_error.json')
# 构建响应数据
result = {
"success": True,
"alpha_id": data.get('alpha_id'),
"stdout": process.stdout,
"stderr": process.stderr,
"return_code": process.returncode
}
# 读取成功生成的表达式
if os.path.exists(output_file):
print(f"读取成功表达式文件: {output_file}")
with open(output_file, 'r', encoding='utf-8') as f:
result['expressions_success'] = json.load(f)
else:
print(f"成功表达式文件不存在: {output_file}")
result['expressions_success'] = []
# 读取候选表达式
if os.path.exists(candidates_file):
print(f"读取候选表达式文件: {candidates_file}")
with open(candidates_file, 'r', encoding='utf-8') as f:
result['candidates'] = json.load(f)
else:
print(f"候选表达式文件不存在: {candidates_file}")
result['candidates'] = []
# 读取生成失败的表达式
if os.path.exists(error_file):
print(f"读取错误表达式文件: {error_file}")
with open(error_file, 'r', encoding='utf-8') as f:
result['expressions_error'] = json.load(f)
else:
print(f"错误表达式文件不存在: {error_file}")
result['expressions_error'] = []
print(f"成功: {len(result['expressions_success'])} 个, 候选: {len(result['candidates'])} 个, 错误: {len(result['expressions_error'])}")
print("=" * 50)
return JSONResponse(content=result)
finally:
# 清理临时配置文件
if os.path.exists(config_path):
os.remove(config_path)
print(f"已清理临时配置文件: {config_path}")
except subprocess.TimeoutExpired:
print("任务执行超时 (600秒)")
return JSONResponse(
status_code=408,
content={"success": False, "error": "Task timeout (600s)"}
)
except Exception as e:
print(f"执行异常: {str(e)}")
return JSONResponse(
status_code=500,
content={"success": False, "error": str(e)}
)
@app.post("/api/transformer/login-and-fetch-options")
async def login_and_fetch_options(request: Request):
"""
登录 BRAIN 并获取地区、Delay、Universe、类别等选项
用于填充高级选项表单
"""
try:
data = await request.json()
username = data.get('username')
password = data.get('password')
if not username or not password:
return JSONResponse(
status_code=400,
content={'success': False, 'error': 'Username and password are required'}
)
# 添加 Transformer 目录到 sys.path
script_dir = os.path.dirname(os.path.abspath(__file__))
transformer_dir = os.path.join(script_dir, 'Tranformer')
if transformer_dir not in sys.path:
sys.path.append(transformer_dir)
# 导入必要的模块
from Tranformer.ace_lib import SingleSession, get_instrument_type_region_delay
import pandas as pd
# 创建新的会话实例
session = SingleSession()
session.auth = (username, password)
brain_api_url = "https://api.worldquantbrain.com"
response = session.post(brain_api_url + "/authentication")
if response.status_code == 201:
# 认证成功
pass
elif response.status_code == 401:
return JSONResponse(
status_code=401,
content={'success': False, 'error': 'Authentication failed: Invalid credentials'}
)
else:
return JSONResponse(
status_code=400,
content={'success': False, 'error': f'Authentication failed: {response.status_code}'}
)
# 获取 region/delay/universe 选项
df = get_instrument_type_region_delay(session)
# 获取数据类别
categories_resp = session.get(brain_api_url + "/data-categories")
categories = []
if categories_resp.status_code == 200:
categories_data = categories_resp.json()
if isinstance(categories_data, list):
categories = categories_data
elif isinstance(categories_data, dict):
categories = categories_data.get('results', [])
# 转换 DataFrame 为前端需要的嵌套字典结构
# 结构: Region -> Delay -> Universe
df_equity = df[df['InstrumentType'] == 'EQUITY']
options = {}
for _, row in df_equity.iterrows():
region = row['Region']
delay = row['Delay']
universes = row['Universe'] # 这是一个列表
if region not in options:
options[region] = {}
# 将 delay 转换为字符串作为字典的键
delay_str = str(delay)
if delay_str not in options[region]:
options[region][delay_str] = universes
return JSONResponse(content={
'success': True,
'options': options,
'categories': categories
})
except Exception as e:
print(f"登录获取选项失败: {str(e)}")
return JSONResponse(
status_code=500,
content={'success': False, 'error': str(e)}
)
@app.get("/api/health")
async def health_check():
"""健康检查端点,用于验证服务是否正常运行"""
return {"status": "healthy", "service": "alpha-transformer"}
@app.get("/api/download/{alpha_id}")
async def download_results(alpha_id: str):
"""
下载生成结果的 zip 压缩包
包含三个 JSON 文件:success, candidates, error
"""
try:
transformer_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'Tranformer')
output_dir = os.path.join(transformer_dir, 'output')
# 检查文件是否存在
files_to_zip = {
'Alpha_generated_expressions_success.json': os.path.join(output_dir, 'Alpha_generated_expressions_success.json'),
'Alpha_candidates.json': os.path.join(output_dir, 'Alpha_candidates.json'),
'Alpha_generated_expressions_error.json': os.path.join(output_dir, 'Alpha_generated_expressions_error.json')
}
# 生成时间戳
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
zip_filename = f"{alpha_id}_{timestamp}.zip"
zip_path = os.path.join(output_dir, zip_filename)
# 创建 zip 文件
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
for arcname, filepath in files_to_zip.items():
if os.path.exists(filepath):
zipf.write(filepath, arcname)
# 返回文件
return FileResponse(
zip_path,
media_type='application/zip',
filename=zip_filename
)
except Exception as e:
return JSONResponse(
status_code=500,
content={"success": False, "error": str(e)}
)
if __name__ == "__main__":
# 启动 FastAPI 应用,使用 uvicorn 作为 ASGI 服务器
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)