You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
234 lines
7.5 KiB
234 lines
7.5 KiB
# -*- coding: utf-8 -*-
|
|
import os
|
|
import random
|
|
import sys
|
|
import json
|
|
import openai
|
|
import httpx
|
|
import csv
|
|
from datetime import datetime
|
|
|
|
sys.path.append(os.path.join(os.path.abspath(__file__).split('AlphaGenerator')[0] + 'AlphaGenerator'))
|
|
PROJECT_PATH = os.path.join(os.path.abspath(__file__).split('AlphaGenerator')[0] + 'AlphaGenerator')
|
|
|
|
PREPARE_PROMPT = os.path.join(PROJECT_PATH, 'prepare_prompt')
|
|
|
|
SELECT_DATA_SET_QTY = 30
|
|
|
|
SILICONFLOW_API_KEY = "sk-pvdiisdowmuwkrpnxsrlhxaovicqibmlljwrwwvbbdjaitdl"
|
|
SILICONFLOW_BASE_URL = "https://api.siliconflow.cn/v1"
|
|
MODELS = [
|
|
# 'deepseek-ai/DeepSeek-V3.2-Exp',
|
|
# 'MiniMaxAI/MiniMax-M2',
|
|
# 'zai-org/GLM-4.6',
|
|
# 'Qwen/Qwen3-VL-235B-A22B-Instruct',
|
|
# 'inclusionAI/Ring-flash-2.0',
|
|
# 'zai-org/GLM-4.6',
|
|
# 'inclusionAI/Ling-flash-2.0',
|
|
'inclusionAI/Ring-flash-2.0',
|
|
# 'zai-org/GLM-4.6V'
|
|
]
|
|
|
|
|
|
def txtFileLoader(file_path):
|
|
if not os.path.exists(file_path):
|
|
print(f"文件不存在: {file_path}")
|
|
exit(1)
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
return [line.strip() for line in f if line.strip()]
|
|
|
|
|
|
def csvFileLoader(file_path):
|
|
if not os.path.exists(file_path):
|
|
print(f"文件不存在: {file_path}")
|
|
exit(1)
|
|
data = []
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
reader = csv.reader(f)
|
|
for row in reader:
|
|
data.append(row)
|
|
return data
|
|
|
|
|
|
def read_prompt(alpha_prompt_path):
|
|
if not os.path.exists(alpha_prompt_path):
|
|
print("alpha_prompt.txt文件不存在")
|
|
exit(1)
|
|
with open(alpha_prompt_path, 'r', encoding='utf-8') as f:
|
|
prompt = f.read().strip()
|
|
if not prompt:
|
|
print("alpha_prompt.txt是空的")
|
|
exit(1)
|
|
return prompt.replace('\n\n', '\n')
|
|
|
|
|
|
def read_operator(operator_prompt_path):
|
|
if not os.path.exists(operator_prompt_path):
|
|
print("wqb_operator.txt文件不存在")
|
|
exit(1)
|
|
with open(operator_prompt_path, 'r', encoding='utf-8') as f:
|
|
operator_lines = [line.strip() for line in f.readlines() if line.strip()]
|
|
if not operator_lines:
|
|
print("wqb_operator.txt是空的")
|
|
exit(1)
|
|
return "\n".join(operator_lines)
|
|
|
|
|
|
def create_result_folder():
|
|
folder_name = "generated_alpha"
|
|
if not os.path.exists(folder_name):
|
|
os.makedirs(folder_name)
|
|
return folder_name
|
|
|
|
|
|
def call_siliconflow(prompt, model):
|
|
try:
|
|
client = openai.OpenAI(
|
|
api_key=SILICONFLOW_API_KEY,
|
|
base_url=SILICONFLOW_BASE_URL
|
|
)
|
|
|
|
response = client.chat.completions.create(
|
|
model=model,
|
|
messages=[{"role": "user", "content": prompt}]
|
|
)
|
|
|
|
return response.choices[0].message.content
|
|
|
|
except openai.AuthenticationError:
|
|
print("API密钥错误")
|
|
except openai.RateLimitError:
|
|
print("调用频率限制")
|
|
except openai.APIError as e:
|
|
print(f"API错误: {e}")
|
|
except Exception as e:
|
|
print(f"其他错误: {e}")
|
|
exit(1)
|
|
|
|
|
|
def save_result(result, folder):
|
|
now = datetime.now()
|
|
date_folder = now.strftime("%Y-%m-%d")
|
|
time_filename = now.strftime("%H%M%S")
|
|
full_folder_path = os.path.join(folder, date_folder)
|
|
|
|
if not os.path.exists(full_folder_path):
|
|
os.makedirs(full_folder_path)
|
|
print(f"创建文件夹: {full_folder_path}")
|
|
|
|
filename = f"{time_filename}.txt"
|
|
filepath = os.path.join(full_folder_path, filename)
|
|
|
|
with open(filepath, 'w', encoding='utf-8') as f:
|
|
f.write(result)
|
|
|
|
print(f"结果保存到: {filepath}")
|
|
|
|
|
|
def get_user_info():
|
|
headers = {"Authorization": f"Bearer {SILICONFLOW_API_KEY}"}
|
|
url = "https://api.siliconflow.cn/v1/user/info"
|
|
response = httpx.get(url, headers=headers)
|
|
data = response.json()['data']
|
|
balance = data['balance']
|
|
print(f"余额: {balance}")
|
|
return float(balance)
|
|
|
|
|
|
def manual_prompt(prompt):
|
|
manual_prompt_path = os.path.join(PROJECT_PATH, "manual_prompt")
|
|
|
|
if not os.path.exists(manual_prompt_path):
|
|
os.makedirs(manual_prompt_path)
|
|
print(f"创建文件夹: {manual_prompt_path}")
|
|
|
|
# 文件名后添加保存时间
|
|
now = datetime.now()
|
|
filename = f"manual_prompt_{now.strftime('%Y%m%d%H%M%S')}.txt"
|
|
filepath = os.path.join(manual_prompt_path, filename)
|
|
|
|
with open(filepath, 'w', encoding='utf-8') as f:
|
|
f.write(prompt)
|
|
|
|
print(f"手动提示词保存到: {filepath}")
|
|
|
|
|
|
def call_ai(prompt):
|
|
balance = get_user_info()
|
|
|
|
folder = create_result_folder()
|
|
|
|
for model in MODELS:
|
|
print(f"正在调用AI...{model}")
|
|
result = call_siliconflow(prompt, model)
|
|
|
|
if result:
|
|
print(f"AI回复: {result[:200]}...")
|
|
save_result(result, folder)
|
|
used_balance = balance - get_user_info()
|
|
print(f'本次调用 api 使用额度 {used_balance}')
|
|
else:
|
|
print("AI调用失败")
|
|
|
|
|
|
def prepare_prompt():
|
|
prompt = ''
|
|
|
|
# 读取基础提示词
|
|
alpha_prompt_path = os.path.join(PREPARE_PROMPT, "alpha_prompt.txt")
|
|
prompt += read_prompt(alpha_prompt_path)
|
|
|
|
# 读取操作符
|
|
prompt += "\n\n以下是我的账号有权限使用的操作符, 请严格按照操作符, 进行生成,组合因子\n\n"
|
|
prompt += "========================= 操作符开始 ======================================="
|
|
prompt += "注意: Operator: 后面的是操作符,\nDescription: 此字段后面的是操作符对应的描述或使用说明, Description字段后面的内容是使用说明, 不是操作符\n"
|
|
prompt += "特别注意!!!! 必须按照操作符字段Operator的使用说明生成 alpha"
|
|
operator_prompt_path = os.path.join(PREPARE_PROMPT, "operator.txt")
|
|
operator = read_operator(operator_prompt_path)
|
|
prompt += operator
|
|
prompt += "========================= 操作符结束 =======================================\n\n"
|
|
|
|
# 读取数据字段, 数据字段数量庞大, 通过 dataset_id 分组读取, 然后每组里面随机选择 {SELECT_DATA_SET_QTY} 个
|
|
data_sets_path = os.path.join(PREPARE_PROMPT, "all_data_combined.csv")
|
|
data_sets = csvFileLoader(data_sets_path)
|
|
|
|
data_groups = {}
|
|
for index, data_set in enumerate(data_sets):
|
|
if index == 0:
|
|
continue
|
|
|
|
if data_set[2] not in data_groups:
|
|
data_groups[data_set[2]] = []
|
|
data_groups[data_set[2]].append({data_set[0]: data_set[1]})
|
|
|
|
selected_data_sets = []
|
|
for key, value in data_groups.items():
|
|
if len(value) < SELECT_DATA_SET_QTY:
|
|
selected_data_sets.extend(value)
|
|
else:
|
|
selected_data_sets.extend(random.sample(value, SELECT_DATA_SET_QTY))
|
|
|
|
prompt += "========================= 数据字段开始 ======================================="
|
|
prompt += "注意: DataField: 后面的是数据字段, DataFieldDescription: 此字段后面的是数据字段对应的描述或使用说明, DataFieldDescription字段后面的内容是使用说明, 不是数据字段\n\n"
|
|
for data_set in selected_data_sets:
|
|
for key, value in data_set.items():
|
|
prompt += f"DataField: {key}\n"
|
|
prompt += f"DataFieldDescription: {value}\n"
|
|
|
|
prompt += "========================= 数据字段结束 =======================================\n\n"
|
|
|
|
return prompt
|
|
|
|
|
|
def main():
|
|
prompt = prepare_prompt()
|
|
|
|
# # 如果需要手动在页面段模型, 使用提示词, 打开这个, 将生成的提示词存到本地
|
|
manual_prompt(prompt)
|
|
|
|
# # 如果需要使用模型, 打开这个
|
|
# call_ai(prompt)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|