- DirectConnectServer: lifecycle management, auth token, port config - DirectConnectWebSocketHandler: WebSocket handler with session management - ServerSession: per-connection AgentLoop with permission forwarding - ServerMessage: 7-type JSON protocol (user/assistant/result/control/interrupt/keep_alive/error) - ServerModeAutoConfiguration: conditional Spring config with @EnableWebSocket - WebSocketServerConfig: endpoint registration and container settings - ClaudeCodeApplication: --server flag detection, web-application-type override - ClaudeCodeRunner: skip TUI in server mode, block on shutdown signal - pom.xml: added spring-boot-starter-websocket dependency Server mode: --server [--server-port=12321] [--server-token=xxx] Env vars: CLAUDE_CODE_SERVER_MODE, CLAUDE_CODE_SERVER_PORT, CLAUDE_CODE_SERVER_TOKEN Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>pull/1/head
parent
a926b31e34
commit
b98675f643
@ -1,18 +1,45 @@ |
|||||||
package com.claudecode; |
package com.claudecode; |
||||||
|
|
||||||
|
import com.claudecode.server.DirectConnectServer; |
||||||
import org.springframework.boot.SpringApplication; |
import org.springframework.boot.SpringApplication; |
||||||
import org.springframework.boot.autoconfigure.SpringBootApplication; |
import org.springframework.boot.autoconfigure.SpringBootApplication; |
||||||
|
|
||||||
|
import java.util.*; |
||||||
|
|
||||||
/** |
/** |
||||||
* Claude Code Java 版主入口。 |
* Claude Code Java 版主入口。 |
||||||
* <p> |
* <p> |
||||||
* 对应 claude-code/src/entrypoints/cli.tsx |
* 对应 claude-code/src/entrypoints/cli.tsx |
||||||
* 以 Spring Boot 应用启动,但关闭 Web 服务器(纯 CLI 模式)。 |
* <p> |
||||||
|
* 支持两种启动模式: |
||||||
|
* <ul> |
||||||
|
* <li>CLI 模式(默认)—— 关闭 Web 服务器,启动 TUI 交互</li> |
||||||
|
* <li>Server 模式(--server)—— 启动 WebSocket 服务器,无 TUI</li> |
||||||
|
* </ul> |
||||||
*/ |
*/ |
||||||
@SpringBootApplication |
@SpringBootApplication |
||||||
public class ClaudeCodeApplication { |
public class ClaudeCodeApplication { |
||||||
|
|
||||||
public static void main(String[] args) { |
public static void main(String[] args) { |
||||||
SpringApplication.run(ClaudeCodeApplication.class, args); |
SpringApplication app = new SpringApplication(ClaudeCodeApplication.class); |
||||||
|
|
||||||
|
if (DirectConnectServer.isServerMode(args)) { |
||||||
|
// Server Mode: 启用 Web 服务器 + WebSocket
|
||||||
|
int port = DirectConnectServer.parsePort(args); |
||||||
|
|
||||||
|
Map<String, Object> serverProps = new HashMap<>(); |
||||||
|
serverProps.put("spring.main.web-application-type", "servlet"); |
||||||
|
serverProps.put("server.port", port); |
||||||
|
serverProps.put("claude-code.server-mode", "true"); |
||||||
|
app.setDefaultProperties(serverProps); |
||||||
|
|
||||||
|
// 将原始参数保存到系统属性,供 ServerModeAutoConfiguration 解析
|
||||||
|
System.setProperty("claude-code.server-args", String.join(" ", args)); |
||||||
|
|
||||||
|
System.out.println("Starting in Server Mode on port " + port + "..."); |
||||||
|
} |
||||||
|
// CLI 模式使用 application.yml 中的 web-application-type: none
|
||||||
|
|
||||||
|
app.run(args); |
||||||
} |
} |
||||||
} |
} |
||||||
|
|||||||
@ -0,0 +1,226 @@ |
|||||||
|
package com.claudecode.server; |
||||||
|
|
||||||
|
import org.slf4j.Logger; |
||||||
|
import org.slf4j.LoggerFactory; |
||||||
|
import org.springframework.ai.chat.model.ChatModel; |
||||||
|
import org.springframework.web.socket.config.annotation.EnableWebSocket; |
||||||
|
|
||||||
|
import java.util.UUID; |
||||||
|
|
||||||
|
/** |
||||||
|
* Server Mode 生命周期管理 —— 对应 claude-code/src/server 的服务端核心。 |
||||||
|
* <p> |
||||||
|
* 管理服务端状态、认证 Token 生成、会话限制等。 |
||||||
|
* <p> |
||||||
|
* 启动流程: |
||||||
|
* <ol> |
||||||
|
* <li>ClaudeCodeApplication 检测 --server 参数</li> |
||||||
|
* <li>Spring Boot 启用 WebSocket 模式</li> |
||||||
|
* <li>DirectConnectServer 初始化,生成 auth token</li> |
||||||
|
* <li>WebSocket 端点开始接受连接</li> |
||||||
|
* </ol> |
||||||
|
* |
||||||
|
* 客户端连接方式: |
||||||
|
* <ul> |
||||||
|
* <li>WebSocket: ws://localhost:{port}/ws?token={auth_token}</li>
|
||||||
|
* <li>HTTP Header: Authorization: Bearer {auth_token}</li> |
||||||
|
* </ul> |
||||||
|
*/ |
||||||
|
public class DirectConnectServer { |
||||||
|
|
||||||
|
private static final Logger log = LoggerFactory.getLogger(DirectConnectServer.class); |
||||||
|
|
||||||
|
/** 默认端口 */ |
||||||
|
public static final int DEFAULT_PORT = 12321; |
||||||
|
|
||||||
|
/** 默认最大会话数 */ |
||||||
|
public static final int DEFAULT_MAX_SESSIONS = 5; |
||||||
|
|
||||||
|
private final int port; |
||||||
|
private final String authToken; |
||||||
|
private final int maxSessions; |
||||||
|
private final DirectConnectWebSocketHandler handler; |
||||||
|
|
||||||
|
private volatile boolean running = false; |
||||||
|
|
||||||
|
public DirectConnectServer(int port, String authToken, int maxSessions, |
||||||
|
DirectConnectWebSocketHandler handler) { |
||||||
|
this.port = port; |
||||||
|
this.authToken = authToken; |
||||||
|
this.maxSessions = maxSessions; |
||||||
|
this.handler = handler; |
||||||
|
} |
||||||
|
|
||||||
|
/** |
||||||
|
* 生成随机认证 Token。 |
||||||
|
*/ |
||||||
|
public static String generateAuthToken() { |
||||||
|
return UUID.randomUUID().toString().replace("-", ""); |
||||||
|
} |
||||||
|
|
||||||
|
/** |
||||||
|
* 标记服务器已启动。打印连接信息。 |
||||||
|
*/ |
||||||
|
public void onServerStarted() { |
||||||
|
running = true; |
||||||
|
printConnectionInfo(); |
||||||
|
} |
||||||
|
|
||||||
|
/** |
||||||
|
* 打印连接信息到控制台。 |
||||||
|
*/ |
||||||
|
public void printConnectionInfo() { |
||||||
|
String separator = "═".repeat(60); |
||||||
|
System.out.println(); |
||||||
|
System.out.println("╔" + separator + "╗"); |
||||||
|
System.out.println("║ Claude Code Java — Server Mode ║"); |
||||||
|
System.out.println("╠" + separator + "╣"); |
||||||
|
System.out.printf("║ WebSocket: ws://localhost:%d/ws%-24s║%n", port, ""); |
||||||
|
System.out.printf("║ Port: %-48s║%n", port); |
||||||
|
if (authToken != null && !authToken.isBlank()) { |
||||||
|
System.out.printf("║ Token: %-48s║%n", authToken); |
||||||
|
} else { |
||||||
|
System.out.printf("║ Auth: %-48s║%n", "disabled (no token)"); |
||||||
|
} |
||||||
|
System.out.printf("║ Max Sess: %-48s║%n", maxSessions); |
||||||
|
System.out.println("╠" + separator + "╣"); |
||||||
|
System.out.println("║ Connect with: ║"); |
||||||
|
if (authToken != null && !authToken.isBlank()) { |
||||||
|
System.out.printf("║ ws://localhost:%d/ws?token=%s ║%n", port, authToken.substring(0, Math.min(12, authToken.length())) + "..."); |
||||||
|
} else { |
||||||
|
System.out.printf("║ ws://localhost:%d/ws%-38s║%n", port, ""); |
||||||
|
} |
||||||
|
System.out.println("╚" + separator + "╝"); |
||||||
|
System.out.println(); |
||||||
|
} |
||||||
|
|
||||||
|
/** |
||||||
|
* 停止服务器,关闭所有会话。 |
||||||
|
*/ |
||||||
|
public void shutdown() { |
||||||
|
if (!running) return; |
||||||
|
running = false; |
||||||
|
log.info("Shutting down server..."); |
||||||
|
handler.closeAllSessions(); |
||||||
|
log.info("Server stopped. All sessions closed."); |
||||||
|
} |
||||||
|
|
||||||
|
// ==================== 静态工具方法 ====================
|
||||||
|
|
||||||
|
/** |
||||||
|
* 检查命令行参数是否包含 --server 标志。 |
||||||
|
*/ |
||||||
|
public static boolean isServerMode(String[] args) { |
||||||
|
if (args == null) return false; |
||||||
|
for (String arg : args) { |
||||||
|
if ("--server".equals(arg) || arg.startsWith("--server=")) { |
||||||
|
return true; |
||||||
|
} |
||||||
|
} |
||||||
|
// 也支持环境变量
|
||||||
|
String envMode = System.getenv("CLAUDE_CODE_SERVER_MODE"); |
||||||
|
return "true".equalsIgnoreCase(envMode) || "1".equals(envMode); |
||||||
|
} |
||||||
|
|
||||||
|
/** |
||||||
|
* 从命令行参数解析端口号。 |
||||||
|
*/ |
||||||
|
public static int parsePort(String[] args) { |
||||||
|
if (args != null) { |
||||||
|
for (String arg : args) { |
||||||
|
if (arg.startsWith("--server-port=")) { |
||||||
|
try { |
||||||
|
return Integer.parseInt(arg.substring("--server-port=".length())); |
||||||
|
} catch (NumberFormatException e) { |
||||||
|
log.warn("Invalid port number: {}", arg); |
||||||
|
} |
||||||
|
} |
||||||
|
if (arg.startsWith("--server=")) { |
||||||
|
try { |
||||||
|
return Integer.parseInt(arg.substring("--server=".length())); |
||||||
|
} catch (NumberFormatException e) { |
||||||
|
// --server=true 等非数字参数忽略
|
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
// 环境变量
|
||||||
|
String envPort = System.getenv("CLAUDE_CODE_SERVER_PORT"); |
||||||
|
if (envPort != null) { |
||||||
|
try { |
||||||
|
return Integer.parseInt(envPort); |
||||||
|
} catch (NumberFormatException e) { |
||||||
|
log.warn("Invalid CLAUDE_CODE_SERVER_PORT: {}", envPort); |
||||||
|
} |
||||||
|
} |
||||||
|
return DEFAULT_PORT; |
||||||
|
} |
||||||
|
|
||||||
|
/** |
||||||
|
* 从命令行参数或环境变量获取认证 Token。 |
||||||
|
* 如果未指定,生成随机 Token。 |
||||||
|
*/ |
||||||
|
public static String parseAuthToken(String[] args) { |
||||||
|
if (args != null) { |
||||||
|
for (String arg : args) { |
||||||
|
if (arg.startsWith("--server-token=")) { |
||||||
|
return arg.substring("--server-token=".length()); |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
String envToken = System.getenv("CLAUDE_CODE_SERVER_TOKEN"); |
||||||
|
if (envToken != null && !envToken.isBlank()) { |
||||||
|
return envToken; |
||||||
|
} |
||||||
|
// 未指定则生成随机 token
|
||||||
|
return generateAuthToken(); |
||||||
|
} |
||||||
|
|
||||||
|
/** |
||||||
|
* 从命令行参数解析最大会话数。 |
||||||
|
*/ |
||||||
|
public static int parseMaxSessions(String[] args) { |
||||||
|
if (args != null) { |
||||||
|
for (String arg : args) { |
||||||
|
if (arg.startsWith("--max-sessions=")) { |
||||||
|
try { |
||||||
|
return Integer.parseInt(arg.substring("--max-sessions=".length())); |
||||||
|
} catch (NumberFormatException e) { |
||||||
|
log.warn("Invalid max sessions: {}", arg); |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
String envMax = System.getenv("CLAUDE_CODE_MAX_SESSIONS"); |
||||||
|
if (envMax != null) { |
||||||
|
try { |
||||||
|
return Integer.parseInt(envMax); |
||||||
|
} catch (NumberFormatException e) { |
||||||
|
log.warn("Invalid CLAUDE_CODE_MAX_SESSIONS: {}", envMax); |
||||||
|
} |
||||||
|
} |
||||||
|
return DEFAULT_MAX_SESSIONS; |
||||||
|
} |
||||||
|
|
||||||
|
// ==================== Getters ====================
|
||||||
|
|
||||||
|
public int getPort() { |
||||||
|
return port; |
||||||
|
} |
||||||
|
|
||||||
|
public String getAuthToken() { |
||||||
|
return authToken; |
||||||
|
} |
||||||
|
|
||||||
|
public int getMaxSessions() { |
||||||
|
return maxSessions; |
||||||
|
} |
||||||
|
|
||||||
|
public boolean isRunning() { |
||||||
|
return running; |
||||||
|
} |
||||||
|
|
||||||
|
public int getActiveSessionCount() { |
||||||
|
return handler.getActiveSessionCount(); |
||||||
|
} |
||||||
|
} |
||||||
@ -0,0 +1,180 @@ |
|||||||
|
package com.claudecode.server; |
||||||
|
|
||||||
|
import com.claudecode.core.AgentLoop; |
||||||
|
import com.claudecode.core.TokenTracker; |
||||||
|
import com.claudecode.tool.ToolContext; |
||||||
|
import com.claudecode.tool.ToolRegistry; |
||||||
|
import org.slf4j.Logger; |
||||||
|
import org.slf4j.LoggerFactory; |
||||||
|
import org.springframework.ai.chat.model.ChatModel; |
||||||
|
import org.springframework.web.socket.CloseStatus; |
||||||
|
import org.springframework.web.socket.TextMessage; |
||||||
|
import org.springframework.web.socket.WebSocketSession; |
||||||
|
import org.springframework.web.socket.handler.TextWebSocketHandler; |
||||||
|
|
||||||
|
import java.io.IOException; |
||||||
|
import java.net.URI; |
||||||
|
import java.util.Map; |
||||||
|
import java.util.UUID; |
||||||
|
import java.util.concurrent.ConcurrentHashMap; |
||||||
|
|
||||||
|
/** |
||||||
|
* WebSocket Handler —— Server Mode 的核心连接处理器。 |
||||||
|
* <p> |
||||||
|
* 对应 claude-code/src/server/directConnectManager.ts 的服务端实现。 |
||||||
|
* <p> |
||||||
|
* 每个 WebSocket 连接创建一个 {@link ServerSession},包含独立的 AgentLoop。 |
||||||
|
* 支持 Bearer Token 认证和多会话管理。 |
||||||
|
*/ |
||||||
|
public class DirectConnectWebSocketHandler extends TextWebSocketHandler { |
||||||
|
|
||||||
|
private static final Logger log = LoggerFactory.getLogger(DirectConnectWebSocketHandler.class); |
||||||
|
|
||||||
|
private final ChatModel chatModel; |
||||||
|
private final ToolRegistry toolRegistry; |
||||||
|
private final ToolContext toolContext; |
||||||
|
private final String systemPrompt; |
||||||
|
private final String authToken; |
||||||
|
private final int maxSessions; |
||||||
|
private final String model; |
||||||
|
|
||||||
|
/** 活跃会话:WebSocket sessionId → ServerSession */ |
||||||
|
private final ConcurrentHashMap<String, ServerSession> sessions = new ConcurrentHashMap<>(); |
||||||
|
|
||||||
|
public DirectConnectWebSocketHandler(ChatModel chatModel, ToolRegistry toolRegistry, |
||||||
|
ToolContext toolContext, String systemPrompt, |
||||||
|
String authToken, int maxSessions, String model) { |
||||||
|
this.chatModel = chatModel; |
||||||
|
this.toolRegistry = toolRegistry; |
||||||
|
this.toolContext = toolContext; |
||||||
|
this.systemPrompt = systemPrompt; |
||||||
|
this.authToken = authToken; |
||||||
|
this.maxSessions = maxSessions; |
||||||
|
this.model = model; |
||||||
|
} |
||||||
|
|
||||||
|
@Override |
||||||
|
public void afterConnectionEstablished(WebSocketSession wsSession) throws Exception { |
||||||
|
log.info("WebSocket connection attempt from: {}", wsSession.getRemoteAddress()); |
||||||
|
|
||||||
|
// Bearer Token 认证
|
||||||
|
if (authToken != null && !authToken.isBlank()) { |
||||||
|
if (!validateAuth(wsSession)) { |
||||||
|
wsSession.close(new CloseStatus(4001, "Unauthorized")); |
||||||
|
return; |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// 会话数限制
|
||||||
|
if (maxSessions > 0 && sessions.size() >= maxSessions) { |
||||||
|
String errorMsg = ServerMessage.error(null, "max_sessions", |
||||||
|
"Maximum sessions (" + maxSessions + ") reached"); |
||||||
|
wsSession.sendMessage(new TextMessage(errorMsg)); |
||||||
|
wsSession.close(new CloseStatus(4002, "Max sessions reached")); |
||||||
|
return; |
||||||
|
} |
||||||
|
|
||||||
|
// 创建新的 AgentLoop(每个连接独立实例)
|
||||||
|
String sessionId = UUID.randomUUID().toString(); |
||||||
|
AgentLoop agentLoop = new AgentLoop(chatModel, toolRegistry, toolContext, systemPrompt, new TokenTracker()); |
||||||
|
|
||||||
|
// 创建 ServerSession
|
||||||
|
ServerSession serverSession = new ServerSession(sessionId, agentLoop, wsSession); |
||||||
|
sessions.put(wsSession.getId(), serverSession); |
||||||
|
|
||||||
|
// 发送初始化消息
|
||||||
|
serverSession.sendInitMessage(model); |
||||||
|
|
||||||
|
log.info("Server session created: {} (ws: {}, total: {})", |
||||||
|
sessionId, wsSession.getId(), sessions.size()); |
||||||
|
} |
||||||
|
|
||||||
|
@Override |
||||||
|
protected void handleTextMessage(WebSocketSession wsSession, TextMessage message) throws Exception { |
||||||
|
ServerSession session = sessions.get(wsSession.getId()); |
||||||
|
if (session == null) { |
||||||
|
log.warn("Message from unknown session: {}", wsSession.getId()); |
||||||
|
wsSession.close(new CloseStatus(4003, "Unknown session")); |
||||||
|
return; |
||||||
|
} |
||||||
|
|
||||||
|
session.handleMessage(message.getPayload()); |
||||||
|
} |
||||||
|
|
||||||
|
@Override |
||||||
|
public void afterConnectionClosed(WebSocketSession wsSession, CloseStatus status) throws Exception { |
||||||
|
ServerSession session = sessions.remove(wsSession.getId()); |
||||||
|
if (session != null) { |
||||||
|
session.close(); |
||||||
|
log.info("Session closed: {} (status: {}, remaining: {})", |
||||||
|
session.getSessionId(), status, sessions.size()); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
@Override |
||||||
|
public void handleTransportError(WebSocketSession wsSession, Throwable exception) throws Exception { |
||||||
|
log.error("WebSocket transport error for {}: {}", wsSession.getId(), exception.getMessage()); |
||||||
|
ServerSession session = sessions.get(wsSession.getId()); |
||||||
|
if (session != null) { |
||||||
|
try { |
||||||
|
String errorMsg = ServerMessage.error(session.getSessionId(), |
||||||
|
"transport_error", exception.getMessage()); |
||||||
|
wsSession.sendMessage(new TextMessage(errorMsg)); |
||||||
|
} catch (Exception ignored) {} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
/** |
||||||
|
* 验证 WebSocket 连接的 Bearer Token。 |
||||||
|
* <p> |
||||||
|
* 支持两种方式: |
||||||
|
* <ul> |
||||||
|
* <li>HTTP Header: {@code Authorization: Bearer <token>}</li> |
||||||
|
* <li>Query Parameter: {@code ?token=<token>}</li> |
||||||
|
* </ul> |
||||||
|
*/ |
||||||
|
private boolean validateAuth(WebSocketSession wsSession) { |
||||||
|
// 方式1: 从 HTTP Header 获取
|
||||||
|
var headers = wsSession.getHandshakeHeaders(); |
||||||
|
String authHeader = headers.getFirst("Authorization"); |
||||||
|
if (authHeader != null && authHeader.startsWith("Bearer ")) { |
||||||
|
String token = authHeader.substring(7); |
||||||
|
if (authToken.equals(token)) { |
||||||
|
return true; |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// 方式2: 从 Query Parameter 获取
|
||||||
|
URI uri = wsSession.getUri(); |
||||||
|
if (uri != null && uri.getQuery() != null) { |
||||||
|
String query = uri.getQuery(); |
||||||
|
for (String param : query.split("&")) { |
||||||
|
if (param.startsWith("token=")) { |
||||||
|
String token = param.substring(6); |
||||||
|
if (authToken.equals(token)) { |
||||||
|
return true; |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
log.warn("Authentication failed for connection from: {}", wsSession.getRemoteAddress()); |
||||||
|
return false; |
||||||
|
} |
||||||
|
|
||||||
|
/** |
||||||
|
* 关闭所有会话。 |
||||||
|
*/ |
||||||
|
public void closeAllSessions() { |
||||||
|
sessions.values().forEach(ServerSession::close); |
||||||
|
sessions.clear(); |
||||||
|
} |
||||||
|
|
||||||
|
public int getActiveSessionCount() { |
||||||
|
return sessions.size(); |
||||||
|
} |
||||||
|
|
||||||
|
public Map<String, ServerSession> getSessions() { |
||||||
|
return Map.copyOf(sessions); |
||||||
|
} |
||||||
|
} |
||||||
@ -0,0 +1,220 @@ |
|||||||
|
package com.claudecode.server; |
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonInclude; |
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty; |
||||||
|
import com.fasterxml.jackson.core.JsonProcessingException; |
||||||
|
import com.fasterxml.jackson.databind.JsonNode; |
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper; |
||||||
|
|
||||||
|
import java.util.List; |
||||||
|
import java.util.Map; |
||||||
|
import java.util.UUID; |
||||||
|
|
||||||
|
/** |
||||||
|
* Server Mode 消息协议 —— 对应 claude-code/src/server/types.ts |
||||||
|
* <p> |
||||||
|
* WebSocket 上的 JSON 消息,7 种类型: |
||||||
|
* <ul> |
||||||
|
* <li><b>user</b> — 客户端发送的用户消息</li> |
||||||
|
* <li><b>assistant</b> — 服务端回复的助手消息</li> |
||||||
|
* <li><b>result</b> — 轮次结束的最终结果</li> |
||||||
|
* <li><b>control_request</b> — 权限请求(服务端→客户端)</li> |
||||||
|
* <li><b>control_response</b> — 权限回复(客户端→服务端)</li> |
||||||
|
* <li><b>interrupt</b> — 中断信号</li> |
||||||
|
* <li><b>keep_alive</b> — 心跳</li> |
||||||
|
* </ul> |
||||||
|
*/ |
||||||
|
public class ServerMessage { |
||||||
|
|
||||||
|
private static final ObjectMapper MAPPER = new ObjectMapper() |
||||||
|
.setSerializationInclusion(JsonInclude.Include.NON_NULL); |
||||||
|
|
||||||
|
// ==================== 消息类型常量 ====================
|
||||||
|
|
||||||
|
public static final String TYPE_USER = "user"; |
||||||
|
public static final String TYPE_ASSISTANT = "assistant"; |
||||||
|
public static final String TYPE_RESULT = "result"; |
||||||
|
public static final String TYPE_CONTROL_REQUEST = "control_request"; |
||||||
|
public static final String TYPE_CONTROL_RESPONSE = "control_response"; |
||||||
|
public static final String TYPE_INTERRUPT = "interrupt"; |
||||||
|
public static final String TYPE_KEEP_ALIVE = "keep_alive"; |
||||||
|
public static final String TYPE_SYSTEM = "system"; |
||||||
|
public static final String TYPE_ERROR = "error"; |
||||||
|
public static final String TYPE_TOOL_USE = "tool_use"; |
||||||
|
|
||||||
|
// ==================== 通用消息 ====================
|
||||||
|
|
||||||
|
@JsonInclude(JsonInclude.Include.NON_NULL) |
||||||
|
public record Envelope( |
||||||
|
String type, |
||||||
|
@JsonProperty("session_id") String sessionId, |
||||||
|
Object payload |
||||||
|
) { |
||||||
|
public String toJson() throws JsonProcessingException { |
||||||
|
return MAPPER.writeValueAsString(this); |
||||||
|
} |
||||||
|
|
||||||
|
public static Envelope fromJson(String json) throws JsonProcessingException { |
||||||
|
return MAPPER.readValue(json, Envelope.class); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// ==================== 客户端→服务端 ====================
|
||||||
|
|
||||||
|
/** 用户消息 */ |
||||||
|
@JsonInclude(JsonInclude.Include.NON_NULL) |
||||||
|
public record UserMessage( |
||||||
|
String content, |
||||||
|
@JsonProperty("parent_tool_use_id") String parentToolUseId |
||||||
|
) {} |
||||||
|
|
||||||
|
/** 权限回复 */ |
||||||
|
@JsonInclude(JsonInclude.Include.NON_NULL) |
||||||
|
public record ControlResponse( |
||||||
|
@JsonProperty("request_id") String requestId, |
||||||
|
String behavior, // "allow" or "deny"
|
||||||
|
String message, |
||||||
|
@JsonProperty("updated_input") Map<String, Object> updatedInput |
||||||
|
) {} |
||||||
|
|
||||||
|
// ==================== 服务端→客户端 ====================
|
||||||
|
|
||||||
|
/** 助手消息(流式或完整) */ |
||||||
|
@JsonInclude(JsonInclude.Include.NON_NULL) |
||||||
|
public record AssistantPayload( |
||||||
|
String text, |
||||||
|
String uuid, |
||||||
|
boolean streaming, |
||||||
|
@JsonProperty("tool_calls") List<ToolCallInfo> toolCalls |
||||||
|
) {} |
||||||
|
|
||||||
|
/** 工具调用信息 */ |
||||||
|
@JsonInclude(JsonInclude.Include.NON_NULL) |
||||||
|
public record ToolCallInfo( |
||||||
|
String id, |
||||||
|
String name, |
||||||
|
@JsonProperty("arguments") String arguments, |
||||||
|
String result, |
||||||
|
String status // "running", "completed", "error"
|
||||||
|
) {} |
||||||
|
|
||||||
|
/** 轮次结果 */ |
||||||
|
@JsonInclude(JsonInclude.Include.NON_NULL) |
||||||
|
public record ResultPayload( |
||||||
|
String text, |
||||||
|
@JsonProperty("tool_calls_count") int toolCallsCount, |
||||||
|
@JsonProperty("prompt_tokens") long promptTokens, |
||||||
|
@JsonProperty("completion_tokens") long completionTokens |
||||||
|
) {} |
||||||
|
|
||||||
|
/** 权限请求 */ |
||||||
|
@JsonInclude(JsonInclude.Include.NON_NULL) |
||||||
|
public record ControlRequest( |
||||||
|
@JsonProperty("request_id") String requestId, |
||||||
|
String subtype, // "can_use_tool"
|
||||||
|
@JsonProperty("tool_name") String toolName, |
||||||
|
@JsonProperty("tool_input") String toolInput, |
||||||
|
@JsonProperty("activity_description") String activityDescription |
||||||
|
) {} |
||||||
|
|
||||||
|
/** 系统事件 */ |
||||||
|
@JsonInclude(JsonInclude.Include.NON_NULL) |
||||||
|
public record SystemEvent( |
||||||
|
String subtype, // "init", "session_ready", "error"
|
||||||
|
String model, |
||||||
|
String message, |
||||||
|
@JsonProperty("session_id") String sessionId |
||||||
|
) {} |
||||||
|
|
||||||
|
/** 错误 */ |
||||||
|
@JsonInclude(JsonInclude.Include.NON_NULL) |
||||||
|
public record ErrorPayload( |
||||||
|
String code, |
||||||
|
String message |
||||||
|
) {} |
||||||
|
|
||||||
|
// ==================== 工厂方法 ====================
|
||||||
|
|
||||||
|
public static String userMessage(String sessionId, String content) throws JsonProcessingException { |
||||||
|
var envelope = new Envelope(TYPE_USER, sessionId, new UserMessage(content, null)); |
||||||
|
return envelope.toJson(); |
||||||
|
} |
||||||
|
|
||||||
|
public static String assistantMessage(String sessionId, String text, boolean streaming) throws JsonProcessingException { |
||||||
|
var payload = new AssistantPayload(text, UUID.randomUUID().toString(), streaming, null); |
||||||
|
var envelope = new Envelope(TYPE_ASSISTANT, sessionId, payload); |
||||||
|
return envelope.toJson(); |
||||||
|
} |
||||||
|
|
||||||
|
public static String assistantToolUse(String sessionId, List<ToolCallInfo> toolCalls) throws JsonProcessingException { |
||||||
|
var payload = new AssistantPayload(null, UUID.randomUUID().toString(), false, toolCalls); |
||||||
|
var envelope = new Envelope(TYPE_TOOL_USE, sessionId, payload); |
||||||
|
return envelope.toJson(); |
||||||
|
} |
||||||
|
|
||||||
|
public static String resultMessage(String sessionId, String text, int toolCallsCount, |
||||||
|
long promptTokens, long completionTokens) throws JsonProcessingException { |
||||||
|
var payload = new ResultPayload(text, toolCallsCount, promptTokens, completionTokens); |
||||||
|
var envelope = new Envelope(TYPE_RESULT, sessionId, payload); |
||||||
|
return envelope.toJson(); |
||||||
|
} |
||||||
|
|
||||||
|
public static String controlRequest(String sessionId, String toolName, String toolInput, |
||||||
|
String activityDescription) throws JsonProcessingException { |
||||||
|
var payload = new ControlRequest( |
||||||
|
UUID.randomUUID().toString(), "can_use_tool", |
||||||
|
toolName, toolInput, activityDescription); |
||||||
|
var envelope = new Envelope(TYPE_CONTROL_REQUEST, sessionId, payload); |
||||||
|
return envelope.toJson(); |
||||||
|
} |
||||||
|
|
||||||
|
public static String controlResponse(String sessionId, String requestId, String behavior, |
||||||
|
String message) throws JsonProcessingException { |
||||||
|
var payload = new ControlResponse(requestId, behavior, message, null); |
||||||
|
var envelope = new Envelope(TYPE_CONTROL_RESPONSE, sessionId, payload); |
||||||
|
return envelope.toJson(); |
||||||
|
} |
||||||
|
|
||||||
|
public static String interrupt(String sessionId) throws JsonProcessingException { |
||||||
|
var envelope = new Envelope(TYPE_INTERRUPT, sessionId, null); |
||||||
|
return envelope.toJson(); |
||||||
|
} |
||||||
|
|
||||||
|
public static String keepAlive(String sessionId) throws JsonProcessingException { |
||||||
|
var envelope = new Envelope(TYPE_KEEP_ALIVE, sessionId, null); |
||||||
|
return envelope.toJson(); |
||||||
|
} |
||||||
|
|
||||||
|
public static String systemEvent(String sessionId, String subtype, String model, String message) throws JsonProcessingException { |
||||||
|
var payload = new SystemEvent(subtype, model, message, sessionId); |
||||||
|
var envelope = new Envelope(TYPE_SYSTEM, sessionId, payload); |
||||||
|
return envelope.toJson(); |
||||||
|
} |
||||||
|
|
||||||
|
public static String error(String sessionId, String code, String message) throws JsonProcessingException { |
||||||
|
var payload = new ErrorPayload(code, message); |
||||||
|
var envelope = new Envelope(TYPE_ERROR, sessionId, payload); |
||||||
|
return envelope.toJson(); |
||||||
|
} |
||||||
|
|
||||||
|
// ==================== 解析工具 ====================
|
||||||
|
|
||||||
|
public static String getType(String json) throws JsonProcessingException { |
||||||
|
JsonNode node = MAPPER.readTree(json); |
||||||
|
return node.has("type") ? node.get("type").asText() : null; |
||||||
|
} |
||||||
|
|
||||||
|
public static JsonNode getPayload(String json) throws JsonProcessingException { |
||||||
|
JsonNode node = MAPPER.readTree(json); |
||||||
|
return node.get("payload"); |
||||||
|
} |
||||||
|
|
||||||
|
public static String getSessionId(String json) throws JsonProcessingException { |
||||||
|
JsonNode node = MAPPER.readTree(json); |
||||||
|
return node.has("session_id") ? node.get("session_id").asText() : null; |
||||||
|
} |
||||||
|
|
||||||
|
public static ObjectMapper mapper() { |
||||||
|
return MAPPER; |
||||||
|
} |
||||||
|
} |
||||||
@ -0,0 +1,93 @@ |
|||||||
|
package com.claudecode.server; |
||||||
|
|
||||||
|
import com.claudecode.tool.ToolContext; |
||||||
|
import com.claudecode.tool.ToolRegistry; |
||||||
|
import org.slf4j.Logger; |
||||||
|
import org.slf4j.LoggerFactory; |
||||||
|
import org.springframework.ai.chat.model.ChatModel; |
||||||
|
import org.springframework.beans.factory.annotation.Autowired; |
||||||
|
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; |
||||||
|
import org.springframework.context.annotation.Bean; |
||||||
|
import org.springframework.context.annotation.Configuration; |
||||||
|
import org.springframework.context.annotation.Lazy; |
||||||
|
import org.springframework.web.socket.config.annotation.EnableWebSocket; |
||||||
|
import org.springframework.web.socket.config.annotation.WebSocketConfigurer; |
||||||
|
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry; |
||||||
|
import org.springframework.web.socket.server.standard.ServletServerContainerFactoryBean; |
||||||
|
|
||||||
|
/** |
||||||
|
* Server Mode 的自动配置 —— 仅在 claude-code.server-mode=true 时激活。 |
||||||
|
* <p> |
||||||
|
* 由 {@link com.claudecode.ClaudeCodeApplication#main(String[])} 在检测到 --server 参数时 |
||||||
|
* 设置 {@code claude-code.server-mode=true},触发此配置类加载。 |
||||||
|
* <p> |
||||||
|
* 注册的 Bean: |
||||||
|
* <ul> |
||||||
|
* <li>{@link DirectConnectWebSocketHandler} — WebSocket 消息处理</li> |
||||||
|
* <li>{@link DirectConnectServer} — 服务器生命周期管理</li> |
||||||
|
* <li>{@link WebSocketConfigurer} — WebSocket 端点注册</li> |
||||||
|
* </ul> |
||||||
|
*/ |
||||||
|
@Configuration |
||||||
|
@ConditionalOnProperty(name = "claude-code.server-mode", havingValue = "true") |
||||||
|
@EnableWebSocket |
||||||
|
public class ServerModeAutoConfiguration implements WebSocketConfigurer { |
||||||
|
|
||||||
|
private static final Logger log = LoggerFactory.getLogger(ServerModeAutoConfiguration.class); |
||||||
|
|
||||||
|
@Autowired @Lazy |
||||||
|
private DirectConnectWebSocketHandler wsHandler; |
||||||
|
|
||||||
|
@Bean |
||||||
|
public DirectConnectWebSocketHandler directConnectWebSocketHandler( |
||||||
|
ChatModel activeChatModel, |
||||||
|
ToolRegistry toolRegistry, |
||||||
|
ToolContext toolContext, |
||||||
|
String systemPrompt) { |
||||||
|
|
||||||
|
String[] args = getApplicationArgs(); |
||||||
|
String authToken = DirectConnectServer.parseAuthToken(args); |
||||||
|
int maxSessions = DirectConnectServer.parseMaxSessions(args); |
||||||
|
String model = System.getenv("AI_MODEL") != null |
||||||
|
? System.getenv("AI_MODEL") : "claude-sonnet-4-20250514"; |
||||||
|
|
||||||
|
log.info("Creating DirectConnect WebSocket handler (maxSessions={}, auth={})", |
||||||
|
maxSessions, authToken != null && !authToken.isBlank() ? "enabled" : "disabled"); |
||||||
|
|
||||||
|
return new DirectConnectWebSocketHandler( |
||||||
|
activeChatModel, toolRegistry, toolContext, |
||||||
|
systemPrompt, authToken, maxSessions, model); |
||||||
|
} |
||||||
|
|
||||||
|
@Bean |
||||||
|
public DirectConnectServer directConnectServer(DirectConnectWebSocketHandler handler) { |
||||||
|
String[] args = getApplicationArgs(); |
||||||
|
int port = DirectConnectServer.parsePort(args); |
||||||
|
String authToken = DirectConnectServer.parseAuthToken(args); |
||||||
|
int maxSessions = DirectConnectServer.parseMaxSessions(args); |
||||||
|
|
||||||
|
DirectConnectServer server = new DirectConnectServer(port, authToken, maxSessions, handler); |
||||||
|
server.onServerStarted(); |
||||||
|
return server; |
||||||
|
} |
||||||
|
|
||||||
|
@Bean |
||||||
|
public ServletServerContainerFactoryBean createWebSocketContainer() { |
||||||
|
return WebSocketServerConfig.createWebSocketContainer(); |
||||||
|
} |
||||||
|
|
||||||
|
@Override |
||||||
|
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { |
||||||
|
registry.addHandler(wsHandler, "/ws") |
||||||
|
.setAllowedOrigins("*"); |
||||||
|
log.info("WebSocket handler registered at /ws endpoint"); |
||||||
|
} |
||||||
|
|
||||||
|
private String[] getApplicationArgs() { |
||||||
|
String argsStr = System.getProperty("claude-code.server-args"); |
||||||
|
if (argsStr != null && !argsStr.isBlank()) { |
||||||
|
return argsStr.split("\\s+"); |
||||||
|
} |
||||||
|
return new String[0]; |
||||||
|
} |
||||||
|
} |
||||||
@ -0,0 +1,300 @@ |
|||||||
|
package com.claudecode.server; |
||||||
|
|
||||||
|
import com.claudecode.core.AgentLoop; |
||||||
|
import com.claudecode.permission.PermissionTypes.PermissionChoice; |
||||||
|
import com.fasterxml.jackson.databind.JsonNode; |
||||||
|
import org.slf4j.Logger; |
||||||
|
import org.slf4j.LoggerFactory; |
||||||
|
import org.springframework.web.socket.CloseStatus; |
||||||
|
import org.springframework.web.socket.TextMessage; |
||||||
|
import org.springframework.web.socket.WebSocketSession; |
||||||
|
|
||||||
|
import java.io.IOException; |
||||||
|
import java.util.UUID; |
||||||
|
import java.util.concurrent.*; |
||||||
|
import java.util.concurrent.atomic.AtomicBoolean; |
||||||
|
|
||||||
|
/** |
||||||
|
* 单个 WebSocket 客户端的会话管理 —— 对应 TS 端 DirectConnectSessionManager 的服务端对应。 |
||||||
|
* <p> |
||||||
|
* 每个 ServerSession 包装一个 AgentLoop 实例, |
||||||
|
* 将 WebSocket 消息转化为 AgentLoop 的调用,并将结果回传。 |
||||||
|
* <p> |
||||||
|
* 权限请求通过 WebSocket 转发给客户端: |
||||||
|
* <ol> |
||||||
|
* <li>AgentLoop 触发 onPermissionRequest 回调</li> |
||||||
|
* <li>ServerSession 发送 control_request 消息给客户端</li> |
||||||
|
* <li>客户端回复 control_response 消息</li> |
||||||
|
* <li>ServerSession 将结果返回给 AgentLoop</li> |
||||||
|
* </ol> |
||||||
|
*/ |
||||||
|
public class ServerSession { |
||||||
|
|
||||||
|
private static final Logger log = LoggerFactory.getLogger(ServerSession.class); |
||||||
|
|
||||||
|
private final String sessionId; |
||||||
|
private final AgentLoop agentLoop; |
||||||
|
private final WebSocketSession wsSession; |
||||||
|
private final AtomicBoolean processing = new AtomicBoolean(false); |
||||||
|
|
||||||
|
/** 权限请求的异步等待队列:requestId → CompletableFuture<PermissionChoice> */ |
||||||
|
private final ConcurrentHashMap<String, CompletableFuture<PermissionChoice>> pendingPermissions |
||||||
|
= new ConcurrentHashMap<>(); |
||||||
|
|
||||||
|
/** 会话线程池(处理用户消息,每个会话一个线程) */ |
||||||
|
private final ExecutorService sessionExecutor = Executors.newSingleThreadExecutor( |
||||||
|
Thread.ofVirtual().name("server-session-", 0).factory() |
||||||
|
); |
||||||
|
|
||||||
|
/** Keep-alive 调度器 */ |
||||||
|
private final ScheduledExecutorService keepAliveScheduler = Executors.newSingleThreadScheduledExecutor( |
||||||
|
Thread.ofVirtual().name("keep-alive-", 0).factory() |
||||||
|
); |
||||||
|
|
||||||
|
private ScheduledFuture<?> keepAliveTask; |
||||||
|
|
||||||
|
public ServerSession(String sessionId, AgentLoop agentLoop, WebSocketSession wsSession) { |
||||||
|
this.sessionId = sessionId; |
||||||
|
this.agentLoop = agentLoop; |
||||||
|
this.wsSession = wsSession; |
||||||
|
|
||||||
|
// 注册 AgentLoop 回调:将事件转发到 WebSocket
|
||||||
|
setupAgentCallbacks(); |
||||||
|
startKeepAlive(); |
||||||
|
} |
||||||
|
|
||||||
|
private void setupAgentCallbacks() { |
||||||
|
// 助手文本回调 → 发送 assistant 消息
|
||||||
|
agentLoop.setOnAssistantMessage(text -> { |
||||||
|
try { |
||||||
|
sendMessage(ServerMessage.assistantMessage(sessionId, text, false)); |
||||||
|
} catch (Exception e) { |
||||||
|
log.error("[{}] Failed to send assistant message", sessionId, e); |
||||||
|
} |
||||||
|
}); |
||||||
|
|
||||||
|
// 流式输出回调 → 发送流式 assistant 消息(使用 onStreamStart)
|
||||||
|
agentLoop.setOnStreamStart(() -> { |
||||||
|
// 不需要特殊处理,流式 token 通过 streaming consumer 发送
|
||||||
|
}); |
||||||
|
|
||||||
|
// 工具事件回调 → 发送 tool_use 消息
|
||||||
|
agentLoop.setOnToolEvent(event -> { |
||||||
|
try { |
||||||
|
var toolInfo = new ServerMessage.ToolCallInfo( |
||||||
|
null, event.toolName(), event.arguments(), event.result(), |
||||||
|
switch (event.phase()) { |
||||||
|
case START -> "running"; |
||||||
|
case END -> "completed"; |
||||||
|
case PROGRESS -> "running"; |
||||||
|
} |
||||||
|
); |
||||||
|
sendMessage(ServerMessage.assistantToolUse(sessionId, java.util.List.of(toolInfo))); |
||||||
|
} catch (Exception e) { |
||||||
|
log.error("[{}] Failed to send tool event", sessionId, e); |
||||||
|
} |
||||||
|
}); |
||||||
|
|
||||||
|
// 权限请求回调 → 转发到 WebSocket 客户端
|
||||||
|
agentLoop.setOnPermissionRequest(req -> { |
||||||
|
try { |
||||||
|
return forwardPermissionRequest(req); |
||||||
|
} catch (Exception e) { |
||||||
|
log.error("[{}] Permission request failed", sessionId, e); |
||||||
|
return PermissionChoice.DENY_ONCE; |
||||||
|
} |
||||||
|
}); |
||||||
|
|
||||||
|
// Thinking 内容 → 发送系统消息
|
||||||
|
agentLoop.setOnThinkingContent(thinking -> { |
||||||
|
try { |
||||||
|
sendMessage(ServerMessage.systemEvent(sessionId, "thinking", null, thinking)); |
||||||
|
} catch (Exception e) { |
||||||
|
log.debug("[{}] Failed to send thinking content", sessionId, e); |
||||||
|
} |
||||||
|
}); |
||||||
|
} |
||||||
|
|
||||||
|
/** |
||||||
|
* 将权限请求转发到 WebSocket 客户端,同步等待响应。 |
||||||
|
*/ |
||||||
|
private PermissionChoice forwardPermissionRequest(AgentLoop.PermissionRequest req) { |
||||||
|
String requestId = UUID.randomUUID().toString(); |
||||||
|
CompletableFuture<PermissionChoice> future = new CompletableFuture<>(); |
||||||
|
pendingPermissions.put(requestId, future); |
||||||
|
|
||||||
|
try { |
||||||
|
// 发送 control_request
|
||||||
|
String msg = ServerMessage.controlRequest( |
||||||
|
sessionId, req.toolName(), req.arguments(), req.activityDescription()); |
||||||
|
sendMessage(msg); |
||||||
|
|
||||||
|
// 等待客户端响应(超时 60 秒)
|
||||||
|
return future.get(60, TimeUnit.SECONDS); |
||||||
|
} catch (TimeoutException e) { |
||||||
|
log.warn("[{}] Permission request timed out for {}", sessionId, req.toolName()); |
||||||
|
return PermissionChoice.DENY_ONCE; |
||||||
|
} catch (Exception e) { |
||||||
|
log.error("[{}] Permission request error", sessionId, e); |
||||||
|
return PermissionChoice.DENY_ONCE; |
||||||
|
} finally { |
||||||
|
pendingPermissions.remove(requestId); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
/** |
||||||
|
* 处理客户端发来的消息。 |
||||||
|
*/ |
||||||
|
public void handleMessage(String rawJson) { |
||||||
|
try { |
||||||
|
String type = ServerMessage.getType(rawJson); |
||||||
|
if (type == null) { |
||||||
|
sendMessage(ServerMessage.error(sessionId, "invalid_message", "Missing message type")); |
||||||
|
return; |
||||||
|
} |
||||||
|
|
||||||
|
switch (type) { |
||||||
|
case ServerMessage.TYPE_USER -> handleUserMessage(rawJson); |
||||||
|
case ServerMessage.TYPE_CONTROL_RESPONSE -> handleControlResponse(rawJson); |
||||||
|
case ServerMessage.TYPE_INTERRUPT -> handleInterrupt(); |
||||||
|
case ServerMessage.TYPE_KEEP_ALIVE -> {} // 忽略
|
||||||
|
default -> sendMessage(ServerMessage.error(sessionId, "unknown_type", |
||||||
|
"Unknown message type: " + type)); |
||||||
|
} |
||||||
|
} catch (Exception e) { |
||||||
|
log.error("[{}] Message handling error", sessionId, e); |
||||||
|
try { |
||||||
|
sendMessage(ServerMessage.error(sessionId, "internal_error", e.getMessage())); |
||||||
|
} catch (Exception ignored) {} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
private void handleUserMessage(String rawJson) throws Exception { |
||||||
|
if (processing.get()) { |
||||||
|
sendMessage(ServerMessage.error(sessionId, "busy", |
||||||
|
"Session is currently processing a message. Send interrupt first.")); |
||||||
|
return; |
||||||
|
} |
||||||
|
|
||||||
|
JsonNode payload = ServerMessage.getPayload(rawJson); |
||||||
|
if (payload == null || !payload.has("content")) { |
||||||
|
sendMessage(ServerMessage.error(sessionId, "invalid_payload", "Missing content in user message")); |
||||||
|
return; |
||||||
|
} |
||||||
|
|
||||||
|
String content = payload.get("content").asText(); |
||||||
|
processing.set(true); |
||||||
|
|
||||||
|
// 在虚拟线程中异步执行 AgentLoop
|
||||||
|
sessionExecutor.submit(() -> { |
||||||
|
try { |
||||||
|
// 使用流式模式,将每个 token 实时转发
|
||||||
|
String result = agentLoop.runStreaming(content, token -> { |
||||||
|
try { |
||||||
|
sendMessage(ServerMessage.assistantMessage(sessionId, token, true)); |
||||||
|
} catch (Exception e) { |
||||||
|
log.error("[{}] Failed to stream token", sessionId, e); |
||||||
|
} |
||||||
|
}); |
||||||
|
|
||||||
|
// 发送最终结果
|
||||||
|
var tracker = agentLoop.getTokenTracker(); |
||||||
|
sendMessage(ServerMessage.resultMessage( |
||||||
|
sessionId, result, 0, |
||||||
|
tracker.getInputTokens(), tracker.getOutputTokens())); |
||||||
|
} catch (Exception e) { |
||||||
|
log.error("[{}] Agent loop execution error", sessionId, e); |
||||||
|
try { |
||||||
|
sendMessage(ServerMessage.error(sessionId, "execution_error", e.getMessage())); |
||||||
|
} catch (Exception ignored) {} |
||||||
|
} finally { |
||||||
|
processing.set(false); |
||||||
|
} |
||||||
|
}); |
||||||
|
} |
||||||
|
|
||||||
|
private void handleControlResponse(String rawJson) throws Exception { |
||||||
|
JsonNode payload = ServerMessage.getPayload(rawJson); |
||||||
|
if (payload == null || !payload.has("request_id") || !payload.has("behavior")) { |
||||||
|
sendMessage(ServerMessage.error(sessionId, "invalid_payload", |
||||||
|
"Missing request_id or behavior in control_response")); |
||||||
|
return; |
||||||
|
} |
||||||
|
|
||||||
|
String requestId = payload.get("request_id").asText(); |
||||||
|
String behavior = payload.get("behavior").asText(); |
||||||
|
|
||||||
|
CompletableFuture<PermissionChoice> future = pendingPermissions.get(requestId); |
||||||
|
if (future != null) { |
||||||
|
PermissionChoice choice = "allow".equals(behavior) |
||||||
|
? PermissionChoice.ALLOW_ONCE |
||||||
|
: PermissionChoice.DENY_ONCE; |
||||||
|
future.complete(choice); |
||||||
|
} else { |
||||||
|
log.warn("[{}] Unknown permission request_id: {}", sessionId, requestId); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
private void handleInterrupt() { |
||||||
|
log.info("[{}] Interrupt signal received", sessionId); |
||||||
|
agentLoop.cancel(); |
||||||
|
// 取消所有挂起的权限请求
|
||||||
|
pendingPermissions.values().forEach(f -> f.complete(PermissionChoice.DENY_ONCE)); |
||||||
|
pendingPermissions.clear(); |
||||||
|
} |
||||||
|
|
||||||
|
private void startKeepAlive() { |
||||||
|
keepAliveTask = keepAliveScheduler.scheduleAtFixedRate(() -> { |
||||||
|
try { |
||||||
|
if (wsSession.isOpen()) { |
||||||
|
sendMessage(ServerMessage.keepAlive(sessionId)); |
||||||
|
} |
||||||
|
} catch (Exception e) { |
||||||
|
log.debug("[{}] Keep-alive failed", sessionId, e); |
||||||
|
} |
||||||
|
}, 30, 30, TimeUnit.SECONDS); |
||||||
|
} |
||||||
|
|
||||||
|
/** |
||||||
|
* 发送 WebSocket 消息(线程安全)。 |
||||||
|
*/ |
||||||
|
private synchronized void sendMessage(String json) throws IOException { |
||||||
|
if (wsSession.isOpen()) { |
||||||
|
wsSession.sendMessage(new TextMessage(json)); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
/** |
||||||
|
* 发送会话初始化消息。 |
||||||
|
*/ |
||||||
|
public void sendInitMessage(String model) throws IOException { |
||||||
|
sendMessage(ServerMessage.systemEvent(sessionId, "init", model, "Session ready")); |
||||||
|
} |
||||||
|
|
||||||
|
/** |
||||||
|
* 关闭会话,清理资源。 |
||||||
|
*/ |
||||||
|
public void close() { |
||||||
|
log.info("[{}] Closing server session", sessionId); |
||||||
|
if (keepAliveTask != null) { |
||||||
|
keepAliveTask.cancel(true); |
||||||
|
} |
||||||
|
keepAliveScheduler.shutdownNow(); |
||||||
|
agentLoop.cancel(); |
||||||
|
pendingPermissions.values().forEach(f -> f.complete(PermissionChoice.DENY_ONCE)); |
||||||
|
pendingPermissions.clear(); |
||||||
|
sessionExecutor.shutdownNow(); |
||||||
|
} |
||||||
|
|
||||||
|
public String getSessionId() { |
||||||
|
return sessionId; |
||||||
|
} |
||||||
|
|
||||||
|
public boolean isProcessing() { |
||||||
|
return processing.get(); |
||||||
|
} |
||||||
|
|
||||||
|
public AgentLoop getAgentLoop() { |
||||||
|
return agentLoop; |
||||||
|
} |
||||||
|
} |
||||||
@ -0,0 +1,49 @@ |
|||||||
|
package com.claudecode.server; |
||||||
|
|
||||||
|
import org.slf4j.Logger; |
||||||
|
import org.slf4j.LoggerFactory; |
||||||
|
import org.springframework.web.socket.config.annotation.WebSocketConfigurer; |
||||||
|
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry; |
||||||
|
import org.springframework.web.socket.server.standard.ServletServerContainerFactoryBean; |
||||||
|
|
||||||
|
/** |
||||||
|
* Server Mode 的 WebSocket 配置 —— 仅在 --server 模式下激活。 |
||||||
|
* <p> |
||||||
|
* 注册 {@link DirectConnectWebSocketHandler} 到 /ws 端点。 |
||||||
|
* <p> |
||||||
|
* 对应 claude-code 的 Server Mode 功能: |
||||||
|
* <ul> |
||||||
|
* <li>WebSocket 端点: ws://localhost:{port}/ws</li>
|
||||||
|
* <li>允许所有来源连接(本地开发用,生产环境应限制)</li> |
||||||
|
* <li>最大消息 1MB</li> |
||||||
|
* </ul> |
||||||
|
*/ |
||||||
|
public class WebSocketServerConfig implements WebSocketConfigurer { |
||||||
|
|
||||||
|
private static final Logger log = LoggerFactory.getLogger(WebSocketServerConfig.class); |
||||||
|
|
||||||
|
private final DirectConnectWebSocketHandler handler; |
||||||
|
|
||||||
|
public WebSocketServerConfig(DirectConnectWebSocketHandler handler) { |
||||||
|
this.handler = handler; |
||||||
|
} |
||||||
|
|
||||||
|
@Override |
||||||
|
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { |
||||||
|
registry.addHandler(handler, "/ws") |
||||||
|
.setAllowedOrigins("*"); |
||||||
|
|
||||||
|
log.info("WebSocket handler registered at /ws"); |
||||||
|
} |
||||||
|
|
||||||
|
/** |
||||||
|
* 配置 WebSocket 容器参数。 |
||||||
|
*/ |
||||||
|
public static ServletServerContainerFactoryBean createWebSocketContainer() { |
||||||
|
ServletServerContainerFactoryBean container = new ServletServerContainerFactoryBean(); |
||||||
|
container.setMaxTextMessageBufferSize(1024 * 1024); // 1MB
|
||||||
|
container.setMaxBinaryMessageBufferSize(1024 * 1024); |
||||||
|
container.setMaxSessionIdleTimeout(300_000L); // 5 分钟空闲超时
|
||||||
|
return container; |
||||||
|
} |
||||||
|
} |
||||||
Loading…
Reference in new issue