You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
AlphaGenerator/upload_to_pg.py

176 lines
5.8 KiB

import psycopg2
import sys
import os
class AlphaImporter:
def __init__(self):
# 数据库连接信息(硬编码)
self.db_config = {
'host': '192.168.31.201',
'port': 5432,
'database': 'alpha',
'user': 'jack',
'password': 'aaaAAA111'
}
def get_db_connection(self):
"""获取数据库连接"""
try:
conn = psycopg2.connect(
host=self.db_config['host'],
port=self.db_config['port'],
database=self.db_config['database'],
user=self.db_config['user'],
password=self.db_config['password']
)
return conn
except Exception as e:
print(f"❌ 数据库连接失败: {e}")
return None
def read_alpha_file(self, filename="planning_post_alpha.txt"):
"""读取alpha表达式文件"""
if not os.path.exists(filename):
print(f"❌ 文件不存在: {filename}")
return []
try:
with open(filename, 'r', encoding='utf-8') as f:
# 读取所有行,过滤掉空行
lines = [line.strip() for line in f if line.strip()]
print(f"📄 从文件 '{filename}' 读取到 {len(lines)} 个alpha表达式")
return lines
except Exception as e:
print(f"❌ 读取文件失败: {e}")
return []
def import_to_database(self, alpha_list):
"""导入alpha表达式到数据库"""
if not alpha_list:
print(" 没有数据需要导入")
return 0, 0
conn = self.get_db_connection()
if not conn:
return 0, 0
cursor = conn.cursor()
# 1. 首先检查和修复序列
try:
cursor.execute("SELECT MAX(id) FROM alpha_simulation")
max_id_result = cursor.fetchone()
max_id = max_id_result[0] if max_id_result[0] is not None else 0
# 获取当前序列值
cursor.execute("SELECT last_value FROM alpha_simulation_id_seq")
current_seq_val = cursor.fetchone()[0]
print(f"🔍 诊断信息: 表中最大ID={max_id}, 序列当前值={current_seq_val}")
# 如果序列值小于等于最大ID,需要修复
if current_seq_val <= max_id:
new_seq_val = max_id + 1
cursor.execute("SELECT setval('alpha_simulation_id_seq', %s, false)", (new_seq_val,))
conn.commit()
print(f"🔄 已修复序列,新起点: {new_seq_val}")
except Exception as e:
print(f" 序列诊断时出现警告: {e}")
conn.rollback() # 回滚序列检查的事务
# 继续执行,因为可能表是空的或者序列不存在
# 2. 导入数据
imported_count = 0
duplicate_count = 0
failed_count = 0
print(f"\n开始导入 {len(alpha_list)} 个alpha表达式...")
print("-" * 60)
for i, alpha in enumerate(alpha_list, 1):
try:
# 清理alpha表达式(移除前后空格)
clean_alpha = alpha.strip()
if not clean_alpha:
continue
if i % 100 == 0 or i == len(alpha_list):
print(f"处理第 {i}/{len(alpha_list)} 行...")
sql = '''INSERT INTO alpha_simulation (alpha, unused) \
VALUES (%s, %s) ON CONFLICT (alpha) DO NOTHING RETURNING id'''
params = (clean_alpha, True)
cursor.execute(sql, params)
result = cursor.fetchone()
if result:
imported_count += 1
if imported_count % 100 == 0:
print(f" 已成功导入 {imported_count}")
else:
duplicate_count += 1
except Exception as e:
failed_count += 1
print(f"❌ 第 {i} 行导入失败: {str(e)[:100]}")
# 继续处理下一个
continue
# 3. 提交事务并关闭
try:
conn.commit()
print(f"✅ 事务已提交")
except Exception as e:
print(f"❌ 提交事务时出错: {e}")
conn.rollback()
finally:
cursor.close()
conn.close()
print(f"\n" + "=" * 50)
print(f"📊 导入完成统计:")
print(f" 📈 成功导入: {imported_count}")
print(f" 跳过重复: {duplicate_count}")
print(f" ❌ 导入失败: {failed_count}")
print(f" 📄 文件总数: {len(alpha_list)}")
print("=" * 50)
return imported_count, duplicate_count
def main():
"""主函数"""
print("=" * 60)
print("Alpha表达式导入工具")
print("=" * 60)
# 创建导入器实例
importer = AlphaImporter()
# 1. 读取文件
print("\n步骤1: 读取文件...")
alpha_list = importer.read_alpha_file("planning_post_alpha.txt")
if not alpha_list:
print("程序退出:没有找到可导入的数据")
return
# 2. 执行导入
print("\n步骤2: 开始导入到数据库...")
imported, duplicates = importer.import_to_database(alpha_list)
# 3. 显示最终结果
print(f"\n🎉 导入完成!")
print(f" 新增: {imported}")
print(f" 重复: {duplicates}")
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\n\n程序被用户中断")
except Exception as e:
print(f"\n❌ 程序执行出错: {e}")
import traceback
traceback.print_exc()
finally:
exit(0)