You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 
alpha_tools/data_sets/load_data_sets.py

129 lines
3.5 KiB

# -*- coding: utf-8 -*-
import os
import csv
from collections import defaultdict
def load_csv_data(file_path):
"""
加载CSV文件数据
Args:
file_path: CSV文件路径
Returns:
数据列表,每个元素是一个字典
"""
if not os.path.exists(file_path):
print(f"文件不存在: {file_path}")
return []
data = []
with open(file_path, 'r', encoding='utf-8') as f:
reader = csv.DictReader(f) # 使用DictReader方便按字段名访问
for row in reader:
data.append(row)
print(f"成功加载 {len(data)} 条数据")
return data
def group_by_dataset(data):
"""
按dataset_name字段对数据进行分组
Args:
data: 数据列表
Returns:
字典,键为dataset_name,值为该数据集下的数据列表
"""
grouped_data = defaultdict(list)
for item in data:
dataset = item.get('dataset_name', '未分类数据集')
grouped_data[dataset].append(item)
return grouped_data
def create_dataset_files(grouped_data):
"""
为每个数据集创建txt文件
Args:
grouped_data: 按数据集分组的数据
"""
# 创建dataset文件夹
dataset_dir = "dataset"
if not os.path.exists(dataset_dir):
os.makedirs(dataset_dir)
print(f"创建文件夹: {dataset_dir}")
# 按数据集名称排序
sorted_datasets = sorted(grouped_data.keys())
for dataset in sorted_datasets:
# 生成安全的文件名(替换特殊字符)
safe_dataset_name = dataset.replace('/', '_').replace('\\', '_').replace(':', '_').replace('*', '_').replace('?', '_').replace('"', '_').replace('<', '_').replace('>', '_').replace('|', '_')
file_path = os.path.join(dataset_dir, f"{safe_dataset_name}.txt")
count = len(grouped_data[dataset])
with open(file_path, 'w', encoding='utf-8') as f:
f.write(f"数据集: {dataset} (共 {count} 条)\n")
f.write("=" * 60 + "\n\n")
# 写入该数据集下的所有name,按name排序
sorted_items = sorted(grouped_data[dataset], key=lambda x: x['name'])
for item in sorted_items:
f.write(f"- {item['name']}\n")
print(f"已创建: {file_path}")
def print_datasets_with_names(grouped_data):
"""
打印数据集和对应的name到控制台
Args:
grouped_data: 按数据集分组的数据
"""
# 按数据集名称排序
sorted_datasets = sorted(grouped_data.keys())
for dataset in sorted_datasets:
count = len(grouped_data[dataset])
print(f"\n数据集: {dataset} (共 {count} 条)")
print("-" * 60)
# 打印该数据集下的所有name
sorted_items = sorted(grouped_data[dataset], key=lambda x: x['name'])
for item in sorted_items:
print(f" - {item['name']}")
def main():
"""
主函数
"""
# CSV文件路径
csv_file_path = "all_data_combined.csv" # 请替换为实际的文件路径
# 1. 加载数据
data = load_csv_data(csv_file_path)
if not data:
return
# 2. 按数据集分组
grouped_data = group_by_dataset(data)
# 3. 创建数据集文件
create_dataset_files(grouped_data)
# 4. 打印到控制台
print_datasets_with_names(grouped_data)
if __name__ == "__main__":
main()