数据集爬取完成

main
Jack 2 weeks ago
parent 129e328f8a
commit 5b15972069
  1. BIN
      data_sets.db
  2. BIN
      data_sets.db.bak
  3. 222
      get_category.py
  4. 515
      main.py
  5. 6
      new_db.py

Binary file not shown.

Binary file not shown.

@ -3,6 +3,7 @@ import os
import json
import random
import time
import sqlite3
import httpx
from httpx import BasicAuth
@ -10,8 +11,67 @@ from httpx import BasicAuth
class CategoryDownloader:
def __init__(self):
self.base_api_url = 'https://api.worldquantbrain.com'
self.db_path = os.path.join(os.getcwd(), 'data_sets.db')
self.client = self.login()
def create_database_if_not_exists(self):
"""创建数据库和表(如果不存在)"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# 创建category表
cursor.execute('''
CREATE TABLE IF NOT EXISTS category
(
name
TEXT,
base_url
TEXT,
category_id
TEXT,
region
TEXT,
universe
TEXT,
instrumentType
TEXT,
delay
TEXT,
downloaded
INTEGER
)
''')
# 创建data_sets表
cursor.execute('''
CREATE TABLE IF NOT EXISTS data_sets
(
name
TEXT,
description
TEXT,
dataset_id
TEXT,
dataset_name
TEXT,
category_id
TEXT,
category_name
TEXT,
region
TEXT,
delay
TEXT,
universe
TEXT,
type
TEXT
)
''')
conn.commit()
conn.close()
def login(self):
"""登录并返回客户端实例"""
username, password = "jack0210_@hotmail.com", "!QAZ2wsx+0913"
@ -31,79 +91,112 @@ class CategoryDownloader:
print(f"登录过程中出现错误: {e}")
return None
def save_to_database(self, category, category_id, region, universe, instrumentType, delay):
"""保存数据到SQLite数据库"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# 检查是否已存在相同记录
cursor.execute('''
SELECT COUNT(*)
FROM category
WHERE name = ?
AND category_id = ?
AND region = ?
AND universe = ?
''', (category, category_id, region, universe))
if cursor.fetchone()[0] == 0:
cursor.execute('''
INSERT INTO category (name, base_url, category_id, region, universe, instrumentType, delay, downloaded)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
''', (category, 'https://api.worldquantbrain.com/data-fields',
category_id, region, universe, instrumentType, delay, 0))
print(f"已保存: {category} - {category_id} - {region} - {universe}")
else:
print(f"记录已存在: {category} - {category_id} - {region} - {universe}")
conn.commit()
conn.close()
def fetch_category_data(self, category, delay, instrumentType, region_list, universe_list):
"""获取分类数据并保存到JSON文件"""
results = []
# 创建category_files文件夹
output_dir = "category_files"
if not os.path.exists(output_dir):
os.makedirs(output_dir)
print(f"已创建文件夹: {output_dir}")
if self.client:
for region in region_list:
for universe in universe_list:
url = f'https://api.worldquantbrain.com/data-sets?category={category}&delay={delay}&instrumentType={instrumentType}&limit=50&offset=0&region={region}&universe={universe}'
print(f"请求URL: {url}")
try:
response = self.client.get(url)
if response.status_code == 200:
data = response.json()
if data.get('count', 0) > 0:
for item in data.get('results', []):
result_item = {
'id': item.get('id', ''),
'region': item.get('region', ''),
'universe': item.get('universe', '')
}
results.append(result_item)
"""获取分类数据并保存到数据库"""
# 确保数据库已创建
self.create_database_if_not_exists()
if not self.client:
print("客户端未初始化,无法获取数据")
return []
for region in region_list:
for universe in universe_list:
url = f'https://api.worldquantbrain.com/data-sets?category={category}&delay={delay}&instrumentType={instrumentType}&limit=50&offset=0&region={region}&universe={universe}'
print(f"请求URL: {url}")
try:
response = self.client.get(url)
if response.status_code == 200:
data = response.json()
if data.get('count', 0) > 0:
for item in data.get('results', []):
category_id = item.get('id', '')
region = item.get('region', '')
universe = item.get('universe', '')
# 保存到数据库
self.save_to_database(category, category_id, region, universe, instrumentType, delay)
else:
print(f"请求失败: {response.status_code}")
print(f"未找到数据: {category} - {region} - {universe}")
else:
print(f"请求失败: {response.status_code}")
except Exception as e:
print(f"请求过程中出现错误: {e}")
except Exception as e:
print(f"请求过程中出现错误: {e}")
time.sleep(random.uniform(5, 8))
sleep_time = random.uniform(10, 15)
print(f"等待 {sleep_time} 秒后继续请求")
time.sleep(sleep_time)
# 保存到JSON文件
filename = os.path.join(output_dir, f"{category}.json")
with open(filename, 'w', encoding='utf-8') as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print(f"数据获取完成,已保存到数据库: {self.db_path}")
print(f"数据已保存到: {filename}")
print(f"总共找到 {len(results)} 条记录")
def check_database_records(self):
"""检查数据库中的记录"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
return results
cursor.execute("SELECT COUNT(*) FROM category")
count = cursor.fetchone()[0]
print(f"category表中总记录数: {count}")
conn.close()
if __name__ == "__main__":
downloader = CategoryDownloader()
if downloader.client:
# category_list = [
# 'analyst',
# 'broker',
# 'earnings',
# 'fundamental',
# 'imbalance',
# 'insiders',
# 'institutions',
# 'macro',
# 'model',
# 'news',
# 'option',
# 'other',
# 'pv',
# 'risk',
# 'sentiment',
# 'shortinterest',
# 'socialmedia'
# ]
category = 'socialmedia'
category_list = [
'analyst',
'broker',
'earnings',
'fundamental',
'imbalance',
'insiders',
'institutions',
'macro',
'model',
'news',
'option',
'other',
'pv',
'risk',
'sentiment',
'shortinterest',
'socialmedia'
]
delay = '1'
instrumentType = 'EQUITY'
@ -124,7 +217,12 @@ if __name__ == "__main__":
'TOP200',
'TOPSP500',
'ILLIQUID_MINVOL1M',
# 'MINVOL1M'
'MINVOL1M'
]
downloader.fetch_category_data(category, delay, instrumentType, region_list, universe_list)
for category in category_list:
downloader.fetch_category_data(category, delay, instrumentType, region_list, universe_list)
downloader.check_database_records()
sleep_time = 30
print(f'{category} 数据已下载, 程序休眠 {sleep_time}')
time.sleep(sleep_time)

@ -1,40 +1,22 @@
# -*- coding: utf-8 -*-
import os
import json
import random
import time
import csv
import sqlite3
import httpx
from httpx import BasicAuth
def read_category_json(filename):
# 构建文件路径
file_path = os.path.join('category_files', filename)
try:
# 读取并解析 JSON 文件
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
print(f"成功读取文件:{file_path}")
return data
except FileNotFoundError:
print(f"错误:文件不存在 - {file_path}")
return None
except json.JSONDecodeError:
print(f"错误:JSON 格式错误 - {file_path}")
return None
except PermissionError:
print(f"错误:没有文件读取权限 - {file_path}")
return None
except Exception as e:
print(f"读取文件时发生未知错误:{e}")
return None
class DataSetDownloader:
def __init__(self):
self.base_api_url = 'https://api.worldquantbrain.com'
self.client = self.login()
self.db_path = os.path.join(os.getcwd(), 'data_sets.db')
self.client = None
self.login_time = None
self.total_tasks = 0
self.completed_tasks = 0
self.current_task_index = 0
self.login()
def login(self):
"""登录并返回客户端实例"""
@ -47,34 +29,226 @@ class DataSetDownloader:
if response.status_code in [200, 201]:
print("登录成功!")
self.client = client
self.login_time = time.time()
return client
else:
print(f"登录失败: {response.json()}")
return None
exit(1)
except Exception as e:
print(f"登录过程中出现错误: {e}")
return None
exit(1)
def check_and_renew_login(self):
"""检查登录状态,超过3.5小时则重新登录"""
if not self.login_time:
return
def _debug_response(self, endpoint, data_set_id, offset=0, limit=20):
"""调试请求响应"""
print(f"\n=== 调试请求: {endpoint} ===")
url = f"{self.base_api_url}/{data_set_id}"
params = self._build_params(data_set_id, offset, limit)
current_time = time.time()
elapsed_hours = (current_time - self.login_time) / 3600
response = self.client.get(url, params=params)
if elapsed_hours > 3.5:
print(f"\n 登录已超过{elapsed_hours:.1f}小时,正在重新登录...")
self.login()
print("✅ 重新登录成功,继续执行任务")
def get_download_stats(self):
"""获取下载统计信息"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# 获取总任务数
cursor.execute('SELECT COUNT(*) FROM category WHERE downloaded = 0')
remaining_tasks = cursor.fetchone()[0]
# 获取已完成任务数
cursor.execute('SELECT COUNT(*) FROM category WHERE downloaded = 1')
completed_tasks = cursor.fetchone()[0]
conn.close()
total_tasks = remaining_tasks + completed_tasks
progress_percentage = (completed_tasks / total_tasks * 100) if total_tasks > 0 else 0
return {
'total_tasks': total_tasks,
'remaining_tasks': remaining_tasks,
'completed_tasks': completed_tasks,
'progress_percentage': progress_percentage
}
def print_progress_bar(self, current, total, prefix='', suffix='', length=50, fill=''):
"""打印进度条"""
percent = f"{100 * (current / float(total)):.1f}"
filled_length = int(length * current // total)
bar = fill * filled_length + '-' * (length - filled_length)
progress_text = f'\r{prefix} |{bar}| {percent}% {suffix}'
if current == total:
print(progress_text)
else:
print(progress_text, end='', flush=True)
def test_download(self):
"""测试每个category分类是否能成功下载数据"""
if not self.client:
print("❌ 客户端未初始化")
exit(1)
print("\n" + "=" * 60)
print("开始测试所有category分类...")
print("=" * 60)
# 从数据库获取每个category分类的一条测试记录
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# 使用DISTINCT获取每个category.name的第一条记录
cursor.execute('''
SELECT DISTINCT name, category_id, region, universe, instrumentType, delay
FROM category
WHERE downloaded = 0
ORDER BY name
''')
test_categories = cursor.fetchall()
conn.close()
print(f"找到 {len(test_categories)} 个不同的category分类进行测试")
success_count = 0
fail_count = 0
for idx, category in enumerate(test_categories):
# 检查登录状态
self.check_and_renew_login()
category_name, category_id, region, universe, instrumentType, delay = category
print(f"\n🔍 测试 [{idx + 1}/{len(test_categories)}]: {category_name}")
print(f" category_id: {category_id}")
print(f" region: {region}, universe: {universe}")
try:
# 测试下载(只下载第一页的1条数据)
result = self._test_single_category(
category_name, category_id, region, universe, instrumentType, delay
)
if result:
print(f" ✅ 测试成功")
success_count += 1
else:
print(f" ❌ 测试失败")
fail_count += 1
except Exception as e:
print(f" ❌ 测试异常: {e}")
fail_count += 1
# 打印测试进度
self.print_progress_bar(idx + 1, len(test_categories),
prefix='测试进度:',
suffix=f'完成 {success_count} 成功, {fail_count} 失败')
# 测试间隔
time.sleep(random.uniform(3, 5))
# 输出测试结果
print("\n" + "=" * 60)
print("测试结果汇总:")
print("=" * 60)
print(f"总测试分类数: {len(test_categories)}")
print(f"成功: {success_count}")
print(f"失败: {fail_count}")
if fail_count == 0:
print("\n🎉 所有category分类测试通过!")
else:
print(f"\n{fail_count} 个分类测试失败")
return fail_count == 0
def _test_single_category(self, category_name, category_id, region, universe, instrumentType, delay):
"""测试单个category是否能成功下载"""
endpoint = 'data-fields'
url = f"{self.base_api_url}/{endpoint}"
# 构建参数,只请求1条数据
params = {
'dataset.id': category_id,
'delay': delay,
'instrumentType': instrumentType,
'limit': 1,
'offset': 0,
'region': region,
'universe': universe
}
try:
response = self.client.get(url, params=params, timeout=30)
if response.status_code != 200:
print(f" HTTP状态码: {response.status_code}")
return False
if response.status_code == 200:
data = response.json()
print(f"count: {data.get('count')}")
print(f"results 长度: {len(data.get('results', []))}")
print(f"响应键: {list(data.keys())}")
total_count = data.get('count', 0)
results = data.get('results', [])
print(f" 数据总数: {total_count}, 本页获取: {len(results)}")
if total_count > 0 and len(results) > 0:
# 打印第一条数据的部分信息用于验证
item = results[0]
print(f" 示例数据字段 - id: {item.get('id', 'N/A')}, name: {item.get('name', 'N/A')}")
return True
else:
print(f" 无数据返回")
return total_count == 0 # 如果count=0也算测试成功(只是没数据)
except httpx.TimeoutException:
print(" ⏰ 请求超时")
return False
except Exception as e:
print(f" ❌ 请求异常: {e}")
return False
def get_categories_to_download(self):
"""从数据库获取需要下载的category记录"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute('''
SELECT name, base_url, category_id, region, universe, instrumentType, delay
FROM category
WHERE downloaded = 0
ORDER BY rowid
''')
categories = cursor.fetchall()
conn.close()
# 更新统计信息
self.total_tasks = len(categories)
self.completed_tasks = 0
self.current_task_index = 0
print(f"找到 {self.total_tasks} 个需要下载的category")
# 打印总体进度
if self.total_tasks > 0:
stats = self.get_download_stats()
print(f"总体进度: {stats['completed_tasks']}/{stats['total_tasks']} ({stats['progress_percentage']:.1f}%)")
return categories
def _build_params(self, data_set_id, region, universe, offset=0, limit=50):
"""构建请求参数"""
def _build_params(self, category_id, region, universe, instrumentType, delay, offset=0, limit=50):
"""构建data-fields请求参数"""
return {
'dataset.id': data_set_id,
'delay': 1,
'instrumentType': 'EQUITY',
'dataset.id': category_id,
'delay': delay,
'instrumentType': instrumentType,
'limit': limit,
'offset': offset,
'region': region,
@ -82,152 +256,211 @@ class DataSetDownloader:
}
def _process_item(self, item):
"""处理单个数据项"""
"""处理单个data-field项 - name字段对应返回的id字段"""
return {
'id': item.get('id', ''),
'name': item.get('id', ''), # name字段对应返回的id字段
'description': item.get('description', ''),
'dataset_id': item.get('dataset', {}).get('id', ''),
'dataset_name': item.get('dataset', {}).get('name', ''),
'category_id': item.get('category', {}).get('id', ''),
'category_name': item.get('category', {}).get('name', ''),
'region': item.get('region', ''),
'delay': item.get('delay', ''),
'delay': str(item.get('delay', '')),
'universe': item.get('universe', ''),
'type': item.get('type', '')
}
def _process_data(self, raw_data):
"""批量处理数据"""
return [self._process_item(item) for item in raw_data]
def download_data_set(self, data_set_id, region, universe):
def save_to_data_sets_table(self, data_items):
"""批量保存到data_sets表"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
for item in data_items:
cursor.execute('''
INSERT INTO data_sets (name, description, dataset_id, dataset_name,
category_id, category_name, region, delay,
universe, type)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (
item['name'], item['description'], item['dataset_id'], item['dataset_name'],
item['category_id'], item['category_name'], item['region'], item['delay'],
item['universe'], item['type']
))
conn.commit()
conn.close()
print(f"✅ 保存了 {len(data_items)} 条记录到 data_sets 表")
def update_category_status(self, category_id, region, universe):
"""更新category表的下载状态"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute('''
UPDATE category
SET downloaded = 1
WHERE category_id = ?
AND region = ?
AND universe = ?
''', (category_id, region, universe))
conn.commit()
conn.close()
# 更新进度统计
self.completed_tasks += 1
stats = self.get_download_stats()
print(f"✅ 已更新category状态: {category_id} - {region} - {universe}")
print(f"📊 总体进度: {stats['completed_tasks']}/{stats['total_tasks']} ({stats['progress_percentage']:.1f}%)")
def download_category_data_fields(self, category_name, category_id, region, universe, instrumentType, delay):
"""下载指定category的所有data-fields"""
endpoint = 'data-fields'
"""下载数据集"""
# 检查登录状态
if not self.client:
print("❌ 客户端未初始化,无法下载数据")
return
url = f"{self.base_api_url}/{endpoint}"
# 调试请求
self._debug_response(endpoint, data_set_id, offset=0, limit=20)
print(f"\n开始下载: {category_name} - {category_id} - {region} - {universe}")
# 获取数据总数
url = f"{self.base_api_url}/{endpoint}"
params = self._build_params(data_set_id, region, universe, limit=1)
# 打印当前任务进度
self.current_task_index += 1
print(f"📋 任务进度: {self.current_task_index}/{self.total_tasks}")
# 获取数据总数
params = self._build_params(category_id, region, universe, instrumentType, delay, limit=1)
response = self.client.get(url, params=params)
if response.status_code != 200:
print(f"❌ 获取数据总数失败: {response.status_code}")
exit(1)
data = response.json()
total_count = data.get('count', 0)
print(f"📊 数据集总数: {total_count}")
print(f"📊 数据总数: {total_count}")
if total_count == 0:
print("❌ 没有找到数据")
print(" 没有找到数据,跳过")
return
# 下载所有数据
limit = 50
all_data = []
print("🚀 开始下载数据...")
all_data_items = []
for offset in range(0, total_count, limit):
time.sleep(random.uniform(1.0, 1.5))
# 每次循环开始前检查登录状态
self.check_and_renew_login()
params = self._build_params(data_set_id, region, universe, offset, limit) # 修正参数
print(f"📥 下载进度: {offset}/{total_count} ({offset / total_count * 100:.1f}%)")
time.sleep(random.uniform(2, 2.5))
results = []
params = self._build_params(category_id, region, universe, instrumentType, delay, offset, limit)
retry = 3
while retry > 0:
try:
response = self.client.get(url, params=params)
# 打印数据下载进度
data_progress = f" {offset}/{total_count} ({offset / total_count * 100:.1f}%)"
print(f"📥 数据下载进度:{data_progress}")
if response.status_code == 200:
data = response.json()
results = data.get('results', [])
response = self.client.get(url, params=params)
print(f"✅ 本页获取到 {len(results)} 条记录")
all_data.extend(results)
if response.status_code != 200:
print(f"❌ 下载失败: {response.status_code}")
exit(1)
# 成功时退出重试循环
break
else:
print(f"❌ 请求失败: {response.status_code}")
retry -= 1
if retry > 0:
print(f"🔄 重试中... ({retry}次剩余)")
time.sleep(random.uniform(2, 3))
data = response.json()
results = data.get('results', [])
except Exception as e:
print(f"❌ 下载过程中出错: {e}")
retry -= 1
if retry > 0:
print(f"🔄 重试中... ({retry}次剩余)")
time.sleep(random.uniform(2, 3))
# 处理数据
processed_items = [self._process_item(item) for item in results]
all_data_items.extend(processed_items)
# 如果重试用完仍失败,跳过当前offset继续下一个
if retry == 0:
print(f" 跳过 offset {offset}")
continue
print(f"✅ 本页获取到 {len(results)} 条记录")
if len(results) < limit:
print("🎯 到达数据末尾")
break
time.sleep(random.uniform(10, 15))
# 保存到数据库
if all_data_items:
self.save_to_data_sets_table(all_data_items)
def run(self):
"""主运行函数"""
if not self.client:
print("❌ 客户端未初始化")
exit(1)
# 打印开始时的总体进度
print("\n" + "=" * 60)
print("开始下载任务...")
stats = self.get_download_stats()
print(f"初始状态: 已完成 {stats['completed_tasks']}/{stats['total_tasks']} ({stats['progress_percentage']:.1f}%)")
print("=" * 60)
categories = self.get_categories_to_download()
if not categories:
print("✅ 所有category都已下载完成")
return
for category in categories:
# 每次循环开始前检查登录状态
self.check_and_renew_login()
category_name, base_url, category_id, region, universe, instrumentType, delay = category
print(f"\n{'=' * 60}")
print(f"处理category: {category_name}")
print(f"category_id: {category_id}")
print(f"region: {region}, universe: {universe}")
print(f"{'=' * 60}")
try:
# 下载data-fields
self.download_category_data_fields(
category_name, category_id, region, universe, instrumentType, delay
)
# 处理数据
print("🔄 处理数据中...")
processed_data = self._process_data(all_data)
# 更新下载状态
self.update_category_status(category_id, region, universe)
# 确保输出目录存在
output_dir = 'reference_fields'
os.makedirs(output_dir, exist_ok=True)
print(f"✅ 完成: {category_name} - {region} - {universe}")
# 保存为CSV
output_file = os.path.join(output_dir, f"{data_set_id}_{region.lower()}_{universe.lower()}.csv")
except Exception as e:
print(f"❌ 处理失败: {e}")
return
if processed_data:
fieldnames = list(processed_data[0].keys())
with open(output_file, 'w', encoding='utf-8', newline='') as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(processed_data)
sleep_time = random.uniform(20, 30)
print(f"\n等待 {sleep_time} 秒后继续下一个...")
time.sleep(sleep_time)
print(f"💾 处理后的数据已保存到: {output_file}")
print(f"🎉 总共处理了 {len(processed_data)} 条记录")
# 打印最终进度
print("\n" + "=" * 60)
print("🎉 所有category处理完成!")
final_stats = self.get_download_stats()
print(f"最终状态: 已完成 {final_stats['completed_tasks']}/{final_stats['total_tasks']} ({final_stats['progress_percentage']:.1f}%)")
print("=" * 60)
exit(0)
if __name__ == "__main__":
# category_list = [
# 'analyst',
# 'broker',
# 'earnings',
# 'fundamental',
# 'imbalance',
# 'insiders',
# 'institutions',
# 'macro',
# 'model',
# 'news',
# 'option',
# 'other',
# 'pv',
# 'risk',
# 'sentiment',
# 'shortinterest',
# 'socialmedia'
# ]
plan_to_download = read_category_json('analyst.json')
# 'analyst',
# 'broker',
# 'earnings',
# 'fundamental',
# 'imbalance',
# 'insiders',
# 'institutions',
# 'macro',
# 'model',
# 'news',
# 'option',
# 'other',
# 'pv',
# 'risk',
# 'sentiment',
# 'shortinterest',
# 'socialmedia'
# ]
downloader = DataSetDownloader()
if downloader.client:
for item in plan_to_download:
downloader.download_data_set(item['id'], item['region'], item['universe'])
time.sleep(random.uniform(20, 30))
else:
print("❌ 登录失败,无法下载数据")
while True:
downloader.run()
time.sleep(30)

@ -21,7 +21,7 @@ def create_category_table(cursor):
region TEXT,
universe TEXT,
delay TEXT,
downloaded TEXT
downloaded INTEGER
)
''')
@ -42,4 +42,6 @@ def create_datasets_table(cursor):
''')
if __name__ == "__main__":
create_database()
create_database()
# git config --global user.email "jack0210_@hotmail.com"
Loading…
Cancel
Save