package main import ( "bytes" "encoding/json" "fmt" "net/http" "strconv" "strings" "time" ) // BrainClient 是WorldQuant Brain API的客户端 // 类似于Python中的WorldQuantBrainSimulate类 type BrainClient struct { httpClient *http.Client // HTTP客户端,负责发送请求 baseURL string // API的基础URL username string // 用户名 password string // 密码 isLoggedIn bool // 登录状态 } // NewBrainClient 创建一个新的Brain API客户端 // 这是Go中的构造函数,类似于Python的__init__ func NewBrainClient(username, password string) *BrainClient { return &BrainClient{ // 创建HTTP客户端,设置30秒超时 httpClient: &http.Client{ Timeout: 30 * time.Second, }, baseURL: "https://api.worldquantbrain.com", username: username, password: password, isLoggedIn: false, } } // Login 登录Brain API // 对应Python中的login方法 func (c *BrainClient) Login() error { // 1. 创建POST请求 req, err := http.NewRequest("POST", c.baseURL+"/authentication", nil) if err != nil { return fmt.Errorf("创建登录请求失败: %v", err) } // 2. 设置Basic认证(用户名和密码) req.SetBasicAuth(c.username, c.password) // 3. 发送请求 resp, err := c.httpClient.Do(req) if err != nil { return fmt.Errorf("发送登录请求失败: %v", err) } defer resp.Body.Close() // 重要:确保函数结束时关闭响应体 fmt.Printf("登录状态: %d\n", resp.StatusCode) // 4. 检查响应状态码 if resp.StatusCode == 201 { fmt.Println("登录成功!") c.isLoggedIn = true return nil } // 5. 登录失败,读取错误信息 var errorResp map[string]interface{} if err := json.NewDecoder(resp.Body).Decode(&errorResp); err != nil { return fmt.Errorf("登录失败,状态码: %d", resp.StatusCode) } return fmt.Errorf("登录失败: %v", errorResp) } // SimulateAlpha 模拟Alpha因子 // 对应Python中的simulate_alpha方法 func (c *BrainClient) SimulateAlpha(expression string, settings map[string]interface{}) (*SimulationResult, error) { // 检查是否已登录 if !c.isLoggedIn { return nil, fmt.Errorf("请先调用Login()登录") } startTime := time.Now() // 记录开始时间 // ========== 1. 准备模拟设置 ========== // 创建默认设置(对应Python中的default_settings) defaultSettings := SimulationSettings{ InstrumentType: "EQUITY", Region: "USA", Universe: "TOP3000", Delay: 1, Decay: 12, Truncation: 0.05, Neutralization: "FAST", Pasteurization: "ON", UnitHandling: "VERIFY", NanHandling: "ON", Language: "FASTEXPR", Visualization: false, } // 用传入的设置覆盖默认设置 // 注意:这里需要安全地进行类型断言 if region, ok := settings["region"]; ok { if regionStr, ok := region.(string); ok { defaultSettings.Region = regionStr } } if universe, ok := settings["universe"]; ok { if universeStr, ok := universe.(string); ok { defaultSettings.Universe = universeStr } } if instrumentType, ok := settings["instrumentType"]; ok { if instrumentTypeStr, ok := instrumentType.(string); ok { defaultSettings.InstrumentType = instrumentTypeStr } } if decay, ok := settings["decay"]; ok { // 注意:JSON中的数字可能是float64类型 if decayFloat, ok := decay.(float64); ok { defaultSettings.Decay = int(decayFloat) } else if decayInt, ok := decay.(int); ok { defaultSettings.Decay = decayInt } } if truncation, ok := settings["truncation"]; ok { if truncationFloat, ok := truncation.(float64); ok { defaultSettings.Truncation = truncationFloat } } if neutralization, ok := settings["neutralization"]; ok { if neutralizationStr, ok := neutralization.(string); ok { defaultSettings.Neutralization = neutralizationStr } } // 可以根据需要添加更多字段的覆盖逻辑 // ========== 2. 构建请求数据 ========== simRequest := SimulationRequest{ Type: "REGULAR", Settings: defaultSettings, Regular: expression, } // 将结构体转换为JSON字节数组 simData, err := json.Marshal(simRequest) if err != nil { return nil, fmt.Errorf("构建模拟请求数据失败: %v", err) } // ========== 3. 发送模拟请求 ========== req, err := http.NewRequest("POST", c.baseURL+"/simulations", bytes.NewBuffer(simData)) if err != nil { return nil, fmt.Errorf("创建模拟请求失败: %v", err) } // 设置请求头 req.SetBasicAuth(c.username, c.password) req.Header.Set("Content-Type", "application/json") // 发送请求 resp, err := c.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("发送模拟请求失败: %v", err) } defer resp.Body.Close() fmt.Printf("模拟提交状态: %d\n", resp.StatusCode) // 检查响应状态 if resp.StatusCode != 201 && resp.StatusCode != 200 { var errorResp map[string]interface{} json.NewDecoder(resp.Body).Decode(&errorResp) return nil, fmt.Errorf("模拟请求失败: %v", errorResp) } // ========== 4. 获取轮询URL ========== location := resp.Header.Get("Location") if location == "" { return nil, fmt.Errorf("未获取到模拟进度URL") } fmt.Printf("进度URL: %s\n", location) // ========== 5. 轮询等待结果 ========== // 对应Python中的while循环 var finalResp map[string]interface{} for { // 创建轮询请求 pollReq, err := http.NewRequest("GET", location, nil) if err != nil { return nil, fmt.Errorf("创建轮询请求失败: %v", err) } pollReq.SetBasicAuth(c.username, c.password) // 发送轮询请求 pollResp, err := c.httpClient.Do(pollReq) if err != nil { return nil, fmt.Errorf("轮询请求失败: %v", err) } // 检查Retry-After头(告诉我们需要等待多久) retryAfter := pollResp.Header.Get("Retry-After") if retryAfter != "" { // 需要等待,解析等待时间 waitSeconds, err := strconv.ParseFloat(retryAfter, 64) if err != nil { waitSeconds = 1.0 // 解析失败则默认等待1秒 } fmt.Printf("等待 %.2f 秒...\n", waitSeconds) // 先关闭当前响应体 pollResp.Body.Close() // 等待指定时间 time.Sleep(time.Duration(waitSeconds) * time.Second) // 继续轮询 continue } // 解析响应JSON err = json.NewDecoder(pollResp.Body).Decode(&finalResp) pollResp.Body.Close() if err != nil { return nil, fmt.Errorf("解析轮询响应失败: %v", err) } // 检查状态是否为ERROR if status, ok := finalResp["status"]; ok && status == "ERROR" { message := "未知错误" if msg, ok := finalResp["message"].(string); ok { message = msg } return nil, fmt.Errorf("因子模拟失败: %s", message) } // 检查是否已经完成(包含alpha字段) if alphaID, ok := finalResp["alpha"]; ok && alphaID != nil { // 模拟完成! alphaIDStr := fmt.Sprintf("%v", alphaID) fmt.Printf("生成的Alpha ID: %s\n", alphaIDStr) // 获取详细指标 metrics, err := c.getAlphaMetrics(alphaIDStr) if err != nil { fmt.Printf("警告: 获取Alpha指标失败: %v\n", err) // 即使获取指标失败,也继续返回基础信息 } // 计算耗时 elapsed := time.Since(startTime).Seconds() // 返回成功结果 return &SimulationResult{ Status: "success", Expression: expression, AlphaID: alphaIDStr, TimeCost: elapsed, FormattedTime: formatTime(elapsed), Timestamp: time.Now().Format("2006-01-02 15:04:05"), Metrics: metrics, Message: "", }, nil } // 如果既没有Retry-After,也没有alpha字段,可能是其他状态 // 打印当前状态以便调试 fmt.Printf("当前状态: %v\n", finalResp) // 为了避免无限循环,如果没有Retry-After也没有完成,就退出 // 但在实际应用中,可能需要继续等待或处理其他状态 break } // 如果走到这里,说明轮询异常退出 elapsed := time.Since(startTime).Seconds() return &SimulationResult{ Status: "failed", Expression: expression, AlphaID: "", TimeCost: elapsed, FormattedTime: formatTime(elapsed), Timestamp: time.Now().Format("2006-01-02 15:04:05"), Metrics: nil, Message: "轮询超时或状态异常", }, nil } // getAlphaMetrics 获取Alpha因子的详细指标 // 对应Python中的get_alpha_metrics和_parse_alpha_metrics func (c *BrainClient) getAlphaMetrics(alphaID string) (*PerformanceMetrics, error) { // 构建请求URL url := fmt.Sprintf("%s/alphas/%s", c.baseURL, alphaID) // 创建GET请求 req, err := http.NewRequest("GET", url, nil) if err != nil { return nil, fmt.Errorf("创建请求失败: %v", err) } req.SetBasicAuth(c.username, c.password) // 发送请求 resp, err := c.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("请求失败: %v", err) } defer resp.Body.Close() // 检查响应状态 if resp.StatusCode != 200 { return nil, fmt.Errorf("HTTP状态码: %d", resp.StatusCode) } // 解析JSON var alphaData map[string]interface{} err = json.NewDecoder(resp.Body).Decode(&alphaData) if err != nil { return nil, fmt.Errorf("解析JSON失败: %v", err) } // 创建指标对象 metrics := &PerformanceMetrics{} // ========== 解析returns字段 ========== if returns, ok := alphaData["returns"].(map[string]interface{}); ok { if sharpe, ok := returns["sharpe"].(float64); ok { metrics.SharpeRatio = &sharpe } if annualReturn, ok := returns["annualReturn"].(float64); ok { metrics.AnnualReturn = &annualReturn } if annualVolatility, ok := returns["annualVolatility"].(float64); ok { metrics.AnnualVolatility = &annualVolatility } if maxDrawdown, ok := returns["maxDrawdown"].(float64); ok { metrics.MaxDrawdown = &maxDrawdown } if informationRatio, ok := returns["informationRatio"].(float64); ok { metrics.InformationRatio = &informationRatio } if tailRatio, ok := returns["tailRatio"].(float64); ok { metrics.TailRatio = &tailRatio } if commonRatio, ok := returns["commonRatio"].(float64); ok { metrics.CommonRatio = &commonRatio } } // ========== 解析riskAdjustment字段 ========== if riskAdj, ok := alphaData["riskAdjustment"].(map[string]interface{}); ok { if score, ok := riskAdj["score"].(float64); ok { metrics.Score = &score } if turnover, ok := riskAdj["turnover"].(float64); ok { metrics.Turnover = &turnover } if specificReturn, ok := riskAdj["specificReturn"].(float64); ok { metrics.SpecificReturn = &specificReturn } if specificRisk, ok := riskAdj["specificRisk"].(float64); ok { metrics.SpecificRisk = &specificRisk } } // ========== 解析quantiles字段 ========== if quantiles, ok := alphaData["quantiles"].(map[string]interface{}); ok { if topMinusBottom, ok := quantiles["topMinusBottom"].(float64); ok { metrics.TopMinusBottom = &topMinusBottom } if topDecileReturn, ok := quantiles["topDecileReturn"].(float64); ok { metrics.TopDecileReturn = &topDecileReturn } if bottomDecileReturn, ok := quantiles["bottomDecileReturn"].(float64); ok { metrics.BottomDecileReturn = &bottomDecileReturn } if ic, ok := quantiles["ic"].(float64); ok { metrics.IC = &ic } if icDecay, ok := quantiles["icDecay"].(float64); ok { metrics.ICDecay = &icDecay } } // ========== 解析其他字段 ========== if totalReturn, ok := alphaData["totalReturn"].(float64); ok { metrics.TotalReturn = &totalReturn } if capacity, ok := alphaData["capacity"].(float64); ok { metrics.Capacity = &capacity } if fitness, ok := alphaData["fitness"].(float64); ok { metrics.Fitness = &fitness } if instrumentCount, ok := alphaData["instrumentCount"].(float64); ok { metrics.InstrumentCount = &instrumentCount } if startDate, ok := alphaData["startDate"].(string); ok { metrics.StartDate = &startDate } if endDate, ok := alphaData["endDate"].(string); ok { metrics.EndDate = &endDate } return metrics, nil } // formatTime 将秒数格式化为"xx分xx秒"或"xx秒"的格式 // 对应Python中的format_time方法 func formatTime(seconds float64) string { if seconds < 60 { return fmt.Sprintf("%.2f秒", seconds) } minutes := int(seconds / 60) remainingSeconds := seconds - float64(minutes*60) return fmt.Sprintf("%d分%.2f秒", minutes, remainingSeconds) } // Close 关闭客户端,释放资源 // 对应Python中的close方法 func (c *BrainClient) Close() error { // Go的http.Client不需要显式关闭 // 但如果需要清理其他资源,可以在这里添加 c.httpClient = nil c.isLoggedIn = false return nil } // 辅助函数:打印分隔线(方便使用) func printSeparator() { fmt.Println(strings.Repeat("=", 60)) }