main
parent
beda7fafb4
commit
7e7757fc63
@ -1,26 +0,0 @@ |
|||||||
# 使用 Python 3.12 官方镜像 |
|
||||||
FROM python:3.12-slim |
|
||||||
|
|
||||||
# 设置工作目录 |
|
||||||
WORKDIR /app |
|
||||||
|
|
||||||
# 设置环境变量 |
|
||||||
ENV PYTHONPATH=/app |
|
||||||
ENV PYTHONDONTWRITEBYTECODE=1 |
|
||||||
ENV PYTHONUNBUFFERED=1 |
|
||||||
|
|
||||||
# 安装系统依赖(如果需要连接 PostgreSQL) |
|
||||||
RUN apt-get update && apt-get install -y \ |
|
||||||
gcc \ |
|
||||||
libpq-dev \ |
|
||||||
&& rm -rf /var/lib/apt/lists/* |
|
||||||
|
|
||||||
# 复制 requirements.txt 并安装 Python 依赖 |
|
||||||
COPY requirements.txt . |
|
||||||
RUN pip install --no-cache-dir -r requirements.txt |
|
||||||
|
|
||||||
# 复制项目文件 |
|
||||||
COPY . . |
|
||||||
|
|
||||||
# 设置启动命令 |
|
||||||
CMD ["python", "main.py"] |
|
||||||
@ -1,52 +0,0 @@ |
|||||||
# -*- coding: utf-8 -*- |
|
||||||
"""配置文件""" |
|
||||||
|
|
||||||
import os |
|
||||||
|
|
||||||
|
|
||||||
class Settings: |
|
||||||
"""应用配置""" |
|
||||||
|
|
||||||
# 数据库配置 |
|
||||||
DATABASE_CONFIG = { |
|
||||||
"host": "192.168.31.201", |
|
||||||
"port": "5432", |
|
||||||
"user": "jack", |
|
||||||
"password": "aaaAAA111", |
|
||||||
"database": "alpha" |
|
||||||
} |
|
||||||
|
|
||||||
# API配置 |
|
||||||
BRAIN_API_URL = "https://api.worldquantbrain.com" |
|
||||||
|
|
||||||
# 模拟配置 |
|
||||||
SIMULATION_SETTINGS = { |
|
||||||
'instrumentType': 'EQUITY', |
|
||||||
'region': 'USA', |
|
||||||
'universe': 'TOP3000', |
|
||||||
'delay': 1, |
|
||||||
'decay': 0, |
|
||||||
'neutralization': 'INDUSTRY', |
|
||||||
'truncation': 0.08, |
|
||||||
'pasteurization': 'ON', |
|
||||||
'unitHandling': 'VERIFY', |
|
||||||
'nanHandling': 'OFF', |
|
||||||
'language': 'FASTEXPR', |
|
||||||
'visualization': False, |
|
||||||
} |
|
||||||
|
|
||||||
# 批处理配置 |
|
||||||
BATCH_SIZE = 3 |
|
||||||
CHECK_INTERVAL = 300 # 5分钟 |
|
||||||
TOKEN_REFRESH_THRESHOLD = 1800 # 30分钟 |
|
||||||
|
|
||||||
# 通知配置 |
|
||||||
GOTIFY_URL = "https://gotify.erhe.top/message?token=AvKJCJwQKU6yLP8" |
|
||||||
|
|
||||||
@property |
|
||||||
def credentials_file(self) -> str: |
|
||||||
"""获取凭证文件路径""" |
|
||||||
return os.path.join(os.path.dirname(os.path.dirname(__file__)), 'account.txt') |
|
||||||
|
|
||||||
|
|
||||||
settings = Settings() |
|
||||||
@ -1,176 +0,0 @@ |
|||||||
# -*- coding: utf-8 -*- |
|
||||||
"""数据库管理层""" |
|
||||||
|
|
||||||
import psycopg2 |
|
||||||
from typing import List |
|
||||||
from config.settings import settings |
|
||||||
from models.entities import SimulationResult |
|
||||||
|
|
||||||
|
|
||||||
class DatabaseManager: |
|
||||||
"""数据库管理类""" |
|
||||||
|
|
||||||
def __init__(self): |
|
||||||
self.connection = None |
|
||||||
self.init_database() |
|
||||||
|
|
||||||
def create_database(self) -> None: |
|
||||||
"""创建数据库(如果不存在)""" |
|
||||||
try: |
|
||||||
# 先连接到默认的postgres数据库来创建alpha数据库 |
|
||||||
conn = psycopg2.connect( |
|
||||||
host=settings.DATABASE_CONFIG["host"], |
|
||||||
port=settings.DATABASE_CONFIG["port"], |
|
||||||
database="postgres", |
|
||||||
user=settings.DATABASE_CONFIG["user"], |
|
||||||
password=settings.DATABASE_CONFIG["password"] |
|
||||||
) |
|
||||||
conn.autocommit = True |
|
||||||
cursor = conn.cursor() |
|
||||||
|
|
||||||
# 检查数据库是否存在 |
|
||||||
cursor.execute("SELECT 1 FROM pg_catalog.pg_database WHERE datname = %s", |
|
||||||
(settings.DATABASE_CONFIG["database"],)) |
|
||||||
exists = cursor.fetchone() |
|
||||||
|
|
||||||
if not exists: |
|
||||||
cursor.execute(f"CREATE DATABASE {settings.DATABASE_CONFIG['database']}") |
|
||||||
print(f"数据库 {settings.DATABASE_CONFIG['database']} 创建成功") |
|
||||||
else: |
|
||||||
print(f"数据库 {settings.DATABASE_CONFIG['database']} 已存在") |
|
||||||
|
|
||||||
cursor.close() |
|
||||||
conn.close() |
|
||||||
|
|
||||||
except Exception as e: |
|
||||||
print(f"创建数据库时出错: {e}") |
|
||||||
raise |
|
||||||
|
|
||||||
def get_connection(self) -> psycopg2.extensions.connection: |
|
||||||
"""获取数据库连接""" |
|
||||||
if self.connection is None or self.connection.closed: |
|
||||||
self.connection = psycopg2.connect( |
|
||||||
host=settings.DATABASE_CONFIG["host"], |
|
||||||
port=settings.DATABASE_CONFIG["port"], |
|
||||||
database=settings.DATABASE_CONFIG["database"], |
|
||||||
user=settings.DATABASE_CONFIG["user"], |
|
||||||
password=settings.DATABASE_CONFIG["password"] |
|
||||||
) |
|
||||||
return self.connection |
|
||||||
|
|
||||||
def init_database(self) -> None: |
|
||||||
"""初始化数据库和表结构""" |
|
||||||
# 先创建数据库 |
|
||||||
self.create_database() |
|
||||||
|
|
||||||
# 然后连接到此数据库创建表 |
|
||||||
conn = self.get_connection() |
|
||||||
cursor = conn.cursor() |
|
||||||
|
|
||||||
# 创建 alpha_prepare 表 |
|
||||||
cursor.execute(''' |
|
||||||
CREATE TABLE IF NOT EXISTS alpha_prepare |
|
||||||
( |
|
||||||
id |
|
||||||
SERIAL |
|
||||||
PRIMARY |
|
||||||
KEY, |
|
||||||
alpha |
|
||||||
TEXT |
|
||||||
NOT |
|
||||||
NULL |
|
||||||
UNIQUE, |
|
||||||
unused |
|
||||||
BOOLEAN |
|
||||||
NOT |
|
||||||
NULL |
|
||||||
DEFAULT |
|
||||||
TRUE, |
|
||||||
created_time |
|
||||||
TIMESTAMP |
|
||||||
DEFAULT |
|
||||||
CURRENT_TIMESTAMP |
|
||||||
) |
|
||||||
''') |
|
||||||
|
|
||||||
# 创建 simulation 表 |
|
||||||
cursor.execute(''' |
|
||||||
CREATE TABLE IF NOT EXISTS simulation |
|
||||||
( |
|
||||||
id |
|
||||||
SERIAL |
|
||||||
PRIMARY |
|
||||||
KEY, |
|
||||||
expression |
|
||||||
TEXT |
|
||||||
NOT |
|
||||||
NULL, |
|
||||||
time_consuming |
|
||||||
REAL |
|
||||||
NOT |
|
||||||
NULL, |
|
||||||
status |
|
||||||
TEXT |
|
||||||
NOT |
|
||||||
NULL, |
|
||||||
timestamp |
|
||||||
TEXT |
|
||||||
NOT |
|
||||||
NULL, |
|
||||||
alpha_id |
|
||||||
TEXT, |
|
||||||
message |
|
||||||
TEXT, |
|
||||||
created_time |
|
||||||
TIMESTAMP |
|
||||||
DEFAULT |
|
||||||
CURRENT_TIMESTAMP |
|
||||||
) |
|
||||||
''') |
|
||||||
|
|
||||||
conn.commit() |
|
||||||
print(f"数据库 {settings.DATABASE_CONFIG['database']} 表结构初始化完成") |
|
||||||
|
|
||||||
def get_unused_alpha(self) -> List[str]: |
|
||||||
"""获取所有未使用的alpha表达式""" |
|
||||||
conn = self.get_connection() |
|
||||||
cursor = conn.cursor() |
|
||||||
|
|
||||||
cursor.execute("SELECT alpha FROM alpha_prepare WHERE unused = TRUE") |
|
||||||
results = cursor.fetchall() |
|
||||||
|
|
||||||
alpha_list = [result[0] for result in results] |
|
||||||
return alpha_list |
|
||||||
|
|
||||||
def mark_alpha_used(self, alpha: str) -> None: |
|
||||||
"""将alpha标记为已使用""" |
|
||||||
conn = self.get_connection() |
|
||||||
cursor = conn.cursor() |
|
||||||
|
|
||||||
cursor.execute("UPDATE alpha_prepare SET unused = FALSE WHERE alpha = %s", (alpha,)) |
|
||||||
conn.commit() |
|
||||||
|
|
||||||
def insert_simulation_result(self, result: SimulationResult) -> None: |
|
||||||
"""插入模拟结果到simulation表""" |
|
||||||
conn = self.get_connection() |
|
||||||
cursor = conn.cursor() |
|
||||||
|
|
||||||
cursor.execute(''' |
|
||||||
INSERT INTO simulation |
|
||||||
(expression, time_consuming, status, timestamp, alpha_id, message) |
|
||||||
VALUES (%s, %s, %s, %s, %s, %s) |
|
||||||
''', ( |
|
||||||
result.expression, |
|
||||||
result.time_consuming, |
|
||||||
result.status, |
|
||||||
result.timestamp, |
|
||||||
result.alpha_id, |
|
||||||
result.message or "" |
|
||||||
)) |
|
||||||
|
|
||||||
conn.commit() |
|
||||||
|
|
||||||
def close_connection(self) -> None: |
|
||||||
"""关闭数据库连接""" |
|
||||||
if self.connection and not self.connection.closed: |
|
||||||
self.connection.close() |
|
||||||
@ -1,126 +0,0 @@ |
|||||||
# -*- coding: utf-8 -*- |
|
||||||
"""Alpha模拟器核心类""" |
|
||||||
|
|
||||||
import time |
|
||||||
import httpx |
|
||||||
from datetime import datetime |
|
||||||
from typing import List, Tuple |
|
||||||
from config.settings import settings |
|
||||||
from models.entities import SimulationResult, TokenInfo |
|
||||||
from core.database import DatabaseManager |
|
||||||
from services.auth_service import AuthService |
|
||||||
from services.alpha_service import AlphaService |
|
||||||
from services.notification_service import NotificationService |
|
||||||
|
|
||||||
|
|
||||||
class AlphaSimulator: |
|
||||||
"""Alpha模拟器主类""" |
|
||||||
|
|
||||||
def __init__(self): |
|
||||||
self.db_manager = DatabaseManager() |
|
||||||
self.auth_service = AuthService() |
|
||||||
self.client = None |
|
||||||
self.token_info = None |
|
||||||
self.alpha_service = None |
|
||||||
|
|
||||||
def __del__(self): |
|
||||||
"""析构函数,确保数据库连接被关闭""" |
|
||||||
if hasattr(self, 'db_manager'): |
|
||||||
self.db_manager.close_connection() |
|
||||||
|
|
||||||
def initialize(self) -> None: |
|
||||||
"""初始化模拟器""" |
|
||||||
self.client = httpx.Client() |
|
||||||
self.login() |
|
||||||
|
|
||||||
def login(self) -> TokenInfo: |
|
||||||
"""登录并初始化alpha服务""" |
|
||||||
self.token_info = self.auth_service.login(self.client) |
|
||||||
self.alpha_service = AlphaService(self.client) |
|
||||||
return self.token_info |
|
||||||
|
|
||||||
def needs_token_refresh(self) -> bool: |
|
||||||
"""检查是否需要刷新token""" |
|
||||||
if not self.token_info: |
|
||||||
return True |
|
||||||
return self.token_info.expiry < settings.TOKEN_REFRESH_THRESHOLD |
|
||||||
|
|
||||||
def load_alpha_list(self) -> List[str]: |
|
||||||
"""从数据库加载未使用的alpha表达式""" |
|
||||||
return self.db_manager.get_unused_alpha() |
|
||||||
|
|
||||||
def run_batch_simulation(self, alpha_list: List[str]) -> Tuple[int, int]: |
|
||||||
"""运行批量模拟""" |
|
||||||
success_count = 0 |
|
||||||
fail_count = 0 |
|
||||||
|
|
||||||
for i in range(0, len(alpha_list), settings.BATCH_SIZE): |
|
||||||
batch = alpha_list[i:i + settings.BATCH_SIZE] |
|
||||||
print(f"\n开始处理第 {i // settings.BATCH_SIZE + 1} 批因子,共 {len(batch)} 个") |
|
||||||
|
|
||||||
for expression in batch: |
|
||||||
result = self._simulate_single_alpha(expression) |
|
||||||
if result.status == "ok": |
|
||||||
success_count += 1 |
|
||||||
else: |
|
||||||
fail_count += 1 |
|
||||||
|
|
||||||
print(f"第 {i // settings.BATCH_SIZE + 1} 批处理完成") |
|
||||||
|
|
||||||
self._print_summary(success_count, fail_count) |
|
||||||
return success_count, fail_count |
|
||||||
|
|
||||||
def _simulate_single_alpha(self, expression: str) -> SimulationResult: |
|
||||||
"""模拟单个Alpha表达式""" |
|
||||||
print(f"\n模拟因子: {expression}") |
|
||||||
start_time = time.time() |
|
||||||
|
|
||||||
try: |
|
||||||
result = self.alpha_service.simulate_alpha(expression) |
|
||||||
end_time = time.time() |
|
||||||
time_consuming = round(end_time - start_time, 2) |
|
||||||
|
|
||||||
simulation_result = SimulationResult( |
|
||||||
expression=expression, |
|
||||||
time_consuming=time_consuming, |
|
||||||
status=result["status"], |
|
||||||
timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
|
||||||
alpha_id=result.get("alpha_id"), |
|
||||||
message=result.get("message", "") |
|
||||||
) |
|
||||||
|
|
||||||
if result["status"] == "ok": |
|
||||||
print(f"✅ 模拟成功 - Alpha ID: {result['alpha_id']} - 耗时: {time_consuming}秒") |
|
||||||
else: |
|
||||||
print(f"❌ 模拟失败 - {result.get('message', '未知错误')}") |
|
||||||
|
|
||||||
except Exception as e: |
|
||||||
end_time = time.time() |
|
||||||
time_consuming = round(end_time - start_time, 2) |
|
||||||
simulation_result = SimulationResult( |
|
||||||
expression=expression, |
|
||||||
time_consuming=time_consuming, |
|
||||||
status="err", |
|
||||||
message=str(e), |
|
||||||
timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
||||||
) |
|
||||||
print(f"❌ 模拟异常 - {str(e)}") |
|
||||||
|
|
||||||
# 保存结果并标记为已使用 |
|
||||||
self._save_simulation_result(simulation_result) |
|
||||||
return simulation_result |
|
||||||
|
|
||||||
def _save_simulation_result(self, result: SimulationResult) -> None: |
|
||||||
"""保存模拟结果到数据库""" |
|
||||||
self.db_manager.mark_alpha_used(result.expression) |
|
||||||
self.db_manager.insert_simulation_result(result) |
|
||||||
|
|
||||||
def _print_summary(self, success_count: int, fail_count: int) -> None: |
|
||||||
"""打印总结信息""" |
|
||||||
now = datetime.now().strftime('%Y-%m-%d %H:%M:%S') |
|
||||||
print(f"\n总计: 成功 {success_count} 个, 失败 {fail_count} 个") |
|
||||||
print(f"完成时间: {now}") |
|
||||||
print(f"所有结果已保存到 PostgreSQL 数据库 {settings.DATABASE_CONFIG['database']} 的 simulation 表中") |
|
||||||
|
|
||||||
# 发送通知 |
|
||||||
NotificationService.send_to_gotify(success_count, fail_count) |
|
||||||
@ -0,0 +1,7 @@ |
|||||||
|
[database] |
||||||
|
host = 192.168.31.201 |
||||||
|
; host = 127.0.0.1 |
||||||
|
port = 5432 |
||||||
|
database = alpha |
||||||
|
user = jack |
||||||
|
password = aaaAAA111 |
||||||
@ -1,9 +0,0 @@ |
|||||||
services: |
|
||||||
factorsimulator: |
|
||||||
build: |
|
||||||
context: . |
|
||||||
dockerfile: Dockerfile |
|
||||||
container_name: factor-simulator |
|
||||||
environment: |
|
||||||
- PYTHONPATH=/app |
|
||||||
restart: unless-stopped |
|
||||||
@ -1,45 +1,529 @@ |
|||||||
# -*- coding: utf-8 -*- |
# -*- coding: utf-8 -*- |
||||||
"""主程序入口""" |
import os |
||||||
|
|
||||||
import time |
import time |
||||||
from core.simulator import AlphaSimulator |
import random |
||||||
from config.settings import settings |
import psycopg2 |
||||||
from utils.helpers import retry_on_exception |
import psycopg2.extras |
||||||
|
import httpx |
||||||
|
import threading |
||||||
|
from datetime import datetime |
||||||
|
from dataclasses import dataclass |
||||||
|
from typing import List, Tuple, Dict, Any, Optional |
||||||
|
from httpx import BasicAuth |
||||||
|
from configparser import ConfigParser |
||||||
|
import logging |
||||||
|
|
||||||
|
# 配置日志 |
||||||
|
logging.basicConfig( |
||||||
|
level=logging.INFO, |
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
||||||
|
) |
||||||
|
logger = logging.getLogger(__name__) |
||||||
|
|
||||||
|
|
||||||
|
class Settings: |
||||||
|
"""应用配置 - 完全按照原代码""" |
||||||
|
|
||||||
|
def __init__(self): |
||||||
|
# 数据库配置 |
||||||
|
self.config_file = "db.conf" |
||||||
|
self.config = ConfigParser() |
||||||
|
|
||||||
|
if not os.path.exists(self.config_file): |
||||||
|
raise FileNotFoundError(f"数据库配置文件 {self.config_file} 不存在") |
||||||
|
|
||||||
|
self.config.read(self.config_file, encoding='utf-8') |
||||||
|
|
||||||
|
self.DATABASE_CONFIG = { |
||||||
|
"host": self.config.get('database', 'host'), |
||||||
|
"port": self.config.get('database', 'port'), |
||||||
|
"user": self.config.get('database', 'user'), |
||||||
|
"password": self.config.get('database', 'password'), |
||||||
|
"database": self.config.get('database', 'database') |
||||||
|
} |
||||||
|
|
||||||
|
# API配置 |
||||||
|
self.BRAIN_API_URL = "https://api.worldquantbrain.com" |
||||||
|
|
||||||
|
# 模拟配置 - 完全按照原代码 |
||||||
|
self.SIMULATION_SETTINGS = { |
||||||
|
'instrumentType': 'EQUITY', |
||||||
|
'region': 'USA', |
||||||
|
'universe': 'TOP3000', |
||||||
|
'delay': 1, |
||||||
|
'decay': 0, |
||||||
|
'neutralization': 'INDUSTRY', |
||||||
|
'truncation': 0.08, |
||||||
|
'pasteurization': 'ON', |
||||||
|
'unitHandling': 'VERIFY', |
||||||
|
'nanHandling': 'OFF', |
||||||
|
'language': 'FASTEXPR', |
||||||
|
'visualization': False, |
||||||
|
} |
||||||
|
|
||||||
|
# 批次配置 - 每批8个 |
||||||
|
self.BATCH_SIZE = 8 |
||||||
|
|
||||||
|
# 模拟间隔配置 |
||||||
|
self.SIMULATION_INTERVAL_MIN = 5 # 最小间隔秒数 |
||||||
|
self.SIMULATION_INTERVAL_MAX = 10 # 最大间隔秒数 |
||||||
|
|
||||||
|
# 通知配置 |
||||||
|
self.GOTIFY_URL = os.getenv("GOTIFY_URL", "https://gotify.erhe.top/message?token=AvKJCJwQKU6yLP8") |
||||||
|
|
||||||
|
# Session有效期(秒) - 4小时 |
||||||
|
self.SESSION_EXPIRY_SECONDS = 4 * 3600 # 4小时 |
||||||
|
|
||||||
|
@property |
||||||
|
def credentials_file(self) -> str: |
||||||
|
"""获取凭证文件路径""" |
||||||
|
return "account.txt" |
||||||
|
|
||||||
|
|
||||||
|
@dataclass |
||||||
|
class SimulationResult: |
||||||
|
"""模拟结果实体""" |
||||||
|
alpha: str |
||||||
|
time_consuming: float |
||||||
|
status: str |
||||||
|
timestamp: str |
||||||
|
alpha_id: Optional[str] = None |
||||||
|
message: Optional[str] = None |
||||||
|
|
||||||
|
|
||||||
|
class AuthService: |
||||||
|
"""认证服务类 - 按照原代码逻辑""" |
||||||
|
|
||||||
|
def __init__(self, settings: Settings): |
||||||
|
self.settings = settings |
||||||
|
self.credentials_file = settings.credentials_file |
||||||
|
self.last_login_time: Optional[float] = None |
||||||
|
self.login_expiry: Optional[float] = None |
||||||
|
|
||||||
|
def load_credentials(self) -> Tuple[str, str]: |
||||||
|
"""加载凭证""" |
||||||
|
if not os.path.exists(self.credentials_file): |
||||||
|
self._create_credentials_file() |
||||||
|
|
||||||
|
with open(self.credentials_file, 'r', encoding='utf-8') as f: |
||||||
|
credentials = eval(f.read()) |
||||||
|
return credentials[0], credentials[1] |
||||||
|
|
||||||
|
def _create_credentials_file(self) -> None: |
||||||
|
"""创建凭证文件""" |
||||||
|
logger.error(f"未找到 {self.credentials_file} 文件") |
||||||
|
with open(self.credentials_file, 'w', encoding='utf-8') as f: |
||||||
|
f.write("['your_username', 'your_password']") |
||||||
|
logger.error(f"请编辑 {self.credentials_file} 文件,填写账号密码") |
||||||
|
raise FileNotFoundError(f"请创建并配置 {self.credentials_file} 文件") |
||||||
|
|
||||||
|
def login(self, client: httpx.Client) -> Dict[str, Any]: |
||||||
|
"""登录并设置client的认证信息 - 按照原代码""" |
||||||
|
username, password = self.load_credentials() |
||||||
|
client.auth = BasicAuth(username, password) |
||||||
|
|
||||||
|
response = client.post(f'{self.settings.BRAIN_API_URL}/authentication') |
||||||
|
logger.info(f"登录状态: {response.status_code}") |
||||||
|
|
||||||
|
if response.status_code == 201: |
||||||
|
login_data = response.json() |
||||||
|
self.last_login_time = time.time() |
||||||
|
self.login_expiry = self.last_login_time + self.settings.SESSION_EXPIRY_SECONDS |
||||||
|
|
||||||
|
logger.info(f"登录成功!Session有效期: {self.settings.SESSION_EXPIRY_SECONDS}秒 (4.0小时)") |
||||||
|
|
||||||
|
return login_data |
||||||
|
elif response.status_code == 429: |
||||||
|
logger.error("API rate limit exceeded") |
||||||
|
raise Exception("API rate limit exceeded") |
||||||
|
else: |
||||||
|
logger.error(f"登录失败: {response.status_code} - {response.text}") |
||||||
|
raise Exception(f"登录失败: {response.status_code} - {response.text}") |
||||||
|
|
||||||
|
def is_session_expired(self) -> bool: |
||||||
|
"""检查session是否已过期""" |
||||||
|
if not self.last_login_time or not self.login_expiry: |
||||||
|
return True |
||||||
|
|
||||||
|
current_time = time.time() |
||||||
|
return current_time >= self.login_expiry |
||||||
|
|
||||||
|
def needs_renewal(self) -> bool: |
||||||
|
"""检查session是否需要续期(小于30分钟)""" |
||||||
|
if not self.login_expiry: |
||||||
|
return True |
||||||
|
|
||||||
|
current_time = time.time() |
||||||
|
time_remaining = self.login_expiry - current_time |
||||||
|
return time_remaining < 1800 # 30分钟 |
||||||
|
|
||||||
|
|
||||||
|
class AlphaService: |
||||||
|
"""Alpha表达式服务类 - 完全按照原代码逻辑""" |
||||||
|
|
||||||
|
def __init__(self, client: httpx.Client, settings: Settings): |
||||||
|
self.client = client |
||||||
|
self.settings = settings |
||||||
|
|
||||||
|
def simulate_alpha(self, alpha: str) -> Dict[str, Any]: |
||||||
|
"""模拟单个Alpha表达式 - 完全按照原代码""" |
||||||
|
simulation_data = { |
||||||
|
'type': 'REGULAR', |
||||||
|
'settings': self.settings.SIMULATION_SETTINGS, # 直接使用settings |
||||||
|
'regular': alpha |
||||||
|
} |
||||||
|
|
||||||
|
logger.info(f"发送模拟请求,alpha长度: {len(alpha)}") |
||||||
|
sim_resp = self.client.post(f'{self.settings.BRAIN_API_URL}/simulations', json=simulation_data) |
||||||
|
logger.info(f"模拟提交状态: {sim_resp.status_code}") |
||||||
|
|
||||||
|
if 'location' not in sim_resp.headers: |
||||||
|
# 打印更多调试信息 |
||||||
|
logger.error(f"缺少location header,状态码: {sim_resp.status_code}") |
||||||
|
logger.error(f"响应头: {dict(sim_resp.headers)}") |
||||||
|
logger.error(f"响应内容: {sim_resp.text[:500] if sim_resp.text else '空'}") |
||||||
|
return {"status": "err", "message": "No location header in response"} |
||||||
|
|
||||||
|
sim_progress_url = sim_resp.headers['location'] |
||||||
|
logger.info(f"模拟任务已创建,进度URL: {sim_progress_url}") |
||||||
|
return self._wait_for_simulation_result(sim_progress_url) |
||||||
|
|
||||||
|
def _wait_for_simulation_result(self, progress_url: str) -> Dict[str, Any]: |
||||||
|
"""等待模拟结果 - 按照原代码逻辑""" |
||||||
|
while True: |
||||||
|
sim_progress_resp = self.client.get(progress_url) |
||||||
|
retry_after_sec = float(sim_progress_resp.headers.get("Retry-After", 0)) |
||||||
|
|
||||||
|
if retry_after_sec == 0: |
||||||
|
break |
||||||
|
|
||||||
|
if sim_progress_resp.json(): |
||||||
|
result = sim_progress_resp.json() |
||||||
|
progress = result.get('progress', 0) |
||||||
|
logger.info(f"模拟进度: {progress}%") |
||||||
|
|
||||||
|
time.sleep(retry_after_sec) |
||||||
|
|
||||||
|
result_data = sim_progress_resp.json() |
||||||
|
|
||||||
|
if result_data.get("status") == "ERROR": |
||||||
|
error_message = result_data.get("message", "未知错误") |
||||||
|
logger.error(f"因子模拟失败: {error_message}") |
||||||
|
return {"status": "err", "message": error_message} |
||||||
|
|
||||||
|
alpha_id = result_data.get("alpha") |
||||||
|
logger.info(f"生成的Alpha ID: {alpha_id}") |
||||||
|
|
||||||
|
return {"status": "ok", "alpha_id": alpha_id} |
||||||
|
|
||||||
|
|
||||||
|
class NotificationService: |
||||||
|
"""通知服务类""" |
||||||
|
|
||||||
|
def __init__(self, settings: Settings): |
||||||
|
self.settings = settings |
||||||
|
|
||||||
|
def send_to_gotify(self, success_count: int, fail_count: int) -> None: |
||||||
|
"""发送结果到Gotify""" |
||||||
|
now = datetime.now().strftime('%Y-%m-%d %H:%M:%S') |
||||||
|
text = f"总计: 成功 {success_count} 个, 失败 {fail_count} 个\n\n完成时间: {now}" |
||||||
|
title = f"alpha模拟结果 时间: {now}" |
||||||
|
|
||||||
@retry_on_exception(max_retries=3, delay=5.0) |
try: |
||||||
def main_loop(simulator: AlphaSimulator) -> None: |
resp = httpx.post( |
||||||
"""主循环""" |
self.settings.GOTIFY_URL, |
||||||
if simulator.needs_token_refresh(): |
json={'title': title, 'message': text}, |
||||||
print("Token需要刷新,重新登录...") |
timeout=10 |
||||||
simulator.login() |
) |
||||||
|
logger.info("通知发送成功") |
||||||
|
except Exception as e: |
||||||
|
logger.error(f"通知发送失败: {e}") |
||||||
|
|
||||||
alpha_list = simulator.load_alpha_list() |
|
||||||
if not alpha_list: |
|
||||||
print("暂无待处理的alpha表达式,10分钟后重新检查...") |
|
||||||
time.sleep(600) |
|
||||||
return |
|
||||||
|
|
||||||
print(f"共加载 {len(alpha_list)} 个需要模拟的因子表达式") |
class DatabaseManager: |
||||||
success_count, fail_count = simulator.run_batch_simulation(alpha_list) |
"""数据库管理类""" |
||||||
print(f"本轮处理完成: 成功 {success_count} 个, 失败 {fail_count} 个") |
|
||||||
|
def __init__(self, settings: Settings): |
||||||
|
self.settings = settings |
||||||
|
self.connection = None |
||||||
|
|
||||||
|
def get_connection(self) -> psycopg2.extensions.connection: |
||||||
|
"""获取数据库连接""" |
||||||
|
if self.connection is None or self.connection.closed: |
||||||
|
self.connection = psycopg2.connect( |
||||||
|
host=self.settings.DATABASE_CONFIG["host"], |
||||||
|
port=self.settings.DATABASE_CONFIG["port"], |
||||||
|
database=self.settings.DATABASE_CONFIG["database"], |
||||||
|
user=self.settings.DATABASE_CONFIG["user"], |
||||||
|
password=self.settings.DATABASE_CONFIG["password"] |
||||||
|
) |
||||||
|
return self.connection |
||||||
|
|
||||||
|
def get_unused_alpha(self) -> List[Tuple[int, str]]: |
||||||
|
"""获取所有未使用的alpha表达式(返回ID和alpha值)""" |
||||||
|
conn = self.get_connection() |
||||||
|
cursor = conn.cursor() |
||||||
|
cursor.execute("SELECT id, alpha FROM alpha_simulation WHERE unused = TRUE AND alpha IS NOT NULL") |
||||||
|
return cursor.fetchall() |
||||||
|
|
||||||
|
def update_alpha_simulation(self, record_id: int, result: SimulationResult) -> None: |
||||||
|
"""更新alpha_simulation表的记录""" |
||||||
|
conn = self.get_connection() |
||||||
|
cursor = conn.cursor() |
||||||
|
cursor.execute(''' |
||||||
|
UPDATE alpha_simulation |
||||||
|
SET unused = FALSE, |
||||||
|
time_consuming = %s, |
||||||
|
status = %s, |
||||||
|
timestamp = %s, |
||||||
|
alpha_id = %s, |
||||||
|
message = %s |
||||||
|
WHERE id = %s |
||||||
|
''', ( |
||||||
|
result.time_consuming, |
||||||
|
result.status, |
||||||
|
result.timestamp, |
||||||
|
result.alpha_id, |
||||||
|
result.message or "", |
||||||
|
record_id |
||||||
|
)) |
||||||
|
conn.commit() |
||||||
|
|
||||||
|
def close_connection(self) -> None: |
||||||
|
"""关闭数据库连接""" |
||||||
|
if self.connection and not self.connection.closed: |
||||||
|
self.connection.close() |
||||||
|
|
||||||
|
|
||||||
|
class AlphaSimulator: |
||||||
|
"""Alpha模拟器主类 - 按照原代码逻辑""" |
||||||
|
|
||||||
|
def __init__(self): |
||||||
|
self.settings = Settings() |
||||||
|
self.db_manager = DatabaseManager(self.settings) |
||||||
|
self.auth_service = AuthService(self.settings) |
||||||
|
self.notification_service = NotificationService(self.settings) |
||||||
|
self.client: Optional[httpx.Client] = None |
||||||
|
self.alpha_service: Optional[AlphaService] = None |
||||||
|
self.success_count = 0 |
||||||
|
self.fail_count = 0 |
||||||
|
self.total_processed = 0 |
||||||
|
|
||||||
|
def __del__(self): |
||||||
|
"""析构函数,确保数据库连接被关闭""" |
||||||
|
if hasattr(self, 'db_manager'): |
||||||
|
self.db_manager.close_connection() |
||||||
|
if self.client: |
||||||
|
self.client.close() |
||||||
|
|
||||||
|
def initialize(self) -> None: |
||||||
|
"""初始化模拟器""" |
||||||
|
try: |
||||||
|
self.client = httpx.Client() |
||||||
|
self.login() |
||||||
|
logger.info("模拟器初始化成功") |
||||||
|
except Exception as e: |
||||||
|
logger.error(f"模拟器初始化失败: {e}") |
||||||
|
raise |
||||||
|
|
||||||
|
def login(self) -> Dict[str, Any]: |
||||||
|
"""登录并初始化alpha服务""" |
||||||
|
logger.info("正在登录...") |
||||||
|
login_result = self.auth_service.login(self.client) |
||||||
|
self.alpha_service = AlphaService(self.client, self.settings) |
||||||
|
|
||||||
|
expires_in = self.settings.SESSION_EXPIRY_SECONDS |
||||||
|
hours = int(expires_in // 3600) |
||||||
|
minutes = int((expires_in % 3600) // 60) |
||||||
|
logger.info(f"登录成功!Session将在 {hours}小时{minutes}分钟后过期") |
||||||
|
|
||||||
|
return login_result |
||||||
|
|
||||||
|
def check_and_refresh_session_if_needed(self) -> None: |
||||||
|
""" |
||||||
|
检查并刷新session(如果需要) |
||||||
|
如果session即将过期(小于30分钟),则重新登录 |
||||||
|
""" |
||||||
|
if not self.auth_service.last_login_time: |
||||||
|
logger.info("未找到登录信息,需要登录") |
||||||
|
self.login() |
||||||
|
return |
||||||
|
|
||||||
|
if self.auth_service.is_session_expired(): |
||||||
|
current_time = time.time() |
||||||
|
expired_seconds = current_time - self.auth_service.login_expiry |
||||||
|
logger.info(f"Session已过期(已过期 {expired_seconds:.0f} 秒),需要重新登录") |
||||||
|
self.login() |
||||||
|
return |
||||||
|
|
||||||
|
if self.auth_service.needs_renewal(): |
||||||
|
time_remaining = self.auth_service.login_expiry - time.time() |
||||||
|
minutes = int(time_remaining // 60) |
||||||
|
seconds = int(time_remaining % 60) |
||||||
|
logger.info(f"Session将在{minutes}分{seconds}秒后过期(小于30分钟),需要重新登录") |
||||||
|
|
||||||
|
# 创建新的client对象 |
||||||
|
if self.client: |
||||||
|
self.client.close() |
||||||
|
self.client = httpx.Client() |
||||||
|
self.login() |
||||||
|
return |
||||||
|
|
||||||
|
time_remaining = self.auth_service.login_expiry - time.time() |
||||||
|
hours = int(time_remaining // 3600) |
||||||
|
minutes = int((time_remaining % 3600) // 60) |
||||||
|
seconds = int(time_remaining % 60) |
||||||
|
|
||||||
|
logger.info(f"Session仍然有效,剩余时间: {hours}小时{minutes}分{seconds}秒") |
||||||
|
|
||||||
|
def load_alpha_list(self) -> List[Tuple[int, str]]: |
||||||
|
"""从数据库加载未使用的alpha表达式(包含ID)""" |
||||||
|
return self.db_manager.get_unused_alpha() |
||||||
|
|
||||||
|
def run_batch_simulation(self, alpha_records: List[Tuple[int, str]]) -> Tuple[int, int]: |
||||||
|
"""运行批次模拟,每批8个,等一批全部完成再开始下一批""" |
||||||
|
total_count = len(alpha_records) |
||||||
|
batch_count = (total_count + self.settings.BATCH_SIZE - 1) // self.settings.BATCH_SIZE |
||||||
|
|
||||||
|
logger.info(f"开始批次模拟,总共 {total_count} 个因子,每批 {self.settings.BATCH_SIZE} 个, 当前时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") |
||||||
|
|
||||||
|
# 重置计数器 |
||||||
|
self.success_count = 0 |
||||||
|
self.fail_count = 0 |
||||||
|
self.total_processed = total_count |
||||||
|
|
||||||
|
for batch_idx in range(batch_count): |
||||||
|
# 在每批开始前检查session状态,如果需要则重新登录 |
||||||
|
logger.info("检查session状态...") |
||||||
|
self.check_and_refresh_session_if_needed() |
||||||
|
|
||||||
|
start_idx = batch_idx * self.settings.BATCH_SIZE |
||||||
|
end_idx = min(start_idx + self.settings.BATCH_SIZE, total_count) |
||||||
|
batch_records = alpha_records[start_idx:end_idx] |
||||||
|
|
||||||
|
logger.info(f"\n=== 开始第 {batch_idx + 1}/{batch_count} 批,本批 {len(batch_records)} 个因子 ===\t 当前时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") |
||||||
|
|
||||||
|
# 创建并启动本批的所有线程 |
||||||
|
threads = [] |
||||||
|
thread_results = [None] * len(batch_records) |
||||||
|
|
||||||
|
for i, (record_id, alpha) in enumerate(batch_records): |
||||||
|
thread = threading.Thread( |
||||||
|
target=self._simulate_single_in_batch, |
||||||
|
args=(record_id, alpha, i, thread_results) |
||||||
|
) |
||||||
|
thread.daemon = True |
||||||
|
threads.append(thread) |
||||||
|
thread.start() |
||||||
|
logger.info(f"启动第 {batch_idx + 1} 批第 {i + 1} 个因子: {alpha[:50]}... (ID: {record_id})") |
||||||
|
|
||||||
|
# 等待本批所有线程完成 |
||||||
|
for thread in threads: |
||||||
|
thread.join() |
||||||
|
|
||||||
|
logger.info(f"=== 第 {batch_idx + 1}/{batch_count} 批完成 ===, 当前时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") |
||||||
|
|
||||||
|
# 每批之间随机sleep |
||||||
|
if batch_idx < batch_count - 1: |
||||||
|
sleep_time = random.uniform(self.settings.SIMULATION_INTERVAL_MIN, self.settings.SIMULATION_INTERVAL_MAX) |
||||||
|
logger.info(f"等待 {sleep_time:.1f} 秒后开始下一批...") |
||||||
|
time.sleep(sleep_time) |
||||||
|
|
||||||
|
# 打印总结并发送通知 |
||||||
|
self._print_summary() |
||||||
|
|
||||||
|
# 返回最终结果 |
||||||
|
return self.success_count, self.fail_count |
||||||
|
|
||||||
|
def _simulate_single_in_batch(self, record_id: int, alpha: str, index: int, results_list: list) -> None: |
||||||
|
"""批次中模拟单个Alpha表达式""" |
||||||
|
start_time = time.time() |
||||||
|
|
||||||
|
try: |
||||||
|
# 模拟alpha - 直接调用,不传递额外参数 |
||||||
|
result = self.alpha_service.simulate_alpha(alpha) |
||||||
|
end_time = time.time() |
||||||
|
time_consuming = round(end_time - start_time, 2) |
||||||
|
|
||||||
|
simulation_result = SimulationResult( |
||||||
|
alpha=alpha, |
||||||
|
time_consuming=time_consuming, |
||||||
|
status=result["status"], |
||||||
|
timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
||||||
|
alpha_id=result.get("alpha_id"), |
||||||
|
message=result.get("message", "") |
||||||
|
) |
||||||
|
|
||||||
|
# 立即保存到数据库 |
||||||
|
self.db_manager.update_alpha_simulation(record_id, simulation_result) |
||||||
|
|
||||||
|
if result["status"] == "ok": |
||||||
|
self.success_count += 1 |
||||||
|
logger.info(f"✅ 批次内第 {index + 1} 个因子模拟成功 - Alpha ID: {result['alpha_id']} - 耗时: {time_consuming}秒") |
||||||
|
else: |
||||||
|
self.fail_count += 1 |
||||||
|
logger.error(f"❌ 批次内第 {index + 1} 个因子模拟失败 - {result.get('message', '未知错误')}") |
||||||
|
|
||||||
|
results_list[index] = simulation_result |
||||||
|
|
||||||
|
except Exception as e: |
||||||
|
end_time = time.time() |
||||||
|
time_consuming = round(end_time - start_time, 2) |
||||||
|
|
||||||
|
simulation_result = SimulationResult( |
||||||
|
alpha=alpha, |
||||||
|
time_consuming=time_consuming, |
||||||
|
status="err", |
||||||
|
message=str(e), |
||||||
|
timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
||||||
|
) |
||||||
|
|
||||||
|
self.db_manager.update_alpha_simulation(record_id, simulation_result) |
||||||
|
self.fail_count += 1 |
||||||
|
logger.error(f"❌ 批次内第 {index + 1} 个因子模拟异常 - {str(e)}") |
||||||
|
results_list[index] = simulation_result |
||||||
|
|
||||||
|
def _print_summary(self) -> None: |
||||||
|
"""打印总结信息""" |
||||||
|
now = datetime.now().strftime('%Y-%m-%d %H:%M:%S') |
||||||
|
logger.info(f"\n总计: 成功 {self.success_count} 个, 失败 {self.fail_count} 个,共处理 {self.total_processed} 个因子") |
||||||
|
logger.info(f"完成时间: {now}") |
||||||
|
logger.info(f"所有结果已保存到 PostgreSQL 数据库 {self.settings.DATABASE_CONFIG['database']} 的 alpha_simulation 表中") |
||||||
|
|
||||||
|
# 发送通知 |
||||||
|
self.notification_service.send_to_gotify(self.success_count, self.fail_count) |
||||||
|
|
||||||
|
|
||||||
def main(): |
def main(): |
||||||
"""主函数""" |
"""主函数""" |
||||||
simulator = AlphaSimulator() |
logger.info("开始运行Alpha因子模拟器...") |
||||||
simulator.initialize() |
|
||||||
|
|
||||||
try: |
try: |
||||||
while True: |
# 创建模拟器实例 |
||||||
main_loop(simulator) |
simulator = AlphaSimulator() |
||||||
print(f"等待{settings.CHECK_INTERVAL // 60}分钟后继续检查...") |
|
||||||
time.sleep(settings.CHECK_INTERVAL) |
# 初始化(包括登录) |
||||||
|
simulator.initialize() |
||||||
|
|
||||||
|
# 加载待处理的alpha表达式 |
||||||
|
alpha_records = simulator.load_alpha_list() |
||||||
|
if not alpha_records: |
||||||
|
logger.info("暂无待处理的alpha表达式") |
||||||
|
return |
||||||
|
|
||||||
|
logger.info(f"共加载 {len(alpha_records)} 个需要模拟的因子表达式") |
||||||
|
|
||||||
|
# 运行批次模拟 |
||||||
|
success_count, fail_count = simulator.run_batch_simulation(alpha_records) |
||||||
|
|
||||||
|
logger.info(f"模拟任务完成!成功: {success_count} 个, 失败: {fail_count} 个,共处理 {len(alpha_records)} 个因子") |
||||||
|
|
||||||
except KeyboardInterrupt: |
except KeyboardInterrupt: |
||||||
print("\n程序被用户中断") |
logger.info("用户中断执行") |
||||||
except Exception as e: |
except Exception as e: |
||||||
print(f"程序执行异常: {e}") |
logger.error(f"模拟过程发生错误: {e}", exc_info=True) |
||||||
raise |
finally: |
||||||
|
logger.info("Alpha因子模拟器运行结束") |
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__": |
if __name__ == "__main__": |
||||||
|
|||||||
@ -1,32 +0,0 @@ |
|||||||
# -*- coding: utf-8 -*- |
|
||||||
"""数据实体模型""" |
|
||||||
|
|
||||||
from dataclasses import dataclass |
|
||||||
from datetime import datetime |
|
||||||
from typing import Optional |
|
||||||
|
|
||||||
|
|
||||||
@dataclass |
|
||||||
class SimulationResult: |
|
||||||
"""模拟结果实体""" |
|
||||||
expression: str |
|
||||||
time_consuming: float |
|
||||||
status: str # 'ok' or 'err' |
|
||||||
timestamp: str |
|
||||||
alpha_id: Optional[str] = None |
|
||||||
message: Optional[str] = None |
|
||||||
|
|
||||||
|
|
||||||
@dataclass |
|
||||||
class AlphaExpression: |
|
||||||
"""Alpha表达式实体""" |
|
||||||
expression: str |
|
||||||
unused: bool = True |
|
||||||
created_time: Optional[datetime] = None |
|
||||||
|
|
||||||
|
|
||||||
@dataclass |
|
||||||
class TokenInfo: |
|
||||||
"""认证令牌信息""" |
|
||||||
token: str |
|
||||||
expiry: int |
|
||||||
@ -1,60 +0,0 @@ |
|||||||
# -*- coding: utf-8 -*- |
|
||||||
"""Alpha表达式服务""" |
|
||||||
|
|
||||||
import time |
|
||||||
import httpx |
|
||||||
from typing import List, Dict, Any |
|
||||||
from config.settings import settings |
|
||||||
|
|
||||||
|
|
||||||
class AlphaService: |
|
||||||
"""Alpha表达式服务类""" |
|
||||||
|
|
||||||
def __init__(self, client: httpx.Client): |
|
||||||
self.client = client |
|
||||||
|
|
||||||
def simulate_alpha(self, expression: str) -> Dict[str, Any]: |
|
||||||
"""模拟单个Alpha表达式""" |
|
||||||
simulation_data = { |
|
||||||
'type': 'REGULAR', |
|
||||||
'settings': settings.SIMULATION_SETTINGS, |
|
||||||
'regular': expression |
|
||||||
} |
|
||||||
|
|
||||||
sim_resp = self.client.post(f'{settings.BRAIN_API_URL}/simulations', json=simulation_data) |
|
||||||
print(f"模拟提交状态: {sim_resp.status_code}") |
|
||||||
|
|
||||||
if 'location' not in sim_resp.headers: |
|
||||||
return {"status": "err", "message": "No location header in response"} |
|
||||||
|
|
||||||
sim_progress_url = sim_resp.headers['location'] |
|
||||||
return self._wait_for_simulation_result(sim_progress_url) |
|
||||||
|
|
||||||
def _wait_for_simulation_result(self, progress_url: str) -> Dict[str, Any]: |
|
||||||
"""等待模拟结果""" |
|
||||||
while True: |
|
||||||
sim_progress_resp = self.client.get(progress_url) |
|
||||||
retry_after_sec = float(sim_progress_resp.headers.get("Retry-After", 0)) |
|
||||||
|
|
||||||
if retry_after_sec == 0: |
|
||||||
break |
|
||||||
|
|
||||||
if sim_progress_resp.json(): |
|
||||||
result = sim_progress_resp.json() |
|
||||||
progress = result.get('progress', 0) |
|
||||||
print(f"模拟进度: {float(progress) * 100}%") |
|
||||||
|
|
||||||
print(f"等待 {retry_after_sec} 秒...") |
|
||||||
time.sleep(retry_after_sec) |
|
||||||
|
|
||||||
result_data = sim_progress_resp.json() |
|
||||||
|
|
||||||
if result_data.get("status") == "ERROR": |
|
||||||
error_message = result_data.get("message", "未知错误") |
|
||||||
print(f"因子模拟失败: {error_message}") |
|
||||||
return {"status": "err", "message": error_message} |
|
||||||
|
|
||||||
alpha_id = result_data.get("alpha") |
|
||||||
print(f"生成的Alpha ID: {alpha_id}") |
|
||||||
|
|
||||||
return {"status": "ok", "alpha_id": alpha_id} |
|
||||||
@ -1,57 +0,0 @@ |
|||||||
# -*- coding: utf-8 -*- |
|
||||||
"""认证服务""" |
|
||||||
|
|
||||||
import os |
|
||||||
import httpx |
|
||||||
from httpx import BasicAuth |
|
||||||
from typing import Tuple, Dict, Any |
|
||||||
from config.settings import settings |
|
||||||
from models.entities import TokenInfo |
|
||||||
|
|
||||||
|
|
||||||
class AuthService: |
|
||||||
"""认证服务类""" |
|
||||||
|
|
||||||
def __init__(self): |
|
||||||
self.credentials_file = settings.credentials_file |
|
||||||
|
|
||||||
def load_credentials(self) -> Tuple[str, str]: |
|
||||||
"""加载凭证""" |
|
||||||
if not os.path.exists(self.credentials_file): |
|
||||||
self._create_credentials_file() |
|
||||||
|
|
||||||
with open(self.credentials_file, 'r', encoding='utf-8') as f: |
|
||||||
credentials = eval(f.read()) |
|
||||||
return credentials[0], credentials[1] |
|
||||||
|
|
||||||
def _create_credentials_file(self) -> None: |
|
||||||
"""创建凭证文件""" |
|
||||||
print("未找到 account.txt 文件") |
|
||||||
with open(self.credentials_file, 'w', encoding='utf-8') as f: |
|
||||||
f.write("") |
|
||||||
print("account.txt 文件已创建,请填写账号密码, 格式: ['username', 'password']") |
|
||||||
exit(1) |
|
||||||
|
|
||||||
def login(self, client: httpx.Client) -> TokenInfo: |
|
||||||
"""登录并获取token""" |
|
||||||
username, password = self.load_credentials() |
|
||||||
|
|
||||||
# 设置 BasicAuth |
|
||||||
client.auth = BasicAuth(username, password) |
|
||||||
|
|
||||||
response = client.post(f'{settings.BRAIN_API_URL}/authentication') |
|
||||||
print(f"登录状态: {response.status_code}") |
|
||||||
|
|
||||||
if response.status_code == 201: |
|
||||||
login_data = response.json() |
|
||||||
print(f"登录成功!: {login_data}") |
|
||||||
return TokenInfo( |
|
||||||
token=login_data['token'], |
|
||||||
expiry=int(login_data['token']['expiry']) |
|
||||||
) |
|
||||||
elif response.status_code == 429: |
|
||||||
print("API rate limit exceeded") |
|
||||||
exit(1) |
|
||||||
else: |
|
||||||
print(f"登录失败: {response.json()}") |
|
||||||
raise Exception(f"登录失败: {response.json()}") |
|
||||||
@ -1,27 +0,0 @@ |
|||||||
# -*- coding: utf-8 -*- |
|
||||||
"""通知服务""" |
|
||||||
|
|
||||||
import httpx |
|
||||||
from datetime import datetime |
|
||||||
from config.settings import settings |
|
||||||
|
|
||||||
|
|
||||||
class NotificationService: |
|
||||||
"""通知服务类""" |
|
||||||
|
|
||||||
@staticmethod |
|
||||||
def send_to_gotify(success_count: int, fail_count: int) -> None: |
|
||||||
"""发送结果到Gotify""" |
|
||||||
now = datetime.now().strftime('%Y-%m-%d %H:%M:%S') |
|
||||||
text = f"总计: 成功 {success_count} 个, 失败 {fail_count} 个\n\n完成时间: {now}" |
|
||||||
title = f"alpha模拟结果 时间: {now}" |
|
||||||
|
|
||||||
try: |
|
||||||
resp = httpx.post( |
|
||||||
settings.GOTIFY_URL, |
|
||||||
json={'title': title, 'message': text}, |
|
||||||
timeout=10 |
|
||||||
) |
|
||||||
print("通知发送成功") |
|
||||||
except Exception as e: |
|
||||||
print(f"通知发送失败: {e}") |
|
||||||
@ -1,23 +0,0 @@ |
|||||||
# -*- coding: utf-8 -*- |
|
||||||
"""工具函数""" |
|
||||||
|
|
||||||
import time |
|
||||||
from typing import Callable, Any |
|
||||||
|
|
||||||
|
|
||||||
def retry_on_exception(max_retries: int = 3, delay: float = 5.0) -> Callable: |
|
||||||
"""异常重试装饰器""" |
|
||||||
def decorator(func: Callable) -> Callable: |
|
||||||
def wrapper(*args, **kwargs) -> Any: |
|
||||||
for attempt in range(max_retries): |
|
||||||
try: |
|
||||||
return func(*args, **kwargs) |
|
||||||
except Exception as e: |
|
||||||
if attempt == max_retries - 1: |
|
||||||
raise e |
|
||||||
print(f"尝试 {attempt + 1}/{max_retries} 失败: {e}") |
|
||||||
print(f"{delay}秒后重试...") |
|
||||||
time.sleep(delay) |
|
||||||
return None |
|
||||||
return wrapper |
|
||||||
return decorator |
|
||||||
Loading…
Reference in new issue