You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
FieldDownloader/get_category.py

228 lines
7.5 KiB

# -*- coding: utf-8 -*-
import os
import json
import random
import time
import sqlite3
import httpx
from httpx import BasicAuth
class CategoryDownloader:
def __init__(self):
self.base_api_url = 'https://api.worldquantbrain.com'
self.db_path = os.path.join(os.getcwd(), 'data_sets.db')
self.client = self.login()
def create_database_if_not_exists(self):
"""创建数据库和表(如果不存在)"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# 创建category表
cursor.execute('''
CREATE TABLE IF NOT EXISTS category
(
name
TEXT,
base_url
TEXT,
category_id
TEXT,
region
TEXT,
universe
TEXT,
instrumentType
TEXT,
delay
TEXT,
downloaded
INTEGER
)
''')
# 创建data_sets表
cursor.execute('''
CREATE TABLE IF NOT EXISTS data_sets
(
name
TEXT,
description
TEXT,
dataset_id
TEXT,
dataset_name
TEXT,
category_id
TEXT,
category_name
TEXT,
region
TEXT,
delay
TEXT,
universe
TEXT,
type
TEXT
)
''')
conn.commit()
conn.close()
def login(self):
"""登录并返回客户端实例"""
username, password = "jack0210_@hotmail.com", "!QAZ2wsx+0913"
client = httpx.Client(auth=BasicAuth(username, password))
try:
response = client.post(f'{self.base_api_url}/authentication')
print(f"登录状态: {response.status_code}")
if response.status_code in [200, 201]:
print("登录成功!")
return client
else:
print(f"登录失败: {response.json()}")
return None
except Exception as e:
print(f"登录过程中出现错误: {e}")
return None
def save_to_database(self, category, category_id, region, universe, instrumentType, delay):
"""保存数据到SQLite数据库"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# 检查是否已存在相同记录
cursor.execute('''
SELECT COUNT(*)
FROM category
WHERE name = ?
AND category_id = ?
AND region = ?
AND universe = ?
''', (category, category_id, region, universe))
if cursor.fetchone()[0] == 0:
cursor.execute('''
INSERT INTO category (name, base_url, category_id, region, universe, instrumentType, delay, downloaded)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
''', (category, 'https://api.worldquantbrain.com/data-fields',
category_id, region, universe, instrumentType, delay, 0))
print(f"已保存: {category} - {category_id} - {region} - {universe}")
else:
print(f"记录已存在: {category} - {category_id} - {region} - {universe}")
conn.commit()
conn.close()
def fetch_category_data(self, category, delay, instrumentType, region_list, universe_list):
"""获取分类数据并保存到数据库"""
# 确保数据库已创建
self.create_database_if_not_exists()
if not self.client:
print("客户端未初始化,无法获取数据")
return []
for region in region_list:
for universe in universe_list:
url = f'https://api.worldquantbrain.com/data-sets?category={category}&delay={delay}&instrumentType={instrumentType}&limit=50&offset=0&region={region}&universe={universe}'
print(f"请求URL: {url}")
try:
response = self.client.get(url)
if response.status_code == 200:
data = response.json()
if data.get('count', 0) > 0:
for item in data.get('results', []):
category_id = item.get('id', '')
region = item.get('region', '')
universe = item.get('universe', '')
# 保存到数据库
self.save_to_database(category, category_id, region, universe, instrumentType, delay)
else:
print(f"未找到数据: {category} - {region} - {universe}")
else:
print(f"请求失败: {response.status_code}")
except Exception as e:
print(f"请求过程中出现错误: {e}")
sleep_time = random.uniform(10, 15)
print(f"等待 {sleep_time} 秒后继续请求")
time.sleep(sleep_time)
print(f"数据获取完成,已保存到数据库: {self.db_path}")
def check_database_records(self):
"""检查数据库中的记录"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) FROM category")
count = cursor.fetchone()[0]
print(f"category表中总记录数: {count}")
conn.close()
if __name__ == "__main__":
downloader = CategoryDownloader()
if downloader.client:
category_list = [
'analyst',
'broker',
'earnings',
'fundamental',
'imbalance',
'insiders',
'institutions',
'macro',
'model',
'news',
'option',
'other',
'pv',
'risk',
'sentiment',
'shortinterest',
'socialmedia'
]
delay = '1'
instrumentType = 'EQUITY'
region_list = [
'USA',
'GLB',
'EUR',
'ASI',
'CHN',
'KOR',
'TWN',
'IND'
]
universe_list = [
'TOP3000',
'TOP1000',
'TOP500',
'TOP200',
'TOPSP500',
'ILLIQUID_MINVOL1M',
'MINVOL1M'
]
for category in category_list:
downloader.fetch_category_data(category, delay, instrumentType, region_list, universe_list)
downloader.check_database_records()
sleep_time = 30
print(f'{category} 数据已下载, 程序休眠 {sleep_time}')
time.sleep(sleep_time)