You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
144 lines
4.9 KiB
144 lines
4.9 KiB
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
快速验证 LLM 解码模板功能
|
|
|
|
使用方法:
|
|
1. 将 llm_received.md 和 prompt.md 放在当前目录
|
|
2. 运行: python decode_template.py
|
|
|
|
输出:
|
|
解码后的模板和表达式内容
|
|
"""
|
|
|
|
import sys
|
|
import json
|
|
import re
|
|
from pathlib import Path
|
|
|
|
import decode_template
|
|
|
|
|
|
def extract_fields_from_prompt(prompt_file: Path) -> list:
|
|
"""
|
|
从 prompt.md 文件中提取 fields 字段列表
|
|
"""
|
|
content = prompt_file.read_text(encoding='utf-8')
|
|
|
|
# 查找 "fields": [...] 的部分
|
|
# 使用正则表达式匹配 fields 数组
|
|
pattern = r'"fields":\s*(\[.*?\])\s*\n\s*\}'
|
|
match = re.search(pattern, content, re.DOTALL)
|
|
|
|
if not match:
|
|
# 尝试另一种格式:直接查找 fields 数组
|
|
pattern = r'"fields":\s*(\[.*?\])\s*\n\s*}'
|
|
match = re.search(pattern, content, re.DOTALL)
|
|
|
|
if not match:
|
|
raise ValueError("无法在 prompt.md 中找到 fields 字段")
|
|
|
|
fields_json = match.group(1)
|
|
|
|
# 处理可能的 Python 字典格式(单引号)
|
|
# 将单引号替换为双引号,但要注意已经包含的字符串
|
|
try:
|
|
fields = json.loads(fields_json)
|
|
except json.JSONDecodeError:
|
|
# 尝试使用 ast.literal_eval 解析 Python 字面量
|
|
import ast
|
|
fields = ast.literal_eval(fields_json)
|
|
|
|
# 转换为 decode_template 需要的格式
|
|
data_sets_list = [{'id': field['id']} for field in fields]
|
|
|
|
return data_sets_list
|
|
|
|
|
|
def main():
|
|
# 检查文件是否存在
|
|
llm_file = Path('llm_received.md')
|
|
prompt_file = Path('prompt.md')
|
|
|
|
if not llm_file.exists():
|
|
print(f"错误: 找不到文件 {llm_file.absolute()}")
|
|
print("请确保 llm_received.md 文件存在于当前目录")
|
|
sys.exit(1)
|
|
|
|
if not prompt_file.exists():
|
|
print(f"错误: 找不到文件 {prompt_file.absolute()}")
|
|
print("请确保 prompt.md 文件存在于当前目录")
|
|
sys.exit(1)
|
|
|
|
# 读取 llm_received.md 文件
|
|
llm_template = llm_file.read_text(encoding='utf-8')
|
|
print(f"✅ 成功读取 {llm_file.absolute()}")
|
|
print(f" 文件大小: {len(llm_template)} 字符")
|
|
print("-" * 60)
|
|
|
|
# 从 prompt.md 提取 fields
|
|
try:
|
|
data_sets_list = extract_fields_from_prompt(prompt_file)
|
|
print(f"\n✅ 成功从 {prompt_file.absolute()} 提取字段")
|
|
print(f" 字段数量: {len(data_sets_list)}")
|
|
except Exception as e:
|
|
print(f"\n❌ 提取字段失败: {e}")
|
|
exit(1)
|
|
|
|
print("\n📊 数据集字段 (前20个):")
|
|
for item in data_sets_list[:20]:
|
|
print(f" - {item['id']}")
|
|
if len(data_sets_list) > 20:
|
|
print(f" ... 还有 {len(data_sets_list) - 20} 个字段")
|
|
print("-" * 60)
|
|
|
|
# 调用解码函数
|
|
print("\n🔍 开始解码模板...")
|
|
print("=" * 60)
|
|
|
|
result = decode_template.process(data_sets_list, llm_template)
|
|
|
|
print("\n" + "=" * 60)
|
|
print("📋 解码结果:")
|
|
print("=" * 60)
|
|
|
|
if result['success']:
|
|
print(f"\n✅ 解码成功!")
|
|
print(f"\n📊 统计信息:")
|
|
for key, value in result.get('summary', {}).items():
|
|
print(f" {key}: {value}")
|
|
|
|
print(f"\n📝 生成的模板 ({len(result.get('templates', []))} 个):")
|
|
for idx, template_item in enumerate(result.get('templates', []), 1):
|
|
print(f"\n [{idx}] 模板: {template_item.get('template', '')}")
|
|
print(f" 原始模板: {template_item.get('original_template', '')}")
|
|
print(f" 表达式数量: {template_item.get('expression_count', 0)}")
|
|
idea = template_item.get('idea', '')
|
|
print(f" Idea: {idea[:100]}..." if len(
|
|
idea) > 100 else f" Idea: {idea}")
|
|
print(f" 表达式示例:")
|
|
for expr_idx, expr in enumerate(template_item.get('expressions', [])[:3], 1):
|
|
print(f" - {expr}")
|
|
if len(template_item.get('expressions', [])) > 3:
|
|
print(
|
|
f" ... 还有 {len(template_item.get('expressions', [])) - 3} 个表达式")
|
|
|
|
print(f"\n🎯 最终表达式列表 ({len(result.get('expressions', []))} 个):")
|
|
for idx, expr in enumerate(result.get('expressions', [])[:20], 1):
|
|
print(f" {idx}. {expr}")
|
|
if len(result.get('expressions', [])) > 20:
|
|
print(f" ... 还有 {len(result.get('expressions', [])) - 20} 个表达式")
|
|
|
|
# 保存结果到文件
|
|
output_file = Path('decode_result.json')
|
|
with open(output_file, 'w', encoding='utf-8') as f:
|
|
json.dump(result, f, ensure_ascii=False, indent=2)
|
|
print(f"\n💾 完整结果已保存到: {output_file.absolute()}")
|
|
|
|
else:
|
|
print(f"\n❌ 解码失败!")
|
|
print(f"错误信息: {result.get('error', '未知错误')}")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|
|
|