You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
404 lines
14 KiB
404 lines
14 KiB
# FastAPI 应用主入口
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse, FileResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi.templating import Jinja2Templates
|
|
import sys
|
|
import os
|
|
import json
|
|
import subprocess
|
|
import uuid
|
|
import asyncio
|
|
import zipfile
|
|
from datetime import datetime
|
|
|
|
# 创建 FastAPI 应用实例
|
|
app = FastAPI(title="Alpha Transformer", version="1.0.0")
|
|
|
|
# Chrome DevTools 健康检查端点(可选,阻止 404 日志)
|
|
@app.get("/.well-known/appspecific/com.chrome.devtools.json")
|
|
async def chrome_devtools_check():
|
|
"""Chrome DevTools 健康检查"""
|
|
return {"status": "ok"}
|
|
|
|
# 提供模板目录中的 CSS 和 JS 文件
|
|
@app.get("/styles.css")
|
|
async def get_styles():
|
|
"""提供 styles.css 文件"""
|
|
from fastapi.responses import FileResponse
|
|
return FileResponse("templates/styles.css")
|
|
|
|
@app.get("/app.js")
|
|
async def get_app_js():
|
|
"""提供 app.js 文件"""
|
|
from fastapi.responses import FileResponse
|
|
return FileResponse("templates/app.js")
|
|
|
|
# 配置 Jinja2 模板引擎,用于渲染 HTML 页面
|
|
templates = Jinja2Templates(directory="templates")
|
|
|
|
# 存储正在运行的任务信息(task_id -> 任务状态)
|
|
transformer_tasks = {}
|
|
|
|
# 全局配置变量
|
|
app_config = {}
|
|
|
|
def load_config():
|
|
"""
|
|
加载配置文件
|
|
读取 config.json 文件,如果有则加载,否则返回空字典
|
|
"""
|
|
config_path = os.path.join(os.path.dirname(__file__), 'config.json')
|
|
if os.path.exists(config_path):
|
|
try:
|
|
with open(config_path, 'r', encoding='utf-8') as f:
|
|
config = json.load(f)
|
|
print(f"✓ 已加载配置文件: {config_path}")
|
|
return config
|
|
except Exception as e:
|
|
print(f"⚠ 加载配置文件失败: {e}")
|
|
return {}
|
|
else:
|
|
print(f"⚠ 配置文件不存在: {config_path}")
|
|
return {}
|
|
|
|
# 启动时加载配置
|
|
app_config = load_config()
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse)
|
|
async def home():
|
|
# 读取并返回前端首页 HTML
|
|
with open("templates/index.html", "r", encoding="utf-8") as f:
|
|
return f.read()
|
|
|
|
|
|
@app.get("/api/config/defaults")
|
|
async def get_config_defaults():
|
|
"""
|
|
获取默认配置
|
|
返回 config.json 中的配置作为表单默认值
|
|
"""
|
|
return JSONResponse(content={
|
|
"success": True,
|
|
"config": app_config
|
|
})
|
|
|
|
|
|
@app.post("/api/config/save")
|
|
async def save_config(request: Request):
|
|
"""
|
|
保存配置到 config.json
|
|
用于在页面上修改配置后保存
|
|
"""
|
|
try:
|
|
data = await request.json()
|
|
|
|
# 更新全局配置
|
|
global app_config
|
|
app_config = data
|
|
|
|
# 写入文件
|
|
config_path = os.path.join(os.path.dirname(__file__), 'config.json')
|
|
with open(config_path, 'w', encoding='utf-8') as f:
|
|
json.dump(data, f, indent=4, ensure_ascii=False)
|
|
|
|
return JSONResponse(content={
|
|
"success": True,
|
|
"message": "配置已保存"
|
|
})
|
|
except Exception as e:
|
|
return JSONResponse(
|
|
status_code=500,
|
|
content={"success": False, "error": str(e)}
|
|
)
|
|
|
|
|
|
@app.post("/api/generate")
|
|
async def generate_alpha(request: Request):
|
|
"""
|
|
生成 Alpha 变种的 API 端点
|
|
接收前端表单数据,启动 Transformer 脚本执行 Alpha 生成任务
|
|
"""
|
|
print("=" * 50)
|
|
print("收到生成变种请求")
|
|
|
|
try:
|
|
# 解析请求数据
|
|
data = await request.json()
|
|
print(f"请求数据: alpha_id={data.get('alpha_id')}, llm_model={data.get('llm_model')}")
|
|
|
|
# 生成唯一任务 ID
|
|
task_id = str(uuid.uuid4())
|
|
print(f"生成任务 ID: {task_id}")
|
|
|
|
# 定义必须提交的字段
|
|
required_fields = [
|
|
"alpha_id",
|
|
"llm_api_key",
|
|
"llm_base_url",
|
|
"llm_model",
|
|
"brain_username",
|
|
"brain_password"
|
|
]
|
|
|
|
# 检查必填字段是否完整
|
|
for field in required_fields:
|
|
if not data.get(field):
|
|
print(f"缺少必填字段: {field}")
|
|
return JSONResponse(
|
|
status_code=400,
|
|
content={"success": False, "error": f"Missing required field: {field}"}
|
|
)
|
|
|
|
# 获取脚本所在目录和 Transformer 子目录
|
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
transformer_dir = os.path.join(script_dir, 'Tranformer')
|
|
print(f"Transformer 目录: {transformer_dir}")
|
|
|
|
# 构建传递给 Transformer 脚本的配置
|
|
config = {
|
|
"LLM_model_name": data.get('llm_model'),
|
|
"LLM_API_KEY": data.get('llm_api_key'),
|
|
"llm_base_url": data.get('llm_base_url'),
|
|
"username": data.get('brain_username'),
|
|
"password": data.get('brain_password'),
|
|
"alpha_id": data.get('alpha_id'),
|
|
"top_n_datafield": int(data.get('top_n_datafield', 50)),
|
|
"user_region": data.get('user_region'),
|
|
"user_universe": data.get('user_universe'),
|
|
"user_delay": int(data.get('user_delay')) if data.get('user_delay') else None,
|
|
"user_category": data.get('user_category'),
|
|
"user_data_type": data.get('user_data_type', 'MATRIX')
|
|
}
|
|
print(f"配置已构建: LLM_model={config['LLM_model_name']}, alpha_id={config['alpha_id']}")
|
|
|
|
# 将配置写入临时 JSON 文件,供 Transformer 脚本读取
|
|
config_path = os.path.join(transformer_dir, f'config_{task_id}.json')
|
|
with open(config_path, 'w', encoding='utf-8') as f:
|
|
json.dump(config, f, indent=4)
|
|
print(f"配置文件已写入: {config_path}")
|
|
|
|
try:
|
|
# 启动 Transformer.py 子进程执行 Alpha 生成
|
|
print(f"启动 Transformer 脚本...")
|
|
process = subprocess.run(
|
|
[sys.executable, '-u', os.path.join(transformer_dir, 'Transformer.py'), config_path],
|
|
cwd=transformer_dir,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=600,
|
|
env={**os.environ, "PYTHONIOENCODING": "utf-8"}
|
|
)
|
|
print(f"Transformer 脚本执行完成,返回码: {process.returncode}")
|
|
|
|
# 定义输出文件路径
|
|
output_file = os.path.join(transformer_dir, 'output', 'Alpha_generated_expressions_success.json')
|
|
candidates_file = os.path.join(transformer_dir, 'output', 'Alpha_candidates.json')
|
|
error_file = os.path.join(transformer_dir, 'output', 'Alpha_generated_expressions_error.json')
|
|
|
|
# 构建响应数据
|
|
result = {
|
|
"success": True,
|
|
"alpha_id": data.get('alpha_id'),
|
|
"stdout": process.stdout,
|
|
"stderr": process.stderr,
|
|
"return_code": process.returncode
|
|
}
|
|
|
|
# 读取成功生成的表达式
|
|
if os.path.exists(output_file):
|
|
print(f"读取成功表达式文件: {output_file}")
|
|
with open(output_file, 'r', encoding='utf-8') as f:
|
|
result['expressions_success'] = json.load(f)
|
|
else:
|
|
print(f"成功表达式文件不存在: {output_file}")
|
|
result['expressions_success'] = []
|
|
|
|
# 读取候选表达式
|
|
if os.path.exists(candidates_file):
|
|
print(f"读取候选表达式文件: {candidates_file}")
|
|
with open(candidates_file, 'r', encoding='utf-8') as f:
|
|
result['candidates'] = json.load(f)
|
|
else:
|
|
print(f"候选表达式文件不存在: {candidates_file}")
|
|
result['candidates'] = []
|
|
|
|
# 读取生成失败的表达式
|
|
if os.path.exists(error_file):
|
|
print(f"读取错误表达式文件: {error_file}")
|
|
with open(error_file, 'r', encoding='utf-8') as f:
|
|
result['expressions_error'] = json.load(f)
|
|
else:
|
|
print(f"错误表达式文件不存在: {error_file}")
|
|
result['expressions_error'] = []
|
|
|
|
print(f"成功: {len(result['expressions_success'])} 个, 候选: {len(result['candidates'])} 个, 错误: {len(result['expressions_error'])} 个")
|
|
print("=" * 50)
|
|
return JSONResponse(content=result)
|
|
|
|
finally:
|
|
# 清理临时配置文件
|
|
if os.path.exists(config_path):
|
|
os.remove(config_path)
|
|
print(f"已清理临时配置文件: {config_path}")
|
|
|
|
except subprocess.TimeoutExpired:
|
|
print("任务执行超时 (600秒)")
|
|
return JSONResponse(
|
|
status_code=408,
|
|
content={"success": False, "error": "Task timeout (600s)"}
|
|
)
|
|
except Exception as e:
|
|
print(f"执行异常: {str(e)}")
|
|
return JSONResponse(
|
|
status_code=500,
|
|
content={"success": False, "error": str(e)}
|
|
)
|
|
|
|
|
|
@app.post("/api/transformer/login-and-fetch-options")
|
|
async def login_and_fetch_options(request: Request):
|
|
"""
|
|
登录 BRAIN 并获取地区、Delay、Universe、类别等选项
|
|
用于填充高级选项表单
|
|
"""
|
|
try:
|
|
data = await request.json()
|
|
username = data.get('username')
|
|
password = data.get('password')
|
|
|
|
if not username or not password:
|
|
return JSONResponse(
|
|
status_code=400,
|
|
content={'success': False, 'error': 'Username and password are required'}
|
|
)
|
|
|
|
# 添加 Transformer 目录到 sys.path
|
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
transformer_dir = os.path.join(script_dir, 'Tranformer')
|
|
if transformer_dir not in sys.path:
|
|
sys.path.append(transformer_dir)
|
|
|
|
# 导入必要的模块
|
|
from Tranformer.ace_lib import SingleSession, get_instrument_type_region_delay
|
|
import pandas as pd
|
|
|
|
# 创建新的会话实例
|
|
session = SingleSession()
|
|
session.auth = (username, password)
|
|
|
|
brain_api_url = "https://api.worldquantbrain.com"
|
|
response = session.post(brain_api_url + "/authentication")
|
|
|
|
if response.status_code == 201:
|
|
# 认证成功
|
|
pass
|
|
elif response.status_code == 401:
|
|
return JSONResponse(
|
|
status_code=401,
|
|
content={'success': False, 'error': 'Authentication failed: Invalid credentials'}
|
|
)
|
|
else:
|
|
return JSONResponse(
|
|
status_code=400,
|
|
content={'success': False, 'error': f'Authentication failed: {response.status_code}'}
|
|
)
|
|
|
|
# 获取 region/delay/universe 选项
|
|
df = get_instrument_type_region_delay(session)
|
|
|
|
# 获取数据类别
|
|
categories_resp = session.get(brain_api_url + "/data-categories")
|
|
categories = []
|
|
if categories_resp.status_code == 200:
|
|
categories_data = categories_resp.json()
|
|
if isinstance(categories_data, list):
|
|
categories = categories_data
|
|
elif isinstance(categories_data, dict):
|
|
categories = categories_data.get('results', [])
|
|
|
|
# 转换 DataFrame 为前端需要的嵌套字典结构
|
|
# 结构: Region -> Delay -> Universe
|
|
df_equity = df[df['InstrumentType'] == 'EQUITY']
|
|
|
|
options = {}
|
|
for _, row in df_equity.iterrows():
|
|
region = row['Region']
|
|
delay = row['Delay']
|
|
universes = row['Universe'] # 这是一个列表
|
|
|
|
if region not in options:
|
|
options[region] = {}
|
|
|
|
# 将 delay 转换为字符串作为字典的键
|
|
delay_str = str(delay)
|
|
if delay_str not in options[region]:
|
|
options[region][delay_str] = universes
|
|
|
|
return JSONResponse(content={
|
|
'success': True,
|
|
'options': options,
|
|
'categories': categories
|
|
})
|
|
|
|
except Exception as e:
|
|
print(f"登录获取选项失败: {str(e)}")
|
|
return JSONResponse(
|
|
status_code=500,
|
|
content={'success': False, 'error': str(e)}
|
|
)
|
|
|
|
|
|
@app.get("/api/health")
|
|
async def health_check():
|
|
"""健康检查端点,用于验证服务是否正常运行"""
|
|
return {"status": "healthy", "service": "alpha-transformer"}
|
|
|
|
|
|
@app.get("/api/download/{alpha_id}")
|
|
async def download_results(alpha_id: str):
|
|
"""
|
|
下载生成结果的 zip 压缩包
|
|
包含三个 JSON 文件:success, candidates, error
|
|
"""
|
|
try:
|
|
transformer_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'Tranformer')
|
|
output_dir = os.path.join(transformer_dir, 'output')
|
|
|
|
# 检查文件是否存在
|
|
files_to_zip = {
|
|
'Alpha_generated_expressions_success.json': os.path.join(output_dir, 'Alpha_generated_expressions_success.json'),
|
|
'Alpha_candidates.json': os.path.join(output_dir, 'Alpha_candidates.json'),
|
|
'Alpha_generated_expressions_error.json': os.path.join(output_dir, 'Alpha_generated_expressions_error.json')
|
|
}
|
|
|
|
# 生成时间戳
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
zip_filename = f"{alpha_id}_{timestamp}.zip"
|
|
zip_path = os.path.join(output_dir, zip_filename)
|
|
|
|
# 创建 zip 文件
|
|
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
|
for arcname, filepath in files_to_zip.items():
|
|
if os.path.exists(filepath):
|
|
zipf.write(filepath, arcname)
|
|
|
|
# 返回文件
|
|
return FileResponse(
|
|
zip_path,
|
|
media_type='application/zip',
|
|
filename=zip_filename
|
|
)
|
|
|
|
except Exception as e:
|
|
return JSONResponse(
|
|
status_code=500,
|
|
content={"success": False, "error": str(e)}
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# 启动 FastAPI 应用,使用 uvicorn 作为 ASGI 服务器
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|