You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
404 lines
13 KiB
404 lines
13 KiB
# -*- coding: utf-8 -*-
|
|
import os
|
|
import random
|
|
import sys
|
|
import openai
|
|
import httpx
|
|
import csv
|
|
from datetime import datetime
|
|
import jieba
|
|
import time
|
|
|
|
sys.path.append(os.path.join(os.path.abspath(__file__).split('AlphaGenerator')[0] + 'AlphaGenerator'))
|
|
PROJECT_PATH = os.path.join(os.path.abspath(__file__).split('AlphaGenerator')[0] + 'AlphaGenerator')
|
|
|
|
PREPARE_PROMPT = os.path.join(PROJECT_PATH, 'prepare_prompt')
|
|
KEYS_TEXT = os.path.join(PREPARE_PROMPT, 'keys_text.txt')
|
|
|
|
USE_AI = 1
|
|
|
|
TEMPERATURE = 0.1
|
|
|
|
RANDOM_DATA_SETS_COUNT = 100
|
|
|
|
SILICONFLOW_API_KEY = "sk-pvdiisdowmuwkrpnxsrlhxaovicqibmlljwrwwvbbdjaitdl"
|
|
SILICONFLOW_BASE_URL = "https://api.siliconflow.cn/v1"
|
|
MODELS = [
|
|
'Pro/deepseek-ai/DeepSeek-V3.1-Terminus',
|
|
'deepseek-ai/DeepSeek-V3.2-Exp',
|
|
'Qwen/Qwen3-VL-235B-A22B-Instruct',
|
|
# 'MiniMaxAI/MiniMax-M2',
|
|
# 'zai-org/GLM-4.6',
|
|
# 'inclusionAI/Ring-flash-2.0',
|
|
# 'zai-org/GLM-4.6',
|
|
# 'inclusionAI/Ling-flash-2.0',
|
|
# 'inclusionAI/Ring-flash-2.0',
|
|
]
|
|
|
|
|
|
def process_text(text):
|
|
filter_list = ['\n', '\t', '\r', '\b', '\f', '\v', ':', '的', '或', '10', '天', '了', '可', '是', '该', ',', ' ', '、', '让', '和', '集',
|
|
'/', '日', '在', '(', '_', '-', ')', '(', '上', '距', '与', '比', '下', '及', ')', '...', ';', '%', '&', '+', ',', '.',
|
|
':', ';', '<', '=', '>', '?', '[', ']', '|', '—', '。'
|
|
]
|
|
|
|
text_list = jieba.lcut(text)
|
|
results = []
|
|
for tl in text_list:
|
|
should_include = True
|
|
for fl in filter_list:
|
|
if fl == tl:
|
|
should_include = False
|
|
break
|
|
if should_include:
|
|
results.append(tl.lower())
|
|
|
|
results = [item for item in results if item != '"' and len(item) > 2]
|
|
|
|
if results:
|
|
return list(set(results))
|
|
else:
|
|
return None
|
|
|
|
|
|
def keysTextLoader():
|
|
if not os.path.exists(KEYS_TEXT):
|
|
print(f"文件不存在: {KEYS_TEXT}")
|
|
exit(1)
|
|
with open(KEYS_TEXT, 'r', encoding='utf-8') as f:
|
|
text_list = [line.strip() for line in f if line.strip()]
|
|
if not text_list:
|
|
print('关键词文本无数据, 程序退出')
|
|
exit(1)
|
|
|
|
result_str = process_text(';'.join(text_list))
|
|
|
|
print(f'\n关键词文本处理结果: {result_str}\n')
|
|
|
|
return result_str
|
|
|
|
|
|
def txtFileLoader(file_path):
|
|
if not os.path.exists(file_path):
|
|
print(f"文件不存在: {file_path}")
|
|
exit(1)
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
return [line.strip() for line in f if line.strip()]
|
|
|
|
|
|
def csvFileLoader(file_path, keys_text):
|
|
if not os.path.exists(file_path):
|
|
print(f"文件不存在: {file_path}")
|
|
exit(1)
|
|
|
|
data_dict = {} # 使用字典来存储,以id为键
|
|
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
reader = csv.reader(f)
|
|
for row in reader:
|
|
for key in keys_text:
|
|
if key in row[11] or key in row[12]:
|
|
item_id = row[0]
|
|
# 如果id不存在,或者想要保留第一个出现的记录
|
|
if item_id not in data_dict:
|
|
data_dict[item_id] = {
|
|
'id': int(row[0]),
|
|
'data_set_name': f"可以使用:{row[1]}",
|
|
'description': f"不可使用,仅供参考:{row[2]}"
|
|
}
|
|
|
|
# 将字典的值转换为列表
|
|
return list(data_dict.values())
|
|
|
|
|
|
def extend_data_sets(file_path, original_data_sets):
|
|
result = original_data_sets.copy()
|
|
|
|
if not os.path.exists(file_path):
|
|
print(f"文件不存在: {file_path}")
|
|
return result
|
|
|
|
all_data_sets = []
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
reader = csv.reader(f)
|
|
for row in reader:
|
|
all_data_sets.append({
|
|
'id': int(row[0]),
|
|
'data_set_name': f"可以使用:{row[1]}",
|
|
'description': f"不可使用,仅供参考:{row[2]}",
|
|
})
|
|
|
|
if RANDOM_DATA_SETS_COUNT and RANDOM_DATA_SETS_COUNT > 0:
|
|
total_count = len(original_data_sets) + RANDOM_DATA_SETS_COUNT
|
|
|
|
# 找出所有可用的数据集(不在原始数据中的)
|
|
available_datasets = []
|
|
|
|
# 先获取原始数据集的所有ID
|
|
original_ids = []
|
|
for item in original_data_sets:
|
|
original_ids.append(item['id'])
|
|
|
|
# 找出不在原始数据中的数据集
|
|
for dataset in all_data_sets:
|
|
exists = False
|
|
for oid in original_ids:
|
|
if dataset['id'] == oid:
|
|
exists = True
|
|
break
|
|
if not exists:
|
|
available_datasets.append(dataset)
|
|
|
|
# 计算最多能添加多少个
|
|
max_can_add = len(available_datasets)
|
|
need_to_add = RANDOM_DATA_SETS_COUNT
|
|
|
|
if need_to_add > max_can_add:
|
|
print(f"警告:要求添加{need_to_add}个,但只有{max_can_add}个可用")
|
|
need_to_add = max_can_add
|
|
|
|
# 随机选择需要数量的数据集
|
|
random_selected = random.sample(available_datasets, need_to_add)
|
|
|
|
# 添加到结果
|
|
result.extend(random_selected)
|
|
|
|
final_result = []
|
|
for item in result:
|
|
final_result.append({
|
|
'data_set_name': item['data_set_name'],
|
|
'description': item['description']
|
|
})
|
|
|
|
return final_result
|
|
|
|
|
|
def promptLoader(alpha_prompt_path):
|
|
if not os.path.exists(alpha_prompt_path):
|
|
print("alpha_prompt.txt文件不存在")
|
|
exit(1)
|
|
with open(alpha_prompt_path, 'r', encoding='utf-8') as f:
|
|
prompt = f.read().strip()
|
|
if not prompt:
|
|
print("alpha_prompt.txt是空的")
|
|
exit(1)
|
|
return prompt.replace('\n\n', '\n')
|
|
|
|
|
|
def operatorLoader(operator_prompt_path):
|
|
if not os.path.exists(operator_prompt_path):
|
|
print("operator.txt文件不存在")
|
|
exit(1)
|
|
with open(operator_prompt_path, 'r', encoding='utf-8') as f:
|
|
operator_lines = [line.strip() for line in f.readlines() if line.strip()]
|
|
if not operator_lines:
|
|
print("operator.txt是空的")
|
|
exit(1)
|
|
return "\n".join(operator_lines)
|
|
|
|
def mistakesNoteBookLoader(mistakes_notebook_path):
|
|
if not os.path.exists(mistakes_notebook_path):
|
|
print("mistakes_notebook.txt文件不存在")
|
|
return ''
|
|
with open(mistakes_notebook_path, 'r', encoding='utf-8') as f:
|
|
mistakes_notebook_lines = [line.strip() for line in f.readlines() if line.strip()]
|
|
if not mistakes_notebook_lines:
|
|
print("mistakes_notebook.txt是空的")
|
|
exit(1)
|
|
return "\n".join(mistakes_notebook_lines)
|
|
|
|
def create_result_folder():
|
|
folder_name = "generated_alpha"
|
|
if not os.path.exists(folder_name):
|
|
os.makedirs(folder_name)
|
|
|
|
now = datetime.now()
|
|
year_folder = os.path.join(folder_name, str(now.year))
|
|
month_folder = os.path.join(year_folder, f"{now.month:02d}")
|
|
day_folder = os.path.join(month_folder, f"{now.day:02d}")
|
|
|
|
if not os.path.exists(year_folder):
|
|
os.makedirs(year_folder)
|
|
if not os.path.exists(month_folder):
|
|
os.makedirs(month_folder)
|
|
if not os.path.exists(day_folder):
|
|
os.makedirs(day_folder)
|
|
|
|
return day_folder
|
|
|
|
|
|
def call_siliconflow(prompt, model):
|
|
try:
|
|
client = openai.OpenAI(
|
|
api_key=SILICONFLOW_API_KEY,
|
|
base_url=SILICONFLOW_BASE_URL
|
|
)
|
|
|
|
response = client.chat.completions.create(
|
|
model=model,
|
|
messages=[
|
|
{"role": "system", "content": "你是一个专业的量化投资专家,擅长编写Alpha因子。"},
|
|
{"role": "user", "content": prompt}
|
|
],
|
|
temperature=TEMPERATURE
|
|
)
|
|
|
|
return response.choices[0].message.content
|
|
|
|
except openai.AuthenticationError:
|
|
print("API密钥错误")
|
|
except openai.RateLimitError:
|
|
print("调用频率限制")
|
|
except openai.APIError as e:
|
|
print(f"API错误: {e}")
|
|
except Exception as e:
|
|
print(f"其他错误: {e}")
|
|
exit(1)
|
|
|
|
|
|
def save_result(result, folder, model_name):
|
|
now = datetime.now()
|
|
time_filename = now.strftime("%H%M%S")
|
|
|
|
filename = f"{model_name}_{time_filename}.txt"
|
|
filepath = os.path.join(folder, filename)
|
|
|
|
with open(filepath, 'w', encoding='utf-8') as f:
|
|
f.write(result)
|
|
|
|
print(f"结果保存到: {filepath}")
|
|
|
|
|
|
def get_user_info():
|
|
headers = {"Authorization": f"Bearer {SILICONFLOW_API_KEY}"}
|
|
url = "https://api.siliconflow.cn/v1/user/info"
|
|
response = httpx.get(url, headers=headers)
|
|
data = response.json()['data']
|
|
balance = data['totalBalance']
|
|
print(f"余额: {balance}")
|
|
return float(balance)
|
|
|
|
|
|
def manual_prompt(prompt):
|
|
manual_prompt_path = os.path.join(PROJECT_PATH, "manual_prompt")
|
|
|
|
if not os.path.exists(manual_prompt_path):
|
|
os.makedirs(manual_prompt_path)
|
|
|
|
now = datetime.now()
|
|
year_folder = os.path.join(manual_prompt_path, str(now.year))
|
|
month_folder = os.path.join(year_folder, f"{now.month:02d}")
|
|
day_folder = os.path.join(month_folder, f"{now.day:02d}")
|
|
|
|
if not os.path.exists(year_folder):
|
|
os.makedirs(year_folder)
|
|
if not os.path.exists(month_folder):
|
|
os.makedirs(month_folder)
|
|
if not os.path.exists(day_folder):
|
|
os.makedirs(day_folder)
|
|
|
|
# 文件名后添加保存时间
|
|
filename = f"manual_prompt_{now.strftime('%Y%m%d%H%M%S')}.txt"
|
|
filepath = os.path.join(day_folder, filename)
|
|
|
|
with open(filepath, 'w', encoding='utf-8') as f:
|
|
f.write(prompt)
|
|
|
|
print(f"手动提示词保存到: {filepath}")
|
|
|
|
|
|
def call_ai(prompt, model):
|
|
balance = get_user_info()
|
|
|
|
folder = create_result_folder()
|
|
|
|
print(f"正在调用AI...{model}")
|
|
result = call_siliconflow(prompt, model)
|
|
|
|
if result:
|
|
print(f"AI回复: {result[:200]}...")
|
|
model_name = model.replace("/", "_")
|
|
save_result(result, folder, model_name)
|
|
used_balance = balance - get_user_info()
|
|
print(f'本次调用 api 使用额度 {used_balance}')
|
|
else:
|
|
print("AI调用失败")
|
|
|
|
|
|
def prepare_prompt(data_sets):
|
|
prompt = ''
|
|
|
|
# 读取基础提示词
|
|
alpha_prompt_path = os.path.join(PREPARE_PROMPT, "alpha_prompt.txt")
|
|
prompt += promptLoader(alpha_prompt_path)
|
|
|
|
# 读取操作符
|
|
prompt += "\n\n以下是我的账号有权限使用的操作符, 请严格按照操作符, 进行生成,组合因子\n\n"
|
|
prompt += "========================= 操作符开始 =======================================\n"
|
|
prompt += "注意: Operator: 后面的是操作符(是可以使用的),\nDescription: 此字段后面的是操作符对应的描述或使用说明(禁止使用, 仅供参考), Description字段后面的内容是使用说明, 不是操作符\n"
|
|
prompt += "特别注意!!!! 必须按照操作符字段Operator的使用说明生成 alpha"
|
|
operator_prompt_path = os.path.join(PREPARE_PROMPT, "operator.txt")
|
|
operator = operatorLoader(operator_prompt_path)
|
|
prompt += operator
|
|
prompt += "\n========================= 操作符结束 =======================================\n\n"
|
|
|
|
prompt += "========================= 数据字段开始 =======================================\n"
|
|
prompt += "注意: data_set_name: 后面的是数据字段(可以使用), description: 此字段后面的是数据字段对应的描述或使用说明(不能使用), description_cn字段后面的内容是中文使用说明(不能使用)\n\n"
|
|
for data_set in data_sets:
|
|
prompt += str(data_set) + '\n'
|
|
|
|
prompt += "========================= 数据字段结束 =======================================\n\n"
|
|
|
|
prompt += "以上数据字段和操作符, 按照Description说明组合, 但是每一个 alpha 组合的使用的数据字段和操作符不要过于集中, 在符合语法的情况下, 多尝试不同的组合\n\n"
|
|
|
|
mistakes_note_book_path = os.path.join(PREPARE_PROMPT, "mistakes_notebook.txt")
|
|
mistakesNoteBook = mistakesNoteBookLoader(mistakes_note_book_path)
|
|
if mistakesNoteBook:
|
|
prompt += mistakesNoteBook
|
|
|
|
return prompt
|
|
|
|
|
|
def main():
|
|
# 将金融逻辑, 分割成标签
|
|
keys_text = keysTextLoader()
|
|
|
|
# 分割好的标签, 搜索对应的数据集, 返回匹配到的结果
|
|
data_sets_path = os.path.join(PREPARE_PROMPT, "all_data_combined.csv")
|
|
result_data_sets = csvFileLoader(data_sets_path, keys_text)
|
|
|
|
if not result_data_sets:
|
|
print(f'搜索数据集为空, 程序退出')
|
|
exit(1)
|
|
|
|
# 扩展数据集
|
|
mistakes_notebook_path = os.path.join(PREPARE_PROMPT, "all_data_combined.csv")
|
|
if RANDOM_DATA_SETS_COUNT:
|
|
print('=' * 100)
|
|
print(f'扩展前 {len(result_data_sets)} 条数据')
|
|
result_data_sets = extend_data_sets(mistakes_notebook_path, result_data_sets)
|
|
print(f'扩展后 {len(result_data_sets)} 条数据')
|
|
print('=' * 100)
|
|
|
|
print(f'从数据集中提取了 {len(result_data_sets)} 条数据')
|
|
if len(result_data_sets) > 500:
|
|
data_sets = random.sample(result_data_sets, 10)
|
|
else:
|
|
data_sets = result_data_sets
|
|
|
|
# 组合提示词
|
|
prompt = prepare_prompt(data_sets)
|
|
|
|
# # 如果需要手动在页面段模型, 使用提示词, 打开这个, 将生成的提示词存到本地
|
|
manual_prompt(prompt)
|
|
|
|
if USE_AI:
|
|
for model in MODELS:
|
|
# 如果需要使用模型, 打开这个
|
|
call_ai(prompt, model)
|
|
|
|
time.sleep(5)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|