You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

431 lines
12 KiB

package main
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"strconv"
"strings"
"time"
)
// BrainClient 是WorldQuant Brain API的客户端
// 类似于Python中的WorldQuantBrainSimulate类
type BrainClient struct {
httpClient *http.Client // HTTP客户端,负责发送请求
baseURL string // API的基础URL
username string // 用户名
password string // 密码
isLoggedIn bool // 登录状态
}
// NewBrainClient 创建一个新的Brain API客户端
// 这是Go中的构造函数,类似于Python的__init__
func NewBrainClient(username, password string) *BrainClient {
return &BrainClient{
// 创建HTTP客户端,设置30秒超时
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
baseURL: "https://api.worldquantbrain.com",
username: username,
password: password,
isLoggedIn: false,
}
}
// Login 登录Brain API
// 对应Python中的login方法
func (c *BrainClient) Login() error {
// 1. 创建POST请求
req, err := http.NewRequest("POST", c.baseURL+"/authentication", nil)
if err != nil {
return fmt.Errorf("创建登录请求失败: %v", err)
}
// 2. 设置Basic认证(用户名和密码)
req.SetBasicAuth(c.username, c.password)
// 3. 发送请求
resp, err := c.httpClient.Do(req)
if err != nil {
return fmt.Errorf("发送登录请求失败: %v", err)
}
defer resp.Body.Close() // 重要:确保函数结束时关闭响应体
fmt.Printf("登录状态: %d\n", resp.StatusCode)
// 4. 检查响应状态码
if resp.StatusCode == 201 {
fmt.Println("登录成功!")
c.isLoggedIn = true
return nil
}
// 5. 登录失败,读取错误信息
var errorResp map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&errorResp); err != nil {
return fmt.Errorf("登录失败,状态码: %d", resp.StatusCode)
}
return fmt.Errorf("登录失败: %v", errorResp)
}
// SimulateAlpha 模拟Alpha因子
// 对应Python中的simulate_alpha方法
func (c *BrainClient) SimulateAlpha(expression string, settings map[string]interface{}) (*SimulationResult, error) {
// 检查是否已登录
if !c.isLoggedIn {
return nil, fmt.Errorf("请先调用Login()登录")
}
startTime := time.Now() // 记录开始时间
// ========== 1. 准备模拟设置 ==========
// 创建默认设置(对应Python中的default_settings)
defaultSettings := SimulationSettings{
InstrumentType: "EQUITY",
Region: "USA",
Universe: "TOP3000",
Delay: 1,
Decay: 12,
Truncation: 0.05,
Neutralization: "FAST",
Pasteurization: "ON",
UnitHandling: "VERIFY",
NanHandling: "ON",
Language: "FASTEXPR",
Visualization: false,
}
// 用传入的设置覆盖默认设置
// 注意:这里需要安全地进行类型断言
if region, ok := settings["region"]; ok {
if regionStr, ok := region.(string); ok {
defaultSettings.Region = regionStr
}
}
if universe, ok := settings["universe"]; ok {
if universeStr, ok := universe.(string); ok {
defaultSettings.Universe = universeStr
}
}
if instrumentType, ok := settings["instrumentType"]; ok {
if instrumentTypeStr, ok := instrumentType.(string); ok {
defaultSettings.InstrumentType = instrumentTypeStr
}
}
if decay, ok := settings["decay"]; ok {
// 注意:JSON中的数字可能是float64类型
if decayFloat, ok := decay.(float64); ok {
defaultSettings.Decay = int(decayFloat)
} else if decayInt, ok := decay.(int); ok {
defaultSettings.Decay = decayInt
}
}
if truncation, ok := settings["truncation"]; ok {
if truncationFloat, ok := truncation.(float64); ok {
defaultSettings.Truncation = truncationFloat
}
}
if neutralization, ok := settings["neutralization"]; ok {
if neutralizationStr, ok := neutralization.(string); ok {
defaultSettings.Neutralization = neutralizationStr
}
}
// 可以根据需要添加更多字段的覆盖逻辑
// ========== 2. 构建请求数据 ==========
simRequest := SimulationRequest{
Type: "REGULAR",
Settings: defaultSettings,
Regular: expression,
}
// 将结构体转换为JSON字节数组
simData, err := json.Marshal(simRequest)
if err != nil {
return nil, fmt.Errorf("构建模拟请求数据失败: %v", err)
}
// ========== 3. 发送模拟请求 ==========
req, err := http.NewRequest("POST", c.baseURL+"/simulations", bytes.NewBuffer(simData))
if err != nil {
return nil, fmt.Errorf("创建模拟请求失败: %v", err)
}
// 设置请求头
req.SetBasicAuth(c.username, c.password)
req.Header.Set("Content-Type", "application/json")
// 发送请求
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("发送模拟请求失败: %v", err)
}
defer resp.Body.Close()
fmt.Printf("模拟提交状态: %d\n", resp.StatusCode)
// 检查响应状态
if resp.StatusCode != 201 && resp.StatusCode != 200 {
var errorResp map[string]interface{}
json.NewDecoder(resp.Body).Decode(&errorResp)
return nil, fmt.Errorf("模拟请求失败: %v", errorResp)
}
// ========== 4. 获取轮询URL ==========
location := resp.Header.Get("Location")
if location == "" {
return nil, fmt.Errorf("未获取到模拟进度URL")
}
fmt.Printf("进度URL: %s\n", location)
// ========== 5. 轮询等待结果 ==========
// 对应Python中的while循环
var finalResp map[string]interface{}
for {
// 创建轮询请求
pollReq, err := http.NewRequest("GET", location, nil)
if err != nil {
return nil, fmt.Errorf("创建轮询请求失败: %v", err)
}
pollReq.SetBasicAuth(c.username, c.password)
// 发送轮询请求
pollResp, err := c.httpClient.Do(pollReq)
if err != nil {
return nil, fmt.Errorf("轮询请求失败: %v", err)
}
// 检查Retry-After头(告诉我们需要等待多久)
retryAfter := pollResp.Header.Get("Retry-After")
if retryAfter != "" {
// 需要等待,解析等待时间
waitSeconds, err := strconv.ParseFloat(retryAfter, 64)
if err != nil {
waitSeconds = 1.0 // 解析失败则默认等待1秒
}
fmt.Printf("等待 %.2f 秒...\n", waitSeconds)
// 先关闭当前响应体
pollResp.Body.Close()
// 等待指定时间
time.Sleep(time.Duration(waitSeconds) * time.Second)
// 继续轮询
continue
}
// 解析响应JSON
err = json.NewDecoder(pollResp.Body).Decode(&finalResp)
pollResp.Body.Close()
if err != nil {
return nil, fmt.Errorf("解析轮询响应失败: %v", err)
}
// 检查状态是否为ERROR
if status, ok := finalResp["status"]; ok && status == "ERROR" {
message := "未知错误"
if msg, ok := finalResp["message"].(string); ok {
message = msg
}
return nil, fmt.Errorf("因子模拟失败: %s", message)
}
// 检查是否已经完成(包含alpha字段)
if alphaID, ok := finalResp["alpha"]; ok && alphaID != nil {
// 模拟完成!
alphaIDStr := fmt.Sprintf("%v", alphaID)
fmt.Printf("生成的Alpha ID: %s\n", alphaIDStr)
// 获取详细指标
metrics, err := c.getAlphaMetrics(alphaIDStr)
if err != nil {
fmt.Printf("警告: 获取Alpha指标失败: %v\n", err)
// 即使获取指标失败,也继续返回基础信息
}
// 计算耗时
elapsed := time.Since(startTime).Seconds()
// 返回成功结果
return &SimulationResult{
Status: "success",
Expression: expression,
AlphaID: alphaIDStr,
TimeCost: elapsed,
FormattedTime: formatTime(elapsed),
Timestamp: time.Now().Format("2006-01-02 15:04:05"),
Metrics: metrics,
Message: "",
}, nil
}
// 如果既没有Retry-After,也没有alpha字段,可能是其他状态
// 打印当前状态以便调试
fmt.Printf("当前状态: %v\n", finalResp)
// 为了避免无限循环,如果没有Retry-After也没有完成,就退出
// 但在实际应用中,可能需要继续等待或处理其他状态
break
}
// 如果走到这里,说明轮询异常退出
elapsed := time.Since(startTime).Seconds()
return &SimulationResult{
Status: "failed",
Expression: expression,
AlphaID: "",
TimeCost: elapsed,
FormattedTime: formatTime(elapsed),
Timestamp: time.Now().Format("2006-01-02 15:04:05"),
Metrics: nil,
Message: "轮询超时或状态异常",
}, nil
}
// getAlphaMetrics 获取Alpha因子的详细指标
// 对应Python中的get_alpha_metrics和_parse_alpha_metrics
func (c *BrainClient) getAlphaMetrics(alphaID string) (*PerformanceMetrics, error) {
// 构建请求URL
url := fmt.Sprintf("%s/alphas/%s", c.baseURL, alphaID)
// 创建GET请求
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %v", err)
}
req.SetBasicAuth(c.username, c.password)
// 发送请求
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("请求失败: %v", err)
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode != 200 {
return nil, fmt.Errorf("HTTP状态码: %d", resp.StatusCode)
}
// 解析JSON
var alphaData map[string]interface{}
err = json.NewDecoder(resp.Body).Decode(&alphaData)
if err != nil {
return nil, fmt.Errorf("解析JSON失败: %v", err)
}
// 创建指标对象
metrics := &PerformanceMetrics{}
// ========== 解析returns字段 ==========
if returns, ok := alphaData["returns"].(map[string]interface{}); ok {
if sharpe, ok := returns["sharpe"].(float64); ok {
metrics.SharpeRatio = &sharpe
}
if annualReturn, ok := returns["annualReturn"].(float64); ok {
metrics.AnnualReturn = &annualReturn
}
if annualVolatility, ok := returns["annualVolatility"].(float64); ok {
metrics.AnnualVolatility = &annualVolatility
}
if maxDrawdown, ok := returns["maxDrawdown"].(float64); ok {
metrics.MaxDrawdown = &maxDrawdown
}
if informationRatio, ok := returns["informationRatio"].(float64); ok {
metrics.InformationRatio = &informationRatio
}
if tailRatio, ok := returns["tailRatio"].(float64); ok {
metrics.TailRatio = &tailRatio
}
if commonRatio, ok := returns["commonRatio"].(float64); ok {
metrics.CommonRatio = &commonRatio
}
}
// ========== 解析riskAdjustment字段 ==========
if riskAdj, ok := alphaData["riskAdjustment"].(map[string]interface{}); ok {
if score, ok := riskAdj["score"].(float64); ok {
metrics.Score = &score
}
if turnover, ok := riskAdj["turnover"].(float64); ok {
metrics.Turnover = &turnover
}
if specificReturn, ok := riskAdj["specificReturn"].(float64); ok {
metrics.SpecificReturn = &specificReturn
}
if specificRisk, ok := riskAdj["specificRisk"].(float64); ok {
metrics.SpecificRisk = &specificRisk
}
}
// ========== 解析quantiles字段 ==========
if quantiles, ok := alphaData["quantiles"].(map[string]interface{}); ok {
if topMinusBottom, ok := quantiles["topMinusBottom"].(float64); ok {
metrics.TopMinusBottom = &topMinusBottom
}
if topDecileReturn, ok := quantiles["topDecileReturn"].(float64); ok {
metrics.TopDecileReturn = &topDecileReturn
}
if bottomDecileReturn, ok := quantiles["bottomDecileReturn"].(float64); ok {
metrics.BottomDecileReturn = &bottomDecileReturn
}
if ic, ok := quantiles["ic"].(float64); ok {
metrics.IC = &ic
}
if icDecay, ok := quantiles["icDecay"].(float64); ok {
metrics.ICDecay = &icDecay
}
}
// ========== 解析其他字段 ==========
if totalReturn, ok := alphaData["totalReturn"].(float64); ok {
metrics.TotalReturn = &totalReturn
}
if capacity, ok := alphaData["capacity"].(float64); ok {
metrics.Capacity = &capacity
}
if fitness, ok := alphaData["fitness"].(float64); ok {
metrics.Fitness = &fitness
}
if instrumentCount, ok := alphaData["instrumentCount"].(float64); ok {
metrics.InstrumentCount = &instrumentCount
}
if startDate, ok := alphaData["startDate"].(string); ok {
metrics.StartDate = &startDate
}
if endDate, ok := alphaData["endDate"].(string); ok {
metrics.EndDate = &endDate
}
return metrics, nil
}
// formatTime 将秒数格式化为"xx分xx秒"或"xx秒"的格式
// 对应Python中的format_time方法
func formatTime(seconds float64) string {
if seconds < 60 {
return fmt.Sprintf("%.2f秒", seconds)
}
minutes := int(seconds / 60)
remainingSeconds := seconds - float64(minutes*60)
return fmt.Sprintf("%d分%.2f秒", minutes, remainingSeconds)
}
// Close 关闭客户端,释放资源
// 对应Python中的close方法
func (c *BrainClient) Close() error {
// Go的http.Client不需要显式关闭
// 但如果需要清理其他资源,可以在这里添加
c.httpClient = nil
c.isLoggedIn = false
return nil
}
// 辅助函数:打印分隔线(方便使用)
func printSeparator() {
fmt.Println(strings.Repeat("=", 60))
}