You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
228 lines
7.5 KiB
228 lines
7.5 KiB
# -*- coding: utf-8 -*-
|
|
import os
|
|
import json
|
|
import random
|
|
import time
|
|
import sqlite3
|
|
import httpx
|
|
from httpx import BasicAuth
|
|
|
|
|
|
class CategoryDownloader:
|
|
def __init__(self):
|
|
self.base_api_url = 'https://api.worldquantbrain.com'
|
|
self.db_path = os.path.join(os.getcwd(), 'data_sets.db')
|
|
self.client = self.login()
|
|
|
|
def create_database_if_not_exists(self):
|
|
"""创建数据库和表(如果不存在)"""
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
# 创建category表
|
|
cursor.execute('''
|
|
CREATE TABLE IF NOT EXISTS category
|
|
(
|
|
name
|
|
TEXT,
|
|
base_url
|
|
TEXT,
|
|
category_id
|
|
TEXT,
|
|
region
|
|
TEXT,
|
|
universe
|
|
TEXT,
|
|
instrumentType
|
|
TEXT,
|
|
delay
|
|
TEXT,
|
|
downloaded
|
|
INTEGER
|
|
)
|
|
''')
|
|
|
|
# 创建data_sets表
|
|
cursor.execute('''
|
|
CREATE TABLE IF NOT EXISTS data_sets
|
|
(
|
|
name
|
|
TEXT,
|
|
description
|
|
TEXT,
|
|
dataset_id
|
|
TEXT,
|
|
dataset_name
|
|
TEXT,
|
|
category_id
|
|
TEXT,
|
|
category_name
|
|
TEXT,
|
|
region
|
|
TEXT,
|
|
delay
|
|
TEXT,
|
|
universe
|
|
TEXT,
|
|
type
|
|
TEXT
|
|
)
|
|
''')
|
|
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
def login(self):
|
|
"""登录并返回客户端实例"""
|
|
username, password = "jack0210_@hotmail.com", "!QAZ2wsx+0913"
|
|
client = httpx.Client(auth=BasicAuth(username, password))
|
|
|
|
try:
|
|
response = client.post(f'{self.base_api_url}/authentication')
|
|
print(f"登录状态: {response.status_code}")
|
|
|
|
if response.status_code in [200, 201]:
|
|
print("登录成功!")
|
|
return client
|
|
else:
|
|
print(f"登录失败: {response.json()}")
|
|
return None
|
|
except Exception as e:
|
|
print(f"登录过程中出现错误: {e}")
|
|
return None
|
|
|
|
def save_to_database(self, category, category_id, region, universe, instrumentType, delay):
|
|
"""保存数据到SQLite数据库"""
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
# 检查是否已存在相同记录
|
|
cursor.execute('''
|
|
SELECT COUNT(*)
|
|
FROM category
|
|
WHERE name = ?
|
|
AND category_id = ?
|
|
AND region = ?
|
|
AND universe = ?
|
|
''', (category, category_id, region, universe))
|
|
|
|
if cursor.fetchone()[0] == 0:
|
|
cursor.execute('''
|
|
INSERT INTO category (name, base_url, category_id, region, universe, instrumentType, delay, downloaded)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
''', (category, 'https://api.worldquantbrain.com/data-fields',
|
|
category_id, region, universe, instrumentType, delay, 0))
|
|
print(f"已保存: {category} - {category_id} - {region} - {universe}")
|
|
else:
|
|
print(f"记录已存在: {category} - {category_id} - {region} - {universe}")
|
|
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
def fetch_category_data(self, category, delay, instrumentType, region_list, universe_list):
|
|
"""获取分类数据并保存到数据库"""
|
|
# 确保数据库已创建
|
|
self.create_database_if_not_exists()
|
|
|
|
if not self.client:
|
|
print("客户端未初始化,无法获取数据")
|
|
return []
|
|
|
|
for region in region_list:
|
|
for universe in universe_list:
|
|
url = f'https://api.worldquantbrain.com/data-sets?category={category}&delay={delay}&instrumentType={instrumentType}&limit=50&offset=0®ion={region}&universe={universe}'
|
|
print(f"请求URL: {url}")
|
|
|
|
try:
|
|
response = self.client.get(url)
|
|
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
|
|
if data.get('count', 0) > 0:
|
|
for item in data.get('results', []):
|
|
category_id = item.get('id', '')
|
|
region = item.get('region', '')
|
|
universe = item.get('universe', '')
|
|
|
|
# 保存到数据库
|
|
self.save_to_database(category, category_id, region, universe, instrumentType, delay)
|
|
else:
|
|
print(f"未找到数据: {category} - {region} - {universe}")
|
|
else:
|
|
print(f"请求失败: {response.status_code}")
|
|
|
|
except Exception as e:
|
|
print(f"请求过程中出现错误: {e}")
|
|
|
|
sleep_time = random.uniform(10, 15)
|
|
print(f"等待 {sleep_time} 秒后继续请求")
|
|
time.sleep(sleep_time)
|
|
|
|
print(f"数据获取完成,已保存到数据库: {self.db_path}")
|
|
|
|
def check_database_records(self):
|
|
"""检查数据库中的记录"""
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("SELECT COUNT(*) FROM category")
|
|
count = cursor.fetchone()[0]
|
|
print(f"category表中总记录数: {count}")
|
|
|
|
conn.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
downloader = CategoryDownloader()
|
|
|
|
if downloader.client:
|
|
category_list = [
|
|
'analyst',
|
|
'broker',
|
|
'earnings',
|
|
'fundamental',
|
|
'imbalance',
|
|
'insiders',
|
|
'institutions',
|
|
'macro',
|
|
'model',
|
|
'news',
|
|
'option',
|
|
'other',
|
|
'pv',
|
|
'risk',
|
|
'sentiment',
|
|
'shortinterest',
|
|
'socialmedia'
|
|
]
|
|
|
|
delay = '1'
|
|
instrumentType = 'EQUITY'
|
|
|
|
region_list = [
|
|
'USA',
|
|
'GLB',
|
|
'EUR',
|
|
'ASI',
|
|
'CHN',
|
|
'KOR',
|
|
'TWN',
|
|
'IND'
|
|
]
|
|
universe_list = [
|
|
'TOP3000',
|
|
'TOP1000',
|
|
'TOP500',
|
|
'TOP200',
|
|
'TOPSP500',
|
|
'ILLIQUID_MINVOL1M',
|
|
'MINVOL1M'
|
|
]
|
|
|
|
for category in category_list:
|
|
downloader.fetch_category_data(category, delay, instrumentType, region_list, universe_list)
|
|
downloader.check_database_records()
|
|
sleep_time = 30
|
|
print(f'{category} 数据已下载, 程序休眠 {sleep_time} 秒')
|
|
time.sleep(sleep_time) |