You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
AlphaGenerator/alpha_template.py

418 lines
14 KiB

import os
import random
import sys
import re
import pandas as pd
from sqlalchemy import create_engine
from itertools import product
# ==================== 全局配置 ====================
# 添加项目路径
sys.path.append(os.path.join(os.path.abspath(__file__).split('AlphaGenerator')[0] + 'AlphaGenerator'))
PROJECT_PATH = os.path.join(os.path.abspath(__file__).split('AlphaGenerator')[0] + 'AlphaGenerator')
PREPARE_PROMPT = os.path.join(str(PROJECT_PATH), 'prepare_prompt')
SQLITE_PATH = os.path.join(str(PREPARE_PROMPT), 'data_sets.db')
# Alpha模板 - 支持 <keyword1|keyword2|keyword3> 语法
ALPHA_TEMPLATE = "ts_covariance(<Profit|Income|Earning>, <size>, <window>)"
# 窗口列表
WINDOW_LIST = [5, 20, 60, 250]
# 运行模式:1=name搜索, 2=description搜索, 3=混合模式
MODE = 2
# 数据库配置
REGION = 'USA'
UNIVERSE = 'TOP3000'
MAX_OUTPUT_COUNT = 1000
ACTUAL_GENERATION_COUNT = 500
# ==================== 初始化 ====================
engine = create_engine(f"sqlite:///{SQLITE_PATH}")
OUTPUT_FILE = "template_output.txt"
def extract_keywords_from_template(template):
"""
从模板中提取所有<keyword|keyword|...>
支持单个关键词和多关键词语法
返回结构: [{'param_index': i, 'keywords': [kw1, kw2, ...], 'original': '<mdl264|model110>'}, ...]
"""
# 匹配 <> 中的内容,支持竖线分隔的多个关键词
pattern = r'<([^<>]+)>'
all_matches = re.findall(pattern, template)
keyword_params = []
for i, match in enumerate(all_matches):
# 分割竖线分隔的关键词
if '|' in match:
keywords = [kw.strip() for kw in match.split('|') if kw.strip()]
else:
keywords = [match.strip()]
# 检查是否包含'window'(特殊处理)
if 'window' in keywords:
has_window = True
# 从普通关键词中移除window
keywords = [kw for kw in keywords if kw != 'window']
else:
has_window = False
param_info = {
'param_index': i,
'keywords': keywords,
'original': f'<{match}>',
'has_window': has_window
}
keyword_params.append(param_info)
print(f"从模板解析到 {len(keyword_params)} 个参数:")
for param in keyword_params:
if param['has_window']:
print(f" 参数{param['param_index'] + 1}: {param['original']} (包含window)")
else:
print(f" 参数{param['param_index'] + 1}: {param['original']}")
print(f" 关键词列表: {param['keywords']}")
return keyword_params
def search_keywords_in_database(keyword_params, search_field):
"""
搜索所有参数的所有关键词
search_field: 'name''description'
返回结构: {参数索引: [匹配的指标名列表]}
如果任何参数的所有关键词都没有匹配结果,返回None
"""
if not keyword_params:
return {}
results = {}
all_missing_params = []
for param in keyword_params:
param_idx = param['param_index']
param_keywords = param['keywords']
if not param_keywords:
# 没有普通关键词,只有window
results[param_idx] = []
continue
print(f"\n 搜索参数{param_idx + 1}的关键词:")
print(f" 搜索字段: {search_field}")
param_all_results = []
param_missing_keywords = []
for keyword in param_keywords:
# 构建查询
if search_field == 'name':
query = """
SELECT DISTINCT name
FROM data_sets
WHERE region = ?
AND universe = ?
AND LOWER(name) LIKE LOWER(?) \
"""
else: # description
query = """
SELECT DISTINCT name
FROM data_sets
WHERE region = ?
AND universe = ?
AND LOWER(description) LIKE LOWER(?) \
"""
# 使用 % 进行模糊匹配
like_pattern = f'%{keyword}%'
print(f" 关键词: '{keyword}',模式: '{like_pattern}'")
try:
df = pd.read_sql_query(
query,
engine,
params=(REGION, UNIVERSE, like_pattern)
)
names = df['name'].tolist()
if names:
param_all_results.extend(names) # 合并所有关键词的结果
print(f" ✓ 找到 {len(names)} 个结果")
else:
param_missing_keywords.append(keyword)
print(f" ✗ 没有匹配结果")
except Exception as e:
print(f" ✗ 搜索出错: {e}")
param_missing_keywords.append(keyword)
# 去重
param_all_results = list(set(param_all_results))
if param_all_results:
results[param_idx] = param_all_results
print(f" 参数{param_idx + 1}总计找到 {len(param_all_results)} 个唯一结果")
if param_all_results:
print(f" 示例: {param_all_results[:3]}{'...' if len(param_all_results) > 3 else ''}")
else:
if param_missing_keywords:
print(f" ✗ 参数{param_idx + 1}的所有关键词都没有匹配结果:")
for kw in param_missing_keywords:
print(f" - {kw}")
all_missing_params.append(param_idx)
# 如果有任何参数的所有关键词都没有匹配结果,返回None
if all_missing_params:
print(f"\n错误: 以下参数在{search_field}字段中没有匹配到任何结果:")
for param_idx in all_missing_params:
param_keywords = keyword_params[param_idx]['keywords']
print(f" 参数{param_idx + 1}: <{'|'.join(param_keywords)}>")
return None
return results
def generate_alpha_combinations(template, keyword_params, search_results):
"""
生成Alpha表达式组合
keyword_params: 参数信息列表
search_results: {参数索引: [匹配的指标名列表]}
"""
if not search_results and not any(p['has_window'] for p in keyword_params):
return [template] # 没有需要替换的内容
# 准备所有参数的替换项
all_param_replacements = []
param_keys = []
# 按参数顺序处理
for param in sorted(keyword_params, key=lambda x: x['param_index']):
param_idx = param['param_index']
if param_idx in search_results and search_results[param_idx]:
# 有普通关键词的搜索结果
all_param_replacements.append(search_results[param_idx])
param_keys.append(param['original'])
elif param['has_window']:
# 只有window,没有普通关键词
all_param_replacements.append([str(w) for w in WINDOW_LIST])
param_keys.append(param['original'])
else:
# 没有搜索结果也没有window(不应该发生)
print(f"警告: 参数{param_idx + 1}没有可替换的内容")
all_param_replacements.append([param['original']]) # 保留原样
param_keys.append(param['original'])
# 生成所有组合
print(f"\n生成组合...")
print(f"替换键 ({len(param_keys)}个参数): {param_keys}")
# 计算总组合数
total_combinations = 1
for replacement_list in all_param_replacements:
total_combinations *= len(replacement_list)
print(f"预计生成 {total_combinations} 个组合")
# 使用笛卡尔积生成所有组合
combinations = []
count = 0
for values in product(*all_param_replacements):
result = template
for key, value in zip(param_keys, values):
result = result.replace(key, str(value))
combinations.append(result)
count += 1
# 进度显示
if total_combinations > 1000 and count % 1000 == 0:
print(f" 已生成 {count}/{total_combinations} 个组合...")
# 去重
unique_combinations = list(set(combinations))
print(f"实际生成 {len(combinations)} 个组合,去重后 {len(unique_combinations)}")
return unique_combinations
def run_mode(mode, template):
"""
运行指定模式
如果任何参数的所有关键词都没有匹配结果,返回None
"""
print(f"\n{'=' * 60}")
print(f"运行模式 {mode}")
print(f"模板: {template}")
print(f"{'=' * 60}")
# 解析模板
keyword_params = extract_keywords_from_template(template)
if not keyword_params:
print("错误: 模板中没有找到任何参数")
return None
all_results = []
# 模式1: name搜索
if mode in [1, 3]:
print(f"\n[模式1: name字段搜索]")
print("-" * 40)
# 搜索name字段
search_results = search_keywords_in_database(keyword_params, 'name')
if search_results is None:
print(f"模式1失败: 有参数在name字段中没有匹配结果")
if mode == 1: # 如果是模式1,直接退出
return None
# 如果是模式3,继续尝试模式2
else:
# 生成组合
combinations = generate_alpha_combinations(template, keyword_params, search_results)
all_results.extend(combinations)
print(f"模式1生成 {len(combinations)} 个表达式")
# 模式2: description搜索
if mode in [2, 3]:
print(f"\n[模式2: description字段搜索]")
print("-" * 40)
# 搜索description字段
search_results = search_keywords_in_database(keyword_params, 'description')
if search_results is None:
print(f"模式2失败: 有参数在description字段中没有匹配结果")
if mode == 2: # 如果是模式2,直接退出
return None
# 如果是模式3,继续(可能模式1已经成功)
else:
# 生成组合
combinations = generate_alpha_combinations(template, keyword_params, search_results)
all_results.extend(combinations)
print(f"模式2生成 {len(combinations)} 个表达式")
# 检查是否有结果
if not all_results:
print(f"\n错误: 没有生成任何表达式")
return None
# 去重
unique_results = list(set(all_results))
print(f"\n总计生成 {len(all_results)} 个表达式,去重后 {len(unique_results)}")
return unique_results
def save_results(results, filename):
"""
保存结果到文件
如果结果数量超过MAX_OUTPUT_COUNT,随机抽取ACTUAL_GENERATION_COUNT个结果
"""
if not results:
print("没有结果需要保存")
return False
total_count = len(results)
# 处理结果数量限制
if total_count > MAX_OUTPUT_COUNT:
print(f"\n警告: 生成结果 {total_count} 条,超过阈值 {MAX_OUTPUT_COUNT}")
print(f"将随机抽取 {ACTUAL_GENERATION_COUNT} 条结果保存")
# 随机抽取(不使用固定种子,每次运行结果不同)
selected_results = random.sample(results, ACTUAL_GENERATION_COUNT)
# 统计信息
print(f"随机抽取 {len(selected_results)} 条结果")
save_results_list = selected_results
else:
save_results_list = results
try:
with open(filename, 'w', encoding='utf-8') as f:
for alpha in save_results_list:
f.write(alpha + '\n')
print(f"\n已保存 {len(save_results_list)} 个Alpha表达式到: {filename}")
# 显示前10个结果
print(f"\n前10个结果:")
print("-" * 60)
for i, alpha in enumerate(save_results_list[:10], 1):
print(f"{i:3d}. {alpha}")
if len(save_results_list) > 10:
print(f"... 还有 {len(save_results_list) - 10} 个表达式")
# 如果进行了随机抽取,显示统计信息
if total_count > MAX_OUTPUT_COUNT:
print(f"\n统计信息:")
print(f" - 原始生成: {total_count}")
print(f" - 随机抽取: {ACTUAL_GENERATION_COUNT}")
print(f" - 抽取比例: {ACTUAL_GENERATION_COUNT / total_count * 100:.1f}%")
return True
except Exception as e:
print(f"保存文件时出错: {e}")
return False
def main():
"""
主函数
"""
print("=" * 60)
print("Alpha表达式生成器 v2.0 - 支持多关键词参数")
print("=" * 60)
print(f"数据库: {SQLITE_PATH}")
print(f"区域: {REGION}, 股票池: {UNIVERSE}")
print(f"模板: {ALPHA_TEMPLATE}")
print(f"窗口列表: {WINDOW_LIST}")
print(f"运行模式: {MODE}")
print(f"输出文件: {OUTPUT_FILE}")
print("语法说明: 使用 <keyword1|keyword2|keyword3> 支持多关键词")
print("=" * 60)
# 检查数据库连接
try:
# 测试连接
test_df = pd.read_sql_query("SELECT 1", engine)
print("数据库连接成功")
except Exception as e:
print(f"数据库连接失败: {e}")
print(f"请检查数据库路径: {SQLITE_PATH}")
return
# 运行指定模式
results = run_mode(MODE, ALPHA_TEMPLATE)
# 保存结果
if results:
save_results(results, OUTPUT_FILE)
else:
print("\n程序终止: 有参数没有匹配到数据,请检查关键词或模板")
sys.exit(1)
if __name__ == '__main__':
try:
main()
except Exception as e:
print(f"程序运行出错: {e}")
import traceback
traceback.print_exc()
sys.exit(1)