You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
802 lines
32 KiB
802 lines
32 KiB
import shutil
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.responses import HTMLResponse, JSONResponse, FileResponse
|
|
from fastapi.templating import Jinja2Templates
|
|
import sys
|
|
import os
|
|
import json
|
|
import uuid
|
|
import zipfile
|
|
from datetime import datetime
|
|
import logging
|
|
import openai
|
|
from Tranformer.ace_lib import SingleSession, get_instrument_type_region_delay, get_datafields, get_datasets
|
|
from Tranformer.Transformer import run_transformer
|
|
import asyncio
|
|
import pandas as pd
|
|
|
|
# 配置日志
|
|
logger = logging.getLogger("alpha_transformer")
|
|
if not logger.handlers:
|
|
handler = logging.StreamHandler(sys.stdout)
|
|
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
|
logger.addHandler(handler)
|
|
logger.setLevel(logging.INFO)
|
|
|
|
# 创建 FastAPI 应用实例
|
|
app = FastAPI(title="Alpha Transformer", version="1.0.0")
|
|
|
|
# Chrome DevTools 健康检查端点(可选,阻止 404 日志)
|
|
@app.get("/.well-known/appspecific/com.chrome.devtools.json")
|
|
async def chrome_devtools_check():
|
|
"""Chrome DevTools 健康检查"""
|
|
return {"status": "ok"}
|
|
|
|
# 提供模板目录中的 CSS 和 JS 文件
|
|
@app.get("/styles.css")
|
|
async def get_styles():
|
|
"""提供 styles.css 文件"""
|
|
from fastapi.responses import FileResponse
|
|
return FileResponse("templates/styles.css")
|
|
|
|
@app.get("/app.js")
|
|
async def get_app_js():
|
|
"""提供 app.js 文件"""
|
|
from fastapi.responses import FileResponse
|
|
return FileResponse("templates/app.js")
|
|
|
|
# 配置 Jinja2 模板引擎,用于渲染 HTML 页面
|
|
templates = Jinja2Templates(directory="templates")
|
|
|
|
# 程序启动前清空 Tranformer/output 目录
|
|
if os.path.exists("Tranformer/output"):
|
|
for filename in os.listdir("Tranformer/output"):
|
|
file_path = os.path.join("Tranformer/output", filename)
|
|
try:
|
|
if os.path.isfile(file_path) or os.path.islink(file_path):
|
|
os.unlink(file_path)
|
|
elif os.path.isdir(file_path):
|
|
shutil.rmtree(file_path)
|
|
except Exception as e:
|
|
logger.error(f"⚠ 清空目录 {file_path} 时出错: {e}")
|
|
|
|
|
|
|
|
# 存储正在运行的任务信息(task_id -> 任务状态)
|
|
transformer_tasks = {}
|
|
|
|
# 全局配置变量
|
|
app_config = {}
|
|
|
|
def load_config():
|
|
"""
|
|
加载配置文件
|
|
读取 config.json 文件,如果有则加载,否则返回空字典
|
|
"""
|
|
config_path = os.path.join(os.path.dirname(__file__), 'config.json')
|
|
if os.path.exists(config_path):
|
|
try:
|
|
with open(config_path, 'r', encoding='utf-8') as f:
|
|
config = json.load(f)
|
|
logger.info(f"✓ 已加载配置文件: {config_path}")
|
|
return config
|
|
except Exception as e:
|
|
logger.warning(f"⚠ 加载配置文件失败: {e}")
|
|
return {}
|
|
else:
|
|
logger.warning(f"⚠ 配置文件不存在: {config_path}")
|
|
return {}
|
|
|
|
# 启动时加载配置
|
|
app_config = load_config()
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse)
|
|
async def home():
|
|
# 读取并返回前端首页 HTML
|
|
with open("templates/index.html", "r", encoding="utf-8") as f:
|
|
return f.read()
|
|
|
|
|
|
@app.get("/api/config/defaults")
|
|
async def get_config_defaults():
|
|
"""
|
|
获取默认配置
|
|
返回 config.json 中的配置作为表单默认值
|
|
"""
|
|
return JSONResponse(content={
|
|
"success": True,
|
|
"config": app_config
|
|
})
|
|
|
|
|
|
@app.post("/api/config/save")
|
|
async def save_config(request: Request):
|
|
"""
|
|
保存配置到 config.json
|
|
用于在页面上修改配置后保存
|
|
"""
|
|
try:
|
|
data = await request.json()
|
|
|
|
# 更新全局配置
|
|
global app_config
|
|
app_config = data
|
|
|
|
# 写入文件
|
|
config_path = os.path.join(os.path.dirname(__file__), 'config.json')
|
|
with open(config_path, 'w', encoding='utf-8') as f:
|
|
json.dump(data, f, indent=4, ensure_ascii=False)
|
|
|
|
return JSONResponse(content={
|
|
"success": True,
|
|
"message": "配置已保存"
|
|
})
|
|
except Exception as e:
|
|
return JSONResponse(
|
|
status_code=500,
|
|
content={"success": False, "error": str(e)}
|
|
)
|
|
|
|
|
|
@app.post("/api/generate")
|
|
async def generate_alpha(request: Request):
|
|
"""
|
|
生成 Alpha 变种的 API 端点
|
|
接收前端表单数据,启动 Transformer 脚本执行 Alpha 生成任务
|
|
"""
|
|
logger.info("=" * 50)
|
|
logger.info("收到生成变种请求")
|
|
|
|
try:
|
|
# 解析请求数据
|
|
data = await request.json()
|
|
logger.info(f"请求数据: alpha_id={data.get('alpha_id')}, llm_model={data.get('llm_model')}")
|
|
|
|
# 生成唯一任务 ID
|
|
task_id = str(uuid.uuid4())
|
|
logger.info(f"生成任务 ID: {task_id}")
|
|
|
|
# 定义必须提交的字段
|
|
required_fields = [
|
|
"alpha_id",
|
|
"llm_api_key",
|
|
"llm_base_url",
|
|
"llm_model",
|
|
"brain_username",
|
|
"brain_password"
|
|
]
|
|
|
|
# 检查必填字段是否完整
|
|
for field in required_fields:
|
|
if not data.get(field):
|
|
logger.warning(f"缺少必填字段: {field}")
|
|
return JSONResponse(
|
|
status_code=400,
|
|
content={"success": False, "error": f"Missing required field: {field}"}
|
|
)
|
|
|
|
# 获取脚本所在目录和 Transformer 子目录
|
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
transformer_dir = os.path.join(script_dir, 'Tranformer')
|
|
logger.info(f"Transformer 目录: {transformer_dir}")
|
|
|
|
# 构建传递给 Transformer 脚本的配置
|
|
config = {
|
|
"LLM_model_name": data.get('llm_model'),
|
|
"LLM_API_KEY": data.get('llm_api_key'),
|
|
"llm_base_url": data.get('llm_base_url'),
|
|
"username": data.get('brain_username'),
|
|
"password": data.get('brain_password'),
|
|
"alpha_id": data.get('alpha_id'),
|
|
"top_n_datafield": int(data.get('top_n_datafield', 50)),
|
|
"user_region": data.get('user_region'),
|
|
"user_universe": data.get('user_universe'),
|
|
"user_delay": int(data.get('user_delay')) if data.get('user_delay') else None,
|
|
"user_category": data.get('user_category'),
|
|
"user_data_type": data.get('user_data_type', 'MATRIX'),
|
|
"max_retries": int(data.get('max_retries', 20))
|
|
}
|
|
logger.info(f"配置已构建: LLM_model={config['LLM_model_name']}, alpha_id={config['alpha_id']}, max_retries={config['max_retries']}")
|
|
|
|
# 检查数据字段缓存是否存在,不存在则自动下载
|
|
dataset_dir = os.path.join(script_dir, 'dataset')
|
|
os.makedirs(dataset_dir, exist_ok=True)
|
|
|
|
user_category = data.get('user_category', [])
|
|
categories_to_check = user_category if isinstance(user_category, list) else [user_category] if user_category else []
|
|
|
|
# 输出用户选择的类别
|
|
logger.info(f"[Generate] 用户选择的类别: {categories_to_check}")
|
|
logger.info(f"[Generate] Region: {data.get('user_region')}, Delay: {data.get('user_delay')}, Universe: {data.get('user_universe')}")
|
|
|
|
# 创建会话来获取数据集信息
|
|
session = SingleSession()
|
|
session.auth = (data.get('brain_username'), data.get('brain_password'))
|
|
brain_api_url = "https://api.worldquantbrain.com"
|
|
auth_response = session.post(brain_api_url + "/authentication")
|
|
|
|
if auth_response.status_code != 201:
|
|
return JSONResponse(
|
|
status_code=401,
|
|
content={"success": False, "error": "Authentication failed"}
|
|
)
|
|
|
|
# 获取所有可用的数据集
|
|
datasets_df = get_datasets(
|
|
s=session,
|
|
instrument_type="EQUITY",
|
|
region=data.get('user_region', 'USA'),
|
|
delay=int(data.get('user_delay', 1)),
|
|
universe=data.get('user_universe', 'TOP3000'),
|
|
theme="ALL"
|
|
)
|
|
|
|
if datasets_df.empty:
|
|
return JSONResponse(
|
|
status_code=404,
|
|
content={"success": False, "error": "No datasets found"}
|
|
)
|
|
|
|
# 使用 category_name 列(category 列是字典类型)
|
|
category_column = 'category_name' if 'category_name' in datasets_df.columns else 'category'
|
|
|
|
# 如果没有指定类别,使用所有类别
|
|
if not categories_to_check:
|
|
if category_column == 'category':
|
|
# category 是字典,提取 name
|
|
categories_to_check = datasets_df[category_column].apply(lambda x: x['name'] if isinstance(x, dict) and 'name' in x else x).unique().tolist()
|
|
else:
|
|
categories_to_check = datasets_df[category_column].unique().tolist()
|
|
logger.info(f"未指定类别,将使用所有 {len(categories_to_check)} 个类别: {categories_to_check}")
|
|
|
|
# 构建类别到 dataset_id 的映射
|
|
category_to_datasets = {}
|
|
for cat in categories_to_check:
|
|
if category_column == 'category':
|
|
# category 是字典,需要匹配 name 字段
|
|
cat_datasets = datasets_df[datasets_df[category_column].apply(lambda x: x.get('name') if isinstance(x, dict) else x) == cat]
|
|
else:
|
|
cat_datasets = datasets_df[datasets_df[category_column] == cat]
|
|
if not cat_datasets.empty:
|
|
category_to_datasets[cat] = cat_datasets['id'].tolist()
|
|
|
|
# 检查每个类别的缓存文件
|
|
missing_categories = []
|
|
for cat in categories_to_check:
|
|
cache_filename = f"datafields_cache_{data.get('user_region', 'USA')}_{data.get('user_universe', 'TOP3000')}_D{data.get('user_delay', 1)}_{cat}.csv"
|
|
cache_path = os.path.join(dataset_dir, cache_filename)
|
|
if not os.path.exists(cache_path):
|
|
missing_categories.append(cat)
|
|
|
|
# 如果有缺失的缓存,检查是否都能对应到数据集
|
|
if missing_categories:
|
|
# 检查是否有类别在数据集中不存在
|
|
invalid_categories = [cat for cat in missing_categories if cat not in category_to_datasets]
|
|
if invalid_categories:
|
|
error_msg = f"以下类别在数据集中不存在,无法下载: {', '.join(invalid_categories)}。请检查类别名称或先点击'下载数据字段缓存'按钮。"
|
|
logger.error(error_msg)
|
|
return JSONResponse(
|
|
status_code=400,
|
|
content={"success": False, "error": error_msg}
|
|
)
|
|
|
|
logger.info(f"检测到缺失的数据字段缓存: {missing_categories},自动下载中...")
|
|
|
|
downloaded_count = 0
|
|
failed_categories = []
|
|
|
|
for cat in missing_categories:
|
|
dataset_ids = category_to_datasets[cat]
|
|
all_datafields = []
|
|
|
|
try:
|
|
# 下载该类别下的所有数据集
|
|
for dataset_id in dataset_ids:
|
|
logger.info(f"下载数据集 {dataset_id} (类别: {cat})...")
|
|
datafields_df = get_datafields(
|
|
s=session,
|
|
instrument_type="EQUITY",
|
|
region=data.get('user_region', 'USA'),
|
|
delay=int(data.get('user_delay', 1)),
|
|
universe=data.get('user_universe', 'TOP3000'),
|
|
dataset_id=dataset_id,
|
|
data_type=data.get('user_data_type', 'MATRIX')
|
|
)
|
|
|
|
if not datafields_df.empty:
|
|
all_datafields.append(datafields_df)
|
|
logger.info(f"数据集 {dataset_id} 包含 {len(datafields_df)} 个字段")
|
|
|
|
if all_datafields:
|
|
# 合并该类别下的所有数据字段
|
|
combined_df = pd.concat(all_datafields, ignore_index=True)
|
|
combined_df.drop_duplicates(subset=['id'], inplace=True)
|
|
|
|
cache_filename = f"datafields_cache_{data.get('user_region', 'USA')}_{data.get('user_universe', 'TOP3000')}_D{data.get('user_delay', 1)}_{cat}.csv"
|
|
cache_path = os.path.join(dataset_dir, cache_filename)
|
|
combined_df.to_csv(cache_path, index=False)
|
|
downloaded_count += 1
|
|
logger.info(f"自动下载完成: {cache_filename} ({len(combined_df)} 个字段)")
|
|
else:
|
|
failed_categories.append(cat)
|
|
logger.error(f"类别 '{cat}' 没有获取到任何数据字段")
|
|
|
|
except Exception as e:
|
|
failed_categories.append(cat)
|
|
logger.error(f"下载类别 '{cat}' 失败: {e}")
|
|
|
|
# 如果有下载失败的类别,返回错误
|
|
if failed_categories:
|
|
error_msg = f"以下类别下载失败: {', '.join(failed_categories)}。请检查网络连接或手动下载。"
|
|
logger.error(error_msg)
|
|
return JSONResponse(
|
|
status_code=500,
|
|
content={"success": False, "error": error_msg}
|
|
)
|
|
|
|
logger.info(f"自动下载完成,共下载 {downloaded_count}/{len(missing_categories)} 个类别")
|
|
|
|
# 直接调用 Transformer 模块(替代 subprocess)
|
|
logger.info(f"启动 Transformer 执行...")
|
|
try:
|
|
transformer_result = await run_transformer(config)
|
|
logger.info(f"Transformer 执行完成: {transformer_result}")
|
|
|
|
# 定义输出文件路径
|
|
output_file = os.path.join(transformer_dir, 'output', 'Alpha_generated_expressions_success.json')
|
|
candidates_file = os.path.join(transformer_dir, 'output', 'Alpha_candidates.json')
|
|
error_file = os.path.join(transformer_dir, 'output', 'Alpha_generated_expressions_error.json')
|
|
|
|
# 构建响应数据
|
|
result = {
|
|
"success": True,
|
|
"alpha_id": data.get('alpha_id'),
|
|
"stdout": "",
|
|
"stderr": "",
|
|
"return_code": 0
|
|
}
|
|
|
|
# 读取成功生成的表达式
|
|
if os.path.exists(output_file):
|
|
logger.info(f"读取成功表达式文件: {output_file}")
|
|
with open(output_file, 'r', encoding='utf-8') as f:
|
|
result['expressions_success'] = json.load(f)
|
|
else:
|
|
logger.info(f"成功表达式文件不存在: {output_file}")
|
|
result['expressions_success'] = []
|
|
|
|
# 读取候选表达式
|
|
if os.path.exists(candidates_file):
|
|
logger.info(f"读取候选表达式文件: {candidates_file}")
|
|
with open(candidates_file, 'r', encoding='utf-8') as f:
|
|
result['candidates'] = json.load(f)
|
|
else:
|
|
logger.info(f"候选表达式文件不存在: {candidates_file}")
|
|
result['candidates'] = []
|
|
|
|
# 读取生成失败的表达式
|
|
if os.path.exists(error_file):
|
|
logger.info(f"读取错误表达式文件: {error_file}")
|
|
with open(error_file, 'r', encoding='utf-8') as f:
|
|
result['expressions_error'] = json.load(f)
|
|
else:
|
|
logger.info(f"错误表达式文件不存在: {error_file}")
|
|
result['expressions_error'] = []
|
|
|
|
logger.info(f"成功: {len(result['expressions_success'])} 个, 候选: {len(result['candidates'])} 个, 错误: {len(result['expressions_error'])} 个")
|
|
logger.info("=" * 50)
|
|
return JSONResponse(content=result)
|
|
|
|
except Exception as transformer_error:
|
|
logger.error(f"Transformer 执行失败: {transformer_error}")
|
|
raise
|
|
except Exception as e:
|
|
logger.warning(f"执行异常: {str(e)}")
|
|
return JSONResponse(
|
|
status_code=500,
|
|
content={"success": False, "error": str(e)}
|
|
)
|
|
|
|
|
|
@app.post("/api/transformer/login-and-fetch-options")
|
|
async def login_and_fetch_options(request: Request):
|
|
"""
|
|
登录 BRAIN 并获取地区、Delay、Universe、类别等选项
|
|
用于填充高级选项表单
|
|
"""
|
|
try:
|
|
data = await request.json()
|
|
username = data.get('username')
|
|
password = data.get('password')
|
|
|
|
if not username or not password:
|
|
return JSONResponse(
|
|
status_code=400,
|
|
content={'success': False, 'error': 'Username and password are required'}
|
|
)
|
|
|
|
# 添加 Transformer 目录到 sys.path
|
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
transformer_dir = os.path.join(script_dir, 'Tranformer')
|
|
if transformer_dir not in sys.path:
|
|
sys.path.append(transformer_dir)
|
|
|
|
# 创建新的会话实例
|
|
session = SingleSession()
|
|
session.auth = (username, password)
|
|
|
|
brain_api_url = "https://api.worldquantbrain.com"
|
|
response = session.post(brain_api_url + "/authentication")
|
|
|
|
if response.status_code == 201:
|
|
# 认证成功
|
|
pass
|
|
elif response.status_code == 401:
|
|
return JSONResponse(
|
|
status_code=401,
|
|
content={'success': False, 'error': 'Authentication failed: Invalid credentials'}
|
|
)
|
|
else:
|
|
return JSONResponse(
|
|
status_code=400,
|
|
content={'success': False, 'error': f'Authentication failed: {response.status_code}'}
|
|
)
|
|
|
|
# 获取 region/delay/universe 选项
|
|
df = get_instrument_type_region_delay(session)
|
|
|
|
# 获取数据类别
|
|
categories_resp = session.get(brain_api_url + "/data-categories")
|
|
categories = []
|
|
if categories_resp.status_code == 200:
|
|
categories_data = categories_resp.json()
|
|
if isinstance(categories_data, list):
|
|
categories = categories_data
|
|
elif isinstance(categories_data, dict):
|
|
categories = categories_data.get('results', [])
|
|
|
|
# 转换 DataFrame 为前端需要的嵌套字典结构
|
|
# 结构: Region -> Delay -> Universe
|
|
df_equity = df[df['InstrumentType'] == 'EQUITY']
|
|
|
|
options = {}
|
|
for _, row in df_equity.iterrows():
|
|
region = row['Region']
|
|
delay = row['Delay']
|
|
universes = row['Universe'] # 这是一个列表
|
|
|
|
if region not in options:
|
|
options[region] = {}
|
|
|
|
# 将 delay 转换为字符串作为字典的键
|
|
delay_str = str(delay)
|
|
if delay_str not in options[region]:
|
|
options[region][delay_str] = universes
|
|
|
|
return JSONResponse(content={
|
|
'success': True,
|
|
'options': options,
|
|
'categories': categories
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.warning(f"登录获取选项失败: {str(e)}")
|
|
return JSONResponse(
|
|
status_code=500,
|
|
content={'success': False, 'error': str(e)}
|
|
)
|
|
|
|
|
|
@app.get("/api/health")
|
|
async def health_check():
|
|
"""健康检查端点,用于验证服务是否正常运行"""
|
|
return {"status": "healthy", "service": "alpha-transformer"}
|
|
|
|
|
|
@app.post("/api/test-llm")
|
|
async def test_llm_connection(request: Request):
|
|
"""
|
|
测试 LLM 连接
|
|
接收 LLM 配置,尝试连接并返回测试结果
|
|
对 529 状态码会自动重试,默认重试 3 次
|
|
"""
|
|
try:
|
|
data = await request.json()
|
|
|
|
api_key = data.get('llm_api_key')
|
|
base_url = data.get('llm_base_url')
|
|
model = data.get('llm_model')
|
|
max_retries = data.get('max_retries', 3)
|
|
|
|
if not api_key or not base_url or not model:
|
|
return JSONResponse(
|
|
status_code=400,
|
|
content={"success": False, "error": "Missing required LLM configuration"}
|
|
)
|
|
|
|
logger.info(f"测试 LLM 连接: {base_url}, 模型: {model}, 最大重试次数: {max_retries}")
|
|
|
|
# 创建客户端
|
|
client = openai.AsyncOpenAI(
|
|
api_key=api_key,
|
|
base_url=base_url
|
|
)
|
|
|
|
# 重试机制
|
|
last_error = None
|
|
for attempt in range(1, max_retries + 1):
|
|
try:
|
|
response = await client.chat.completions.create(
|
|
model=model,
|
|
messages=[{"role": "user", "content": "Hello, this is a connection test. Reply with 'OK' only."}],
|
|
max_tokens=10,
|
|
timeout=30
|
|
)
|
|
|
|
# 检查响应
|
|
if response and response.choices and len(response.choices) > 0:
|
|
content = response.choices[0].message.content
|
|
logger.info(f"✓ LLM 连接测试成功: {content[:50]}...")
|
|
return JSONResponse(content={
|
|
"success": True,
|
|
"message": "LLM 连接成功",
|
|
"response": content[:100]
|
|
})
|
|
else:
|
|
return JSONResponse(
|
|
status_code=500,
|
|
content={"success": False, "error": "Empty response from LLM"}
|
|
)
|
|
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
last_error = e
|
|
|
|
# 检查是否是 529 错误(MiniMax 过载错误)
|
|
if "529" in error_msg or "overloaded" in error_msg.lower():
|
|
if attempt < max_retries:
|
|
logger.warning(f"⚠ LLM 返回 529 错误,第 {attempt}/{max_retries} 次尝试失败,2 秒后重试...")
|
|
await asyncio.sleep(2)
|
|
continue
|
|
else:
|
|
logger.error(f"✗ LLM 连接测试失败,已重试 {max_retries} 次,仍然返回 529 错误")
|
|
return JSONResponse(
|
|
status_code=529,
|
|
content={"success": False, "error": f"LLM service overloaded (529) after {max_retries} retries"}
|
|
)
|
|
else:
|
|
# 其他错误直接抛出
|
|
raise
|
|
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.warning(f"✗ LLM 连接测试失败: {error_msg}")
|
|
return JSONResponse(
|
|
status_code=500,
|
|
content={"success": False, "error": error_msg}
|
|
)
|
|
|
|
|
|
@app.post("/api/download-datafields")
|
|
async def download_datafields(request: Request):
|
|
"""
|
|
下载数据字段缓存
|
|
全量获取指定 region/delay/universe/category 的数据字段并保存到 CSV
|
|
"""
|
|
try:
|
|
data = await request.json()
|
|
username = data.get('username')
|
|
password = data.get('password')
|
|
region = data.get('region')
|
|
delay = data.get('delay')
|
|
universe = data.get('universe')
|
|
data_type = data.get('data_type', 'MATRIX')
|
|
category = data.get('category', []) # 类别列表
|
|
|
|
if not all([username, password, region, delay, universe]):
|
|
return JSONResponse(
|
|
status_code=400,
|
|
content={"success": False, "error": "Missing required parameters"}
|
|
)
|
|
|
|
# 创建会话
|
|
session = SingleSession()
|
|
session.auth = (username, password)
|
|
|
|
brain_api_url = "https://api.worldquantbrain.com"
|
|
response = session.post(brain_api_url + "/authentication")
|
|
|
|
if response.status_code != 201:
|
|
return JSONResponse(
|
|
status_code=401,
|
|
content={"success": False, "error": "Authentication failed"}
|
|
)
|
|
|
|
# 获取所有数据集信息
|
|
datasets_df = get_datasets(
|
|
s=session,
|
|
instrument_type="EQUITY",
|
|
region=region,
|
|
delay=delay,
|
|
universe=universe,
|
|
theme="ALL"
|
|
)
|
|
|
|
if datasets_df.empty:
|
|
return JSONResponse(
|
|
status_code=404,
|
|
content={"success": False, "error": "No datasets found"}
|
|
)
|
|
|
|
# 处理类别参数(前端总是传递类别列表)
|
|
categories_to_download = category if isinstance(category, list) else [category] if category else []
|
|
|
|
# 如果没有指定类别,使用所有类别
|
|
if not categories_to_download:
|
|
if 'category_id' in datasets_df.columns:
|
|
categories_to_download = datasets_df['category_id'].unique().tolist()
|
|
elif 'category_name' in datasets_df.columns:
|
|
categories_to_download = datasets_df['category_name'].unique().tolist()
|
|
else:
|
|
# category 是字典,提取 id
|
|
categories_to_download = datasets_df['category'].apply(lambda x: x['id'] if isinstance(x, dict) and 'id' in x else x).unique().tolist()
|
|
logger.info(f"未指定类别,将使用所有 {len(categories_to_download)} 个类别")
|
|
|
|
# 构建类别到 dataset_id 的映射(使用 category_id 匹配)
|
|
category_to_datasets = {}
|
|
for cat in categories_to_download:
|
|
if 'category_id' in datasets_df.columns:
|
|
cat_datasets = datasets_df[datasets_df['category_id'] == cat]
|
|
elif 'category_name' in datasets_df.columns:
|
|
cat_datasets = datasets_df[datasets_df['category_name'] == cat]
|
|
else:
|
|
cat_datasets = datasets_df[datasets_df['category'].apply(lambda x: x.get('id') if isinstance(x, dict) else x) == cat]
|
|
if not cat_datasets.empty:
|
|
category_to_datasets[cat] = cat_datasets['id'].tolist()
|
|
|
|
# 保存到 CSV (dataset 文件夹)
|
|
dataset_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'dataset')
|
|
os.makedirs(dataset_dir, exist_ok=True)
|
|
|
|
# 检查哪些类别已经存在缓存文件
|
|
categories_to_download_filtered = []
|
|
already_exists = []
|
|
for cat in categories_to_download:
|
|
cache_filename = f"datafields_cache_{region}_{universe}_D{delay}_{cat}.csv"
|
|
cache_path = os.path.join(dataset_dir, cache_filename)
|
|
if os.path.exists(cache_path):
|
|
already_exists.append(cat)
|
|
logger.info(f"类别 '{cat}' 的缓存文件已存在,跳过下载: {cache_filename}")
|
|
else:
|
|
categories_to_download_filtered.append(cat)
|
|
|
|
if already_exists:
|
|
logger.info(f"以下类别缓存已存在,将跳过: {already_exists}")
|
|
|
|
if not categories_to_download_filtered:
|
|
logger.info("所有类别的缓存文件都已存在,无需下载")
|
|
return JSONResponse(content={
|
|
"success": True,
|
|
"count": 0,
|
|
"message": "所有类别的缓存文件都已存在",
|
|
"skipped_categories": already_exists
|
|
})
|
|
|
|
categories_to_download = categories_to_download_filtered
|
|
|
|
logger.info(f"开始下载数据字段: region={region}, delay={delay}, universe={universe}, type={data_type}, categories={categories_to_download}")
|
|
|
|
total_count = 0
|
|
saved_files = []
|
|
|
|
# 每个类别单独下载并保存
|
|
for cat in categories_to_download:
|
|
if cat not in category_to_datasets:
|
|
logger.warning(f"类别 '{cat}' 在数据集中不存在,跳过")
|
|
continue
|
|
|
|
dataset_ids = category_to_datasets[cat]
|
|
all_datafields = []
|
|
|
|
# 下载该类别下的所有数据集
|
|
for dataset_id in dataset_ids:
|
|
logger.info(f"下载数据集 {dataset_id} (类别: {cat})...")
|
|
datafields_df = get_datafields(
|
|
s=session,
|
|
instrument_type="EQUITY",
|
|
region=region,
|
|
delay=delay,
|
|
universe=universe,
|
|
dataset_id=dataset_id,
|
|
data_type=data_type
|
|
)
|
|
|
|
if not datafields_df.empty:
|
|
all_datafields.append(datafields_df)
|
|
logger.info(f"数据集 {dataset_id} 包含 {len(datafields_df)} 个字段")
|
|
|
|
if all_datafields:
|
|
# 合并该类别下的所有数据字段
|
|
combined_df = pd.concat(all_datafields, ignore_index=True)
|
|
combined_df.drop_duplicates(subset=['id'], inplace=True)
|
|
|
|
cache_filename = f"datafields_cache_{region}_{universe}_D{delay}_{cat}.csv"
|
|
cache_path = os.path.join(dataset_dir, cache_filename)
|
|
combined_df.to_csv(cache_path, index=False)
|
|
total_count += len(combined_df)
|
|
saved_files.append(cache_filename)
|
|
logger.info(f"类别 '{cat}' 数据字段已保存: {cache_path}, 共 {len(combined_df)} 个字段")
|
|
else:
|
|
logger.warning(f"类别 '{cat}' 未找到数据字段")
|
|
|
|
if total_count == 0:
|
|
return JSONResponse(
|
|
status_code=404,
|
|
content={"success": False, "error": "No datafields found for any category"}
|
|
)
|
|
|
|
return JSONResponse(content={
|
|
"success": True,
|
|
"count": total_count,
|
|
"cache_files": saved_files
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"下载数据字段失败: {e}")
|
|
return JSONResponse(
|
|
status_code=500,
|
|
content={"success": False, "error": str(e)}
|
|
)
|
|
|
|
|
|
@app.get("/api/download/{alpha_id}")
|
|
async def download_results(alpha_id: str):
|
|
"""
|
|
下载生成结果的 zip 压缩包
|
|
包含三个 JSON 文件:success, candidates, error
|
|
"""
|
|
try:
|
|
transformer_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'Tranformer')
|
|
output_dir = os.path.join(transformer_dir, 'output')
|
|
|
|
# 检查文件是否存在
|
|
files_to_zip = {
|
|
'Alpha_generated_expressions_success.json': os.path.join(output_dir, 'Alpha_generated_expressions_success.json'),
|
|
'Alpha_candidates.json': os.path.join(output_dir, 'Alpha_candidates.json'),
|
|
'Alpha_generated_expressions_error.json': os.path.join(output_dir, 'Alpha_generated_expressions_error.json')
|
|
}
|
|
|
|
# 生成时间戳
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
zip_filename = f"{alpha_id}_{timestamp}.zip"
|
|
|
|
# 保存到 save_zip 文件夹(不会被清空)
|
|
save_zip_dir = os.path.join(transformer_dir, 'save_zip')
|
|
os.makedirs(save_zip_dir, exist_ok=True)
|
|
zip_path = os.path.join(save_zip_dir, zip_filename)
|
|
|
|
# 创建 zip 文件
|
|
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
|
for arcname, filepath in files_to_zip.items():
|
|
if os.path.exists(filepath):
|
|
zipf.write(filepath, arcname)
|
|
|
|
# 返回文件
|
|
return FileResponse(
|
|
zip_path,
|
|
media_type='application/zip',
|
|
filename=zip_filename
|
|
)
|
|
|
|
except Exception as e:
|
|
return JSONResponse(
|
|
status_code=500,
|
|
content={"success": False, "error": str(e)}
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# 启动 FastAPI 应用,使用 uvicorn 作为 ASGI 服务器
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|