# -*- coding: utf-8 -*- import os import random import sys import threading import openai import httpx import csv from datetime import datetime import jieba import time import sqlite3 sys.path.append(os.path.join(os.path.abspath(__file__).split('AlphaGenerator')[0] + 'AlphaGenerator')) PROJECT_PATH = os.path.join(os.path.abspath(__file__).split('AlphaGenerator')[0] + 'AlphaGenerator') PREPARE_PROMPT = os.path.join(PROJECT_PATH, 'prepare_prompt') KEYS_TEXT = os.path.join(PREPARE_PROMPT, 'keys_text.txt') # 是否使用AI生成 USE_AI = 1 # 模型温度 TEMPERATURE = 0.2 # 随机补充数据集(数据集不够的情况下打开) RANDOM_DATA_SETS_COUNT = 0 # 数据集筛选之后大于此数值, 则随机抽取x条数据 MAX_DATA_COUNT = 800 RANDOM_DATA_COUNT = 500 # 数据库搜索字段 REGION = 'GLB' UNIVERSE = 'TOP3000' # 读取操作符 LOAD_OPERATOR = 1 # 数据类别名 CATEGORY_NAME_LIST = [ 'Analyst', 'Fundamental', 'Price Volume' ] # 基础提示词文件 alpha_prompt_path = os.path.join(PREPARE_PROMPT, "alpha_prompt.txt") SILICONFLOW_API_KEY = "sk-pvdiisdowmuwkrpnxsrlhxaovicqibmlljwrwwvbbdjaitdl" SILICONFLOW_BASE_URL = "https://api.siliconflow.cn/v1" MODELS = [ 'Pro/deepseek-ai/DeepSeek-V3.1-Terminus', # 'deepseek-ai/DeepSeek-V3.2-Exp', 'Qwen/Qwen3-VL-235B-A22B-Instruct', 'Pro/moonshotai/Kimi-K2-Thinking', # 'MiniMaxAI/MiniMax-M2', # 'zai-org/GLM-4.6', # 'inclusionAI/Ring-flash-2.0', # 'inclusionAI/Ling-flash-2.0', ] def process_text(text): filter_list = ['\n', '\t', '\r', '\b', '\f', '\v', ':', '的', '或', '10', '天', '了', '可', '是', '该', ',', ' ', '、', '让', '和', '集', '/', '日', '在', '(', '_', '-', ')', '(', '上', '距', '与', '比', '下', '及', ')', '...', ';', '%', '&', '+', ',', '.', ':', ';', '<', '=', '>', '?', '[', ']', '|', '—', '。' ] text_list = jieba.lcut(text) results = [] for tl in text_list: should_include = True for fl in filter_list: if fl == tl: should_include = False break if should_include: results.append(tl.lower()) results = [item for item in results if item != '"' and len(item) > 2] if results: return list(set(results)) else: return None def keysTextLoader(): if not os.path.exists(KEYS_TEXT): print(f"文件不存在: {KEYS_TEXT}") exit(1) with open(KEYS_TEXT, 'r', encoding='utf-8') as f: text_list = [line.strip() for line in f if line.strip()] if not text_list: print('关键词文本无数据, 程序退出') exit(1) result_str = process_text(';'.join(text_list)) print(f'\n关键词文本处理结果: {result_str}\n') return result_str def txtFileLoader(file_path): if not os.path.exists(file_path): print(f"文件不存在: {file_path}") exit(1) with open(file_path, 'r', encoding='utf-8') as f: return [line.strip() for line in f if line.strip()] def csvFileLoader(file_path, keys_text): if not os.path.exists(file_path): print(f"文件不存在: {file_path}") exit(1) data_dict = {} # 使用字典来存储,以id为键 with open(file_path, 'r', encoding='utf-8') as f: reader = csv.reader(f) for row in reader: for key in keys_text: if key in row[11] or key in row[12]: item_id = row[0] # 如果id不存在,或者想要保留第一个出现的记录 if item_id not in data_dict: data_dict[item_id] = { 'id': int(row[0]), 'data_set_name': f"可以使用:{row[1]}", 'description': f"不可使用,仅供参考:{row[2]}" } # 将字典的值转换为列表 return list(data_dict.values()) def sqliteLoader(file_path, keys_text): if not os.path.exists(file_path): print(f"SQLite数据库文件不存在: {file_path}") exit(1) data_dict = {} # 使用字典来存储,以id为键 try: conn = sqlite3.connect(file_path) cursor = conn.cursor() # 首先筛选符合 region 和 universe 条件的数据 cursor.execute("SELECT id, name, description, region, universe, category_name FROM data_sets WHERE region=? AND universe=?", (REGION, UNIVERSE)) rows = cursor.fetchall() for row in rows: row_id, name, description, region, universe, category_name = row # 检查关键词是否在 name 中 for key in keys_text: if CATEGORY_NAME_LIST: if key in name and category_name in CATEGORY_NAME_LIST: item_id = str(row_id) if item_id not in data_dict: data_dict[item_id] = { 'id': int(row_id), 'data_set_name': f"可以使用:{name}", 'description': f"不可使用,仅供参考:{description}" } else: if key in name: item_id = str(row_id) if item_id not in data_dict: data_dict[item_id] = { 'id': int(row_id), 'data_set_name': f"可以使用:{name}", 'description': f"不可使用,仅供参考:{description}" } conn.close() # 将字典的值转换为列表 return list(data_dict.values()) except sqlite3.Error as e: print(f"SQLite数据库错误: {e}") exit(1) def extend_data_sets(file_path, original_data_sets): result = original_data_sets.copy() if not os.path.exists(file_path): print(f"文件不存在: {file_path}") return result all_data_sets = [] with open(file_path, 'r', encoding='utf-8') as f: reader = csv.reader(f) for row in reader: all_data_sets.append({ 'id': int(row[0]), 'data_set_name': f"可以使用:{row[1]}", 'description': f"不可使用,仅供参考:{row[2]}", }) if RANDOM_DATA_SETS_COUNT and RANDOM_DATA_SETS_COUNT > 0: total_count = len(original_data_sets) + RANDOM_DATA_SETS_COUNT # 找出所有可用的数据集(不在原始数据中的) available_datasets = [] # 先获取原始数据集的所有ID original_ids = [] for item in original_data_sets: original_ids.append(item['id']) # 找出不在原始数据中的数据集 for dataset in all_data_sets: exists = False for oid in original_ids: if dataset['id'] == oid: exists = True break if not exists: available_datasets.append(dataset) # 计算最多能添加多少个 max_can_add = len(available_datasets) need_to_add = RANDOM_DATA_SETS_COUNT if need_to_add > max_can_add: print(f"警告:要求添加{need_to_add}个,但只有{max_can_add}个可用") need_to_add = max_can_add # 随机选择需要数量的数据集 random_selected = random.sample(available_datasets, need_to_add) # 添加到结果 result.extend(random_selected) final_result = [] for item in result: final_result.append({ 'data_set_name': item['data_set_name'], 'description': item['description'] }) return final_result def promptLoader(alpha_prompt_path): if not os.path.exists(alpha_prompt_path): print("alpha_prompt.txt文件不存在") exit(1) with open(alpha_prompt_path, 'r', encoding='utf-8') as f: prompt = f.read().strip() if not prompt: print("alpha_prompt.txt是空的") exit(1) return prompt.replace('\n\n', '\n') def operatorLoader(operator_prompt_path): if not os.path.exists(operator_prompt_path): print("operator.txt文件不存在") exit(1) with open(operator_prompt_path, 'r', encoding='utf-8') as f: operator_lines = [line.strip() for line in f.readlines() if line.strip()] if not operator_lines: print("operator.txt是空的") exit(1) return "\n".join(operator_lines) def mistakesNoteBookLoader(mistakes_notebook_path): if not os.path.exists(mistakes_notebook_path): print("mistakes_notebook.txt文件不存在") return '' with open(mistakes_notebook_path, 'r', encoding='utf-8') as f: mistakes_notebook_lines = [line.strip() for line in f.readlines() if line.strip()] if not mistakes_notebook_lines: print("mistakes_notebook.txt是空的") exit(1) return "\n".join(mistakes_notebook_lines) def create_result_folder(): folder_name = "generated_alpha" if not os.path.exists(folder_name): os.makedirs(folder_name) now = datetime.now() year_folder = os.path.join(folder_name, str(now.year)) month_folder = os.path.join(year_folder, f"{now.month:02d}") day_folder = os.path.join(month_folder, f"{now.day:02d}") if not os.path.exists(year_folder): os.makedirs(year_folder) if not os.path.exists(month_folder): os.makedirs(month_folder) if not os.path.exists(day_folder): os.makedirs(day_folder) return day_folder def call_siliconflow(prompt, model): try: client = openai.OpenAI( api_key=SILICONFLOW_API_KEY, base_url=SILICONFLOW_BASE_URL ) response = client.chat.completions.create( model=model, messages=[ {"role": "system", "content": "你是一个专业的量化投资专家,擅长编写Alpha因子。"}, {"role": "user", "content": prompt} ], temperature=TEMPERATURE ) return response.choices[0].message.content except openai.AuthenticationError: print("API密钥错误") except openai.RateLimitError: print("调用频率限制") except openai.APIError as e: print(f"API错误: {e}") except Exception as e: print(f"其他错误: {e}") exit(1) def save_result(result, folder, model_name): now = datetime.now() time_filename = now.strftime("%H%M%S") filename = f"{model_name}_{time_filename}.txt" filepath = os.path.join(folder, filename) with open(filepath, 'w', encoding='utf-8') as f: f.write(result) print(f"结果保存到: {filepath}") def get_user_info(): headers = {"Authorization": f"Bearer {SILICONFLOW_API_KEY}"} url = "https://api.siliconflow.cn/v1/user/info" response = httpx.get(url, headers=headers) data = response.json()['data'] balance = data['totalBalance'] print(f"余额: {balance}") return float(balance) def manual_prompt(prompt): manual_prompt_path = os.path.join(PROJECT_PATH, "manual_prompt") if not os.path.exists(manual_prompt_path): os.makedirs(manual_prompt_path) now = datetime.now() year_folder = os.path.join(manual_prompt_path, str(now.year)) month_folder = os.path.join(year_folder, f"{now.month:02d}") day_folder = os.path.join(month_folder, f"{now.day:02d}") if not os.path.exists(year_folder): os.makedirs(year_folder) if not os.path.exists(month_folder): os.makedirs(month_folder) if not os.path.exists(day_folder): os.makedirs(day_folder) # 文件名后添加保存时间 filename = f"manual_prompt_{now.strftime('%Y%m%d%H%M%S')}.txt" filepath = os.path.join(day_folder, filename) with open(filepath, 'w', encoding='utf-8') as f: f.write(prompt) print(f"手动提示词保存到: {filepath}") def call_ai(prompt, model): balance = get_user_info() folder = create_result_folder() print(f"正在调用AI...{model}") result = call_siliconflow(prompt, model) if result: print(f"AI回复: {result[:200]}...") model_name = model.replace("/", "_") save_result(result, folder, model_name) used_balance = balance - get_user_info() print(f'本次调用 api 使用额度 {used_balance}') else: print("AI调用失败") def prepare_prompt(data_sets): prompt = '' # 读取基础提示词 prompt += promptLoader(alpha_prompt_path) # 读取操作符 prompt += "\n\n以下是我的账号有权限使用的操作符, 请严格按照操作符, 进行生成,组合因子\n\n" prompt += "========================= 操作符开始 =======================================\n" prompt += "注意: Operator: 后面的是操作符(是可以使用的),\nDescription: 此字段后面的是操作符对应的描述或使用说明(禁止使用, 仅供参考), Description字段后面的内容是使用说明, 不是操作符\n" prompt += "特别注意!!!! 必须按照操作符字段Operator的使用说明生成 alpha" if LOAD_OPERATOR: operator_prompt_path = os.path.join(PREPARE_PROMPT, "operator.txt") operator = operatorLoader(operator_prompt_path) prompt += operator prompt += "\n========================= 操作符结束 =======================================\n\n" prompt += "========================= 数据字段开始 =======================================\n" prompt += "注意: data_set_name: 后面的是数据字段(可以使用), description: 此字段后面的是数据字段对应的描述或使用说明(不能使用)\n\n" for data_set in data_sets: prompt += str(data_set) + '\n' prompt += "========================= 数据字段结束 =======================================\n\n" prompt += "以上数据字段和操作符, 按照Description说明组合, 但是每一个 alpha 组合的使用的数据字段和操作符不要过于集中, 在符合语法的情况下, 多尝试不同的组合\n\n输出只要语法正确的WebSim表达式, 不需要任何解释\n" mistakes_note_book_path = os.path.join(PREPARE_PROMPT, "mistakes_notebook.txt") mistakesNoteBook = mistakesNoteBookLoader(mistakes_note_book_path) if mistakesNoteBook: prompt += mistakesNoteBook return prompt def main(): # 将金融逻辑, 分割成标签 keys_text = keysTextLoader() # # 分割好的标签, 搜索对应的数据集, 返回匹配到的结果 # data_sets_path = os.path.join(PREPARE_PROMPT, "all_data_combined.csv") # result_data_sets = csvFileLoader(data_sets_path, keys_text) # if not result_data_sets: # print(f'搜索数据集为空, 程序退出') # exit(1) data_sets_path = os.path.join(PREPARE_PROMPT, "data_sets.db") result_data_sets = sqliteLoader(data_sets_path, keys_text) if not result_data_sets: print(f'搜索数据集为空, 程序退出') exit(1) # 扩展数据集 mistakes_notebook_path = os.path.join(PREPARE_PROMPT, "all_data_combined.csv") if RANDOM_DATA_SETS_COUNT: print('=' * 100) print(f'扩展前 {len(result_data_sets)} 条数据') result_data_sets = extend_data_sets(mistakes_notebook_path, result_data_sets) print(f'扩展后 {len(result_data_sets)} 条数据') print('=' * 100) print(f'从数据集中提取了 {len(result_data_sets)} 条数据') if len(result_data_sets) > MAX_DATA_COUNT: print(f'筛选数据集数量大于 {MAX_DATA_COUNT}, 随机选择其中的 {RANDOM_DATA_COUNT} 条') data_sets = random.sample(result_data_sets, RANDOM_DATA_COUNT) else: data_sets = result_data_sets # 组合提示词 prompt = prepare_prompt(data_sets) # # 如果需要手动在页面段模型, 使用提示词, 打开这个, 将生成的提示词存到本地 save_thread = threading.Thread(target=manual_prompt, args=(prompt,)) save_thread.start() if USE_AI: for model in MODELS: call_ai(prompt, model) time.sleep(5) if __name__ == "__main__": main()