|
|
import logging
|
|
|
import io
|
|
|
import getpass
|
|
|
import itertools
|
|
|
import threading
|
|
|
import time
|
|
|
from urllib.parse import urljoin
|
|
|
from pathlib import Path
|
|
|
import pandas as pd
|
|
|
import requests
|
|
|
import json
|
|
|
import sys
|
|
|
import asyncio
|
|
|
import os
|
|
|
import openai
|
|
|
import re
|
|
|
from typing import Optional, Union # Added this import
|
|
|
try:
|
|
|
from .validator_hooks import is_valid_template_expr, has_empty_datafield_candidates
|
|
|
except Exception:
|
|
|
# Fallback for direct script execution
|
|
|
try:
|
|
|
from validator_hooks import is_valid_template_expr, has_empty_datafield_candidates
|
|
|
except Exception:
|
|
|
is_valid_template_expr = None
|
|
|
has_empty_datafield_candidates = None
|
|
|
|
|
|
# --- Validation wrappers to integrate into the pipeline ---
|
|
|
|
|
|
|
|
|
def _filter_valid_templates(
|
|
|
proposed_templates: dict,
|
|
|
operators_meta,
|
|
|
brain_session,
|
|
|
settings: dict,
|
|
|
parse_alpha_code_func,
|
|
|
):
|
|
|
"""Return dict of only templates that pass validation.
|
|
|
|
|
|
Safe no-op if validation helpers are unavailable.
|
|
|
"""
|
|
|
if not is_valid_template_expr or not parse_alpha_code_func:
|
|
|
return proposed_templates
|
|
|
filtered = {}
|
|
|
for template_expr, template_expl in proposed_templates.items():
|
|
|
try:
|
|
|
if is_valid_template_expr(
|
|
|
template_expr,
|
|
|
operators_meta,
|
|
|
brain_session,
|
|
|
settings,
|
|
|
parse_alpha_code_func,
|
|
|
):
|
|
|
filtered[template_expr] = template_expl
|
|
|
except Exception:
|
|
|
# Be conservative: drop on exceptions
|
|
|
continue
|
|
|
return filtered
|
|
|
|
|
|
|
|
|
def _should_skip_due_to_empty_candidates(populated_info: dict) -> bool:
|
|
|
"""True if any data_field placeholder has zero candidates.
|
|
|
|
|
|
Safe no-op fallback when helper is missing.
|
|
|
"""
|
|
|
if not has_empty_datafield_candidates:
|
|
|
return False
|
|
|
try:
|
|
|
return has_empty_datafield_candidates(populated_info)
|
|
|
except Exception:
|
|
|
return False
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
if not logger.handlers:
|
|
|
handler = logging.StreamHandler(sys.stdout)
|
|
|
handler.setFormatter(logging.Formatter(
|
|
|
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
|
|
logger.addHandler(handler)
|
|
|
logger.setLevel(logging.INFO)
|
|
|
|
|
|
try:
|
|
|
from . import validator as val
|
|
|
from .ace_lib import get_instrument_type_region_delay
|
|
|
except ImportError:
|
|
|
import validator as val
|
|
|
from ace_lib import get_instrument_type_region_delay
|
|
|
# Force stdout/stderr to use utf-8 on Windows to avoid UnicodeEncodeError
|
|
|
if sys.platform.startswith('win'):
|
|
|
try:
|
|
|
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
|
|
|
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8')
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
# 这些变量将在交互式输入中设置
|
|
|
LLM_model_name = None
|
|
|
LLM_API_KEY = None
|
|
|
llm_base_url = None
|
|
|
username = None
|
|
|
password = None
|
|
|
DATA_CATEGORIES = None
|
|
|
|
|
|
|
|
|
# 加载模板总结文件
|
|
|
template_summary_path = os.path.join(
|
|
|
os.path.dirname(__file__), "template_summary.md")
|
|
|
try:
|
|
|
with open(template_summary_path, "r", encoding="utf-8") as f:
|
|
|
template_summary = f.read()
|
|
|
logger.info(f"✓ 已加载模板总结文件: {template_summary_path}")
|
|
|
except FileNotFoundError:
|
|
|
logger.warning(f"⚠ 模板总结文件不存在: {template_summary_path},使用内置模板")
|
|
|
template_summary = """# BRAIN论坛Alpha模板精华总结
|
|
|
|
|
|
请创建 template_summary.md 文件"""
|
|
|
except Exception as e:
|
|
|
logger.error(f"⚠ 加载模板总结文件失败: {e},使用内置模板")
|
|
|
template_summary = """# BRAIN论坛Alpha模板精华总结
|
|
|
|
|
|
请检查 template_summary.md 文件"""
|
|
|
|
|
|
|
|
|
class SingleSession(requests.Session):
|
|
|
_instance = None
|
|
|
_lock = threading.Lock()
|
|
|
_relogin_lock = threading.Lock()
|
|
|
_initialized = False
|
|
|
|
|
|
def __new__(cls, *args, **kwargs):
|
|
|
if cls._instance is None:
|
|
|
with cls._lock:
|
|
|
if cls._instance is None:
|
|
|
cls._instance = super().__new__(cls)
|
|
|
return cls._instance
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
if not self._initialized:
|
|
|
super(SingleSession, self).__init__(*args, **kwargs)
|
|
|
self._initialized = True
|
|
|
|
|
|
def get_relogin_lock(self):
|
|
|
return self._relogin_lock
|
|
|
|
|
|
|
|
|
def load_template_summary(file_path: Optional[str] = None) -> str:
|
|
|
"""
|
|
|
Loads the template summary from a file or returns the built-in template summary.
|
|
|
|
|
|
Args:
|
|
|
file_path: Optional path to a .txt or .md file containing the template summary.
|
|
|
If None or file doesn't exist, returns the built-in template summary.
|
|
|
|
|
|
Returns:
|
|
|
str: The template summary content.
|
|
|
"""
|
|
|
if file_path:
|
|
|
try:
|
|
|
file_path_obj = Path(file_path)
|
|
|
if file_path_obj.exists() and file_path_obj.is_file():
|
|
|
with open(file_path_obj, 'r', encoding='utf-8') as f:
|
|
|
content = f.read()
|
|
|
logger.info(f"✓ 成功从文件加载模板总结: {file_path}")
|
|
|
return content
|
|
|
else:
|
|
|
logger.warning(f"⚠ 警告: 文件不存在: {file_path},将使用内置模板总结")
|
|
|
except Exception as e:
|
|
|
logger.warning(f"⚠ 警告: 读取文件时出错: {e},将使用内置模板总结")
|
|
|
|
|
|
# 返回内置的模板总结
|
|
|
logger.info("✓ 使用内置模板总结")
|
|
|
return template_summary
|
|
|
|
|
|
|
|
|
def get_credentials() -> tuple[str, str]:
|
|
|
"""
|
|
|
Retrieve or prompt for platform credentials.
|
|
|
|
|
|
This function attempts to read credentials from a JSON file in the user's home directory.
|
|
|
If the file doesn't exist or is empty, it prompts the user to enter credentials and saves them.
|
|
|
|
|
|
Returns:
|
|
|
tuple: A tuple containing the email and password.
|
|
|
|
|
|
Raises:
|
|
|
json.JSONDecodeError: If the credentials file exists but contains invalid JSON.
|
|
|
"""
|
|
|
# 声明使用全局变量
|
|
|
global username, password
|
|
|
# please input your own BRAIN Credentials into the function
|
|
|
return (username, password)
|
|
|
|
|
|
|
|
|
def get_token_from_auth_server() -> str:
|
|
|
# 声明使用全局变量
|
|
|
global LLM_API_KEY
|
|
|
# please input your own LLM Gateway token into the function, please note, we are using kimi-k2.5 model
|
|
|
return LLM_API_KEY
|
|
|
|
|
|
|
|
|
def interactive_input() -> dict:
|
|
|
"""
|
|
|
交互式输入函数,收集所有必要的配置信息。
|
|
|
|
|
|
Returns:
|
|
|
dict: 包含所有配置信息的字典
|
|
|
"""
|
|
|
logger.info("\n" + "="*60)
|
|
|
logger.info("欢迎使用 Alpha Transformer 交互式配置")
|
|
|
logger.info("此程序在于让您输入一个Alpha ID即可通过历史总结的Alpha模板,转化成更多的表达式")
|
|
|
logger.info("72变,助您腾云驾雾")
|
|
|
logger.info("如果你想修改模型,则可以使用新模型的url和api key")
|
|
|
logger.info("不同模型效果不同,默认的kimi可能会产生语法错误,请检查生成的模板文件进行甄别")
|
|
|
logger.info("强烈推荐你使用自己总结的模板文档,效果会更好")
|
|
|
logger.info("="*60 + "\n")
|
|
|
|
|
|
config = {}
|
|
|
|
|
|
# 1. 询问 LLM 模型名称
|
|
|
logger.info("【1/6】LLM 模型配置")
|
|
|
logger.info("如果你想修改模型,则可以使用新模型的名称")
|
|
|
default_model = "kimi-k2.5"
|
|
|
model_input = input(f"请输入 LLM 模型名称 (直接回车使用默认值: {default_model}): ").strip()
|
|
|
config['LLM_model_name'] = model_input if model_input else default_model
|
|
|
logger.info(f"✓ LLM 模型名称: {config['LLM_model_name']}\n")
|
|
|
|
|
|
# 2. 询问 LLM API Key
|
|
|
logger.info("【2/6】LLM API Key 配置")
|
|
|
api_key = getpass.getpass("请输入 LLM API Key (输入时不会显示): ").strip()
|
|
|
if not api_key:
|
|
|
logger.warning("⚠ 警告: API Key 为空,程序可能无法正常工作")
|
|
|
config['LLM_API_KEY'] = api_key
|
|
|
logger.info("✓ API Key 已设置\n")
|
|
|
|
|
|
# 3. 询问 LLM Base URL
|
|
|
logger.info("【3/6】LLM Base URL 配置")
|
|
|
logger.info("提示:不同模型有不同的URL")
|
|
|
default_url = "https://api.moonshot.cn/v1"
|
|
|
url_input = input(f"请输入 LLM Base URL (直接回车使用默认值: {default_url}): ").strip()
|
|
|
config['llm_base_url'] = url_input if url_input else default_url
|
|
|
logger.info(f"✓ LLM Base URL: {config['llm_base_url']}\n")
|
|
|
|
|
|
# 4. 询问 BRAIN 平台用户名
|
|
|
logger.info("【4/6】BRAIN 平台认证信息")
|
|
|
username_input = input("请输入 BRAIN 平台用户名/邮箱: ").strip()
|
|
|
if not username_input:
|
|
|
logger.warning("⚠ 警告: 用户名为空,程序可能无法正常工作")
|
|
|
config['username'] = username_input
|
|
|
logger.info("✓ 用户名已设置\n")
|
|
|
|
|
|
# 5. 询问 BRAIN 平台密码
|
|
|
password_input = getpass.getpass("请输入 BRAIN 平台密码 (输入时不会显示): ").strip()
|
|
|
if not password_input:
|
|
|
logger.warning("⚠ 警告: 密码为空,程序可能无法正常工作")
|
|
|
config['password'] = password_input
|
|
|
logger.info("✓ 密码已设置\n")
|
|
|
|
|
|
# 6. 询问模板总结文件路径
|
|
|
logger.info("【5/6】模板总结文件配置")
|
|
|
logger.info("强烈推荐你使用自己总结的模板文档,效果会更好")
|
|
|
logger.info("提示: 如果您有 template_summary 的 .txt 或 .md 文件,请输入完整路径")
|
|
|
logger.info(" 如果没有,直接回车将使用内置模板总结")
|
|
|
template_path = input("请输入模板总结文件路径 (直接回车使用内置模板): ").strip()
|
|
|
config['template_summary_path'] = template_path if template_path else None
|
|
|
if template_path:
|
|
|
logger.info(f"✓ 将尝试从文件加载: {template_path}\n")
|
|
|
else:
|
|
|
logger.info("✓ 将使用内置模板总结\n")
|
|
|
|
|
|
# 7. 询问 Alpha ID
|
|
|
logger.info("【6/7】Alpha ID 配置")
|
|
|
alpha_id = input("请输入要处理的 Alpha ID: ").strip()
|
|
|
if not alpha_id:
|
|
|
logger.error("❌ 错误: Alpha ID 不能为空")
|
|
|
sys.exit(1)
|
|
|
config['alpha_id'] = alpha_id
|
|
|
logger.info(f"✓ Alpha ID: {alpha_id}\n")
|
|
|
|
|
|
# 8. 询问 Top N 参数(仅数据字段)
|
|
|
logger.info("【7/7】候选数量配置 (Top N)")
|
|
|
logger.info("提示: 此参数控制为每个占位符生成的数据字段候选数量")
|
|
|
|
|
|
# Datafield top_n
|
|
|
default_datafield_topn = 50
|
|
|
datafield_topn_input = input(
|
|
|
f"请输入数据字段候选数量 (直接回车使用默认值: {default_datafield_topn}): ").strip()
|
|
|
try:
|
|
|
config['top_n_datafield'] = int(
|
|
|
datafield_topn_input) if datafield_topn_input else default_datafield_topn
|
|
|
except ValueError:
|
|
|
logger.warning(f"⚠ 警告: 输入无效,使用默认值: {default_datafield_topn}")
|
|
|
config['top_n_datafield'] = default_datafield_topn
|
|
|
logger.info(f"✓ 数据字段候选数量: {config['top_n_datafield']}\n")
|
|
|
|
|
|
logger.info("="*60)
|
|
|
logger.info("配置完成!开始处理...")
|
|
|
logger.info("="*60 + "\n")
|
|
|
|
|
|
return config
|
|
|
|
|
|
|
|
|
def expand_dict_columns(data: pd.DataFrame) -> pd.DataFrame:
|
|
|
"""
|
|
|
Expand dictionary columns in a DataFrame into separate columns.
|
|
|
|
|
|
Args:
|
|
|
data (pandas.DataFrame): The input DataFrame with dictionary columns.
|
|
|
|
|
|
Returns:
|
|
|
pandas.DataFrame: A new DataFrame with expanded columns.
|
|
|
"""
|
|
|
dict_columns = list(filter(lambda x: isinstance(
|
|
|
data[x].iloc[0], dict), data.columns))
|
|
|
new_columns = pd.concat(
|
|
|
[data[col].apply(pd.Series).rename(
|
|
|
columns=lambda x: f"{col}_{x}") for col in dict_columns],
|
|
|
axis=1,
|
|
|
)
|
|
|
|
|
|
data = pd.concat([data, new_columns], axis=1)
|
|
|
return data
|
|
|
|
|
|
|
|
|
def start_session() -> SingleSession:
|
|
|
"""
|
|
|
Start a new session with the WorldQuant BRAIN platform.
|
|
|
|
|
|
This function authenticates the user, handles biometric authentication if required,
|
|
|
and creates a new session.
|
|
|
|
|
|
Returns:
|
|
|
SingleSession: An authenticated session object.
|
|
|
|
|
|
Raises:
|
|
|
requests.exceptions.RequestException: If there's an error during the authentication process.
|
|
|
"""
|
|
|
brain_api_url = "https://api.worldquantbrain.com"
|
|
|
s = SingleSession()
|
|
|
s.auth = get_credentials()
|
|
|
r = s.post(brain_api_url + "/authentication")
|
|
|
logger.info(
|
|
|
f"New session created (ID: {id(s)}) with authentication response: {r.status_code}, {r.json()} (新会话已创建)")
|
|
|
if r.status_code == requests.status_codes.codes.unauthorized:
|
|
|
if r.headers["WWW-Authenticate"] == "persona":
|
|
|
logger.info(
|
|
|
"Complete biometrics authentication and press any key to continue (请完成生物识别认证并按任意键继续): \n" + urljoin(
|
|
|
r.url, r.headers["Location"]) + "\n"
|
|
|
)
|
|
|
input()
|
|
|
s.post(urljoin(r.url, r.headers["Location"]))
|
|
|
while True:
|
|
|
if s.post(urljoin(r.url, r.headers["Location"])).status_code != 201:
|
|
|
input(
|
|
|
"Biometrics authentication is not complete. Please try again and press any key when completed (生物识别认证未完成,请重试并按任意键): \n"
|
|
|
)
|
|
|
else:
|
|
|
break
|
|
|
else:
|
|
|
logger.info("\nIncorrect email or password (邮箱或密码错误)\n")
|
|
|
return start_session()
|
|
|
return s
|
|
|
|
|
|
|
|
|
def get_data_categories(s: SingleSession) -> list[dict]:
|
|
|
"""
|
|
|
Fetch and cache data categories from the BRAIN API.
|
|
|
"""
|
|
|
global DATA_CATEGORIES
|
|
|
if DATA_CATEGORIES is not None:
|
|
|
return DATA_CATEGORIES
|
|
|
|
|
|
try:
|
|
|
brain_api_url = "https://api.worldquantbrain.com"
|
|
|
response = s.get(brain_api_url + "/data-categories")
|
|
|
response.raise_for_status()
|
|
|
data = response.json()
|
|
|
if isinstance(data, list):
|
|
|
DATA_CATEGORIES = data
|
|
|
elif isinstance(data, dict):
|
|
|
DATA_CATEGORIES = data.get('results', [])
|
|
|
else:
|
|
|
DATA_CATEGORIES = []
|
|
|
return DATA_CATEGORIES
|
|
|
except Exception as e:
|
|
|
logger.info(f"Error fetching data categories: {e}")
|
|
|
return []
|
|
|
|
|
|
|
|
|
def get_datafields(
|
|
|
s: SingleSession,
|
|
|
instrument_type: str = "EQUITY",
|
|
|
region: str = "USA",
|
|
|
delay: int = 1,
|
|
|
universe: str = "TOP3000",
|
|
|
theme: str = "false",
|
|
|
dataset_id: str = "",
|
|
|
data_type: str = "MATRIX",
|
|
|
search: str = "",
|
|
|
category: Union[str, list] = "",
|
|
|
) -> pd.DataFrame:
|
|
|
"""
|
|
|
Retrieve available datafields based on specified parameters.
|
|
|
|
|
|
Args:
|
|
|
s (SingleSession): An authenticated session object.
|
|
|
instrument_type (str, optional): The type of instrument. Defaults to "EQUITY".
|
|
|
region (str, optional): The region. Defaults to "USA".
|
|
|
delay (int, optional): The delay. Defaults to 1.
|
|
|
universe (str, optional): The universe. Defaults to "TOP3000".
|
|
|
theme (str, optional): The theme. Defaults to "false".
|
|
|
dataset_id (str, optional): The ID of a specific dataset. Defaults to "".
|
|
|
data_type (str, optional): The type of data. Defaults to "MATRIX".
|
|
|
search (str, optional): A search string to filter datafields. Defaults to "".
|
|
|
category (str or list, optional): A category ID or list of IDs to filter datafields. Defaults to "".
|
|
|
|
|
|
Returns:
|
|
|
pandas.DataFrame: A DataFrame containing information about available datafields.
|
|
|
"""
|
|
|
brain_api_url = "https://api.worldquantbrain.com"
|
|
|
type_param = f"&type={data_type}" if data_type != "ALL" else ""
|
|
|
|
|
|
url_template = (
|
|
|
brain_api_url
|
|
|
+ "/data-fields?"
|
|
|
+ f"&instrumentType={instrument_type}"
|
|
|
+ f"®ion={region}&delay={str(delay)}&universe={universe}{type_param}&limit=50"
|
|
|
)
|
|
|
|
|
|
if dataset_id:
|
|
|
url_template += f"&dataset.id={dataset_id}"
|
|
|
|
|
|
if len(search) > 0:
|
|
|
url_template += f"&search={search}"
|
|
|
|
|
|
url_template += "&offset={x}"
|
|
|
|
|
|
count = 0
|
|
|
if len(search) == 0:
|
|
|
try:
|
|
|
count = s.get(url_template.format(x=0)).json()["count"]
|
|
|
except Exception as e:
|
|
|
logger.info(f"Error getting count: {e}")
|
|
|
return pd.DataFrame()
|
|
|
|
|
|
if count == 0:
|
|
|
logger.info(
|
|
|
f"No fields found (未找到字段): region={region}, delay={str(delay)}, universe={universe}, "
|
|
|
f"type={data_type}, dataset.id={dataset_id}"
|
|
|
)
|
|
|
return pd.DataFrame()
|
|
|
else:
|
|
|
if category:
|
|
|
count = 500 # Search deeper if filtering
|
|
|
else:
|
|
|
count = 100
|
|
|
|
|
|
max_try = 5
|
|
|
datafields_list = []
|
|
|
found_count = 0
|
|
|
target_found = 50 if category else count
|
|
|
time.sleep(2)
|
|
|
for x in range(0, count, 50):
|
|
|
for _ in range(max_try):
|
|
|
try:
|
|
|
resp = s.get(url_template.format(x=x))
|
|
|
while resp.status_code == 429:
|
|
|
logger.info("status_code 429, sleep 3 seconds")
|
|
|
time.sleep(3)
|
|
|
resp = s.get(url_template.format(x=x))
|
|
|
if resp.status_code == 200 and "results" in resp.json():
|
|
|
datafields = resp
|
|
|
break
|
|
|
except:
|
|
|
pass
|
|
|
time.sleep(5)
|
|
|
else:
|
|
|
continue
|
|
|
|
|
|
results = datafields.json().get("results", [])
|
|
|
if not results:
|
|
|
break
|
|
|
|
|
|
if category:
|
|
|
if isinstance(category, list):
|
|
|
filtered_results = [
|
|
|
item for item in results
|
|
|
if isinstance(item.get('category'), dict) and item['category'].get('id') in category
|
|
|
]
|
|
|
else:
|
|
|
filtered_results = [
|
|
|
item for item in results
|
|
|
if isinstance(item.get('category'), dict) and item['category'].get('id') == category
|
|
|
]
|
|
|
datafields_list.append(filtered_results)
|
|
|
found_count += len(filtered_results)
|
|
|
if len(search) > 0 and found_count >= target_found:
|
|
|
break
|
|
|
else:
|
|
|
datafields_list.append(results)
|
|
|
|
|
|
datafields_list_flat = [
|
|
|
item for sublist in datafields_list for item in sublist]
|
|
|
|
|
|
if not datafields_list_flat:
|
|
|
return pd.DataFrame()
|
|
|
|
|
|
datafields_df = pd.DataFrame(datafields_list_flat)
|
|
|
datafields_df = expand_dict_columns(datafields_df)
|
|
|
return datafields_df
|
|
|
|
|
|
|
|
|
def set_alpha_properties(
|
|
|
s: SingleSession,
|
|
|
alpha_id: str,
|
|
|
name: Optional[str] = None,
|
|
|
color: Optional[str] = None,
|
|
|
regular_desc: Optional[str] = None,
|
|
|
selection_desc: str = "None",
|
|
|
combo_desc: str = "None",
|
|
|
tags: Optional[list[str]] = None,
|
|
|
) -> requests.Response:
|
|
|
"""
|
|
|
Update the properties of an alpha.
|
|
|
|
|
|
Args:
|
|
|
s (SingleSession): An authenticated session object.
|
|
|
alpha_id (str): The ID of the alpha to update.
|
|
|
name (str, optional): The new name for the alpha. Defaults to None.
|
|
|
color (str, optional): The new color for the alpha. Defaults to None.
|
|
|
regular_desc (str, optional): Description for regular alpha. Defaults to None.
|
|
|
selection_desc (str, optional): Description for the selection part of a super alpha. Defaults to "None".
|
|
|
combo_desc (str, optional): Description for the combo part of a super alpha. Defaults to "None".
|
|
|
tags (list, optional): List of tags to apply to the alpha. Defaults to None.
|
|
|
|
|
|
Returns:
|
|
|
requests.Response: The response object from the API call.
|
|
|
"""
|
|
|
brain_api_url = "https://api.worldquantbrain.com"
|
|
|
params = {}
|
|
|
if name is not None:
|
|
|
params["name"] = name
|
|
|
if color is not None:
|
|
|
params["color"] = color
|
|
|
if tags is not None:
|
|
|
params["tags"] = tags
|
|
|
if regular_desc is not None:
|
|
|
params.setdefault("regular", {})["description"] = regular_desc
|
|
|
if selection_desc != "None": # Assuming "None" is the default string value for selection_desc
|
|
|
params.setdefault("selection", {})["description"] = selection_desc
|
|
|
if combo_desc != "None": # Assuming "None" is the default string value for combo_desc
|
|
|
params.setdefault("combo", {})["description"] = combo_desc
|
|
|
|
|
|
response = s.patch(brain_api_url + "/alphas/" + alpha_id, json=params)
|
|
|
|
|
|
return response
|
|
|
|
|
|
|
|
|
def extract_placeholders(template_expression: str) -> list[str]:
|
|
|
"""
|
|
|
Extracts placeholders from a template expression using regular expressions.
|
|
|
Placeholders are identified by text enclosed in angle brackets (e.g., `<data_field/>`).
|
|
|
"""
|
|
|
# Only match placeholders of the form `<name/>` or `<name/>` with alphanumeric and underscores
|
|
|
return re.findall(r'(<[A-Za-z0-9_]+/>)', template_expression)
|
|
|
|
|
|
|
|
|
def parse_alpha_code(alpha_code: str, all_operators: list[dict]) -> tuple[list[str], list[str]]:
|
|
|
"""
|
|
|
Parses the alpha code to extract operators and data fields.
|
|
|
"""
|
|
|
# Remove C-style comments /* ... */
|
|
|
alpha_code = re.sub(r"/\*[\s\S]*?\*/", "", alpha_code)
|
|
|
# Remove Python-style comments # ...
|
|
|
alpha_code = re.sub(r"#.*", "", alpha_code)
|
|
|
|
|
|
operators_names = [op['name'] for op in all_operators]
|
|
|
|
|
|
found_operators = []
|
|
|
found_datafields = []
|
|
|
|
|
|
# Regex to find potential identifiers (operators or datafields)
|
|
|
# This regex looks for words that could be operators or datafields,
|
|
|
# excluding numbers and common programming constructs.
|
|
|
identifiers = re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', alpha_code)
|
|
|
|
|
|
for identifier in identifiers:
|
|
|
if identifier in operators_names:
|
|
|
found_operators.append(identifier)
|
|
|
elif not (identifier.isdigit() or identifier.lower() in ['true', 'false', 'null', 'nan', 'if', 'else', 'for', 'while', 'return', 'and', 'or', 'not', 'in', 'is', 'try', 'except', 'finally', 'with', 'as', 'def', 'class', 'import', 'from', 'yield', 'lambda', 'global', 'nonlocal', 'break', 'continue', 'pass', 'async', 'await', 'raise', 'assert', 'del', 'logger.info', 'input', 'len', 'min', 'max', 'sum', 'abs', 'round', 'int', 'float', 'str', 'list', 'dict', 'set', 'tuple', 'range', 'map', 'filter', 'zip', 'open', 'file', 'type', 'id', 'dir', 'help', 'object', 'super', 'issubclass', 'isinstance', 'hasattr', 'getattr', 'setattr', 'delattr', '__import__', 'None', 'True', 'False']):
|
|
|
found_datafields.append(identifier)
|
|
|
|
|
|
# Remove duplicates
|
|
|
found_operators = list(set(found_operators))
|
|
|
found_datafields = list(set(found_datafields))
|
|
|
|
|
|
return found_operators, found_datafields
|
|
|
|
|
|
|
|
|
async def generate_alpha_description(alpha_id: str, brain_session: SingleSession) -> str:
|
|
|
"""
|
|
|
Generates and potentially enriches the description of a given Alpha ID from the WorldQuant BRAIN API.
|
|
|
|
|
|
Args:
|
|
|
alpha_id (str): The ID of the alpha to retrieve.
|
|
|
brain_session (SingleSession): The active BRAIN API session.
|
|
|
llm_client (openai.AsyncOpenAI): The authenticated OpenAI-compatible client.
|
|
|
|
|
|
Returns:
|
|
|
str: A JSON string containing the alpha's settings, expression, and potentially enriched description,
|
|
|
or an empty JSON string if an error occurs.
|
|
|
"""
|
|
|
|
|
|
async def call_llm_new(prompt: str) -> dict:
|
|
|
# 声明使用全局变量
|
|
|
global LLM_model_name, LLM_API_KEY, llm_base_url
|
|
|
logger.info(f"\n[call_llm_new] 准备调用 LLM...")
|
|
|
logger.info(f"[call_llm_new] 模型: {LLM_model_name}")
|
|
|
logger.info(f"[call_llm_new] llm_base_url: {llm_base_url}")
|
|
|
|
|
|
try:
|
|
|
logger.info(f"[call_llm_new] 正在获取 LLM token...")
|
|
|
llm_api_key = get_token_from_auth_server()
|
|
|
llm_base_url_value = llm_base_url # 使用全局变量
|
|
|
logger.info(f"[call_llm_new] 创建 LLM 客户端...")
|
|
|
llm_client = openai.AsyncOpenAI(
|
|
|
base_url=llm_base_url_value, api_key=llm_api_key)
|
|
|
logger.info(
|
|
|
"[call_llm_new] LLM Gateway Authentication successful. (LLM网关认证成功)")
|
|
|
except Exception as e:
|
|
|
logger.error(
|
|
|
f"[call_llm_new] ❌ LLM Gateway Authentication failed (LLM网关认证失败): {e}")
|
|
|
sys.exit(1)
|
|
|
|
|
|
logger.info("[call_llm_new] --- Calling LLM... (正在调用LLM...) ---")
|
|
|
logger.info(f"[call_llm_new] Prompt 长度: {len(prompt)} 字符")
|
|
|
try:
|
|
|
# Await the async create call
|
|
|
logger.info(f"[call_llm_new] 正在发送请求到 LLM...")
|
|
|
response = await llm_client.chat.completions.create(
|
|
|
model=LLM_model_name,
|
|
|
messages=[
|
|
|
{"role": "system", "content": "You are a quantitative finance expert and a helpful assistant designed to output JSON."},
|
|
|
{"role": "user", "content": prompt},
|
|
|
],
|
|
|
# response_format={"type": "json_object"},
|
|
|
)
|
|
|
logger.info(f"[call_llm_new] 收到 LLM 响应")
|
|
|
|
|
|
# The async client may return a nested structure. Try to extract content robustly.
|
|
|
content = None
|
|
|
if isinstance(response, dict):
|
|
|
# Some clients return raw dicts
|
|
|
# Try common paths
|
|
|
choices = response.get('choices')
|
|
|
if choices and isinstance(choices, list):
|
|
|
msg = choices[0].get('message') or choices[0]
|
|
|
content = msg.get('content') if isinstance(
|
|
|
msg, dict) else None
|
|
|
elif 'content' in response:
|
|
|
content = response.get('content')
|
|
|
else:
|
|
|
# Fallback: attempt attribute access
|
|
|
try:
|
|
|
content = response.choices[0].message.content
|
|
|
except Exception:
|
|
|
content = None
|
|
|
|
|
|
if content is None:
|
|
|
# As a last resort, try to stringify the response
|
|
|
content = str(response)
|
|
|
|
|
|
# If content is already a dict/list, return it directly; if it's a JSON string, parse it.
|
|
|
if isinstance(content, (dict, list)):
|
|
|
return content
|
|
|
if isinstance(content, str):
|
|
|
try:
|
|
|
return json.loads(content)
|
|
|
except json.JSONDecodeError:
|
|
|
# Return wrapped string if not JSON
|
|
|
return {"text": content}
|
|
|
|
|
|
logger.info(f"[call_llm_new] ✓ 成功返回结果")
|
|
|
return {}
|
|
|
except Exception as e:
|
|
|
logger.error(f"[call_llm_new] ❌ Error calling LLM (调用LLM出错): {e}")
|
|
|
import traceback
|
|
|
logger.error(f"[call_llm_new] 错误详情: {traceback.format_exc()}")
|
|
|
return {}
|
|
|
|
|
|
logger.info(f"\n[Alpha Description] 开始获取 Alpha {alpha_id} 的详情...")
|
|
|
|
|
|
try:
|
|
|
brain_api_url = "https://api.worldquantbrain.com"
|
|
|
alpha_url = f"{brain_api_url}/alphas/{alpha_id}"
|
|
|
logger.info(f"[Alpha Description] 请求 URL: {alpha_url}")
|
|
|
response = brain_session.get(alpha_url)
|
|
|
logger.info(f"[Alpha Description] 响应状态码: {response.status_code}")
|
|
|
response.raise_for_status() # Raise an exception for HTTP errors
|
|
|
|
|
|
alpha_data = response.json()
|
|
|
logger.info(f"[Alpha Description] 成功获取 Alpha 数据")
|
|
|
settings = alpha_data.get('settings', {})
|
|
|
expression_dict = alpha_data.get(
|
|
|
'regular', alpha_data.get('combo', None))
|
|
|
|
|
|
if not expression_dict or 'code' not in expression_dict:
|
|
|
logger.info(
|
|
|
f"Error: Alpha expression code not found for Alpha ID (未找到Alpha表达式代码): {alpha_id}")
|
|
|
return json.dumps({})
|
|
|
|
|
|
alpha_code = expression_dict['code']
|
|
|
current_description = expression_dict.get('description', )
|
|
|
|
|
|
# 1. Get all operators for parsing (no filter as per feedback)
|
|
|
operators_data = get_brain_operators()
|
|
|
all_operators = operators_data.get('operators', [])
|
|
|
|
|
|
# 2. Parse the code to get operators and datafields
|
|
|
found_operators_names, found_datafields_names = parse_alpha_code(
|
|
|
alpha_code, all_operators)
|
|
|
|
|
|
# 3. Get descriptions for operators
|
|
|
operator_descriptions = {op['name']: op.get(
|
|
|
'description', 'No description available.') for op in all_operators if op['name'] in found_operators_names}
|
|
|
|
|
|
# 4. Get descriptions for datafields
|
|
|
datafield_descriptions = {}
|
|
|
if found_datafields_names:
|
|
|
# Extract settings from alpha_data for the get_datafields call
|
|
|
instrument_type = settings.get('instrumentType', 'EQUITY')
|
|
|
region = settings.get('region', 'USA')
|
|
|
universe = settings.get('universe', 'TOP3000')
|
|
|
delay = settings.get('delay', 1)
|
|
|
|
|
|
for df_name in found_datafields_names:
|
|
|
# get_datafields returns a DataFrame, so we need to process it
|
|
|
datafield_df = get_datafields(s=brain_session, instrument_type=instrument_type,
|
|
|
region=region, delay=delay, universe=universe, search=df_name)
|
|
|
if not datafield_df.empty:
|
|
|
# Assuming the first result is the most relevant
|
|
|
datafield_descriptions[df_name] = datafield_df.iloc[0].get(
|
|
|
'description', 'No description available.')
|
|
|
else:
|
|
|
datafield_descriptions[df_name] = 'No description found.'
|
|
|
|
|
|
# 5. Use LLM to judge if current description is good
|
|
|
judgment_prompt = f"""
|
|
|
Given the following alpha code, its current description, and descriptions of its operators and datafields:
|
|
|
|
|
|
Alpha Code:
|
|
|
{alpha_code}
|
|
|
|
|
|
Current Description:
|
|
|
{current_description}
|
|
|
|
|
|
Operators and their descriptions:
|
|
|
{json.dumps(operator_descriptions, indent=2)}
|
|
|
|
|
|
Datafields and their descriptions:
|
|
|
{json.dumps(datafield_descriptions, indent=2)}
|
|
|
|
|
|
Alpha Settings:
|
|
|
{json.dumps(settings, indent=2)}
|
|
|
|
|
|
Is the current description good enough? Respond with 'yes' or 'no' in a JSON object: {{"judgment": "yes/no"}}
|
|
|
A "good" description should clearly explain the investment idea, rationale for data used, and rationale for operators used.
|
|
|
"""
|
|
|
|
|
|
judgment_response = await call_llm_new(judgment_prompt)
|
|
|
is_description_good = judgment_response.get(
|
|
|
"judgment", "no").lower() == "yes"
|
|
|
|
|
|
new_description = current_description
|
|
|
if not is_description_good:
|
|
|
# 6. If not good, use another LLM to generate a new description
|
|
|
generation_prompt = f"""
|
|
|
Based on the following alpha code, its operators, datafields, and settings, generate a new, improved description.
|
|
|
The description should clearly explain the investment idea, rationale for data used, and rationale for operators used.
|
|
|
Format the output as:
|
|
|
"Idea: xxxxx\\nRationale for data used: xxxxx\\nRationale for operators used: xxxxxxx"
|
|
|
|
|
|
Alpha Code:
|
|
|
{alpha_code}
|
|
|
|
|
|
Operators and their descriptions:
|
|
|
{json.dumps(operator_descriptions, indent=2)}
|
|
|
|
|
|
Datafields and their descriptions:
|
|
|
{json.dumps(datafield_descriptions, indent=2)}
|
|
|
|
|
|
Alpha Settings:
|
|
|
{json.dumps(settings, indent=2)}
|
|
|
"""
|
|
|
|
|
|
generated_description_response = await call_llm_new(generation_prompt)
|
|
|
# Assuming LLM returns a string directly or a JSON with a 'description' key
|
|
|
new_description = generated_description_response.get(
|
|
|
"description", generated_description_response)
|
|
|
# Handle cases where LLM might return a dict directly
|
|
|
if isinstance(new_description, dict):
|
|
|
new_description = json.dumps(new_description, indent=2)
|
|
|
|
|
|
# 7. Override this new description and patch the alpha
|
|
|
set_alpha_properties(
|
|
|
s=brain_session,
|
|
|
alpha_id=alpha_id,
|
|
|
regular_desc=new_description
|
|
|
)
|
|
|
logger.info(
|
|
|
f"Alpha {alpha_id} description updated on platform. (Alpha描述已在平台更新)")
|
|
|
|
|
|
if 'regular' in alpha_data:
|
|
|
alpha_data['regular']['description'] = new_description
|
|
|
elif 'combo' in alpha_data:
|
|
|
alpha_data['combo']['description'] = new_description
|
|
|
|
|
|
return json.dumps({
|
|
|
'settings': settings,
|
|
|
'expression': expression_dict
|
|
|
})
|
|
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
|
logger.info(f"Error during API request (API请求出错): {e}")
|
|
|
return json.dumps({})
|
|
|
except json.JSONDecodeError:
|
|
|
logger.info(
|
|
|
"Error: Could not decode JSON response from API. (无法解析API的JSON响应)")
|
|
|
return json.dumps({})
|
|
|
except Exception as e:
|
|
|
logger.info(f"An unexpected error occurred (发生意外错误): {e}")
|
|
|
return json.dumps({})
|
|
|
|
|
|
|
|
|
def get_brain_operators(scope_filters: Optional[list[str]] = None) -> dict:
|
|
|
"""
|
|
|
Retrieves the list of available operators from the WorldQuant BRAIN API,
|
|
|
optionally filtered by a list of scopes. If no scopes are provided, all operators are returned.
|
|
|
|
|
|
Args:
|
|
|
scope_filters (list[str], optional): A list of strings to filter operators by their scope (e.g., ["REGULAR", "TS_OPERATOR"]).
|
|
|
If None or empty, all operators are returned.
|
|
|
|
|
|
Returns:
|
|
|
dict: A dictionary containing the operators list and count,
|
|
|
or an empty dictionary if an error occurs.
|
|
|
"""
|
|
|
try:
|
|
|
brain_api_url = "https://api.worldquantbrain.com"
|
|
|
session = start_session()
|
|
|
operators_url = f"{brain_api_url}/operators"
|
|
|
response = session.get(operators_url)
|
|
|
response.raise_for_status() # Raise an exception for HTTP errors
|
|
|
|
|
|
operators_list = response.json()
|
|
|
|
|
|
if not isinstance(operators_list, list):
|
|
|
logger.info(
|
|
|
f"Error: Expected a list of operators, but received type (预期运算符列表,但收到类型): {type(operators_list)}")
|
|
|
return {}
|
|
|
|
|
|
if scope_filters:
|
|
|
filtered_operators = [
|
|
|
op for op in operators_list
|
|
|
if any(s_filter in op.get('scope', []) for s_filter in scope_filters)
|
|
|
]
|
|
|
return {
|
|
|
'operators': filtered_operators,
|
|
|
'count': len(filtered_operators)
|
|
|
}
|
|
|
else:
|
|
|
return {
|
|
|
'operators': operators_list,
|
|
|
'count': len(operators_list)
|
|
|
}
|
|
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
|
logger.info(
|
|
|
f"Error during API request for operators (获取运算符时API请求出错): {e}")
|
|
|
return {}
|
|
|
except json.JSONDecodeError:
|
|
|
logger.info(
|
|
|
"Error: Could not decode JSON response from operators API. (无法解析运算符API的JSON响应)")
|
|
|
return {}
|
|
|
except Exception as e:
|
|
|
logger.info(
|
|
|
f"An unexpected error occurred while getting operators (获取运算符时发生意外错误): {e}")
|
|
|
return {}
|
|
|
|
|
|
|
|
|
async def call_llm(prompt: str, llm_client: openai.AsyncOpenAI, max_retries: int = 3) -> dict:
|
|
|
"""
|
|
|
Interface with a Large Language Model to process prompts and get a JSON response.
|
|
|
Includes retry logic for JSON parsing errors.
|
|
|
"""
|
|
|
# 声明使用全局变量
|
|
|
global LLM_model_name
|
|
|
if not llm_client:
|
|
|
logger.info(
|
|
|
"LLM client not initialized. Please check authentication. (LLM客户端未初始化,请检查认证)")
|
|
|
return {}
|
|
|
|
|
|
logger.info("\n[LLM Call] 准备调用 LLM API...")
|
|
|
logger.info(f"[LLM Call] 模型: {LLM_model_name}")
|
|
|
logger.info(f"[LLM Call] Prompt 长度: {len(prompt)} 字符")
|
|
|
logger.info("[LLM Call] 正在发送请求...")
|
|
|
|
|
|
for attempt in range(max_retries):
|
|
|
try:
|
|
|
logger.info(f"[LLM Call] 第 {attempt + 1} 次尝试...")
|
|
|
response = await llm_client.chat.completions.create(
|
|
|
model=LLM_model_name, # Or your preferred model
|
|
|
messages=[
|
|
|
{"role": "system", "content": "You are a quantitative finance expert and a helpful assistant designed to output JSON."},
|
|
|
{"role": "user", "content": prompt},
|
|
|
],
|
|
|
# response_format={"type": "json_object"},
|
|
|
)
|
|
|
logger.info(f"[LLM Call] 收到响应,状态: OK")
|
|
|
content = response.choices[0].message.content
|
|
|
logger.info(f"[LLM Call] 响应内容长度: {len(content)} 字符")
|
|
|
|
|
|
# with open(f'llm_resp{str(int(time.time()))}.md', 'w', encoding='utf-8') as f:
|
|
|
# f.write(content)
|
|
|
|
|
|
# 增强的 JSON 清洗逻辑
|
|
|
original_content = content
|
|
|
|
|
|
# 首先移除 <think>...</think> 标签及其内容
|
|
|
if "<think>" in content and "</think>" in content:
|
|
|
think_start = content.find("<think>")
|
|
|
think_end = content.find("</think>") + len("</think>")
|
|
|
content = content[:think_start] + content[think_end:]
|
|
|
logger.info(f"[LLM Call] 移除了 <think> 标签")
|
|
|
|
|
|
content = content.strip()
|
|
|
|
|
|
# Try to clean markdown code blocks if present
|
|
|
if "```json" in content:
|
|
|
content = content.split("```json")[1].split("```")[0].strip()
|
|
|
logger.info(f"[LLM Call] 清理了 JSON markdown 标记")
|
|
|
elif "```" in content:
|
|
|
content = content.split("```")[1].split("```")[0].strip()
|
|
|
logger.info(f"[LLM Call] 清理了 markdown 标记")
|
|
|
else:
|
|
|
# 如果 markdown 清理失败,尝试通过大括号提取 JSON
|
|
|
logger.info(f"[LLM Call] 未找到 markdown 标记,尝试大括号提取...")
|
|
|
start_idx = content.find('{')
|
|
|
end_idx = content.rfind('}')
|
|
|
if start_idx != -1 and end_idx != -1 and start_idx < end_idx:
|
|
|
content = content[start_idx:end_idx+1]
|
|
|
logger.info(f"[LLM Call] 通过大括号提取 JSON,长度: {len(content)} 字符")
|
|
|
|
|
|
logger.info(f"[LLM Call] 解析 JSON...")
|
|
|
result = json.loads(content)
|
|
|
logger.info(f"[LLM Call] JSON 解析成功,返回 {len(result)} 个结果")
|
|
|
return result
|
|
|
except json.JSONDecodeError as e:
|
|
|
logger.info(f"start *" * 88)
|
|
|
logger.warning(
|
|
|
f"⚠ JSON Decode Error (Attempt {attempt + 1}/{max_retries}): {e}\n\n{content}\n\n")
|
|
|
logger.info(f"end *" * 88)
|
|
|
if attempt == max_retries - 1:
|
|
|
logger.error(
|
|
|
f"❌ Failed to parse JSON after {max_retries} attempts. Raw content: {content[:100]}...")
|
|
|
except Exception as e:
|
|
|
logger.warning(
|
|
|
f"⚠ LLM Call Error (Attempt {attempt + 1}/{max_retries}): {e}")
|
|
|
if attempt == max_retries - 1:
|
|
|
logger.error(
|
|
|
f"❌ Failed to call LLM after {max_retries} attempts.")
|
|
|
raise Exception(f"LLM 调用失败: {e}")
|
|
|
|
|
|
# Wait before retrying (2 seconds for MiniMax 529)
|
|
|
logger.info(f"⏳ 等待 2 秒后重试...")
|
|
|
await asyncio.sleep(2)
|
|
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
def has_valid_placeholders(template_str: str) -> bool:
|
|
|
"""检查模板字符串是否包含有效的占位符"""
|
|
|
import re
|
|
|
placeholders = re.findall(r'(<[A-Za-z0-9_]+/>)', template_str)
|
|
|
return len(placeholders) > 0
|
|
|
|
|
|
|
|
|
async def propose_alpha_templates_with_retry(alpha_details: dict, template_summary: str, llm_client: openai.AsyncOpenAI, user_data_type: str = "MATRIX", max_retries: int = 20) -> dict:
|
|
|
"""
|
|
|
使用重试机制生成 Alpha 模板,确保包含占位符
|
|
|
"""
|
|
|
if not alpha_details.get('expression'):
|
|
|
logger.info("Error: Alpha expression is missing. (错误:缺少Alpha表达式)")
|
|
|
return {}
|
|
|
|
|
|
data_type_instruction = ""
|
|
|
if user_data_type == "MATRIX":
|
|
|
data_type_instruction = "\n**Important Note on Data Type:**\nThe user has specified the data type as **MATRIX**. Please do NOT use any vector-type operators (e.g., `vec_avg`, `vec_sum`) in your proposed templates, as they will raise errors for MATRIX type data in BRAIN. Note: 'MATRIX' is just a system identifier and does not refer to mathematical matrices."
|
|
|
elif user_data_type == "VECTOR":
|
|
|
data_type_instruction = "\n**Important Note on Data Type:**\nThe user has specified the data type as **VECTOR**. Please ensure you use vector-type operators (e.g., `vec_avg`, `vec_sum`) to handle the data fields before applying other operators."
|
|
|
|
|
|
prompt = f"""
|
|
|
As a world-class BRAIN consultant, your task is to design new alpha templates based on an existing seed alpha.
|
|
|
You will be provided with the seed alpha's expression and a summary of successful alpha templates for inspiration.
|
|
|
|
|
|
**Seed Alpha Expression:**
|
|
|
{alpha_details['expression']}
|
|
|
|
|
|
**Inspiration: Summary of Alpha Templates:**
|
|
|
{template_summary}
|
|
|
|
|
|
**Your Task:**
|
|
|
Based on the structure and potential economic rationale of the seed alpha, by the aid of the Alpha template summary, propose 3-5 new, diverse alpha templates.
|
|
|
|
|
|
**CRITICAL RULES (必须遵守):**
|
|
|
1. The proposed templates must be valid BRAIN alpha expressions.
|
|
|
2. **MANDATORY: You MUST use placeholders like `<data_field/>` for data fields and `<operator/>` for operators. DO NOT use actual data field names like `avg_pct_change_estimate_12m_earnings_7d` directly in the template. Placeholders are REQUIRED and will be replaced programmatically later.**
|
|
|
3. Valid placeholder formats: `<data_field/>`, `<operator/>`, `<ts_operator/>`, `<group_operator/>`, `<integer_parameter/>`, `<float_parameter/>`
|
|
|
4. For each proposed template, provide a brief, clear explanation of its investment rationale.
|
|
|
5. Return the output as a single, valid JSON object where keys are the proposed template strings and values are their corresponding explanations. Do not include any other text or formatting outside of the JSON object.
|
|
|
6. The proposed new alpha template should be related to the economic sense of seed Alpha but in different format. Utilize the inspiration well.
|
|
|
{data_type_instruction}
|
|
|
|
|
|
**Example Output Format (占位符格式示例):**
|
|
|
{{
|
|
|
"<group_operator/>(<ts_operator/>(<data_field/>, 60), industry)": "A cross-sectional momentum signal, neutralized by industry, to capture relative strength within peer groups.",
|
|
|
"<operator/>(<ts_operator/>(<data_field/>, 20), <float_parameter/>)": "A simple short-term momentum operator applied to a data field with a parameter."
|
|
|
}}
|
|
|
|
|
|
**WARNING: If you do not use placeholders like `<data_field/>`, the template will be rejected and you will need to regenerate. Placeholders are ESSENTIAL for the template system to work.**
|
|
|
|
|
|
Now, generate the JSON object with your proposed templates. Remember: USE PLACEHOLDERS like `<data_field/>`, NOT actual field names!
|
|
|
"""
|
|
|
|
|
|
logger.info(f"\n[Step 1/5] 正在调用 LLM 生成 Alpha 模板...")
|
|
|
logger.info(f" - 模型: {LLM_model_name}")
|
|
|
logger.info(f" - 数据类型: {user_data_type}")
|
|
|
logger.info(f" - 最大重试次数: {max_retries}")
|
|
|
alpha_expr = alpha_details.get('expression', {})
|
|
|
if isinstance(alpha_expr, dict):
|
|
|
alpha_expr = alpha_expr.get('code', 'N/A')
|
|
|
logger.info(f" - 种子 Alpha: {str(alpha_expr)[:50]}...")
|
|
|
|
|
|
# 重试机制
|
|
|
for attempt in range(1, max_retries + 1):
|
|
|
try:
|
|
|
logger.info(f"\n [尝试 {attempt}/{max_retries}] 调用 LLM...")
|
|
|
proposed_templates = await call_llm(prompt, llm_client)
|
|
|
|
|
|
# 验证是否包含占位符
|
|
|
valid_templates = {}
|
|
|
invalid_templates = []
|
|
|
|
|
|
for template_expr, explanation in proposed_templates.items():
|
|
|
if has_valid_placeholders(template_expr):
|
|
|
valid_templates[template_expr] = explanation
|
|
|
else:
|
|
|
invalid_templates.append(template_expr)
|
|
|
|
|
|
if valid_templates:
|
|
|
logger.info(f" ✓ 成功生成 {len(valid_templates)} 个有效模板(含占位符)")
|
|
|
if invalid_templates:
|
|
|
logger.warning(
|
|
|
f" ⚠ 丢弃 {len(invalid_templates)} 个无效模板(无占位符)")
|
|
|
return valid_templates
|
|
|
else:
|
|
|
logger.warning(f" ✗ 所有模板均无占位符,需要重试")
|
|
|
if invalid_templates:
|
|
|
logger.info(f" 无效模板示例: {invalid_templates[0][:80]}...")
|
|
|
|
|
|
if attempt < max_retries:
|
|
|
logger.info(f" ↻ 等待重试...")
|
|
|
await asyncio.sleep(1) # 短暂延迟避免请求过快
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f" ✗ 调用 LLM 时发生错误: {e}")
|
|
|
# 529 是 MiniMax 的特色,继续重试
|
|
|
if "overloaded" in str(e) or "529" in str(e):
|
|
|
logger.warning(f" ⚠ MiniMax 529 错误,继续重试...")
|
|
|
if attempt < max_retries:
|
|
|
logger.info(f" ↻ 等待 2 秒后重试...")
|
|
|
await asyncio.sleep(2)
|
|
|
|
|
|
# 20次都失败了
|
|
|
logger.warning(f"\n⚠⚠⚠ 警告: 经过 {max_retries} 次重试,仍未能生成包含占位符的模板!")
|
|
|
logger.info(f" 可能原因: LLM 未遵循指令,或模型不支持此格式。")
|
|
|
logger.info(f" 建议: 检查 LLM 模型是否正确,或手动修改 prompt。")
|
|
|
return {}
|
|
|
|
|
|
|
|
|
async def propose_alpha_templates(alpha_details: dict, template_summary: str, llm_client: openai.AsyncOpenAI, user_data_type: str = "MATRIX", max_retries: int = 20) -> dict:
|
|
|
"""
|
|
|
Uses an LLM to propose new alpha templates based on a seed alpha's details.
|
|
|
包装函数,支持重试机制
|
|
|
"""
|
|
|
return await propose_alpha_templates_with_retry(alpha_details, template_summary, llm_client, user_data_type, max_retries)
|
|
|
|
|
|
|
|
|
async def propose_datafield_keywords(template_expression: str, template_explanation: str, placeholder: str, llm_client: openai.AsyncOpenAI, user_category: Optional[Union[str, list]] = None) -> list[str]:
|
|
|
"""
|
|
|
Uses an LLM to propose search keywords for finding data fields.
|
|
|
"""
|
|
|
category_instruction = ""
|
|
|
if user_category:
|
|
|
category_instruction = f"\n**User Specified Data Category:**\nThe user has specified the data category: {user_category}. Please ensure the proposed keywords are relevant to this category."
|
|
|
else:
|
|
|
category_instruction = "\n**Data Category:**\n Please propose keywords across diverse and relevant data categories."
|
|
|
|
|
|
prompt = f"""
|
|
|
As a quantitative researcher, you need to find the best data fields for an alpha template placeholder.
|
|
|
Based on the template's logic and the placeholder's name, suggest a list of 3-5 concise search keywords to use with the WorldQuant BRAIN `get_datafields` tool.
|
|
|
|
|
|
**Alpha Template:**
|
|
|
`{template_expression}`
|
|
|
|
|
|
**Template Explanation:**
|
|
|
`{template_explanation}`
|
|
|
|
|
|
**Placeholder to Fill:**
|
|
|
`{placeholder}`
|
|
|
{category_instruction}
|
|
|
|
|
|
**Your Task:**
|
|
|
Provide a list of search keywords that are likely to yield relevant data fields for this placeholder. The keywords should be specific and diverse. Return the output as a single, valid JSON array of strings.
|
|
|
|
|
|
**Example Input:**
|
|
|
Placeholder: `<slow_moving_characteristic/>`
|
|
|
Explanation: "measures the time-series evolution of a fund's relative rank on a slow-moving characteristic (e.g., fund style, expense tier)"
|
|
|
|
|
|
**Example Output:**
|
|
|
["fund style", "expense ratio", "management fee", "turnover", "aum"]
|
|
|
|
|
|
Now, generate the JSON array of search keywords for the given placeholder.
|
|
|
"""
|
|
|
logger.info(
|
|
|
f"--- Calling LLM to get keywords for placeholder (正在调用LLM获取占位符关键词): {placeholder} ---")
|
|
|
response = await call_llm(prompt, llm_client)
|
|
|
logger.info(f"AI使用如下提示词获取搜索关键词推荐:{prompt}")
|
|
|
# Accept either a direct list or a dict containing a 'keywords' key
|
|
|
if isinstance(response, list) and all(isinstance(item, str) for item in response):
|
|
|
return response
|
|
|
if isinstance(response, dict):
|
|
|
# Common keys that might contain the list
|
|
|
for key in ('keywords', 'data', 'result', 'items'):
|
|
|
if key in response and isinstance(response[key], list) and all(isinstance(i, str) for i in response[key]):
|
|
|
return response[key]
|
|
|
logger.warning(
|
|
|
f"Warning: LLM did not return a valid list of strings for keywords (警告:LLM未返回有效的关键词列表). Got: {response}")
|
|
|
return []
|
|
|
|
|
|
|
|
|
async def get_datafield_candidates(s: SingleSession, alpha_details: dict, template_expression: str, template_explanation: str, placeholder: str, llm_client: openai.AsyncOpenAI, top_n: int = 50, user_region: Optional[str] = None, user_universe: Optional[str] = None, user_delay: Optional[int] = None, user_category: Optional[Union[str, list]] = None, user_data_type: str = "MATRIX") -> list[dict]:
|
|
|
"""
|
|
|
Gets candidate data fields for a placeholder by using an LLM to generate search keywords
|
|
|
and then searching in local cache file.
|
|
|
"""
|
|
|
keywords = await propose_datafield_keywords(template_expression, template_explanation, placeholder, llm_client, user_category=user_category)
|
|
|
if not keywords:
|
|
|
logger.info(
|
|
|
f"Could not generate keywords for placeholder (无法生成占位符关键词): {placeholder}")
|
|
|
return []
|
|
|
|
|
|
logger.info(
|
|
|
f"LLM-proposed keywords for '{placeholder}' (LLM提议的关键词): {keywords}")
|
|
|
|
|
|
# Extract settings from alpha_details for the cache file path
|
|
|
settings = alpha_details.get('settings', {})
|
|
|
logger.info(f"Alpha settings for datafield search (用于数据字段搜索的Alpha设置):")
|
|
|
instrument_type = settings.get('instrumentType', 'EQUITY')
|
|
|
|
|
|
if user_region:
|
|
|
region = user_region
|
|
|
elif 'region' in settings:
|
|
|
region = settings['region']
|
|
|
else:
|
|
|
logger.error(f"❌ Error: Could not determine 'region' for datafield search. It is missing in Alpha settings and not provided by user. (错误:无法确定数据搜索的地区,Alpha设置中缺失且用户未提供)")
|
|
|
return []
|
|
|
logger.info(f" 数据地区: {region}")
|
|
|
|
|
|
if user_universe:
|
|
|
universe = user_universe
|
|
|
elif 'universe' in settings:
|
|
|
universe = settings['universe']
|
|
|
else:
|
|
|
logger.error(f"❌ Error: Could not determine 'universe' for datafield search. It is missing in Alpha settings and not provided by user. (错误:无法确定数据搜索的范围,Alpha设置中缺失且用户未提供)")
|
|
|
return []
|
|
|
logger.info(f" 数据范围: {universe}")
|
|
|
|
|
|
if user_delay is not None:
|
|
|
delay = user_delay
|
|
|
elif 'delay' in settings:
|
|
|
delay = settings['delay']
|
|
|
else:
|
|
|
logger.error(f"❌ Error: Could not determine 'delay' for datafield search. It is missing in Alpha settings and not provided by user. (错误:无法确定数据搜索的Delay,Alpha设置中缺失且用户未提供)")
|
|
|
return []
|
|
|
logger.info(f" Delay: {delay} 类别")
|
|
|
|
|
|
if user_category:
|
|
|
logger.info(f" Category Filter: {user_category}")
|
|
|
|
|
|
# 检查本地缓存文件 (从项目根目录的 dataset 文件夹读取)
|
|
|
dataset_dir = Path(__file__).parent.parent / "dataset"
|
|
|
|
|
|
# 确定要读取的类别列表(前端总是传递类别列表)
|
|
|
categories_to_read = user_category if isinstance(user_category, list) else [user_category] if user_category else []
|
|
|
|
|
|
if not categories_to_read:
|
|
|
logger.error(f"❌ 未指定数据类别,请先选择类别或点击'不筛选(默认)'")
|
|
|
raise ValueError("未指定数据类别,请先选择类别")
|
|
|
|
|
|
logger.info(f"[Cache] 准备读取 {len(categories_to_read)} 个类别的数据字段: {categories_to_read}")
|
|
|
|
|
|
# 读取多个类别的缓存文件并合并
|
|
|
all_dataframes = []
|
|
|
missing_files = []
|
|
|
used_cache_files = [] # 记录使用的缓存文件
|
|
|
|
|
|
for cat in categories_to_read:
|
|
|
cache_filename = f"datafields_cache_{region}_{universe}_D{delay}_{cat}.csv"
|
|
|
cache_path = dataset_dir / cache_filename
|
|
|
|
|
|
if cache_path.exists():
|
|
|
used_cache_files.append(cache_filename)
|
|
|
logger.info(f"[Cache] 读取类别 '{cat}' 的缓存: {cache_filename}")
|
|
|
df = pd.read_csv(cache_path)
|
|
|
all_dataframes.append(df)
|
|
|
logger.info(f"[Cache] 类别 '{cat}' 包含 {len(df)} 个字段")
|
|
|
else:
|
|
|
missing_files.append(cache_filename)
|
|
|
logger.warning(f"[Cache] 类别 '{cat}' 的缓存文件不存在: {cache_filename}")
|
|
|
|
|
|
# 输出使用的缓存文件汇总
|
|
|
logger.info(f"[Cache] ==========================================")
|
|
|
logger.info(f"[Cache] 使用的缓存文件 ({len(used_cache_files)} 个):")
|
|
|
for i, filename in enumerate(used_cache_files, 1):
|
|
|
logger.info(f"[Cache] {i}. {filename}")
|
|
|
logger.info(f"[Cache] ==========================================")
|
|
|
|
|
|
if not all_dataframes:
|
|
|
logger.error(f"❌ 所有数据字段缓存文件都不存在")
|
|
|
logger.error(f"❌ 请先点击'下载数据字段缓存'按钮下载数据字段")
|
|
|
raise FileNotFoundError(f"数据字段缓存文件不存在: {', '.join(missing_files)},请先下载数据字段缓存")
|
|
|
|
|
|
# 合并所有数据框
|
|
|
all_datafields_df = pd.concat(all_dataframes, ignore_index=True)
|
|
|
# 去重
|
|
|
all_datafields_df.drop_duplicates(subset=['id'], inplace=True)
|
|
|
logger.info(f"[Cache] 合并后共有 {len(all_datafields_df)} 个唯一字段")
|
|
|
|
|
|
# 在本地缓存中搜索匹配关键词的字段
|
|
|
matched_results = []
|
|
|
for keyword in keywords:
|
|
|
# 在 id 和 description 中搜索关键词
|
|
|
keyword_lower = keyword.lower()
|
|
|
mask = (
|
|
|
all_datafields_df['id'].str.lower().str.contains(keyword_lower, na=False) |
|
|
|
all_datafields_df.get('description', '').str.lower().str.contains(keyword_lower, na=False)
|
|
|
)
|
|
|
matched = all_datafields_df[mask]
|
|
|
if not matched.empty:
|
|
|
matched_results.append(matched.head(top_n))
|
|
|
logger.info(f"[Cache] 关键词 '{keyword}' 匹配到 {len(matched)} 个字段")
|
|
|
else:
|
|
|
logger.info(f"[Cache] 关键词 '{keyword}' 未匹配到字段")
|
|
|
|
|
|
candidate_datafields = []
|
|
|
if matched_results:
|
|
|
# 合并所有匹配结果
|
|
|
combined_df = pd.concat(matched_results, ignore_index=True)
|
|
|
# 去重
|
|
|
combined_df.drop_duplicates(subset=['id'], inplace=True)
|
|
|
# 限制总数
|
|
|
combined_df = combined_df.head(top_n * len(keywords))
|
|
|
# 格式化结果
|
|
|
candidate_datafields = combined_df[['id', 'description']].to_dict(orient='records')
|
|
|
|
|
|
logger.info(f"[Cache] 最终返回 {len(candidate_datafields)} 个候选字段")
|
|
|
return candidate_datafields
|
|
|
|
|
|
|
|
|
async def get_group_datafield_candidates(template_expression: str, template_explanation: str, placeholder: str, llm_client: openai.AsyncOpenAI, top_n: int = 3) -> list[dict]:
|
|
|
"""
|
|
|
Uses an LLM to select suitable group data fields from a predefined list.
|
|
|
"""
|
|
|
predefined_group_fields = ["industry",
|
|
|
"subindustry", "sector", "market", "exchange"]
|
|
|
|
|
|
prompt = f"""
|
|
|
As a quantitative researcher, you need to select the most relevant group data fields for an alpha template placeholder.
|
|
|
Based on the template's logic and the placeholder's name, select {top_n} group fields from the following list that are most suitable: {predefined_group_fields}.
|
|
|
|
|
|
**Alpha Template:**
|
|
|
`{template_expression}`
|
|
|
|
|
|
**Template Explanation:**
|
|
|
`{template_explanation}`
|
|
|
|
|
|
**Placeholder to Fill:**
|
|
|
`{placeholder}`
|
|
|
|
|
|
**Your Task:**
|
|
|
Provide a list of selected group data fields. Return the output as a single, valid JSON array of strings.
|
|
|
|
|
|
**Example Output Format:**
|
|
|
["industry", "sector"]
|
|
|
|
|
|
Now, generate the JSON array of selected group data fields.
|
|
|
"""
|
|
|
logger.info(
|
|
|
f"--- Calling LLM to select group datafields for placeholder (正在调用LLM选择分组数据字段): {placeholder} ---")
|
|
|
response = await call_llm(prompt, llm_client)
|
|
|
|
|
|
if isinstance(response, list) and all(isinstance(item, str) for item in response):
|
|
|
return [{"name": field} for field in response[:top_n]]
|
|
|
logger.warning(
|
|
|
f"Warning: LLM did not return a valid list of strings for group datafields (警告:LLM未返回有效的分组数据字段列表). Got: {response}")
|
|
|
# Fallback to default if LLM fails
|
|
|
return [{"name": field} for field in predefined_group_fields[:top_n]]
|
|
|
|
|
|
|
|
|
async def get_operator_candidates(template_expression: str, template_explanation: str, placeholder: str, llm_client: openai.AsyncOpenAI, top_n: int = 3) -> list[dict]:
|
|
|
"""
|
|
|
Gets candidate operators for a placeholder by first fetching all REGULAR scope operators
|
|
|
and then using an LLM to select the most relevant ones.
|
|
|
"""
|
|
|
operators_data = get_brain_operators(scope_filters=["REGULAR"])
|
|
|
all_operators = operators_data.get('operators', [])
|
|
|
|
|
|
if not all_operators:
|
|
|
logger.info("No REGULAR scope operators found. (未找到REGULAR范围的运算符)")
|
|
|
return []
|
|
|
|
|
|
# Create a summary of available operators for the LLM
|
|
|
operator_names_and_descriptions = "\n".join(
|
|
|
[f"- {op['name']}: {op.get('description', 'No description available.')}" for op in all_operators])
|
|
|
|
|
|
prompt = f"""
|
|
|
As a quantitative finance expert, you need to select the most relevant operators for an alpha template placeholder.
|
|
|
Based on the template's logic, its explanation, and the specific placeholder, select {top_n} operators from the provided list that are most suitable.
|
|
|
|
|
|
**Alpha Template:**
|
|
|
`{template_expression}`
|
|
|
|
|
|
**Template Explanation:**
|
|
|
`{template_explanation}`
|
|
|
|
|
|
**Placeholder to Fill:**
|
|
|
`{placeholder}`
|
|
|
|
|
|
**Available REGULAR Scope Operators:**
|
|
|
{operator_names_and_descriptions}
|
|
|
|
|
|
**Your Task:**
|
|
|
Provide a list of selected operator names. Return the output as a single, valid JSON array of strings.
|
|
|
|
|
|
**Example Output Format:**
|
|
|
["ts_mean", "ts_rank", "ts_decay"]
|
|
|
|
|
|
Now, generate the JSON array of selected operators.
|
|
|
"""
|
|
|
logger.info(
|
|
|
f"--- Calling LLM to select operator candidates for placeholder (正在调用LLM选择运算符候选): {placeholder} ---")
|
|
|
response = await call_llm(prompt, llm_client)
|
|
|
|
|
|
if isinstance(response, list) and all(isinstance(item, str) for item in response):
|
|
|
# Filter the full list of operators to return the selected ones with their descriptions
|
|
|
selected_ops_details = []
|
|
|
for selected_name in response:
|
|
|
for op in all_operators:
|
|
|
if op['name'] == selected_name:
|
|
|
selected_ops_details.append(
|
|
|
{"name": op['name'], "description": op.get('description', )})
|
|
|
break
|
|
|
return selected_ops_details[:top_n]
|
|
|
|
|
|
logger.warning(
|
|
|
f"Warning: LLM did not return a valid list of strings for operator candidates (警告:LLM未返回有效的运算符候选列表). Got: {response}")
|
|
|
# Fallback to a default set if LLM fails
|
|
|
return [{"name": op['name'], "description": op.get('description', )} for op in all_operators[:top_n]]
|
|
|
|
|
|
|
|
|
async def get_parameter_candidates(param_type: str, template_expression: str, template_explanation: str, placeholder: str, llm_client: openai.AsyncOpenAI) -> list[dict]:
|
|
|
"""
|
|
|
Uses an LLM to suggest sensible numerical candidates for parameters.
|
|
|
"""
|
|
|
param_description = "an integer value, typically a window length or count (e.g., `d` in `ts_mean(x, d)`)" if param_type == "integer_parameter" else \
|
|
|
"a floating-point number, typically a threshold or factor"
|
|
|
|
|
|
prompt = f"""
|
|
|
As a quantitative finance expert, you need to suggest sensible numerical candidates for a placeholder parameter.
|
|
|
Based on the alpha template's logic, its explanation, and the placeholder's type and context, propose 3-5 diverse numerical candidates.
|
|
|
|
|
|
**Alpha Template:**
|
|
|
`{template_expression}`
|
|
|
|
|
|
**Template Explanation:**
|
|
|
`{template_explanation}`
|
|
|
|
|
|
**Placeholder to Fill:**
|
|
|
`{placeholder}`
|
|
|
|
|
|
**Parameter Type:**
|
|
|
This placeholder represents {param_description}.
|
|
|
|
|
|
**Your Task:**
|
|
|
Provide a list of numerical candidates that are appropriate for this parameter. Return the output as a single, valid JSON array of numbers.
|
|
|
|
|
|
**Example Output (for integer_parameter):**
|
|
|
[10, 20, 60, 120, 252]
|
|
|
|
|
|
**Example Output (for float_parameter):**
|
|
|
[0.01, 0.05, 0.1, 0.2, 0.5]
|
|
|
|
|
|
Now, generate the JSON array of numerical candidates.
|
|
|
"""
|
|
|
logger.info(
|
|
|
f"--- Calling LLM to suggest candidates for {param_type} placeholder (正在调用LLM建议参数候选): {placeholder} ---")
|
|
|
response = await call_llm(prompt, llm_client)
|
|
|
|
|
|
if isinstance(response, list) and all(isinstance(item, (int, float)) for item in response):
|
|
|
return [{"value": val} for val in response]
|
|
|
logger.warning(
|
|
|
f"Warning: LLM did not return a valid list of numbers for {param_type} candidates (警告:LLM未返回有效的数字候选列表). Got: {response}")
|
|
|
|
|
|
# Fallback to default if LLM fails
|
|
|
if param_type == "integer_parameter":
|
|
|
return [{"value": x} for x in [10, 20, 60, 120, 252]]
|
|
|
elif param_type == "float_parameter":
|
|
|
return [{"value": x} for x in [0.01, 0.05, 0.1, 0.2, 0.5]]
|
|
|
return []
|
|
|
|
|
|
|
|
|
async def judge_placeholder_type(placeholder: str, template_expression: str, template_explanation: str, operator_summary: str, llm_client: openai.AsyncOpenAI) -> str:
|
|
|
"""
|
|
|
Uses an LLM to judge the type of placeholder (e.g., "data_field", "integer_parameter", "group_operator").
|
|
|
"""
|
|
|
prompt = f"""
|
|
|
As a world-class quantitative finance expert, your task is to classify the type of a placeholder within an alpha expression.
|
|
|
You will be provided with the alpha template, its explanation, the specific placeholder, and a comprehensive summary of available BRAIN operators and data field characteristics.
|
|
|
|
|
|
**Alpha Template:**
|
|
|
`{template_expression}`
|
|
|
|
|
|
**Template Explanation:**
|
|
|
`{template_explanation}`
|
|
|
|
|
|
**Placeholder to Classify:**
|
|
|
`{placeholder}`
|
|
|
|
|
|
**Available BRAIN Operators and Data Field Characteristics:**
|
|
|
{operator_summary}
|
|
|
|
|
|
**Your Task:**
|
|
|
Classify the `{placeholder}` based on the provided context. The classification should be one of the following types:
|
|
|
- "data_field": If the placeholder clearly represents a financial data series (e.g., price, volume, fundamental ratio).
|
|
|
- "group_data_field": If the placeholder represents a categorical field used for grouping or neutralization (e.g., `industry` in `group_zscore(x, industry)`).
|
|
|
- "operator": If the placeholder represents a BRAIN operator that performs a calculation or transformation.
|
|
|
- "vector_operator": If the placeholder represents a vector operator (e.g., vec_avg, vec_sum).
|
|
|
- "integer_parameter": If the placeholder represents an integer value, typically a window length or count (e.g., `d` in `ts_mean(x, d)`).
|
|
|
- "float_parameter": If the placeholder represents a floating-point number, typically a threshold or factor.
|
|
|
- "string_parameter": If the placeholder represents a string value, like a group name (e.g., `industry` in `group_zscore(x, industry)`).
|
|
|
- "unknown": If the type cannot be determined from the context.
|
|
|
|
|
|
Return the classification as a single JSON object with a key "placeholder_type" and its corresponding value. Do not include any other text or formatting outside of the JSON object.
|
|
|
|
|
|
**Example Output Format:**
|
|
|
{{"placeholder_type": "data_field"}}
|
|
|
{{"placeholder_type": "integer_parameter"}}
|
|
|
|
|
|
Now, classify the placeholder.
|
|
|
"""
|
|
|
logger.info(
|
|
|
f"--- Calling LLM to judge type for placeholder (正在调用LLM判断占位符类型): {placeholder} ---")
|
|
|
|
|
|
response = await call_llm(prompt, llm_client)
|
|
|
return response.get("placeholder_type", "unknown")
|
|
|
|
|
|
|
|
|
async def populate_template(s: SingleSession, alpha_details: dict, template_expression: str, template_explanation: str, operator_summary: str, llm_client: openai.AsyncOpenAI, top_n_datafield: int = 50, user_region: Optional[str] = None, user_universe: Optional[str] = None, user_delay: Optional[int] = None, user_category: Optional[Union[str, list]] = None, user_data_type: str = "MATRIX") -> dict:
|
|
|
"""
|
|
|
Populates placeholders in an alpha template with candidate data fields, operators, or parameters.
|
|
|
"""
|
|
|
placeholders = extract_placeholders(template_expression)
|
|
|
|
|
|
if not placeholders:
|
|
|
logger.info("No placeholders found in the template. (模板中未找到占位符)")
|
|
|
return {}
|
|
|
|
|
|
"""
|
|
|
Populates placeholders in an alpha template with candidate data fields, operators, or parameters.
|
|
|
"""
|
|
|
placeholders = extract_placeholders(template_expression)
|
|
|
logger.info(f"Found placeholders in template (在模板中找到占位符): {placeholders}")
|
|
|
|
|
|
populated_placeholders = {}
|
|
|
|
|
|
for ph in placeholders:
|
|
|
# Use LLM to judge placeholder type
|
|
|
ph_type = await judge_placeholder_type(ph, template_expression, template_explanation, operator_summary, llm_client)
|
|
|
logger.info(f"'{ph}' judged as type (判断类型为): {ph_type}")
|
|
|
|
|
|
if ph_type == "data_field":
|
|
|
candidates = await get_datafield_candidates(s, alpha_details, template_expression, template_explanation, ph, llm_client, top_n=top_n_datafield, user_region=user_region, user_universe=user_universe, user_delay=user_delay, user_category=user_category, user_data_type=user_data_type)
|
|
|
populated_placeholders[ph] = {
|
|
|
"type": "data_field", "candidates": candidates}
|
|
|
elif ph_type == "group_data_field":
|
|
|
candidates = await get_group_datafield_candidates(template_expression, template_explanation, ph, llm_client)
|
|
|
populated_placeholders[ph] = {
|
|
|
"type": "group_data_field", "candidates": candidates}
|
|
|
elif ph_type in ["operator", "group_operator", "ts_operator", "vector_operator"]:
|
|
|
candidates = await get_operator_candidates(template_expression, template_explanation, ph, llm_client)
|
|
|
populated_placeholders[ph] = {
|
|
|
"type": ph_type, "candidates": candidates}
|
|
|
elif ph_type in ["integer_parameter", "float_parameter"]:
|
|
|
candidates = await get_parameter_candidates(ph_type, template_expression, template_explanation, ph, llm_client)
|
|
|
populated_placeholders[ph] = {
|
|
|
"type": ph_type, "candidates": candidates}
|
|
|
elif ph_type == "string_parameter":
|
|
|
# Add logic for string_parameter if needed, for now it returns empty
|
|
|
populated_placeholders[ph] = {
|
|
|
"type": "string_parameter", "candidates": []}
|
|
|
else:
|
|
|
logger.info(
|
|
|
f"Could not determine type for placeholder (无法确定占位符类型): {ph} (LLM classified as {ph_type})")
|
|
|
populated_placeholders[ph] = {"type": "unknown", "candidates": []}
|
|
|
|
|
|
return populated_placeholders
|
|
|
|
|
|
|
|
|
def get_datafield_prefix(datafield_name: str) -> str:
|
|
|
"""Extracts the prefix from a datafield name (e.g., 'anl44_...' -> 'anl44')."""
|
|
|
if '_' in datafield_name:
|
|
|
return datafield_name.split('_')[0]
|
|
|
return datafield_name
|
|
|
|
|
|
|
|
|
async def generate_new_alphas(alpha_description, brain_session, template_summary: Optional[str] = None, top_n_datafield: int = 50, user_region: Optional[str] = None, user_universe: Optional[str] = None, user_delay: Optional[int] = None, user_category: Optional[Union[str, list]] = None, user_data_type: str = "MATRIX", max_retries: int = 20):
|
|
|
"""
|
|
|
Main function to generate new alpha templates based on a seed alpha.
|
|
|
|
|
|
Args:
|
|
|
alpha_description: The alpha description JSON string.
|
|
|
brain_session: The BRAIN session object.
|
|
|
template_summary: Optional template summary string. If None, will load from built-in.
|
|
|
top_n_datafield: Number of data field candidates to retrieve (default: 50).
|
|
|
user_data_type: Data type for datafield search (MATRIX or VECTOR).
|
|
|
"""
|
|
|
# 声明使用全局变量
|
|
|
global LLM_model_name, LLM_API_KEY, llm_base_url
|
|
|
|
|
|
# Load template summary if not provided
|
|
|
if template_summary is None:
|
|
|
template_summary = load_template_summary()
|
|
|
# --- Load Operator Summary ---
|
|
|
operator_summary = get_brain_operators(scope_filters=["REGULAR"])
|
|
|
|
|
|
try:
|
|
|
llm_api_key = get_token_from_auth_server()
|
|
|
llm_base_url_value = llm_base_url # 使用全局变量
|
|
|
llm_client = openai.AsyncOpenAI(
|
|
|
base_url=llm_base_url_value, api_key=llm_api_key)
|
|
|
logger.info("✓ LLM Gateway 认证成功")
|
|
|
except Exception as e:
|
|
|
logger.error(f"❌ LLM Gateway 认证失败: {e}")
|
|
|
sys.exit(1)
|
|
|
|
|
|
details = json.loads(alpha_description)
|
|
|
|
|
|
if not details:
|
|
|
logger.error(f"Failed to retrieve details for Alpha (获取Alpha详情失败)")
|
|
|
sys.exit(1)
|
|
|
|
|
|
logger.info("Alpha Details Retrieved (已获取Alpha详情):")
|
|
|
logger.info(json.dumps(details, indent=4))
|
|
|
|
|
|
# --- Step 4: Propose New Alpha Templates ---
|
|
|
logger.info(f"\n{'='*60}")
|
|
|
logger.info("[Step 2/5] 正在生成 Alpha 模板提议...")
|
|
|
logger.info(f"{'='*60}")
|
|
|
proposed_templates = await propose_alpha_templates(details, template_summary, llm_client, user_data_type=user_data_type, max_retries=max_retries)
|
|
|
|
|
|
if not proposed_templates:
|
|
|
logger.error("Failed to generate proposed alpha templates. (生成提议模板失败)")
|
|
|
sys.exit(1)
|
|
|
|
|
|
logger.info(
|
|
|
"\n--- Proposed Alpha Templates (JSON) (建议的Alpha模板,多样性会受到模型和模板总结文档的影响) ---")
|
|
|
logger.info(json.dumps(proposed_templates, indent=4))
|
|
|
|
|
|
# --- Validation: Drop templates with suspicious literal identifiers ---
|
|
|
try:
|
|
|
operators_meta = get_brain_operators().get('operators', [])
|
|
|
proposed_templates = _filter_valid_templates(
|
|
|
proposed_templates,
|
|
|
operators_meta,
|
|
|
brain_session,
|
|
|
details.get('settings', {}),
|
|
|
parse_alpha_code,
|
|
|
)
|
|
|
except Exception as e:
|
|
|
logger.warning(f"⚠ 模板校验步骤出现异常,跳过校验: {e}")
|
|
|
|
|
|
if not proposed_templates:
|
|
|
logger.error("❌ 所有模板在校验后被丢弃,无法继续。")
|
|
|
sys.exit(1)
|
|
|
|
|
|
# --- Step 5: Process all proposed templates and gather candidates ---
|
|
|
# --- Step 6: Prepare for Output ---
|
|
|
logger.info(f"\n{'='*60}")
|
|
|
logger.info("[Step 3/5] 正在处理模板并收集候选数据字段...")
|
|
|
logger.info(f"{'='*60}")
|
|
|
|
|
|
# Ensure the output directory exists next to this script
|
|
|
output_dir = Path(__file__).parent / "output"
|
|
|
try:
|
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
logger.info(f"✓ 输出目录已准备: {output_dir}")
|
|
|
except Exception as e:
|
|
|
logger.warning(
|
|
|
f"Warning: could not create directory {output_dir}: {e}")
|
|
|
|
|
|
output_filepath = output_dir / f"Alpha_candidates.json"
|
|
|
|
|
|
final_output = {}
|
|
|
|
|
|
# --- Step 5: Process all proposed templates and gather candidates ---
|
|
|
total_templates = len(proposed_templates)
|
|
|
for idx, (template_expr, template_expl) in enumerate(proposed_templates.items(), 1):
|
|
|
logger.info(
|
|
|
f"\n[模板 {idx}/{total_templates}] 正在处理: '{template_expr[:60]}...'")
|
|
|
try:
|
|
|
populated_info = await populate_template(brain_session, details, template_expr, template_expl, operator_summary, llm_client, top_n_datafield=top_n_datafield, user_region=user_region, user_universe=user_universe, user_delay=user_delay, user_category=user_category, user_data_type=user_data_type)
|
|
|
|
|
|
# Skip templates where any data_field placeholder has zero candidates
|
|
|
if _should_skip_due_to_empty_candidates(populated_info):
|
|
|
logger.warning("⚠ 该模板存在数据字段候选为空的占位符,跳过此模板。")
|
|
|
continue
|
|
|
|
|
|
final_output[template_expr] = {
|
|
|
"template_explanation": template_expl,
|
|
|
"seed_alpha_settings": details.get('settings', {}),
|
|
|
"placeholder_candidates": populated_info
|
|
|
}
|
|
|
|
|
|
# --- Incremental Saving ---
|
|
|
try:
|
|
|
with output_filepath.open('w', encoding='utf-8') as f:
|
|
|
json.dump(final_output, f, indent=4)
|
|
|
logger.info(f"✓ Progress saved to {output_filepath.name}")
|
|
|
except IOError as e:
|
|
|
logger.warning(f"⚠️ Warning: Failed to save progress: {e}")
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"❌ Error processing template '{template_expr}': {e}")
|
|
|
logger.info("Skipping this template and continuing...")
|
|
|
continue
|
|
|
|
|
|
logger.info(f"\n{'='*60}")
|
|
|
logger.info("[Step 4/5] 正在生成 Alpha 表达式组合...")
|
|
|
logger.info(f"{'='*60}")
|
|
|
logger.info(f"✓ 已处理 {len(final_output)} 个有效模板")
|
|
|
|
|
|
logger.info("\n--- Final Consolidated Output (最终合并输出) ---")
|
|
|
logger.info(json.dumps(final_output, indent=4))
|
|
|
|
|
|
generated_expressions = set()
|
|
|
|
|
|
for template_expression, template_data in final_output.items():
|
|
|
placeholder_candidates = template_data["placeholder_candidates"]
|
|
|
seed_alpha_settings = template_data["seed_alpha_settings"]
|
|
|
|
|
|
# Prepare a dictionary to hold lists of candidates for each placeholder
|
|
|
candidates_for_placeholders = {}
|
|
|
for placeholder, details in placeholder_candidates.items():
|
|
|
# Extract only the 'value' or 'name' from the candidates list
|
|
|
if details["type"] == "data_field":
|
|
|
candidates_for_placeholders[placeholder] = [
|
|
|
c["id"] for c in details["candidates"]]
|
|
|
elif details["type"] in ["integer_parameter", "float_parameter"]:
|
|
|
candidates_for_placeholders[placeholder] = [
|
|
|
str(c["value"]) for c in details["candidates"]]
|
|
|
elif details["type"] == "group_data_field":
|
|
|
candidates_for_placeholders[placeholder] = [
|
|
|
c["name"] for c in details["candidates"]]
|
|
|
elif details["type"] == "operator":
|
|
|
candidates_for_placeholders[placeholder] = [
|
|
|
c["name"] for c in details["candidates"]]
|
|
|
else:
|
|
|
candidates_for_placeholders[placeholder] = []
|
|
|
|
|
|
# --- Step 3: Implement logic to generate all alpha expression combinations from the candidates ---
|
|
|
# Generate all possible combinations of placeholder values
|
|
|
placeholder_names = list(candidates_for_placeholders.keys())
|
|
|
all_combinations_values = list(itertools.product(
|
|
|
*candidates_for_placeholders.values()))
|
|
|
|
|
|
for combination_values in all_combinations_values:
|
|
|
|
|
|
# --- ATOM Mode ---
|
|
|
|
|
|
datafield_values_in_combo = []
|
|
|
placeholder_types = {ph: details["type"]
|
|
|
for ph, details in placeholder_candidates.items()}
|
|
|
|
|
|
for i, placeholder_name in enumerate(placeholder_names):
|
|
|
if placeholder_types.get(placeholder_name) == 'data_field':
|
|
|
datafield_values_in_combo.append(combination_values[i])
|
|
|
|
|
|
if len(datafield_values_in_combo) > 1:
|
|
|
first_prefix = get_datafield_prefix(
|
|
|
datafield_values_in_combo[0])
|
|
|
if not all(get_datafield_prefix(df) == first_prefix for df in datafield_values_in_combo):
|
|
|
continue # Skip this combination as prefixes do not match
|
|
|
|
|
|
current_expression = template_expression
|
|
|
for i, placeholder_name in enumerate(placeholder_names):
|
|
|
current_expression = current_expression.replace(
|
|
|
placeholder_name, combination_values[i])
|
|
|
|
|
|
# Check for duplicates before adding
|
|
|
if current_expression not in generated_expressions:
|
|
|
generated_expressions.add(current_expression)
|
|
|
# dump all unique generated expressions to a file, a list of strings in json file
|
|
|
logger.info(f"\n{'='*60}")
|
|
|
logger.info("[Step 5/5] 正在验证生成的表达式...")
|
|
|
logger.info(f"{'='*60}")
|
|
|
logger.info(f"✓ 生成的唯一 Alpha 表达式总数: {len(generated_expressions)}")
|
|
|
|
|
|
# output_filepath = output_dir / f"Alpha_generated_expressions.json"
|
|
|
# try:
|
|
|
# with output_filepath.open('w', encoding='utf-8') as f:
|
|
|
# json.dump(list(generated_expressions), f, indent=4)
|
|
|
# logger.info(f"\nGenerated expressions successfully written to {output_filepath} (生成的表达式已成功写入)")
|
|
|
# except IOError as e:
|
|
|
# logger.info(f"Error writing generated expressions to file {output_filepath} (写入生成的表达式出错): {e}")
|
|
|
|
|
|
validator = val.ExpressionValidator()
|
|
|
logger.info("请注意,该文件仅用于验证表达式的格式正确性,\n不保证表达式在实际使用中的逻辑正确性或可执行性。\n")
|
|
|
logger.info("不在内置函数列表中的operator将无法检查,如有需要,请使用AI按需修改本源代码添加")
|
|
|
|
|
|
expressions_data = list(generated_expressions)
|
|
|
# 提取表达式列表
|
|
|
# 假设JSON文件结构为 {"expressions": ["expr1", "expr2", ...]} 或直接是 ["expr1", "expr2", ...]
|
|
|
if isinstance(expressions_data, dict) and "expressions" in expressions_data:
|
|
|
expressions = expressions_data["expressions"]
|
|
|
elif isinstance(expressions_data, list):
|
|
|
expressions = expressions_data
|
|
|
else:
|
|
|
logger.error("错误: JSON文件格式不正确,需要包含表达式列表")
|
|
|
return
|
|
|
|
|
|
# 验证表达式
|
|
|
valid_expressions = []
|
|
|
invalid_expressions = []
|
|
|
|
|
|
logger.info(f"开始验证 {len(expressions)} 个表达式...")
|
|
|
for i, expr in enumerate(expressions, 1):
|
|
|
if i % 10 == 0:
|
|
|
logger.info(f"已验证 {i}/{len(expressions)} 个表达式")
|
|
|
|
|
|
result = validator.check_expression(expr)
|
|
|
if result["valid"]:
|
|
|
valid_expressions.append(expr)
|
|
|
else:
|
|
|
invalid_expressions.append(
|
|
|
{"expression": expr, "errors": result["errors"]})
|
|
|
|
|
|
# 生成输出文件路径
|
|
|
name = "Alpha_generated_expressions"
|
|
|
valid_output_path = os.path.join(output_dir, f"{name}_success.json")
|
|
|
invalid_output_path = os.path.join(output_dir, f"{name}_error.json")
|
|
|
|
|
|
# 保存结果到JSON文件
|
|
|
logger.info(f"\n验证完成!")
|
|
|
logger.info(f"有效表达式: {len(valid_expressions)}")
|
|
|
logger.info(f"无效表达式: {len(invalid_expressions)}")
|
|
|
|
|
|
# 保存有效表达式
|
|
|
try:
|
|
|
with open(valid_output_path, 'w', encoding='utf-8') as f:
|
|
|
json.dump(valid_expressions, f, ensure_ascii=False, indent=2)
|
|
|
logger.info(f"有效表达式已保存到: {valid_output_path}")
|
|
|
except Exception as e:
|
|
|
logger.error(f"错误: 保存有效表达式失败 - {e}")
|
|
|
|
|
|
# 保存无效表达式
|
|
|
try:
|
|
|
with open(invalid_output_path, 'w', encoding='utf-8') as f:
|
|
|
json.dump(invalid_expressions, f, ensure_ascii=False, indent=2)
|
|
|
logger.info(f"无效表达式已保存到: {invalid_output_path},文件包含错误详情")
|
|
|
logger.info("查看该文件,你将获得修改模板的灵感,你可以定位到错误的模板并在APP里修改")
|
|
|
except Exception as e:
|
|
|
logger.error(f"错误: 保存无效表达式失败 - {e}")
|
|
|
|
|
|
logger.info("请注意,该文件仅用于验证表达式的格式正确性,\n不保证表达式在实际使用中的逻辑正确性或可执行性。\n")
|
|
|
logger.info("不在内置函数列表中的operator将无法检查,如有需要,请使用AI按需修改validator源代码添加")
|
|
|
|
|
|
logger.info("不同模型效果不同,默认的kimi模型可能会产生Alpha语法错误,请检查生成的模板文件进行甄别")
|
|
|
logger.info("下一步,请下载已完成的模板,放入APP首页进行解析和语法检查,强烈建议生成表达式后手动尝试回测")
|
|
|
|
|
|
|
|
|
async def main():
|
|
|
"""
|
|
|
Main execution function.
|
|
|
"""
|
|
|
|
|
|
# Check for command line argument for config file
|
|
|
if len(sys.argv) > 1:
|
|
|
config_path = sys.argv[1]
|
|
|
if os.path.exists(config_path):
|
|
|
try:
|
|
|
with open(config_path, 'r', encoding='utf-8') as f:
|
|
|
config = json.load(f)
|
|
|
logger.info(f"✓ 已从命令行参数加载配置: {config_path}")
|
|
|
# Ensure all required fields are present or set defaults
|
|
|
if 'top_n_datafield' not in config:
|
|
|
config['top_n_datafield'] = 50
|
|
|
if 'template_summary_path' not in config:
|
|
|
config['template_summary_path'] = None
|
|
|
except Exception as e:
|
|
|
logger.error(f"❌ 加载配置文件失败: {e}")
|
|
|
sys.exit(1)
|
|
|
else:
|
|
|
logger.error(f"❌ 配置文件不存在: {config_path}")
|
|
|
sys.exit(1)
|
|
|
else:
|
|
|
# --- Step 0: 交互式输入收集配置信息 ---
|
|
|
logger.info("\n" + "="*60)
|
|
|
logger.info("交互式配置输入模式")
|
|
|
logger.info("="*60 + "\n")
|
|
|
|
|
|
config = interactive_input()
|
|
|
|
|
|
# 设置全局变量
|
|
|
global LLM_model_name, LLM_API_KEY, llm_base_url, username, password
|
|
|
|
|
|
logger.info("\n[Config] 正在设置全局变量...")
|
|
|
|
|
|
required_config_fields = ['LLM_model_name', 'LLM_API_KEY',
|
|
|
'llm_base_url', 'username', 'password', 'alpha_id']
|
|
|
missing_fields = [f for f in required_config_fields if f not in config]
|
|
|
|
|
|
if missing_fields:
|
|
|
logger.error(f"❌ [Config] 配置缺少必填字段: {missing_fields}")
|
|
|
logger.error(f"❌ [Config] 当前配置内容: {list(config.keys())}")
|
|
|
sys.exit(1)
|
|
|
|
|
|
LLM_model_name = config['LLM_model_name']
|
|
|
LLM_API_KEY = config['LLM_API_KEY']
|
|
|
llm_base_url = config['llm_base_url']
|
|
|
username = config['username']
|
|
|
password = config['password']
|
|
|
|
|
|
logger.info(f"✓ [Config] LLM_model_name: {LLM_model_name}")
|
|
|
logger.info(f"✓ [Config] llm_base_url: {llm_base_url}")
|
|
|
logger.info(f"✓ [Config] username: {username}")
|
|
|
logger.info(f"✓ [Config] alpha_id: {config['alpha_id']}")
|
|
|
|
|
|
# --- Step 1: 加载模板总结 ---
|
|
|
template_summary = load_template_summary(
|
|
|
config.get('template_summary_path'))
|
|
|
|
|
|
# --- Step 2: 启动 BRAIN 会话 ---
|
|
|
logger.info("--- 正在启动 BRAIN 会话... ---")
|
|
|
s = start_session()
|
|
|
|
|
|
# --- Step 3: 认证 LLM Gateway ---
|
|
|
llm_client = None
|
|
|
logger.info("--- 正在认证 LLM Gateway... ---")
|
|
|
try:
|
|
|
llm_api_key = get_token_from_auth_server()
|
|
|
llm_base_url_value = llm_base_url
|
|
|
llm_client = openai.AsyncOpenAI(
|
|
|
base_url=llm_base_url_value, api_key=llm_api_key)
|
|
|
logger.info("✓ LLM Gateway 认证成功")
|
|
|
except Exception as e:
|
|
|
logger.error(f"❌ LLM Gateway 认证失败: {e}")
|
|
|
sys.exit(1)
|
|
|
|
|
|
# --- Step 4: 获取 Alpha 详情 ---
|
|
|
alpha_id = config['alpha_id']
|
|
|
logger.info(f"\n--- 正在获取 Alpha ID: {alpha_id} 的详情... ---")
|
|
|
|
|
|
# --- Step 4.5: 交互式选择数据字段范围 ---
|
|
|
if len(sys.argv) > 1:
|
|
|
user_datafield_config = {
|
|
|
'user_region': config.get('user_region'),
|
|
|
'user_universe': config.get('user_universe'),
|
|
|
'user_delay': config.get('user_delay'),
|
|
|
'user_category': config.get('user_category'),
|
|
|
'user_data_type': config.get('user_data_type', 'MATRIX')
|
|
|
}
|
|
|
else:
|
|
|
user_datafield_config = interactive_datafield_selection(s)
|
|
|
|
|
|
details_str = await generate_alpha_description(alpha_id, brain_session=s)
|
|
|
await generate_new_alphas(
|
|
|
alpha_description=details_str,
|
|
|
brain_session=s,
|
|
|
template_summary=template_summary,
|
|
|
top_n_datafield=config.get('top_n_datafield', 50),
|
|
|
user_region=user_datafield_config.get('user_region'),
|
|
|
user_universe=user_datafield_config.get('user_universe'),
|
|
|
user_delay=user_datafield_config.get('user_delay'),
|
|
|
user_category=user_datafield_config.get('user_category'),
|
|
|
user_data_type=user_datafield_config.get('user_data_type', 'MATRIX'),
|
|
|
max_retries=config.get('max_retries', 20)
|
|
|
)
|
|
|
|
|
|
|
|
|
def interactive_datafield_selection(s: SingleSession) -> dict:
|
|
|
"""
|
|
|
Interactively ask the user for datafield search configuration (Region, Universe, Delay).
|
|
|
"""
|
|
|
logger.info("\n" + "="*60)
|
|
|
logger.info("【附加配置】数据字段搜索范围配置")
|
|
|
logger.info("正在获取有效的 Region/Universe/Delay 组合...")
|
|
|
|
|
|
try:
|
|
|
df = get_instrument_type_region_delay(s)
|
|
|
except Exception as e:
|
|
|
logger.warning(f"⚠ 获取配置选项失败: {e}")
|
|
|
logger.info("将使用 Seed Alpha 的默认设置")
|
|
|
return {}
|
|
|
|
|
|
# Filter for EQUITY only as per current logic
|
|
|
df_equity = df[df['InstrumentType'] == 'EQUITY']
|
|
|
|
|
|
if df_equity.empty:
|
|
|
logger.info("未找到 EQUITY 类型的配置选项。")
|
|
|
return {}
|
|
|
|
|
|
# 1. Select Region
|
|
|
regions = df_equity['Region'].unique().tolist()
|
|
|
logger.info(f"\n可用地区 (Region): {regions}")
|
|
|
region_input = input(f"请输入地区 (直接回车使用 Seed Alpha 默认值): ").strip()
|
|
|
|
|
|
selected_region = None
|
|
|
if region_input:
|
|
|
if region_input in regions:
|
|
|
selected_region = region_input
|
|
|
else:
|
|
|
logger.warning(f"⚠ 输入无效,将使用默认值")
|
|
|
|
|
|
# 2. Select Delay
|
|
|
# If region is selected, filter delays for that region
|
|
|
if selected_region:
|
|
|
delays = df_equity[df_equity['Region'] ==
|
|
|
selected_region]['Delay'].unique().tolist()
|
|
|
else:
|
|
|
delays = df_equity['Delay'].unique().tolist()
|
|
|
|
|
|
logger.info(f"\n可用延迟 (Delay): {delays}")
|
|
|
delay_input = input(f"请输入延迟 (直接回车使用 Seed Alpha 默认值): ").strip()
|
|
|
|
|
|
selected_delay = None
|
|
|
if delay_input:
|
|
|
try:
|
|
|
d_val = int(delay_input)
|
|
|
if d_val in delays:
|
|
|
selected_delay = d_val
|
|
|
else:
|
|
|
logger.warning(f"⚠ 输入不在列表中,将使用默认值")
|
|
|
except ValueError:
|
|
|
logger.warning(f"⚠ 输入无效,将使用默认值")
|
|
|
|
|
|
# 3. Select Universe
|
|
|
# If region and delay are selected, filter universes
|
|
|
if selected_region and selected_delay is not None:
|
|
|
subset = df_equity[(df_equity['Region'] == selected_region) & (
|
|
|
df_equity['Delay'] == selected_delay)]
|
|
|
if not subset.empty:
|
|
|
universes = subset.iloc[0]['Universe']
|
|
|
else:
|
|
|
universes = []
|
|
|
else:
|
|
|
# Just show all unique universes if we can't filter precisely
|
|
|
universes = set()
|
|
|
for u_list in df_equity['Universe']:
|
|
|
universes.update(u_list)
|
|
|
universes = list(universes)
|
|
|
|
|
|
logger.info(f"\n可用范围 (Universe): {universes}")
|
|
|
universe_input = input(f"请输入范围 (直接回车使用 Seed Alpha 默认值): ").strip()
|
|
|
|
|
|
selected_universe = None
|
|
|
if universe_input:
|
|
|
if universe_input in universes:
|
|
|
selected_universe = universe_input
|
|
|
else:
|
|
|
logger.warning(f"⚠ 输入无效,将使用默认值")
|
|
|
|
|
|
# 4. Select Category
|
|
|
logger.info("\n正在获取数据类别 (Data Categories)...")
|
|
|
categories = get_data_categories(s)
|
|
|
|
|
|
selected_category = None
|
|
|
if categories:
|
|
|
logger.info("\n可用类别 (Categories):")
|
|
|
for i, cat in enumerate(categories):
|
|
|
logger.info(f"{i+1}. {cat['name']} (ID: {cat['id']})")
|
|
|
|
|
|
cat_input = input(f"请输入类别编号或ID (多个用逗号分隔, 直接回车不筛选): ").strip()
|
|
|
|
|
|
if cat_input:
|
|
|
selected_categories = []
|
|
|
inputs = [x.strip() for x in cat_input.split(',')]
|
|
|
|
|
|
for inp in inputs:
|
|
|
# Check if input is an index
|
|
|
if inp.isdigit():
|
|
|
idx = int(inp) - 1
|
|
|
if 0 <= idx < len(categories):
|
|
|
selected_categories.append(categories[idx]['id'])
|
|
|
logger.info(f"已选择类别: {categories[idx]['name']}")
|
|
|
else:
|
|
|
# Check if input is an ID
|
|
|
found = False
|
|
|
for cat in categories:
|
|
|
if cat['id'] == inp:
|
|
|
selected_categories.append(cat['id'])
|
|
|
logger.info(f"已选择类别: {cat['name']}")
|
|
|
found = True
|
|
|
break
|
|
|
if not found:
|
|
|
logger.warning(f"⚠ 输入无效: {inp}")
|
|
|
|
|
|
if selected_categories:
|
|
|
selected_category = selected_categories
|
|
|
else:
|
|
|
logger.warning(f"⚠ 未选择有效类别,将不筛选类别")
|
|
|
else:
|
|
|
logger.warning("⚠ 无法获取类别列表,跳过类别选择")
|
|
|
|
|
|
# 5. Select Data Type
|
|
|
logger.info("\n可用数据类型 (Data Type): [MATRIX, VECTOR]")
|
|
|
data_type_input = input(f"请输入数据类型 (直接回车默认 MATRIX): ").strip().upper()
|
|
|
|
|
|
selected_data_type = "MATRIX"
|
|
|
if data_type_input == "VECTOR":
|
|
|
logger.warning(
|
|
|
"⚠ 警告: 请确保您输入的原型Alpha中正确地使用了vector operator,否则极容易造成数据类型错误")
|
|
|
confirm = input("确认使用 VECTOR 吗? (y/n): ").strip().lower()
|
|
|
if confirm == 'y':
|
|
|
selected_data_type = "VECTOR"
|
|
|
else:
|
|
|
logger.info("已取消 VECTOR 选择,使用默认值 MATRIX")
|
|
|
elif data_type_input and data_type_input != "MATRIX":
|
|
|
logger.warning(f"⚠ 输入无效,将使用默认值 MATRIX")
|
|
|
|
|
|
return {
|
|
|
'user_region': selected_region,
|
|
|
'user_universe': selected_universe,
|
|
|
'user_delay': selected_delay,
|
|
|
'user_category': selected_category,
|
|
|
'user_data_type': selected_data_type
|
|
|
}
|
|
|
|
|
|
|
|
|
async def run_transformer(config: dict) -> dict:
|
|
|
"""
|
|
|
可导入的异步函数,用于执行 Alpha 生成流程。
|
|
|
|
|
|
Args:
|
|
|
config: 配置字典,包含以下必需字段:
|
|
|
- LLM_model_name: LLM 模型名称
|
|
|
- LLM_API_KEY: LLM API 密钥
|
|
|
- llm_base_url: LLM 服务基础 URL
|
|
|
- username: BRAIN 用户名
|
|
|
- password: BRAIN 密码
|
|
|
- alpha_id: Alpha ID
|
|
|
以及可选字段:
|
|
|
- top_n_datafield: 数据字段数量限制(默认 50)
|
|
|
- template_summary_path: 模板总结文件路径
|
|
|
- user_region: 用户指定的地区
|
|
|
- user_universe: 用户指定的范围
|
|
|
- user_delay: 用户指定的延迟
|
|
|
- user_category: 用户指定的类别
|
|
|
- user_data_type: 数据类型(默认 MATRIX)
|
|
|
- max_retries: 最大重试次数(默认 20)
|
|
|
|
|
|
Returns:
|
|
|
dict: 执行结果,包含 success 状态和相关信息
|
|
|
"""
|
|
|
# 设置全局变量(保持与原有代码兼容)
|
|
|
global LLM_model_name, LLM_API_KEY, llm_base_url, username, password
|
|
|
|
|
|
logger.info("\n[Config] 正在设置全局变量...")
|
|
|
|
|
|
required_config_fields = ['LLM_model_name', 'LLM_API_KEY',
|
|
|
'llm_base_url', 'username', 'password', 'alpha_id']
|
|
|
missing_fields = [f for f in required_config_fields if f not in config]
|
|
|
|
|
|
if missing_fields:
|
|
|
logger.error(f"❌ [Config] 配置缺少必填字段: {missing_fields}")
|
|
|
logger.error(f"❌ [Config] 当前配置内容: {list(config.keys())}")
|
|
|
raise ValueError(f"Missing required config fields: {missing_fields}")
|
|
|
|
|
|
LLM_model_name = config['LLM_model_name']
|
|
|
LLM_API_KEY = config['LLM_API_KEY']
|
|
|
llm_base_url = config['llm_base_url']
|
|
|
username = config['username']
|
|
|
password = config['password']
|
|
|
|
|
|
logger.info(f"✓ [Config] LLM_model_name: {LLM_model_name}")
|
|
|
logger.info(f"✓ [Config] llm_base_url: {llm_base_url}")
|
|
|
logger.info(f"✓ [Config] username: {username}")
|
|
|
logger.info(f"✓ [Config] alpha_id: {config['alpha_id']}")
|
|
|
|
|
|
# --- Step 1: 加载模板总结 ---
|
|
|
template_summary = load_template_summary(
|
|
|
config.get('template_summary_path'))
|
|
|
|
|
|
# --- Step 2: 启动 BRAIN 会话 ---
|
|
|
logger.info("--- 正在启动 BRAIN 会话... ---")
|
|
|
s = start_session()
|
|
|
|
|
|
# --- Step 3: 认证 LLM Gateway ---
|
|
|
llm_client = None
|
|
|
logger.info("--- 正在认证 LLM Gateway... ---")
|
|
|
try:
|
|
|
llm_api_key = get_token_from_auth_server()
|
|
|
llm_base_url_value = llm_base_url
|
|
|
llm_client = openai.AsyncOpenAI(
|
|
|
base_url=llm_base_url_value, api_key=llm_api_key)
|
|
|
logger.info("✓ LLM Gateway 认证成功")
|
|
|
except Exception as e:
|
|
|
logger.error(f"❌ LLM Gateway 认证失败: {e}")
|
|
|
raise
|
|
|
|
|
|
# --- Step 4: 获取 Alpha 详情 ---
|
|
|
alpha_id = config['alpha_id']
|
|
|
logger.info(f"\n--- 正在获取 Alpha ID: {alpha_id} 的详情... ---")
|
|
|
|
|
|
# --- Step 4.5: 使用配置中的数据字段范围 ---
|
|
|
user_datafield_config = {
|
|
|
'user_region': config.get('user_region'),
|
|
|
'user_universe': config.get('user_universe'),
|
|
|
'user_delay': config.get('user_delay'),
|
|
|
'user_category': config.get('user_category'),
|
|
|
'user_data_type': config.get('user_data_type', 'MATRIX')
|
|
|
}
|
|
|
|
|
|
details_str = await generate_alpha_description(alpha_id, brain_session=s)
|
|
|
await generate_new_alphas(
|
|
|
alpha_description=details_str,
|
|
|
brain_session=s,
|
|
|
template_summary=template_summary,
|
|
|
top_n_datafield=config.get('top_n_datafield', 50),
|
|
|
user_region=user_datafield_config.get('user_region'),
|
|
|
user_universe=user_datafield_config.get('user_universe'),
|
|
|
user_delay=user_datafield_config.get('user_delay'),
|
|
|
user_category=user_datafield_config.get('user_category'),
|
|
|
user_data_type=user_datafield_config.get('user_data_type', 'MATRIX'),
|
|
|
max_retries=config.get('max_retries', 20)
|
|
|
)
|
|
|
|
|
|
return {"success": True, "message": "Alpha generation completed"}
|
|
|
|
|
|
|
|
|
def run_transformer_sync(config: dict) -> dict:
|
|
|
"""
|
|
|
同步版本的 run_transformer,方便直接调用。
|
|
|
|
|
|
Args:
|
|
|
config: 配置字典,详见 run_transformer 文档
|
|
|
|
|
|
Returns:
|
|
|
dict: 执行结果
|
|
|
"""
|
|
|
return asyncio.run(run_transformer(config))
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
# To allow asyncio to run in environments like Jupyter notebooks
|
|
|
if sys.platform.startswith('win') and sys.version_info[:2] >= (3, 8):
|
|
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
|
|
|
|
|
asyncio.run(main())
|
|
|
|