You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
419 lines
14 KiB
419 lines
14 KiB
import os
|
|
import random
|
|
import sys
|
|
import re
|
|
import pandas as pd
|
|
from sqlalchemy import create_engine
|
|
from itertools import product
|
|
|
|
# ==================== 全局配置 ====================
|
|
# 添加项目路径
|
|
sys.path.append(os.path.join(os.path.abspath(__file__).split('AlphaGenerator')[0] + 'AlphaGenerator'))
|
|
PROJECT_PATH = os.path.join(os.path.abspath(__file__).split('AlphaGenerator')[0] + 'AlphaGenerator')
|
|
PREPARE_PROMPT = os.path.join(str(PROJECT_PATH), 'prepare_prompt')
|
|
SQLITE_PATH = os.path.join(str(PREPARE_PROMPT), 'data_sets.db')
|
|
|
|
# Alpha模板 - 支持 <keyword1|keyword2|keyword3> 语法
|
|
# ALPHA_TEMPLATE = "is_nan(<Gross|Profit|Margin>, <window>)"
|
|
ALPHA_TEMPLATE = "not(is_nan(<Gross|Profit|Margin>))"
|
|
|
|
# 窗口列表
|
|
WINDOW_LIST = [5, 20, 60, 250]
|
|
|
|
# 运行模式:1=name搜索, 2=description搜索, 3=混合模式
|
|
MODE = 3
|
|
|
|
# 数据库配置
|
|
REGION = 'GLB'
|
|
UNIVERSE = 'TOP3000'
|
|
|
|
MAX_OUTPUT_COUNT = 2000
|
|
ACTUAL_GENERATION_COUNT = 1000
|
|
|
|
# ==================== 初始化 ====================
|
|
engine = create_engine(f"sqlite:///{SQLITE_PATH}")
|
|
OUTPUT_FILE = "template_output.txt"
|
|
|
|
|
|
def extract_keywords_from_template(template):
|
|
"""
|
|
从模板中提取所有<keyword|keyword|...>
|
|
支持单个关键词和多关键词语法
|
|
返回结构: [{'param_index': i, 'keywords': [kw1, kw2, ...], 'original': '<mdl264|model110>'}, ...]
|
|
"""
|
|
# 匹配 <> 中的内容,支持竖线分隔的多个关键词
|
|
pattern = r'<([^<>]+)>'
|
|
all_matches = re.findall(pattern, template)
|
|
|
|
keyword_params = []
|
|
|
|
for i, match in enumerate(all_matches):
|
|
# 分割竖线分隔的关键词
|
|
if '|' in match:
|
|
keywords = [kw.strip() for kw in match.split('|') if kw.strip()]
|
|
else:
|
|
keywords = [match.strip()]
|
|
|
|
# 检查是否包含'window'(特殊处理)
|
|
if 'window' in keywords:
|
|
has_window = True
|
|
# 从普通关键词中移除window
|
|
keywords = [kw for kw in keywords if kw != 'window']
|
|
else:
|
|
has_window = False
|
|
|
|
param_info = {
|
|
'param_index': i,
|
|
'keywords': keywords,
|
|
'original': f'<{match}>',
|
|
'has_window': has_window
|
|
}
|
|
|
|
keyword_params.append(param_info)
|
|
|
|
print(f"从模板解析到 {len(keyword_params)} 个参数:")
|
|
for param in keyword_params:
|
|
if param['has_window']:
|
|
print(f" 参数{param['param_index'] + 1}: {param['original']} (包含window)")
|
|
else:
|
|
print(f" 参数{param['param_index'] + 1}: {param['original']}")
|
|
print(f" 关键词列表: {param['keywords']}")
|
|
|
|
return keyword_params
|
|
|
|
|
|
def search_keywords_in_database(keyword_params, search_field):
|
|
"""
|
|
搜索所有参数的所有关键词
|
|
search_field: 'name' 或 'description'
|
|
返回结构: {参数索引: [匹配的指标名列表]}
|
|
如果任何参数的所有关键词都没有匹配结果,返回None
|
|
"""
|
|
if not keyword_params:
|
|
return {}
|
|
|
|
results = {}
|
|
all_missing_params = []
|
|
|
|
for param in keyword_params:
|
|
param_idx = param['param_index']
|
|
param_keywords = param['keywords']
|
|
|
|
if not param_keywords:
|
|
# 没有普通关键词,只有window
|
|
results[param_idx] = []
|
|
continue
|
|
|
|
print(f"\n 搜索参数{param_idx + 1}的关键词:")
|
|
print(f" 搜索字段: {search_field}")
|
|
|
|
param_all_results = []
|
|
param_missing_keywords = []
|
|
|
|
for keyword in param_keywords:
|
|
# 构建查询
|
|
if search_field == 'name':
|
|
query = """
|
|
SELECT DISTINCT name
|
|
FROM data_sets
|
|
WHERE region = ?
|
|
AND universe = ?
|
|
AND LOWER(name) LIKE LOWER(?) \
|
|
"""
|
|
else: # description
|
|
query = """
|
|
SELECT DISTINCT name
|
|
FROM data_sets
|
|
WHERE region = ?
|
|
AND universe = ?
|
|
AND LOWER(description) LIKE LOWER(?) \
|
|
"""
|
|
|
|
# 使用 % 进行模糊匹配
|
|
like_pattern = f'%{keyword}%'
|
|
|
|
print(f" 关键词: '{keyword}',模式: '{like_pattern}'")
|
|
|
|
try:
|
|
df = pd.read_sql_query(
|
|
query,
|
|
engine,
|
|
params=(REGION, UNIVERSE, like_pattern)
|
|
)
|
|
|
|
names = df['name'].tolist()
|
|
|
|
if names:
|
|
param_all_results.extend(names) # 合并所有关键词的结果
|
|
print(f" ✓ 找到 {len(names)} 个结果")
|
|
else:
|
|
param_missing_keywords.append(keyword)
|
|
print(f" ✗ 没有匹配结果")
|
|
|
|
except Exception as e:
|
|
print(f" ✗ 搜索出错: {e}")
|
|
param_missing_keywords.append(keyword)
|
|
|
|
# 去重
|
|
param_all_results = list(set(param_all_results))
|
|
|
|
if param_all_results:
|
|
results[param_idx] = param_all_results
|
|
print(f" 参数{param_idx + 1}总计找到 {len(param_all_results)} 个唯一结果")
|
|
if param_all_results:
|
|
print(f" 示例: {param_all_results[:3]}{'...' if len(param_all_results) > 3 else ''}")
|
|
else:
|
|
if param_missing_keywords:
|
|
print(f" ✗ 参数{param_idx + 1}的所有关键词都没有匹配结果:")
|
|
for kw in param_missing_keywords:
|
|
print(f" - {kw}")
|
|
all_missing_params.append(param_idx)
|
|
|
|
# 如果有任何参数的所有关键词都没有匹配结果,返回None
|
|
if all_missing_params:
|
|
print(f"\n错误: 以下参数在{search_field}字段中没有匹配到任何结果:")
|
|
for param_idx in all_missing_params:
|
|
param_keywords = keyword_params[param_idx]['keywords']
|
|
print(f" 参数{param_idx + 1}: <{'|'.join(param_keywords)}>")
|
|
return None
|
|
|
|
return results
|
|
|
|
|
|
def generate_alpha_combinations(template, keyword_params, search_results):
|
|
"""
|
|
生成Alpha表达式组合
|
|
keyword_params: 参数信息列表
|
|
search_results: {参数索引: [匹配的指标名列表]}
|
|
"""
|
|
if not search_results and not any(p['has_window'] for p in keyword_params):
|
|
return [template] # 没有需要替换的内容
|
|
|
|
# 准备所有参数的替换项
|
|
all_param_replacements = []
|
|
param_keys = []
|
|
|
|
# 按参数顺序处理
|
|
for param in sorted(keyword_params, key=lambda x: x['param_index']):
|
|
param_idx = param['param_index']
|
|
|
|
if param_idx in search_results and search_results[param_idx]:
|
|
# 有普通关键词的搜索结果
|
|
all_param_replacements.append(search_results[param_idx])
|
|
param_keys.append(param['original'])
|
|
elif param['has_window']:
|
|
# 只有window,没有普通关键词
|
|
all_param_replacements.append([str(w) for w in WINDOW_LIST])
|
|
param_keys.append(param['original'])
|
|
else:
|
|
# 没有搜索结果也没有window(不应该发生)
|
|
print(f"警告: 参数{param_idx + 1}没有可替换的内容")
|
|
all_param_replacements.append([param['original']]) # 保留原样
|
|
param_keys.append(param['original'])
|
|
|
|
# 生成所有组合
|
|
print(f"\n生成组合...")
|
|
print(f"替换键 ({len(param_keys)}个参数): {param_keys}")
|
|
|
|
# 计算总组合数
|
|
total_combinations = 1
|
|
for replacement_list in all_param_replacements:
|
|
total_combinations *= len(replacement_list)
|
|
|
|
print(f"预计生成 {total_combinations} 个组合")
|
|
|
|
# 使用笛卡尔积生成所有组合
|
|
combinations = []
|
|
count = 0
|
|
|
|
for values in product(*all_param_replacements):
|
|
result = template
|
|
for key, value in zip(param_keys, values):
|
|
result = result.replace(key, str(value))
|
|
combinations.append(result)
|
|
count += 1
|
|
|
|
# 进度显示
|
|
if total_combinations > 1000 and count % 1000 == 0:
|
|
print(f" 已生成 {count}/{total_combinations} 个组合...")
|
|
|
|
# 去重
|
|
unique_combinations = list(set(combinations))
|
|
|
|
print(f"实际生成 {len(combinations)} 个组合,去重后 {len(unique_combinations)} 个")
|
|
|
|
return unique_combinations
|
|
|
|
|
|
def run_mode(mode, template):
|
|
"""
|
|
运行指定模式
|
|
如果任何参数的所有关键词都没有匹配结果,返回None
|
|
"""
|
|
print(f"\n{'=' * 60}")
|
|
print(f"运行模式 {mode}")
|
|
print(f"模板: {template}")
|
|
print(f"{'=' * 60}")
|
|
|
|
# 解析模板
|
|
keyword_params = extract_keywords_from_template(template)
|
|
|
|
if not keyword_params:
|
|
print("错误: 模板中没有找到任何参数")
|
|
return None
|
|
|
|
all_results = []
|
|
|
|
# 模式1: name搜索
|
|
if mode in [1, 3]:
|
|
print(f"\n[模式1: name字段搜索]")
|
|
print("-" * 40)
|
|
|
|
# 搜索name字段
|
|
search_results = search_keywords_in_database(keyword_params, 'name')
|
|
|
|
if search_results is None:
|
|
print(f"模式1失败: 有参数在name字段中没有匹配结果")
|
|
if mode == 1: # 如果是模式1,直接退出
|
|
return None
|
|
# 如果是模式3,继续尝试模式2
|
|
else:
|
|
# 生成组合
|
|
combinations = generate_alpha_combinations(template, keyword_params, search_results)
|
|
all_results.extend(combinations)
|
|
|
|
print(f"模式1生成 {len(combinations)} 个表达式")
|
|
|
|
# 模式2: description搜索
|
|
if mode in [2, 3]:
|
|
print(f"\n[模式2: description字段搜索]")
|
|
print("-" * 40)
|
|
|
|
# 搜索description字段
|
|
search_results = search_keywords_in_database(keyword_params, 'description')
|
|
|
|
if search_results is None:
|
|
print(f"模式2失败: 有参数在description字段中没有匹配结果")
|
|
if mode == 2: # 如果是模式2,直接退出
|
|
return None
|
|
# 如果是模式3,继续(可能模式1已经成功)
|
|
else:
|
|
# 生成组合
|
|
combinations = generate_alpha_combinations(template, keyword_params, search_results)
|
|
all_results.extend(combinations)
|
|
|
|
print(f"模式2生成 {len(combinations)} 个表达式")
|
|
|
|
# 检查是否有结果
|
|
if not all_results:
|
|
print(f"\n错误: 没有生成任何表达式")
|
|
return None
|
|
|
|
# 去重
|
|
unique_results = list(set(all_results))
|
|
print(f"\n总计生成 {len(all_results)} 个表达式,去重后 {len(unique_results)} 个")
|
|
|
|
return unique_results
|
|
|
|
|
|
def save_results(results, filename):
|
|
"""
|
|
保存结果到文件
|
|
如果结果数量超过MAX_OUTPUT_COUNT,随机抽取ACTUAL_GENERATION_COUNT个结果
|
|
"""
|
|
if not results:
|
|
print("没有结果需要保存")
|
|
return False
|
|
|
|
total_count = len(results)
|
|
|
|
# 处理结果数量限制
|
|
if total_count > MAX_OUTPUT_COUNT:
|
|
print(f"\n警告: 生成结果 {total_count} 条,超过阈值 {MAX_OUTPUT_COUNT} 条")
|
|
print(f"将随机抽取 {ACTUAL_GENERATION_COUNT} 条结果保存")
|
|
|
|
# 随机抽取(不使用固定种子,每次运行结果不同)
|
|
selected_results = random.sample(results, ACTUAL_GENERATION_COUNT)
|
|
|
|
# 统计信息
|
|
print(f"随机抽取 {len(selected_results)} 条结果")
|
|
save_results_list = selected_results
|
|
else:
|
|
save_results_list = results
|
|
|
|
try:
|
|
with open(filename, 'w', encoding='utf-8') as f:
|
|
for alpha in save_results_list:
|
|
f.write(alpha + '\n')
|
|
|
|
print(f"\n已保存 {len(save_results_list)} 个Alpha表达式到: {filename}")
|
|
|
|
# 显示前10个结果
|
|
print(f"\n前10个结果:")
|
|
print("-" * 60)
|
|
for i, alpha in enumerate(save_results_list[:10], 1):
|
|
print(f"{i:3d}. {alpha}")
|
|
|
|
if len(save_results_list) > 10:
|
|
print(f"... 还有 {len(save_results_list) - 10} 个表达式")
|
|
|
|
# 如果进行了随机抽取,显示统计信息
|
|
if total_count > MAX_OUTPUT_COUNT:
|
|
print(f"\n统计信息:")
|
|
print(f" - 原始生成: {total_count} 条")
|
|
print(f" - 随机抽取: {ACTUAL_GENERATION_COUNT} 条")
|
|
print(f" - 抽取比例: {ACTUAL_GENERATION_COUNT / total_count * 100:.1f}%")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"保存文件时出错: {e}")
|
|
return False
|
|
|
|
|
|
def main():
|
|
"""
|
|
主函数
|
|
"""
|
|
print("=" * 60)
|
|
print("Alpha表达式生成器 v2.0 - 支持多关键词参数")
|
|
print("=" * 60)
|
|
print(f"数据库: {SQLITE_PATH}")
|
|
print(f"区域: {REGION}, 股票池: {UNIVERSE}")
|
|
print(f"模板: {ALPHA_TEMPLATE}")
|
|
print(f"窗口列表: {WINDOW_LIST}")
|
|
print(f"运行模式: {MODE}")
|
|
print(f"输出文件: {OUTPUT_FILE}")
|
|
print("语法说明: 使用 <keyword1|keyword2|keyword3> 支持多关键词")
|
|
print("=" * 60)
|
|
|
|
# 检查数据库连接
|
|
try:
|
|
# 测试连接
|
|
test_df = pd.read_sql_query("SELECT 1", engine)
|
|
print("数据库连接成功")
|
|
except Exception as e:
|
|
print(f"数据库连接失败: {e}")
|
|
print(f"请检查数据库路径: {SQLITE_PATH}")
|
|
return
|
|
|
|
# 运行指定模式
|
|
results = run_mode(MODE, ALPHA_TEMPLATE)
|
|
|
|
# 保存结果
|
|
if results:
|
|
save_results(results, OUTPUT_FILE)
|
|
else:
|
|
print("\n程序终止: 有参数没有匹配到数据,请检查关键词或模板")
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
try:
|
|
main()
|
|
except Exception as e:
|
|
print(f"程序运行出错: {e}")
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
sys.exit(1) |