You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
59 lines
2.0 KiB
59 lines
2.0 KiB
# -*- coding: utf-8 -*-
|
|
import os
|
|
import json
|
|
import time
|
|
from typing import List, Any
|
|
|
|
|
|
def load_alpha_list(file_path: str) -> List[str]:
|
|
"""从文件加载Alpha因子列表"""
|
|
if not os.path.exists(file_path):
|
|
print(f"{file_path} 文件不存在")
|
|
with open(file_path, 'w', encoding='utf-8') as file:
|
|
file.write("")
|
|
print(f"已创建 {file_path} 文件, 请添加因子后重新运行, 一行一个因子")
|
|
return []
|
|
|
|
with open(file_path, 'r', encoding='utf-8') as file:
|
|
alpha_list = [line.strip() for line in file if line.strip()]
|
|
|
|
return alpha_list
|
|
|
|
|
|
def save_results_to_file(results: List[Any], result_dir: str = 'result') -> str:
|
|
"""保存结果到文件"""
|
|
# 转换为可序列化的格式
|
|
serializable_results = []
|
|
for result in results:
|
|
if hasattr(result, '__dict__'):
|
|
# 如果是dataclass对象
|
|
result_dict = result.__dict__.copy()
|
|
else:
|
|
# 如果是字典
|
|
result_dict = result.copy()
|
|
|
|
# 处理时间消耗
|
|
if 'time_consuming' in result_dict:
|
|
result_dict['time_consuming'] = round(result_dict['time_consuming'], 2)
|
|
|
|
# 处理metrics对象
|
|
for key in list(result_dict.keys()):
|
|
if hasattr(result_dict[key], '__dict__'):
|
|
result_dict[key] = result_dict[key].__dict__
|
|
# 处理浮点数精度
|
|
for metric_key, value in result_dict[key].items():
|
|
if isinstance(value, float):
|
|
result_dict[key][metric_key] = round(value, 6)
|
|
|
|
serializable_results.append(result_dict)
|
|
|
|
# 确保结果目录存在
|
|
if not os.path.exists(result_dir):
|
|
os.makedirs(result_dir)
|
|
|
|
result_name = f"{result_dir}/simulation_results-{str(int(time.time()))}.json"
|
|
with open(result_name, 'w', encoding='utf-8') as f:
|
|
json.dump(serializable_results, f, ensure_ascii=False, indent=2)
|
|
|
|
print(f"结果已保存到 {result_name}")
|
|
return result_name |