# -*- coding: utf-8 -*- import os import random import time import sqlite3 import httpx from httpx import BasicAuth class DataSetDownloader: def __init__(self): self.base_api_url = 'https://api.worldquantbrain.com' self.db_path = os.path.join(os.getcwd(), 'data_sets.db') self.client = None self.login_time = None self.total_tasks = 0 self.completed_tasks = 0 self.current_task_index = 0 self.login() def login(self): """登录并返回客户端实例""" username, password = "jack0210_@hotmail.com", "!QAZ2wsx+0913" client = httpx.Client(auth=BasicAuth(username, password)) try: response = client.post(f'{self.base_api_url}/authentication') print(f"登录状态: {response.status_code}") if response.status_code in [200, 201]: print("登录成功!") self.client = client self.login_time = time.time() return client else: print(f"登录失败: {response.json()}") exit(1) except Exception as e: print(f"登录过程中出现错误: {e}") exit(1) def check_and_renew_login(self): """检查登录状态,超过3.5小时则重新登录""" if not self.login_time: return current_time = time.time() elapsed_hours = (current_time - self.login_time) / 3600 if elapsed_hours > 3.5: print(f"\n⚠️ 登录已超过{elapsed_hours:.1f}小时,正在重新登录...") self.login() print("✅ 重新登录成功,继续执行任务") def get_download_stats(self): """获取下载统计信息""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() # 获取总任务数 cursor.execute('SELECT COUNT(*) FROM category WHERE downloaded = 0') remaining_tasks = cursor.fetchone()[0] # 获取已完成任务数 cursor.execute('SELECT COUNT(*) FROM category WHERE downloaded = 1') completed_tasks = cursor.fetchone()[0] conn.close() total_tasks = remaining_tasks + completed_tasks progress_percentage = (completed_tasks / total_tasks * 100) if total_tasks > 0 else 0 return { 'total_tasks': total_tasks, 'remaining_tasks': remaining_tasks, 'completed_tasks': completed_tasks, 'progress_percentage': progress_percentage } def print_progress_bar(self, current, total, prefix='', suffix='', length=50, fill='█'): """打印进度条""" percent = f"{100 * (current / float(total)):.1f}" filled_length = int(length * current // total) bar = fill * filled_length + '-' * (length - filled_length) progress_text = f'\r{prefix} |{bar}| {percent}% {suffix}' if current == total: print(progress_text) else: print(progress_text, end='', flush=True) def test_download(self): """测试每个category分类是否能成功下载数据""" if not self.client: print("❌ 客户端未初始化") exit(1) print("\n" + "=" * 60) print("开始测试所有category分类...") print("=" * 60) # 从数据库获取每个category分类的一条测试记录 conn = sqlite3.connect(self.db_path) cursor = conn.cursor() # 使用DISTINCT获取每个category.name的第一条记录 cursor.execute(''' SELECT DISTINCT name, category_id, region, universe, instrumentType, delay FROM category WHERE downloaded = 0 ORDER BY name ''') test_categories = cursor.fetchall() conn.close() print(f"找到 {len(test_categories)} 个不同的category分类进行测试") success_count = 0 fail_count = 0 for idx, category in enumerate(test_categories): # 检查登录状态 self.check_and_renew_login() category_name, category_id, region, universe, instrumentType, delay = category print(f"\n🔍 测试 [{idx + 1}/{len(test_categories)}]: {category_name}") print(f" category_id: {category_id}") print(f" region: {region}, universe: {universe}") try: # 测试下载(只下载第一页的1条数据) result = self._test_single_category( category_name, category_id, region, universe, instrumentType, delay ) if result: print(f" ✅ 测试成功") success_count += 1 else: print(f" ❌ 测试失败") fail_count += 1 except Exception as e: print(f" ❌ 测试异常: {e}") fail_count += 1 # 打印测试进度 self.print_progress_bar(idx + 1, len(test_categories), prefix='测试进度:', suffix=f'完成 {success_count} 成功, {fail_count} 失败') # 测试间隔 time.sleep(random.uniform(3, 5)) # 输出测试结果 print("\n" + "=" * 60) print("测试结果汇总:") print("=" * 60) print(f"总测试分类数: {len(test_categories)}") print(f"成功: {success_count}") print(f"失败: {fail_count}") if fail_count == 0: print("\n🎉 所有category分类测试通过!") else: print(f"\n⚠️ 有 {fail_count} 个分类测试失败") return fail_count == 0 def _test_single_category(self, category_name, category_id, region, universe, instrumentType, delay): """测试单个category是否能成功下载""" endpoint = 'data-fields' url = f"{self.base_api_url}/{endpoint}" # 构建参数,只请求1条数据 params = { 'dataset.id': category_id, 'delay': delay, 'instrumentType': instrumentType, 'limit': 1, 'offset': 0, 'region': region, 'universe': universe } try: response = self.client.get(url, params=params, timeout=30) if response.status_code != 200: print(f" HTTP状态码: {response.status_code}") return False data = response.json() total_count = data.get('count', 0) results = data.get('results', []) print(f" 数据总数: {total_count}, 本页获取: {len(results)}") if total_count > 0 and len(results) > 0: # 打印第一条数据的部分信息用于验证 item = results[0] print(f" 示例数据字段 - id: {item.get('id', 'N/A')}, name: {item.get('name', 'N/A')}") return True else: print(f" ⚠️ 无数据返回") return total_count == 0 # 如果count=0也算测试成功(只是没数据) except httpx.TimeoutException: print(" ⏰ 请求超时") return False except Exception as e: print(f" ❌ 请求异常: {e}") return False def get_categories_to_download(self): """从数据库获取需要下载的category记录""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute(''' SELECT name, base_url, category_id, region, universe, instrumentType, delay FROM category WHERE downloaded = 0 ORDER BY rowid ''') categories = cursor.fetchall() conn.close() # 更新统计信息 self.total_tasks = len(categories) self.completed_tasks = 0 self.current_task_index = 0 print(f"找到 {self.total_tasks} 个需要下载的category") # 打印总体进度 if self.total_tasks > 0: stats = self.get_download_stats() print(f"总体进度: {stats['completed_tasks']}/{stats['total_tasks']} ({stats['progress_percentage']:.1f}%)") return categories def _build_params(self, category_id, region, universe, instrumentType, delay, offset=0, limit=50): """构建data-fields请求参数""" return { 'dataset.id': category_id, 'delay': delay, 'instrumentType': instrumentType, 'limit': limit, 'offset': offset, 'region': region, 'universe': universe } def _process_item(self, item): """处理单个data-field项 - name字段对应返回的id字段""" return { 'name': item.get('id', ''), # name字段对应返回的id字段 'description': item.get('description', ''), 'dataset_id': item.get('dataset', {}).get('id', ''), 'dataset_name': item.get('dataset', {}).get('name', ''), 'category_id': item.get('category', {}).get('id', ''), 'category_name': item.get('category', {}).get('name', ''), 'region': item.get('region', ''), 'delay': str(item.get('delay', '')), 'universe': item.get('universe', ''), 'type': item.get('type', '') } def save_to_data_sets_table(self, data_items): """批量保存到data_sets表""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() for item in data_items: cursor.execute(''' INSERT INTO data_sets (name, description, dataset_id, dataset_name, category_id, category_name, region, delay, universe, type) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ''', ( item['name'], item['description'], item['dataset_id'], item['dataset_name'], item['category_id'], item['category_name'], item['region'], item['delay'], item['universe'], item['type'] )) conn.commit() conn.close() print(f"✅ 保存了 {len(data_items)} 条记录到 data_sets 表") def update_category_status(self, category_id, region, universe): """更新category表的下载状态""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute(''' UPDATE category SET downloaded = 1 WHERE category_id = ? AND region = ? AND universe = ? ''', (category_id, region, universe)) conn.commit() conn.close() # 更新进度统计 self.completed_tasks += 1 stats = self.get_download_stats() print(f"✅ 已更新category状态: {category_id} - {region} - {universe}") print(f"📊 总体进度: {stats['completed_tasks']}/{stats['total_tasks']} ({stats['progress_percentage']:.1f}%)") def download_category_data_fields(self, category_name, category_id, region, universe, instrumentType, delay): """下载指定category的所有data-fields""" endpoint = 'data-fields' url = f"{self.base_api_url}/{endpoint}" print(f"\n开始下载: {category_name} - {category_id} - {region} - {universe}") # 打印当前任务进度 self.current_task_index += 1 print(f"📋 任务进度: {self.current_task_index}/{self.total_tasks}") # 获取数据总数 params = self._build_params(category_id, region, universe, instrumentType, delay, limit=1) response = self.client.get(url, params=params) if response.status_code != 200: print(f"❌ 获取数据总数失败: {response.status_code}") exit(1) data = response.json() total_count = data.get('count', 0) print(f"📊 数据总数: {total_count}") if total_count == 0: print("⚠️ 没有找到数据,跳过") return # 下载所有数据 limit = 50 all_data_items = [] for offset in range(0, total_count, limit): # 每次循环开始前检查登录状态 self.check_and_renew_login() time.sleep(random.uniform(2, 2.5)) params = self._build_params(category_id, region, universe, instrumentType, delay, offset, limit) # 打印数据下载进度 data_progress = f" {offset}/{total_count} ({offset / total_count * 100:.1f}%)" print(f"📥 数据下载进度:{data_progress}") response = self.client.get(url, params=params) if response.status_code != 200: print(f"❌ 下载失败: {response.status_code}") exit(1) data = response.json() results = data.get('results', []) # 处理数据 processed_items = [self._process_item(item) for item in results] all_data_items.extend(processed_items) print(f"✅ 本页获取到 {len(results)} 条记录") if len(results) < limit: print("🎯 到达数据末尾") break # 保存到数据库 if all_data_items: self.save_to_data_sets_table(all_data_items) def run(self): """主运行函数""" if not self.client: print("❌ 客户端未初始化") exit(1) # 打印开始时的总体进度 print("\n" + "=" * 60) print("开始下载任务...") stats = self.get_download_stats() print(f"初始状态: 已完成 {stats['completed_tasks']}/{stats['total_tasks']} ({stats['progress_percentage']:.1f}%)") print("=" * 60) categories = self.get_categories_to_download() if not categories: print("✅ 所有category都已下载完成") return for category in categories: # 每次循环开始前检查登录状态 self.check_and_renew_login() category_name, base_url, category_id, region, universe, instrumentType, delay = category print(f"\n{'=' * 60}") print(f"处理category: {category_name}") print(f"category_id: {category_id}") print(f"region: {region}, universe: {universe}") print(f"{'=' * 60}") try: # 下载data-fields self.download_category_data_fields( category_name, category_id, region, universe, instrumentType, delay ) # 更新下载状态 self.update_category_status(category_id, region, universe) print(f"✅ 完成: {category_name} - {region} - {universe}") except Exception as e: print(f"❌ 处理失败: {e}") return sleep_time = random.uniform(20, 30) print(f"\n等待 {sleep_time} 秒后继续下一个...") time.sleep(sleep_time) # 打印最终进度 print("\n" + "=" * 60) print("🎉 所有category处理完成!") final_stats = self.get_download_stats() print(f"最终状态: 已完成 {final_stats['completed_tasks']}/{final_stats['total_tasks']} ({final_stats['progress_percentage']:.1f}%)") print("=" * 60) exit(0) if __name__ == "__main__": # category_list = [ # 'analyst', # 'broker', # 'earnings', # 'fundamental', # 'imbalance', # 'insiders', # 'institutions', # 'macro', # 'model', # 'news', # 'option', # 'other', # 'pv', # 'risk', # 'sentiment', # 'shortinterest', # 'socialmedia' # ] downloader = DataSetDownloader() while True: downloader.run() time.sleep(30)