You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

144 lines
4.9 KiB

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
快速验证 LLM 解码模板功能
使用方法:
1. 将 llm_received.md 和 prompt.md 放在当前目录
2. 运行: python decode_template.py
输出:
解码后的模板和表达式内容
"""
import sys
import json
import re
from pathlib import Path
import decode_template
def extract_fields_from_prompt(prompt_file: Path) -> list:
"""
从 prompt.md 文件中提取 fields 字段列表
"""
content = prompt_file.read_text(encoding='utf-8')
# 查找 "fields": [...] 的部分
# 使用正则表达式匹配 fields 数组
pattern = r'"fields":\s*(\[.*?\])\s*\n\s*\}'
match = re.search(pattern, content, re.DOTALL)
if not match:
# 尝试另一种格式:直接查找 fields 数组
pattern = r'"fields":\s*(\[.*?\])\s*\n\s*}'
match = re.search(pattern, content, re.DOTALL)
if not match:
raise ValueError("无法在 prompt.md 中找到 fields 字段")
fields_json = match.group(1)
# 处理可能的 Python 字典格式(单引号)
# 将单引号替换为双引号,但要注意已经包含的字符串
try:
fields = json.loads(fields_json)
except json.JSONDecodeError:
# 尝试使用 ast.literal_eval 解析 Python 字面量
import ast
fields = ast.literal_eval(fields_json)
# 转换为 decode_template 需要的格式
data_sets_list = [{'id': field['id']} for field in fields]
return data_sets_list
def main():
# 检查文件是否存在
llm_file = Path('llm_received.md')
prompt_file = Path('prompt.md')
if not llm_file.exists():
print(f"错误: 找不到文件 {llm_file.absolute()}")
print("请确保 llm_received.md 文件存在于当前目录")
sys.exit(1)
if not prompt_file.exists():
print(f"错误: 找不到文件 {prompt_file.absolute()}")
print("请确保 prompt.md 文件存在于当前目录")
sys.exit(1)
# 读取 llm_received.md 文件
llm_template = llm_file.read_text(encoding='utf-8')
print(f"✅ 成功读取 {llm_file.absolute()}")
print(f" 文件大小: {len(llm_template)} 字符")
print("-" * 60)
# 从 prompt.md 提取 fields
try:
data_sets_list = extract_fields_from_prompt(prompt_file)
print(f"\n✅ 成功从 {prompt_file.absolute()} 提取字段")
print(f" 字段数量: {len(data_sets_list)}")
except Exception as e:
print(f"\n❌ 提取字段失败: {e}")
exit(1)
print("\n📊 数据集字段 (前20个):")
for item in data_sets_list[:20]:
print(f" - {item['id']}")
if len(data_sets_list) > 20:
print(f" ... 还有 {len(data_sets_list) - 20} 个字段")
print("-" * 60)
# 调用解码函数
print("\n🔍 开始解码模板...")
print("=" * 60)
result = decode_template.process(data_sets_list, llm_template)
print("\n" + "=" * 60)
print("📋 解码结果:")
print("=" * 60)
if result['success']:
print(f"\n✅ 解码成功!")
print(f"\n📊 统计信息:")
for key, value in result.get('summary', {}).items():
print(f" {key}: {value}")
print(f"\n📝 生成的模板 ({len(result.get('templates', []))} 个):")
for idx, template_item in enumerate(result.get('templates', []), 1):
print(f"\n [{idx}] 模板: {template_item.get('template', '')}")
print(f" 原始模板: {template_item.get('original_template', '')}")
print(f" 表达式数量: {template_item.get('expression_count', 0)}")
idea = template_item.get('idea', '')
print(f" Idea: {idea[:100]}..." if len(
idea) > 100 else f" Idea: {idea}")
print(f" 表达式示例:")
for expr_idx, expr in enumerate(template_item.get('expressions', [])[:3], 1):
print(f" - {expr}")
if len(template_item.get('expressions', [])) > 3:
print(
f" ... 还有 {len(template_item.get('expressions', [])) - 3} 个表达式")
print(f"\n🎯 最终表达式列表 ({len(result.get('expressions', []))} 个):")
for idx, expr in enumerate(result.get('expressions', [])[:20], 1):
print(f" {idx}. {expr}")
if len(result.get('expressions', [])) > 20:
print(f" ... 还有 {len(result.get('expressions', [])) - 20} 个表达式")
# 保存结果到文件
output_file = Path('decode_result.json')
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"\n💾 完整结果已保存到: {output_file.absolute()}")
else:
print(f"\n❌ 解码失败!")
print(f"错误信息: {result.get('error', '未知错误')}")
if __name__ == '__main__':
main()