You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
431 lines
12 KiB
431 lines
12 KiB
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// BrainClient 是WorldQuant Brain API的客户端
|
|
// 类似于Python中的WorldQuantBrainSimulate类
|
|
type BrainClient struct {
|
|
httpClient *http.Client // HTTP客户端,负责发送请求
|
|
baseURL string // API的基础URL
|
|
username string // 用户名
|
|
password string // 密码
|
|
isLoggedIn bool // 登录状态
|
|
}
|
|
|
|
// NewBrainClient 创建一个新的Brain API客户端
|
|
// 这是Go中的构造函数,类似于Python的__init__
|
|
func NewBrainClient(username, password string) *BrainClient {
|
|
return &BrainClient{
|
|
// 创建HTTP客户端,设置30秒超时
|
|
httpClient: &http.Client{
|
|
Timeout: 30 * time.Second,
|
|
},
|
|
baseURL: "https://api.worldquantbrain.com",
|
|
username: username,
|
|
password: password,
|
|
isLoggedIn: false,
|
|
}
|
|
}
|
|
|
|
// Login 登录Brain API
|
|
// 对应Python中的login方法
|
|
func (c *BrainClient) Login() error {
|
|
// 1. 创建POST请求
|
|
req, err := http.NewRequest("POST", c.baseURL+"/authentication", nil)
|
|
if err != nil {
|
|
return fmt.Errorf("创建登录请求失败: %v", err)
|
|
}
|
|
|
|
// 2. 设置Basic认证(用户名和密码)
|
|
req.SetBasicAuth(c.username, c.password)
|
|
|
|
// 3. 发送请求
|
|
resp, err := c.httpClient.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("发送登录请求失败: %v", err)
|
|
}
|
|
defer resp.Body.Close() // 重要:确保函数结束时关闭响应体
|
|
|
|
fmt.Printf("登录状态: %d\n", resp.StatusCode)
|
|
|
|
// 4. 检查响应状态码
|
|
if resp.StatusCode == 201 {
|
|
fmt.Println("登录成功!")
|
|
c.isLoggedIn = true
|
|
return nil
|
|
}
|
|
|
|
// 5. 登录失败,读取错误信息
|
|
var errorResp map[string]interface{}
|
|
if err := json.NewDecoder(resp.Body).Decode(&errorResp); err != nil {
|
|
return fmt.Errorf("登录失败,状态码: %d", resp.StatusCode)
|
|
}
|
|
return fmt.Errorf("登录失败: %v", errorResp)
|
|
}
|
|
|
|
// SimulateAlpha 模拟Alpha因子
|
|
// 对应Python中的simulate_alpha方法
|
|
func (c *BrainClient) SimulateAlpha(expression string, settings map[string]interface{}) (*SimulationResult, error) {
|
|
// 检查是否已登录
|
|
if !c.isLoggedIn {
|
|
return nil, fmt.Errorf("请先调用Login()登录")
|
|
}
|
|
|
|
startTime := time.Now() // 记录开始时间
|
|
|
|
// ========== 1. 准备模拟设置 ==========
|
|
// 创建默认设置(对应Python中的default_settings)
|
|
defaultSettings := SimulationSettings{
|
|
InstrumentType: "EQUITY",
|
|
Region: "USA",
|
|
Universe: "TOP3000",
|
|
Delay: 1,
|
|
Decay: 12,
|
|
Truncation: 0.05,
|
|
Neutralization: "FAST",
|
|
Pasteurization: "ON",
|
|
UnitHandling: "VERIFY",
|
|
NanHandling: "ON",
|
|
Language: "FASTEXPR",
|
|
Visualization: false,
|
|
}
|
|
|
|
// 用传入的设置覆盖默认设置
|
|
// 注意:这里需要安全地进行类型断言
|
|
if region, ok := settings["region"]; ok {
|
|
if regionStr, ok := region.(string); ok {
|
|
defaultSettings.Region = regionStr
|
|
}
|
|
}
|
|
if universe, ok := settings["universe"]; ok {
|
|
if universeStr, ok := universe.(string); ok {
|
|
defaultSettings.Universe = universeStr
|
|
}
|
|
}
|
|
if instrumentType, ok := settings["instrumentType"]; ok {
|
|
if instrumentTypeStr, ok := instrumentType.(string); ok {
|
|
defaultSettings.InstrumentType = instrumentTypeStr
|
|
}
|
|
}
|
|
if decay, ok := settings["decay"]; ok {
|
|
// 注意:JSON中的数字可能是float64类型
|
|
if decayFloat, ok := decay.(float64); ok {
|
|
defaultSettings.Decay = int(decayFloat)
|
|
} else if decayInt, ok := decay.(int); ok {
|
|
defaultSettings.Decay = decayInt
|
|
}
|
|
}
|
|
if truncation, ok := settings["truncation"]; ok {
|
|
if truncationFloat, ok := truncation.(float64); ok {
|
|
defaultSettings.Truncation = truncationFloat
|
|
}
|
|
}
|
|
if neutralization, ok := settings["neutralization"]; ok {
|
|
if neutralizationStr, ok := neutralization.(string); ok {
|
|
defaultSettings.Neutralization = neutralizationStr
|
|
}
|
|
}
|
|
// 可以根据需要添加更多字段的覆盖逻辑
|
|
|
|
// ========== 2. 构建请求数据 ==========
|
|
simRequest := SimulationRequest{
|
|
Type: "REGULAR",
|
|
Settings: defaultSettings,
|
|
Regular: expression,
|
|
}
|
|
|
|
// 将结构体转换为JSON字节数组
|
|
simData, err := json.Marshal(simRequest)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("构建模拟请求数据失败: %v", err)
|
|
}
|
|
|
|
// ========== 3. 发送模拟请求 ==========
|
|
req, err := http.NewRequest("POST", c.baseURL+"/simulations", bytes.NewBuffer(simData))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("创建模拟请求失败: %v", err)
|
|
}
|
|
|
|
// 设置请求头
|
|
req.SetBasicAuth(c.username, c.password)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
// 发送请求
|
|
resp, err := c.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("发送模拟请求失败: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
fmt.Printf("模拟提交状态: %d\n", resp.StatusCode)
|
|
|
|
// 检查响应状态
|
|
if resp.StatusCode != 201 && resp.StatusCode != 200 {
|
|
var errorResp map[string]interface{}
|
|
json.NewDecoder(resp.Body).Decode(&errorResp)
|
|
return nil, fmt.Errorf("模拟请求失败: %v", errorResp)
|
|
}
|
|
|
|
// ========== 4. 获取轮询URL ==========
|
|
location := resp.Header.Get("Location")
|
|
if location == "" {
|
|
return nil, fmt.Errorf("未获取到模拟进度URL")
|
|
}
|
|
fmt.Printf("进度URL: %s\n", location)
|
|
|
|
// ========== 5. 轮询等待结果 ==========
|
|
// 对应Python中的while循环
|
|
var finalResp map[string]interface{}
|
|
for {
|
|
// 创建轮询请求
|
|
pollReq, err := http.NewRequest("GET", location, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("创建轮询请求失败: %v", err)
|
|
}
|
|
pollReq.SetBasicAuth(c.username, c.password)
|
|
|
|
// 发送轮询请求
|
|
pollResp, err := c.httpClient.Do(pollReq)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("轮询请求失败: %v", err)
|
|
}
|
|
|
|
// 检查Retry-After头(告诉我们需要等待多久)
|
|
retryAfter := pollResp.Header.Get("Retry-After")
|
|
|
|
if retryAfter != "" {
|
|
// 需要等待,解析等待时间
|
|
waitSeconds, err := strconv.ParseFloat(retryAfter, 64)
|
|
if err != nil {
|
|
waitSeconds = 1.0 // 解析失败则默认等待1秒
|
|
}
|
|
fmt.Printf("等待 %.2f 秒...\n", waitSeconds)
|
|
|
|
// 先关闭当前响应体
|
|
pollResp.Body.Close()
|
|
// 等待指定时间
|
|
time.Sleep(time.Duration(waitSeconds) * time.Second)
|
|
// 继续轮询
|
|
continue
|
|
}
|
|
|
|
// 解析响应JSON
|
|
err = json.NewDecoder(pollResp.Body).Decode(&finalResp)
|
|
pollResp.Body.Close()
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("解析轮询响应失败: %v", err)
|
|
}
|
|
|
|
// 检查状态是否为ERROR
|
|
if status, ok := finalResp["status"]; ok && status == "ERROR" {
|
|
message := "未知错误"
|
|
if msg, ok := finalResp["message"].(string); ok {
|
|
message = msg
|
|
}
|
|
return nil, fmt.Errorf("因子模拟失败: %s", message)
|
|
}
|
|
|
|
// 检查是否已经完成(包含alpha字段)
|
|
if alphaID, ok := finalResp["alpha"]; ok && alphaID != nil {
|
|
// 模拟完成!
|
|
alphaIDStr := fmt.Sprintf("%v", alphaID)
|
|
fmt.Printf("生成的Alpha ID: %s\n", alphaIDStr)
|
|
|
|
// 获取详细指标
|
|
metrics, err := c.getAlphaMetrics(alphaIDStr)
|
|
if err != nil {
|
|
fmt.Printf("警告: 获取Alpha指标失败: %v\n", err)
|
|
// 即使获取指标失败,也继续返回基础信息
|
|
}
|
|
|
|
// 计算耗时
|
|
elapsed := time.Since(startTime).Seconds()
|
|
|
|
// 返回成功结果
|
|
return &SimulationResult{
|
|
Status: "success",
|
|
Expression: expression,
|
|
AlphaID: alphaIDStr,
|
|
TimeCost: elapsed,
|
|
FormattedTime: formatTime(elapsed),
|
|
Timestamp: time.Now().Format("2006-01-02 15:04:05"),
|
|
Metrics: metrics,
|
|
Message: "",
|
|
}, nil
|
|
}
|
|
|
|
// 如果既没有Retry-After,也没有alpha字段,可能是其他状态
|
|
// 打印当前状态以便调试
|
|
fmt.Printf("当前状态: %v\n", finalResp)
|
|
|
|
// 为了避免无限循环,如果没有Retry-After也没有完成,就退出
|
|
// 但在实际应用中,可能需要继续等待或处理其他状态
|
|
break
|
|
}
|
|
|
|
// 如果走到这里,说明轮询异常退出
|
|
elapsed := time.Since(startTime).Seconds()
|
|
return &SimulationResult{
|
|
Status: "failed",
|
|
Expression: expression,
|
|
AlphaID: "",
|
|
TimeCost: elapsed,
|
|
FormattedTime: formatTime(elapsed),
|
|
Timestamp: time.Now().Format("2006-01-02 15:04:05"),
|
|
Metrics: nil,
|
|
Message: "轮询超时或状态异常",
|
|
}, nil
|
|
}
|
|
|
|
// getAlphaMetrics 获取Alpha因子的详细指标
|
|
// 对应Python中的get_alpha_metrics和_parse_alpha_metrics
|
|
func (c *BrainClient) getAlphaMetrics(alphaID string) (*PerformanceMetrics, error) {
|
|
// 构建请求URL
|
|
url := fmt.Sprintf("%s/alphas/%s", c.baseURL, alphaID)
|
|
|
|
// 创建GET请求
|
|
req, err := http.NewRequest("GET", url, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("创建请求失败: %v", err)
|
|
}
|
|
req.SetBasicAuth(c.username, c.password)
|
|
|
|
// 发送请求
|
|
resp, err := c.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("请求失败: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// 检查响应状态
|
|
if resp.StatusCode != 200 {
|
|
return nil, fmt.Errorf("HTTP状态码: %d", resp.StatusCode)
|
|
}
|
|
|
|
// 解析JSON
|
|
var alphaData map[string]interface{}
|
|
err = json.NewDecoder(resp.Body).Decode(&alphaData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("解析JSON失败: %v", err)
|
|
}
|
|
|
|
// 创建指标对象
|
|
metrics := &PerformanceMetrics{}
|
|
|
|
// ========== 解析returns字段 ==========
|
|
if returns, ok := alphaData["returns"].(map[string]interface{}); ok {
|
|
if sharpe, ok := returns["sharpe"].(float64); ok {
|
|
metrics.SharpeRatio = &sharpe
|
|
}
|
|
if annualReturn, ok := returns["annualReturn"].(float64); ok {
|
|
metrics.AnnualReturn = &annualReturn
|
|
}
|
|
if annualVolatility, ok := returns["annualVolatility"].(float64); ok {
|
|
metrics.AnnualVolatility = &annualVolatility
|
|
}
|
|
if maxDrawdown, ok := returns["maxDrawdown"].(float64); ok {
|
|
metrics.MaxDrawdown = &maxDrawdown
|
|
}
|
|
if informationRatio, ok := returns["informationRatio"].(float64); ok {
|
|
metrics.InformationRatio = &informationRatio
|
|
}
|
|
if tailRatio, ok := returns["tailRatio"].(float64); ok {
|
|
metrics.TailRatio = &tailRatio
|
|
}
|
|
if commonRatio, ok := returns["commonRatio"].(float64); ok {
|
|
metrics.CommonRatio = &commonRatio
|
|
}
|
|
}
|
|
|
|
// ========== 解析riskAdjustment字段 ==========
|
|
if riskAdj, ok := alphaData["riskAdjustment"].(map[string]interface{}); ok {
|
|
if score, ok := riskAdj["score"].(float64); ok {
|
|
metrics.Score = &score
|
|
}
|
|
if turnover, ok := riskAdj["turnover"].(float64); ok {
|
|
metrics.Turnover = &turnover
|
|
}
|
|
if specificReturn, ok := riskAdj["specificReturn"].(float64); ok {
|
|
metrics.SpecificReturn = &specificReturn
|
|
}
|
|
if specificRisk, ok := riskAdj["specificRisk"].(float64); ok {
|
|
metrics.SpecificRisk = &specificRisk
|
|
}
|
|
}
|
|
|
|
// ========== 解析quantiles字段 ==========
|
|
if quantiles, ok := alphaData["quantiles"].(map[string]interface{}); ok {
|
|
if topMinusBottom, ok := quantiles["topMinusBottom"].(float64); ok {
|
|
metrics.TopMinusBottom = &topMinusBottom
|
|
}
|
|
if topDecileReturn, ok := quantiles["topDecileReturn"].(float64); ok {
|
|
metrics.TopDecileReturn = &topDecileReturn
|
|
}
|
|
if bottomDecileReturn, ok := quantiles["bottomDecileReturn"].(float64); ok {
|
|
metrics.BottomDecileReturn = &bottomDecileReturn
|
|
}
|
|
if ic, ok := quantiles["ic"].(float64); ok {
|
|
metrics.IC = &ic
|
|
}
|
|
if icDecay, ok := quantiles["icDecay"].(float64); ok {
|
|
metrics.ICDecay = &icDecay
|
|
}
|
|
}
|
|
|
|
// ========== 解析其他字段 ==========
|
|
if totalReturn, ok := alphaData["totalReturn"].(float64); ok {
|
|
metrics.TotalReturn = &totalReturn
|
|
}
|
|
if capacity, ok := alphaData["capacity"].(float64); ok {
|
|
metrics.Capacity = &capacity
|
|
}
|
|
if fitness, ok := alphaData["fitness"].(float64); ok {
|
|
metrics.Fitness = &fitness
|
|
}
|
|
if instrumentCount, ok := alphaData["instrumentCount"].(float64); ok {
|
|
metrics.InstrumentCount = &instrumentCount
|
|
}
|
|
if startDate, ok := alphaData["startDate"].(string); ok {
|
|
metrics.StartDate = &startDate
|
|
}
|
|
if endDate, ok := alphaData["endDate"].(string); ok {
|
|
metrics.EndDate = &endDate
|
|
}
|
|
|
|
return metrics, nil
|
|
}
|
|
|
|
// formatTime 将秒数格式化为"xx分xx秒"或"xx秒"的格式
|
|
// 对应Python中的format_time方法
|
|
func formatTime(seconds float64) string {
|
|
if seconds < 60 {
|
|
return fmt.Sprintf("%.2f秒", seconds)
|
|
}
|
|
minutes := int(seconds / 60)
|
|
remainingSeconds := seconds - float64(minutes*60)
|
|
return fmt.Sprintf("%d分%.2f秒", minutes, remainingSeconds)
|
|
}
|
|
|
|
// Close 关闭客户端,释放资源
|
|
// 对应Python中的close方法
|
|
func (c *BrainClient) Close() error {
|
|
// Go的http.Client不需要显式关闭
|
|
// 但如果需要清理其他资源,可以在这里添加
|
|
c.httpClient = nil
|
|
c.isLoggedIn = false
|
|
return nil
|
|
}
|
|
|
|
// 辅助函数:打印分隔线(方便使用)
|
|
func printSeparator() {
|
|
fmt.Println(strings.Repeat("=", 60))
|
|
}
|
|
|