import shutil from fastapi import FastAPI, Request from fastapi.responses import HTMLResponse, JSONResponse, FileResponse from fastapi.templating import Jinja2Templates import sys import os import json import uuid import zipfile from datetime import datetime import logging import openai from Tranformer.ace_lib import SingleSession, get_instrument_type_region_delay, get_datafields, get_datasets from Tranformer.Transformer import run_transformer import asyncio import pandas as pd # 配置日志 logger = logging.getLogger("alpha_transformer") if not logger.handlers: handler = logging.StreamHandler(sys.stdout) handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) logger.addHandler(handler) logger.setLevel(logging.INFO) # 创建 FastAPI 应用实例 app = FastAPI(title="Alpha Transformer", version="1.0.0") # Chrome DevTools 健康检查端点(可选,阻止 404 日志) @app.get("/.well-known/appspecific/com.chrome.devtools.json") async def chrome_devtools_check(): """Chrome DevTools 健康检查""" return {"status": "ok"} # 提供模板目录中的 CSS 和 JS 文件 @app.get("/styles.css") async def get_styles(): """提供 styles.css 文件""" from fastapi.responses import FileResponse return FileResponse("templates/styles.css") @app.get("/app.js") async def get_app_js(): """提供 app.js 文件""" from fastapi.responses import FileResponse return FileResponse("templates/app.js") # 配置 Jinja2 模板引擎,用于渲染 HTML 页面 templates = Jinja2Templates(directory="templates") # 程序启动前清空 Tranformer/output 目录 if os.path.exists("Tranformer/output"): for filename in os.listdir("Tranformer/output"): file_path = os.path.join("Tranformer/output", filename) try: if os.path.isfile(file_path) or os.path.islink(file_path): os.unlink(file_path) elif os.path.isdir(file_path): shutil.rmtree(file_path) except Exception as e: logger.error(f"⚠ 清空目录 {file_path} 时出错: {e}") # 存储正在运行的任务信息(task_id -> 任务状态) transformer_tasks = {} # 全局配置变量 app_config = {} def load_config(): """ 加载配置文件 读取 config.json 文件,如果有则加载,否则返回空字典 """ config_path = os.path.join(os.path.dirname(__file__), 'config.json') if os.path.exists(config_path): try: with open(config_path, 'r', encoding='utf-8') as f: config = json.load(f) logger.info(f"✓ 已加载配置文件: {config_path}") return config except Exception as e: logger.warning(f"⚠ 加载配置文件失败: {e}") return {} else: logger.warning(f"⚠ 配置文件不存在: {config_path}") return {} # 启动时加载配置 app_config = load_config() @app.get("/", response_class=HTMLResponse) async def home(): # 读取并返回前端首页 HTML with open("templates/index.html", "r", encoding="utf-8") as f: return f.read() @app.get("/api/config/defaults") async def get_config_defaults(): """ 获取默认配置 返回 config.json 中的配置作为表单默认值 """ return JSONResponse(content={ "success": True, "config": app_config }) @app.post("/api/config/save") async def save_config(request: Request): """ 保存配置到 config.json 用于在页面上修改配置后保存 """ try: data = await request.json() # 更新全局配置 global app_config app_config = data # 写入文件 config_path = os.path.join(os.path.dirname(__file__), 'config.json') with open(config_path, 'w', encoding='utf-8') as f: json.dump(data, f, indent=4, ensure_ascii=False) return JSONResponse(content={ "success": True, "message": "配置已保存" }) except Exception as e: return JSONResponse( status_code=500, content={"success": False, "error": str(e)} ) @app.post("/api/generate") async def generate_alpha(request: Request): """ 生成 Alpha 变种的 API 端点 接收前端表单数据,启动 Transformer 脚本执行 Alpha 生成任务 """ logger.info("=" * 50) logger.info("收到生成变种请求") try: # 解析请求数据 data = await request.json() logger.info(f"请求数据: alpha_id={data.get('alpha_id')}, llm_model={data.get('llm_model')}") # 生成唯一任务 ID task_id = str(uuid.uuid4()) logger.info(f"生成任务 ID: {task_id}") # 定义必须提交的字段 required_fields = [ "alpha_id", "llm_api_key", "llm_base_url", "llm_model", "brain_username", "brain_password" ] # 检查必填字段是否完整 for field in required_fields: if not data.get(field): logger.warning(f"缺少必填字段: {field}") return JSONResponse( status_code=400, content={"success": False, "error": f"Missing required field: {field}"} ) # 获取脚本所在目录和 Transformer 子目录 script_dir = os.path.dirname(os.path.abspath(__file__)) transformer_dir = os.path.join(script_dir, 'Tranformer') logger.info(f"Transformer 目录: {transformer_dir}") # 构建传递给 Transformer 脚本的配置 config = { "LLM_model_name": data.get('llm_model'), "LLM_API_KEY": data.get('llm_api_key'), "llm_base_url": data.get('llm_base_url'), "username": data.get('brain_username'), "password": data.get('brain_password'), "alpha_id": data.get('alpha_id'), "top_n_datafield": int(data.get('top_n_datafield', 50)), "user_region": data.get('user_region'), "user_universe": data.get('user_universe'), "user_delay": int(data.get('user_delay')) if data.get('user_delay') else None, "user_category": data.get('user_category'), "user_data_type": data.get('user_data_type', 'MATRIX'), "max_retries": int(data.get('max_retries', 20)) } logger.info(f"配置已构建: LLM_model={config['LLM_model_name']}, alpha_id={config['alpha_id']}, max_retries={config['max_retries']}") # 检查数据字段缓存是否存在,不存在则自动下载 dataset_dir = os.path.join(script_dir, 'dataset') os.makedirs(dataset_dir, exist_ok=True) user_category = data.get('user_category', []) categories_to_check = user_category if isinstance(user_category, list) else [user_category] if user_category else [] # 输出用户选择的类别 logger.info(f"[Generate] 用户选择的类别: {categories_to_check}") logger.info(f"[Generate] Region: {data.get('user_region')}, Delay: {data.get('user_delay')}, Universe: {data.get('user_universe')}") # 创建会话来获取数据集信息 session = SingleSession() session.auth = (data.get('brain_username'), data.get('brain_password')) brain_api_url = "https://api.worldquantbrain.com" auth_response = session.post(brain_api_url + "/authentication") if auth_response.status_code != 201: return JSONResponse( status_code=401, content={"success": False, "error": "Authentication failed"} ) # 获取所有可用的数据集 datasets_df = get_datasets( s=session, instrument_type="EQUITY", region=data.get('user_region', 'USA'), delay=int(data.get('user_delay', 1)), universe=data.get('user_universe', 'TOP3000'), theme="ALL" ) if datasets_df.empty: return JSONResponse( status_code=404, content={"success": False, "error": "No datasets found"} ) # 使用 category_name 列(category 列是字典类型) category_column = 'category_name' if 'category_name' in datasets_df.columns else 'category' # 如果没有指定类别,使用所有类别 if not categories_to_check: if category_column == 'category': # category 是字典,提取 name categories_to_check = datasets_df[category_column].apply(lambda x: x['name'] if isinstance(x, dict) and 'name' in x else x).unique().tolist() else: categories_to_check = datasets_df[category_column].unique().tolist() logger.info(f"未指定类别,将使用所有 {len(categories_to_check)} 个类别: {categories_to_check}") # 构建类别到 dataset_id 的映射 category_to_datasets = {} for cat in categories_to_check: if category_column == 'category': # category 是字典,需要匹配 name 字段 cat_datasets = datasets_df[datasets_df[category_column].apply(lambda x: x.get('name') if isinstance(x, dict) else x) == cat] else: cat_datasets = datasets_df[datasets_df[category_column] == cat] if not cat_datasets.empty: category_to_datasets[cat] = cat_datasets['id'].tolist() # 检查每个类别的缓存文件 missing_categories = [] for cat in categories_to_check: cache_filename = f"datafields_cache_{data.get('user_region', 'USA')}_{data.get('user_universe', 'TOP3000')}_D{data.get('user_delay', 1)}_{cat}.csv" cache_path = os.path.join(dataset_dir, cache_filename) if not os.path.exists(cache_path): missing_categories.append(cat) # 如果有缺失的缓存,检查是否都能对应到数据集 if missing_categories: # 检查是否有类别在数据集中不存在 invalid_categories = [cat for cat in missing_categories if cat not in category_to_datasets] if invalid_categories: error_msg = f"以下类别在数据集中不存在,无法下载: {', '.join(invalid_categories)}。请检查类别名称或先点击'下载数据字段缓存'按钮。" logger.error(error_msg) return JSONResponse( status_code=400, content={"success": False, "error": error_msg} ) logger.info(f"检测到缺失的数据字段缓存: {missing_categories},自动下载中...") downloaded_count = 0 failed_categories = [] for cat in missing_categories: dataset_ids = category_to_datasets[cat] all_datafields = [] try: # 下载该类别下的所有数据集 for dataset_id in dataset_ids: logger.info(f"下载数据集 {dataset_id} (类别: {cat})...") datafields_df = get_datafields( s=session, instrument_type="EQUITY", region=data.get('user_region', 'USA'), delay=int(data.get('user_delay', 1)), universe=data.get('user_universe', 'TOP3000'), dataset_id=dataset_id, data_type=data.get('user_data_type', 'MATRIX') ) if not datafields_df.empty: all_datafields.append(datafields_df) logger.info(f"数据集 {dataset_id} 包含 {len(datafields_df)} 个字段") if all_datafields: # 合并该类别下的所有数据字段 combined_df = pd.concat(all_datafields, ignore_index=True) combined_df.drop_duplicates(subset=['id'], inplace=True) cache_filename = f"datafields_cache_{data.get('user_region', 'USA')}_{data.get('user_universe', 'TOP3000')}_D{data.get('user_delay', 1)}_{cat}.csv" cache_path = os.path.join(dataset_dir, cache_filename) combined_df.to_csv(cache_path, index=False) downloaded_count += 1 logger.info(f"自动下载完成: {cache_filename} ({len(combined_df)} 个字段)") else: failed_categories.append(cat) logger.error(f"类别 '{cat}' 没有获取到任何数据字段") except Exception as e: failed_categories.append(cat) logger.error(f"下载类别 '{cat}' 失败: {e}") # 如果有下载失败的类别,返回错误 if failed_categories: error_msg = f"以下类别下载失败: {', '.join(failed_categories)}。请检查网络连接或手动下载。" logger.error(error_msg) return JSONResponse( status_code=500, content={"success": False, "error": error_msg} ) logger.info(f"自动下载完成,共下载 {downloaded_count}/{len(missing_categories)} 个类别") # 直接调用 Transformer 模块(替代 subprocess) logger.info(f"启动 Transformer 执行...") try: transformer_result = await run_transformer(config) logger.info(f"Transformer 执行完成: {transformer_result}") # 定义输出文件路径 output_file = os.path.join(transformer_dir, 'output', 'Alpha_generated_expressions_success.json') candidates_file = os.path.join(transformer_dir, 'output', 'Alpha_candidates.json') error_file = os.path.join(transformer_dir, 'output', 'Alpha_generated_expressions_error.json') # 构建响应数据 result = { "success": True, "alpha_id": data.get('alpha_id'), "stdout": "", "stderr": "", "return_code": 0 } # 读取成功生成的表达式 if os.path.exists(output_file): logger.info(f"读取成功表达式文件: {output_file}") with open(output_file, 'r', encoding='utf-8') as f: result['expressions_success'] = json.load(f) else: logger.info(f"成功表达式文件不存在: {output_file}") result['expressions_success'] = [] # 读取候选表达式 if os.path.exists(candidates_file): logger.info(f"读取候选表达式文件: {candidates_file}") with open(candidates_file, 'r', encoding='utf-8') as f: result['candidates'] = json.load(f) else: logger.info(f"候选表达式文件不存在: {candidates_file}") result['candidates'] = [] # 读取生成失败的表达式 if os.path.exists(error_file): logger.info(f"读取错误表达式文件: {error_file}") with open(error_file, 'r', encoding='utf-8') as f: result['expressions_error'] = json.load(f) else: logger.info(f"错误表达式文件不存在: {error_file}") result['expressions_error'] = [] logger.info(f"成功: {len(result['expressions_success'])} 个, 候选: {len(result['candidates'])} 个, 错误: {len(result['expressions_error'])} 个") logger.info("=" * 50) return JSONResponse(content=result) except Exception as transformer_error: logger.error(f"Transformer 执行失败: {transformer_error}") raise except Exception as e: logger.warning(f"执行异常: {str(e)}") return JSONResponse( status_code=500, content={"success": False, "error": str(e)} ) @app.post("/api/transformer/login-and-fetch-options") async def login_and_fetch_options(request: Request): """ 登录 BRAIN 并获取地区、Delay、Universe、类别等选项 用于填充高级选项表单 """ try: data = await request.json() username = data.get('username') password = data.get('password') if not username or not password: return JSONResponse( status_code=400, content={'success': False, 'error': 'Username and password are required'} ) # 添加 Transformer 目录到 sys.path script_dir = os.path.dirname(os.path.abspath(__file__)) transformer_dir = os.path.join(script_dir, 'Tranformer') if transformer_dir not in sys.path: sys.path.append(transformer_dir) # 创建新的会话实例 session = SingleSession() session.auth = (username, password) brain_api_url = "https://api.worldquantbrain.com" response = session.post(brain_api_url + "/authentication") if response.status_code == 201: # 认证成功 pass elif response.status_code == 401: return JSONResponse( status_code=401, content={'success': False, 'error': 'Authentication failed: Invalid credentials'} ) else: return JSONResponse( status_code=400, content={'success': False, 'error': f'Authentication failed: {response.status_code}'} ) # 获取 region/delay/universe 选项 df = get_instrument_type_region_delay(session) # 获取数据类别 categories_resp = session.get(brain_api_url + "/data-categories") categories = [] if categories_resp.status_code == 200: categories_data = categories_resp.json() if isinstance(categories_data, list): categories = categories_data elif isinstance(categories_data, dict): categories = categories_data.get('results', []) # 转换 DataFrame 为前端需要的嵌套字典结构 # 结构: Region -> Delay -> Universe df_equity = df[df['InstrumentType'] == 'EQUITY'] options = {} for _, row in df_equity.iterrows(): region = row['Region'] delay = row['Delay'] universes = row['Universe'] # 这是一个列表 if region not in options: options[region] = {} # 将 delay 转换为字符串作为字典的键 delay_str = str(delay) if delay_str not in options[region]: options[region][delay_str] = universes return JSONResponse(content={ 'success': True, 'options': options, 'categories': categories }) except Exception as e: logger.warning(f"登录获取选项失败: {str(e)}") return JSONResponse( status_code=500, content={'success': False, 'error': str(e)} ) @app.get("/api/health") async def health_check(): """健康检查端点,用于验证服务是否正常运行""" return {"status": "healthy", "service": "alpha-transformer"} @app.post("/api/test-llm") async def test_llm_connection(request: Request): """ 测试 LLM 连接 接收 LLM 配置,尝试连接并返回测试结果 对 529 状态码会自动重试,默认重试 3 次 """ try: data = await request.json() api_key = data.get('llm_api_key') base_url = data.get('llm_base_url') model = data.get('llm_model') max_retries = data.get('max_retries', 3) if not api_key or not base_url or not model: return JSONResponse( status_code=400, content={"success": False, "error": "Missing required LLM configuration"} ) logger.info(f"测试 LLM 连接: {base_url}, 模型: {model}, 最大重试次数: {max_retries}") # 创建客户端 client = openai.AsyncOpenAI( api_key=api_key, base_url=base_url ) # 重试机制 last_error = None for attempt in range(1, max_retries + 1): try: response = await client.chat.completions.create( model=model, messages=[{"role": "user", "content": "Hello, this is a connection test. Reply with 'OK' only."}], max_tokens=10, timeout=30 ) # 检查响应 if response and response.choices and len(response.choices) > 0: content = response.choices[0].message.content logger.info(f"✓ LLM 连接测试成功: {content[:50]}...") return JSONResponse(content={ "success": True, "message": "LLM 连接成功", "response": content[:100] }) else: return JSONResponse( status_code=500, content={"success": False, "error": "Empty response from LLM"} ) except Exception as e: error_msg = str(e) last_error = e # 检查是否是 529 错误(MiniMax 过载错误) if "529" in error_msg or "overloaded" in error_msg.lower(): if attempt < max_retries: logger.warning(f"⚠ LLM 返回 529 错误,第 {attempt}/{max_retries} 次尝试失败,2 秒后重试...") await asyncio.sleep(2) continue else: logger.error(f"✗ LLM 连接测试失败,已重试 {max_retries} 次,仍然返回 529 错误") return JSONResponse( status_code=529, content={"success": False, "error": f"LLM service overloaded (529) after {max_retries} retries"} ) else: # 其他错误直接抛出 raise except Exception as e: error_msg = str(e) logger.warning(f"✗ LLM 连接测试失败: {error_msg}") return JSONResponse( status_code=500, content={"success": False, "error": error_msg} ) @app.post("/api/download-datafields") async def download_datafields(request: Request): """ 下载数据字段缓存 全量获取指定 region/delay/universe/category 的数据字段并保存到 CSV """ try: data = await request.json() username = data.get('username') password = data.get('password') region = data.get('region') delay = data.get('delay') universe = data.get('universe') data_type = data.get('data_type', 'MATRIX') category = data.get('category', []) # 类别列表 if not all([username, password, region, delay, universe]): return JSONResponse( status_code=400, content={"success": False, "error": "Missing required parameters"} ) # 创建会话 session = SingleSession() session.auth = (username, password) brain_api_url = "https://api.worldquantbrain.com" response = session.post(brain_api_url + "/authentication") if response.status_code != 201: return JSONResponse( status_code=401, content={"success": False, "error": "Authentication failed"} ) # 获取所有数据集信息 datasets_df = get_datasets( s=session, instrument_type="EQUITY", region=region, delay=delay, universe=universe, theme="ALL" ) if datasets_df.empty: return JSONResponse( status_code=404, content={"success": False, "error": "No datasets found"} ) # 处理类别参数(前端总是传递类别列表) categories_to_download = category if isinstance(category, list) else [category] if category else [] # 如果没有指定类别,使用所有类别 if not categories_to_download: if 'category_id' in datasets_df.columns: categories_to_download = datasets_df['category_id'].unique().tolist() elif 'category_name' in datasets_df.columns: categories_to_download = datasets_df['category_name'].unique().tolist() else: # category 是字典,提取 id categories_to_download = datasets_df['category'].apply(lambda x: x['id'] if isinstance(x, dict) and 'id' in x else x).unique().tolist() logger.info(f"未指定类别,将使用所有 {len(categories_to_download)} 个类别") # 构建类别到 dataset_id 的映射(使用 category_id 匹配) category_to_datasets = {} for cat in categories_to_download: if 'category_id' in datasets_df.columns: cat_datasets = datasets_df[datasets_df['category_id'] == cat] elif 'category_name' in datasets_df.columns: cat_datasets = datasets_df[datasets_df['category_name'] == cat] else: cat_datasets = datasets_df[datasets_df['category'].apply(lambda x: x.get('id') if isinstance(x, dict) else x) == cat] if not cat_datasets.empty: category_to_datasets[cat] = cat_datasets['id'].tolist() # 保存到 CSV (dataset 文件夹) dataset_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'dataset') os.makedirs(dataset_dir, exist_ok=True) # 检查哪些类别已经存在缓存文件 categories_to_download_filtered = [] already_exists = [] for cat in categories_to_download: cache_filename = f"datafields_cache_{region}_{universe}_D{delay}_{cat}.csv" cache_path = os.path.join(dataset_dir, cache_filename) if os.path.exists(cache_path): already_exists.append(cat) logger.info(f"类别 '{cat}' 的缓存文件已存在,跳过下载: {cache_filename}") else: categories_to_download_filtered.append(cat) if already_exists: logger.info(f"以下类别缓存已存在,将跳过: {already_exists}") if not categories_to_download_filtered: logger.info("所有类别的缓存文件都已存在,无需下载") return JSONResponse(content={ "success": True, "count": 0, "message": "所有类别的缓存文件都已存在", "skipped_categories": already_exists }) categories_to_download = categories_to_download_filtered logger.info(f"开始下载数据字段: region={region}, delay={delay}, universe={universe}, type={data_type}, categories={categories_to_download}") total_count = 0 saved_files = [] # 每个类别单独下载并保存 for cat in categories_to_download: if cat not in category_to_datasets: logger.warning(f"类别 '{cat}' 在数据集中不存在,跳过") continue dataset_ids = category_to_datasets[cat] all_datafields = [] # 下载该类别下的所有数据集 for dataset_id in dataset_ids: logger.info(f"下载数据集 {dataset_id} (类别: {cat})...") datafields_df = get_datafields( s=session, instrument_type="EQUITY", region=region, delay=delay, universe=universe, dataset_id=dataset_id, data_type=data_type ) if not datafields_df.empty: all_datafields.append(datafields_df) logger.info(f"数据集 {dataset_id} 包含 {len(datafields_df)} 个字段") if all_datafields: # 合并该类别下的所有数据字段 combined_df = pd.concat(all_datafields, ignore_index=True) combined_df.drop_duplicates(subset=['id'], inplace=True) cache_filename = f"datafields_cache_{region}_{universe}_D{delay}_{cat}.csv" cache_path = os.path.join(dataset_dir, cache_filename) combined_df.to_csv(cache_path, index=False) total_count += len(combined_df) saved_files.append(cache_filename) logger.info(f"类别 '{cat}' 数据字段已保存: {cache_path}, 共 {len(combined_df)} 个字段") else: logger.warning(f"类别 '{cat}' 未找到数据字段") if total_count == 0: return JSONResponse( status_code=404, content={"success": False, "error": "No datafields found for any category"} ) return JSONResponse(content={ "success": True, "count": total_count, "cache_files": saved_files }) except Exception as e: logger.error(f"下载数据字段失败: {e}") return JSONResponse( status_code=500, content={"success": False, "error": str(e)} ) @app.get("/api/download/{alpha_id}") async def download_results(alpha_id: str): """ 下载生成结果的 zip 压缩包 包含三个 JSON 文件:success, candidates, error """ try: transformer_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'Tranformer') output_dir = os.path.join(transformer_dir, 'output') # 检查文件是否存在 files_to_zip = { 'Alpha_generated_expressions_success.json': os.path.join(output_dir, 'Alpha_generated_expressions_success.json'), 'Alpha_candidates.json': os.path.join(output_dir, 'Alpha_candidates.json'), 'Alpha_generated_expressions_error.json': os.path.join(output_dir, 'Alpha_generated_expressions_error.json') } # 生成时间戳 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") zip_filename = f"{alpha_id}_{timestamp}.zip" # 保存到 save_zip 文件夹(不会被清空) save_zip_dir = os.path.join(transformer_dir, 'save_zip') os.makedirs(save_zip_dir, exist_ok=True) zip_path = os.path.join(save_zip_dir, zip_filename) # 创建 zip 文件 with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: for arcname, filepath in files_to_zip.items(): if os.path.exists(filepath): zipf.write(filepath, arcname) # 返回文件 return FileResponse( zip_path, media_type='application/zip', filename=zip_filename ) except Exception as e: return JSONResponse( status_code=500, content={"success": False, "error": str(e)} ) if __name__ == "__main__": # 启动 FastAPI 应用,使用 uvicorn 作为 ASGI 服务器 import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)