diff --git a/pom.xml b/pom.xml index 1aeec7c..69c0dab 100644 --- a/pom.xml +++ b/pom.xml @@ -44,6 +44,12 @@ spring-boot-starter-web + + + org.springframework.boot + spring-boot-starter-websocket + + org.springframework.ai diff --git a/src/main/java/com/claudecode/ClaudeCodeApplication.java b/src/main/java/com/claudecode/ClaudeCodeApplication.java index a3edb82..10702ed 100644 --- a/src/main/java/com/claudecode/ClaudeCodeApplication.java +++ b/src/main/java/com/claudecode/ClaudeCodeApplication.java @@ -1,18 +1,45 @@ package com.claudecode; +import com.claudecode.server.DirectConnectServer; import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.SpringBootApplication; +import java.util.*; + /** * Claude Code Java 版主入口。 *

* 对应 claude-code/src/entrypoints/cli.tsx - * 以 Spring Boot 应用启动,但关闭 Web 服务器(纯 CLI 模式)。 + *

+ * 支持两种启动模式: + *

*/ @SpringBootApplication public class ClaudeCodeApplication { public static void main(String[] args) { - SpringApplication.run(ClaudeCodeApplication.class, args); + SpringApplication app = new SpringApplication(ClaudeCodeApplication.class); + + if (DirectConnectServer.isServerMode(args)) { + // Server Mode: 启用 Web 服务器 + WebSocket + int port = DirectConnectServer.parsePort(args); + + Map serverProps = new HashMap<>(); + serverProps.put("spring.main.web-application-type", "servlet"); + serverProps.put("server.port", port); + serverProps.put("claude-code.server-mode", "true"); + app.setDefaultProperties(serverProps); + + // 将原始参数保存到系统属性,供 ServerModeAutoConfiguration 解析 + System.setProperty("claude-code.server-args", String.join(" ", args)); + + System.out.println("Starting in Server Mode on port " + port + "..."); + } + // CLI 模式使用 application.yml 中的 web-application-type: none + + app.run(args); } } diff --git a/src/main/java/com/claudecode/cli/ClaudeCodeRunner.java b/src/main/java/com/claudecode/cli/ClaudeCodeRunner.java index d37547a..34ceeef 100644 --- a/src/main/java/com/claudecode/cli/ClaudeCodeRunner.java +++ b/src/main/java/com/claudecode/cli/ClaudeCodeRunner.java @@ -1,6 +1,7 @@ package com.claudecode.cli; import com.claudecode.repl.ReplSession; +import com.claudecode.server.DirectConnectServer; import com.claudecode.tui.JinkReplSession; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -10,7 +11,12 @@ import org.springframework.stereotype.Component; /** * 启动编排器 —— 对应 claude-code/src/main.tsx 的初始化逻辑。 *

- * 优先使用 jink TUI 模式,失败时降级到传统 JLine REPL。 + * 支持三种模式: + *

*/ @Component public class ClaudeCodeRunner implements CommandLineRunner { @@ -27,6 +33,14 @@ public class ClaudeCodeRunner implements CommandLineRunner { @Override public void run(String... args) { + // Server Mode: 不启动 TUI,WebSocket 服务器已由 ServerModeAutoConfiguration 启动 + if (DirectConnectServer.isServerMode(args)) { + log.info("Server Mode active — TUI disabled, WebSocket server running"); + // 阻塞主线程,等待 Ctrl+C 或 SIGTERM + waitForShutdown(); + return; + } + log.info("Claude Code (Java) starting..."); // 检查是否强制使用旧模式 @@ -45,4 +59,22 @@ public class ClaudeCodeRunner implements CommandLineRunner { replSession.start(); } } + + /** + * 在 Server Mode 下阻塞主线程直到收到 shutdown 信号。 + */ + private void waitForShutdown() { + Thread shutdownHook = new Thread(() -> { + log.info("Shutdown signal received"); + }); + Runtime.getRuntime().addShutdownHook(shutdownHook); + + try { + // 阻塞直到中断 + Thread.currentThread().join(); + } catch (InterruptedException e) { + log.info("Server mode interrupted"); + Thread.currentThread().interrupt(); + } + } } diff --git a/src/main/java/com/claudecode/server/DirectConnectServer.java b/src/main/java/com/claudecode/server/DirectConnectServer.java new file mode 100644 index 0000000..ac82883 --- /dev/null +++ b/src/main/java/com/claudecode/server/DirectConnectServer.java @@ -0,0 +1,226 @@ +package com.claudecode.server; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.web.socket.config.annotation.EnableWebSocket; + +import java.util.UUID; + +/** + * Server Mode 生命周期管理 —— 对应 claude-code/src/server 的服务端核心。 + *

+ * 管理服务端状态、认证 Token 生成、会话限制等。 + *

+ * 启动流程: + *

    + *
  1. ClaudeCodeApplication 检测 --server 参数
  2. + *
  3. Spring Boot 启用 WebSocket 模式
  4. + *
  5. DirectConnectServer 初始化,生成 auth token
  6. + *
  7. WebSocket 端点开始接受连接
  8. + *
+ * + * 客户端连接方式: + * + */ +public class DirectConnectServer { + + private static final Logger log = LoggerFactory.getLogger(DirectConnectServer.class); + + /** 默认端口 */ + public static final int DEFAULT_PORT = 12321; + + /** 默认最大会话数 */ + public static final int DEFAULT_MAX_SESSIONS = 5; + + private final int port; + private final String authToken; + private final int maxSessions; + private final DirectConnectWebSocketHandler handler; + + private volatile boolean running = false; + + public DirectConnectServer(int port, String authToken, int maxSessions, + DirectConnectWebSocketHandler handler) { + this.port = port; + this.authToken = authToken; + this.maxSessions = maxSessions; + this.handler = handler; + } + + /** + * 生成随机认证 Token。 + */ + public static String generateAuthToken() { + return UUID.randomUUID().toString().replace("-", ""); + } + + /** + * 标记服务器已启动。打印连接信息。 + */ + public void onServerStarted() { + running = true; + printConnectionInfo(); + } + + /** + * 打印连接信息到控制台。 + */ + public void printConnectionInfo() { + String separator = "═".repeat(60); + System.out.println(); + System.out.println("╔" + separator + "╗"); + System.out.println("║ Claude Code Java — Server Mode ║"); + System.out.println("╠" + separator + "╣"); + System.out.printf("║ WebSocket: ws://localhost:%d/ws%-24s║%n", port, ""); + System.out.printf("║ Port: %-48s║%n", port); + if (authToken != null && !authToken.isBlank()) { + System.out.printf("║ Token: %-48s║%n", authToken); + } else { + System.out.printf("║ Auth: %-48s║%n", "disabled (no token)"); + } + System.out.printf("║ Max Sess: %-48s║%n", maxSessions); + System.out.println("╠" + separator + "╣"); + System.out.println("║ Connect with: ║"); + if (authToken != null && !authToken.isBlank()) { + System.out.printf("║ ws://localhost:%d/ws?token=%s ║%n", port, authToken.substring(0, Math.min(12, authToken.length())) + "..."); + } else { + System.out.printf("║ ws://localhost:%d/ws%-38s║%n", port, ""); + } + System.out.println("╚" + separator + "╝"); + System.out.println(); + } + + /** + * 停止服务器,关闭所有会话。 + */ + public void shutdown() { + if (!running) return; + running = false; + log.info("Shutting down server..."); + handler.closeAllSessions(); + log.info("Server stopped. All sessions closed."); + } + + // ==================== 静态工具方法 ==================== + + /** + * 检查命令行参数是否包含 --server 标志。 + */ + public static boolean isServerMode(String[] args) { + if (args == null) return false; + for (String arg : args) { + if ("--server".equals(arg) || arg.startsWith("--server=")) { + return true; + } + } + // 也支持环境变量 + String envMode = System.getenv("CLAUDE_CODE_SERVER_MODE"); + return "true".equalsIgnoreCase(envMode) || "1".equals(envMode); + } + + /** + * 从命令行参数解析端口号。 + */ + public static int parsePort(String[] args) { + if (args != null) { + for (String arg : args) { + if (arg.startsWith("--server-port=")) { + try { + return Integer.parseInt(arg.substring("--server-port=".length())); + } catch (NumberFormatException e) { + log.warn("Invalid port number: {}", arg); + } + } + if (arg.startsWith("--server=")) { + try { + return Integer.parseInt(arg.substring("--server=".length())); + } catch (NumberFormatException e) { + // --server=true 等非数字参数忽略 + } + } + } + } + // 环境变量 + String envPort = System.getenv("CLAUDE_CODE_SERVER_PORT"); + if (envPort != null) { + try { + return Integer.parseInt(envPort); + } catch (NumberFormatException e) { + log.warn("Invalid CLAUDE_CODE_SERVER_PORT: {}", envPort); + } + } + return DEFAULT_PORT; + } + + /** + * 从命令行参数或环境变量获取认证 Token。 + * 如果未指定,生成随机 Token。 + */ + public static String parseAuthToken(String[] args) { + if (args != null) { + for (String arg : args) { + if (arg.startsWith("--server-token=")) { + return arg.substring("--server-token=".length()); + } + } + } + String envToken = System.getenv("CLAUDE_CODE_SERVER_TOKEN"); + if (envToken != null && !envToken.isBlank()) { + return envToken; + } + // 未指定则生成随机 token + return generateAuthToken(); + } + + /** + * 从命令行参数解析最大会话数。 + */ + public static int parseMaxSessions(String[] args) { + if (args != null) { + for (String arg : args) { + if (arg.startsWith("--max-sessions=")) { + try { + return Integer.parseInt(arg.substring("--max-sessions=".length())); + } catch (NumberFormatException e) { + log.warn("Invalid max sessions: {}", arg); + } + } + } + } + String envMax = System.getenv("CLAUDE_CODE_MAX_SESSIONS"); + if (envMax != null) { + try { + return Integer.parseInt(envMax); + } catch (NumberFormatException e) { + log.warn("Invalid CLAUDE_CODE_MAX_SESSIONS: {}", envMax); + } + } + return DEFAULT_MAX_SESSIONS; + } + + // ==================== Getters ==================== + + public int getPort() { + return port; + } + + public String getAuthToken() { + return authToken; + } + + public int getMaxSessions() { + return maxSessions; + } + + public boolean isRunning() { + return running; + } + + public int getActiveSessionCount() { + return handler.getActiveSessionCount(); + } +} diff --git a/src/main/java/com/claudecode/server/DirectConnectWebSocketHandler.java b/src/main/java/com/claudecode/server/DirectConnectWebSocketHandler.java new file mode 100644 index 0000000..1581fe7 --- /dev/null +++ b/src/main/java/com/claudecode/server/DirectConnectWebSocketHandler.java @@ -0,0 +1,180 @@ +package com.claudecode.server; + +import com.claudecode.core.AgentLoop; +import com.claudecode.core.TokenTracker; +import com.claudecode.tool.ToolContext; +import com.claudecode.tool.ToolRegistry; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.handler.TextWebSocketHandler; + +import java.io.IOException; +import java.net.URI; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; + +/** + * WebSocket Handler —— Server Mode 的核心连接处理器。 + *

+ * 对应 claude-code/src/server/directConnectManager.ts 的服务端实现。 + *

+ * 每个 WebSocket 连接创建一个 {@link ServerSession},包含独立的 AgentLoop。 + * 支持 Bearer Token 认证和多会话管理。 + */ +public class DirectConnectWebSocketHandler extends TextWebSocketHandler { + + private static final Logger log = LoggerFactory.getLogger(DirectConnectWebSocketHandler.class); + + private final ChatModel chatModel; + private final ToolRegistry toolRegistry; + private final ToolContext toolContext; + private final String systemPrompt; + private final String authToken; + private final int maxSessions; + private final String model; + + /** 活跃会话:WebSocket sessionId → ServerSession */ + private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + + public DirectConnectWebSocketHandler(ChatModel chatModel, ToolRegistry toolRegistry, + ToolContext toolContext, String systemPrompt, + String authToken, int maxSessions, String model) { + this.chatModel = chatModel; + this.toolRegistry = toolRegistry; + this.toolContext = toolContext; + this.systemPrompt = systemPrompt; + this.authToken = authToken; + this.maxSessions = maxSessions; + this.model = model; + } + + @Override + public void afterConnectionEstablished(WebSocketSession wsSession) throws Exception { + log.info("WebSocket connection attempt from: {}", wsSession.getRemoteAddress()); + + // Bearer Token 认证 + if (authToken != null && !authToken.isBlank()) { + if (!validateAuth(wsSession)) { + wsSession.close(new CloseStatus(4001, "Unauthorized")); + return; + } + } + + // 会话数限制 + if (maxSessions > 0 && sessions.size() >= maxSessions) { + String errorMsg = ServerMessage.error(null, "max_sessions", + "Maximum sessions (" + maxSessions + ") reached"); + wsSession.sendMessage(new TextMessage(errorMsg)); + wsSession.close(new CloseStatus(4002, "Max sessions reached")); + return; + } + + // 创建新的 AgentLoop(每个连接独立实例) + String sessionId = UUID.randomUUID().toString(); + AgentLoop agentLoop = new AgentLoop(chatModel, toolRegistry, toolContext, systemPrompt, new TokenTracker()); + + // 创建 ServerSession + ServerSession serverSession = new ServerSession(sessionId, agentLoop, wsSession); + sessions.put(wsSession.getId(), serverSession); + + // 发送初始化消息 + serverSession.sendInitMessage(model); + + log.info("Server session created: {} (ws: {}, total: {})", + sessionId, wsSession.getId(), sessions.size()); + } + + @Override + protected void handleTextMessage(WebSocketSession wsSession, TextMessage message) throws Exception { + ServerSession session = sessions.get(wsSession.getId()); + if (session == null) { + log.warn("Message from unknown session: {}", wsSession.getId()); + wsSession.close(new CloseStatus(4003, "Unknown session")); + return; + } + + session.handleMessage(message.getPayload()); + } + + @Override + public void afterConnectionClosed(WebSocketSession wsSession, CloseStatus status) throws Exception { + ServerSession session = sessions.remove(wsSession.getId()); + if (session != null) { + session.close(); + log.info("Session closed: {} (status: {}, remaining: {})", + session.getSessionId(), status, sessions.size()); + } + } + + @Override + public void handleTransportError(WebSocketSession wsSession, Throwable exception) throws Exception { + log.error("WebSocket transport error for {}: {}", wsSession.getId(), exception.getMessage()); + ServerSession session = sessions.get(wsSession.getId()); + if (session != null) { + try { + String errorMsg = ServerMessage.error(session.getSessionId(), + "transport_error", exception.getMessage()); + wsSession.sendMessage(new TextMessage(errorMsg)); + } catch (Exception ignored) {} + } + } + + /** + * 验证 WebSocket 连接的 Bearer Token。 + *

+ * 支持两种方式: + *

+ */ + private boolean validateAuth(WebSocketSession wsSession) { + // 方式1: 从 HTTP Header 获取 + var headers = wsSession.getHandshakeHeaders(); + String authHeader = headers.getFirst("Authorization"); + if (authHeader != null && authHeader.startsWith("Bearer ")) { + String token = authHeader.substring(7); + if (authToken.equals(token)) { + return true; + } + } + + // 方式2: 从 Query Parameter 获取 + URI uri = wsSession.getUri(); + if (uri != null && uri.getQuery() != null) { + String query = uri.getQuery(); + for (String param : query.split("&")) { + if (param.startsWith("token=")) { + String token = param.substring(6); + if (authToken.equals(token)) { + return true; + } + } + } + } + + log.warn("Authentication failed for connection from: {}", wsSession.getRemoteAddress()); + return false; + } + + /** + * 关闭所有会话。 + */ + public void closeAllSessions() { + sessions.values().forEach(ServerSession::close); + sessions.clear(); + } + + public int getActiveSessionCount() { + return sessions.size(); + } + + public Map getSessions() { + return Map.copyOf(sessions); + } +} diff --git a/src/main/java/com/claudecode/server/ServerMessage.java b/src/main/java/com/claudecode/server/ServerMessage.java new file mode 100644 index 0000000..fa25dad --- /dev/null +++ b/src/main/java/com/claudecode/server/ServerMessage.java @@ -0,0 +1,220 @@ +package com.claudecode.server; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; + +import java.util.List; +import java.util.Map; +import java.util.UUID; + +/** + * Server Mode 消息协议 —— 对应 claude-code/src/server/types.ts + *

+ * WebSocket 上的 JSON 消息,7 种类型: + *

    + *
  • user — 客户端发送的用户消息
  • + *
  • assistant — 服务端回复的助手消息
  • + *
  • result — 轮次结束的最终结果
  • + *
  • control_request — 权限请求(服务端→客户端)
  • + *
  • control_response — 权限回复(客户端→服务端)
  • + *
  • interrupt — 中断信号
  • + *
  • keep_alive — 心跳
  • + *
+ */ +public class ServerMessage { + + private static final ObjectMapper MAPPER = new ObjectMapper() + .setSerializationInclusion(JsonInclude.Include.NON_NULL); + + // ==================== 消息类型常量 ==================== + + public static final String TYPE_USER = "user"; + public static final String TYPE_ASSISTANT = "assistant"; + public static final String TYPE_RESULT = "result"; + public static final String TYPE_CONTROL_REQUEST = "control_request"; + public static final String TYPE_CONTROL_RESPONSE = "control_response"; + public static final String TYPE_INTERRUPT = "interrupt"; + public static final String TYPE_KEEP_ALIVE = "keep_alive"; + public static final String TYPE_SYSTEM = "system"; + public static final String TYPE_ERROR = "error"; + public static final String TYPE_TOOL_USE = "tool_use"; + + // ==================== 通用消息 ==================== + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record Envelope( + String type, + @JsonProperty("session_id") String sessionId, + Object payload + ) { + public String toJson() throws JsonProcessingException { + return MAPPER.writeValueAsString(this); + } + + public static Envelope fromJson(String json) throws JsonProcessingException { + return MAPPER.readValue(json, Envelope.class); + } + } + + // ==================== 客户端→服务端 ==================== + + /** 用户消息 */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record UserMessage( + String content, + @JsonProperty("parent_tool_use_id") String parentToolUseId + ) {} + + /** 权限回复 */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record ControlResponse( + @JsonProperty("request_id") String requestId, + String behavior, // "allow" or "deny" + String message, + @JsonProperty("updated_input") Map updatedInput + ) {} + + // ==================== 服务端→客户端 ==================== + + /** 助手消息(流式或完整) */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record AssistantPayload( + String text, + String uuid, + boolean streaming, + @JsonProperty("tool_calls") List toolCalls + ) {} + + /** 工具调用信息 */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record ToolCallInfo( + String id, + String name, + @JsonProperty("arguments") String arguments, + String result, + String status // "running", "completed", "error" + ) {} + + /** 轮次结果 */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record ResultPayload( + String text, + @JsonProperty("tool_calls_count") int toolCallsCount, + @JsonProperty("prompt_tokens") long promptTokens, + @JsonProperty("completion_tokens") long completionTokens + ) {} + + /** 权限请求 */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record ControlRequest( + @JsonProperty("request_id") String requestId, + String subtype, // "can_use_tool" + @JsonProperty("tool_name") String toolName, + @JsonProperty("tool_input") String toolInput, + @JsonProperty("activity_description") String activityDescription + ) {} + + /** 系统事件 */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record SystemEvent( + String subtype, // "init", "session_ready", "error" + String model, + String message, + @JsonProperty("session_id") String sessionId + ) {} + + /** 错误 */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record ErrorPayload( + String code, + String message + ) {} + + // ==================== 工厂方法 ==================== + + public static String userMessage(String sessionId, String content) throws JsonProcessingException { + var envelope = new Envelope(TYPE_USER, sessionId, new UserMessage(content, null)); + return envelope.toJson(); + } + + public static String assistantMessage(String sessionId, String text, boolean streaming) throws JsonProcessingException { + var payload = new AssistantPayload(text, UUID.randomUUID().toString(), streaming, null); + var envelope = new Envelope(TYPE_ASSISTANT, sessionId, payload); + return envelope.toJson(); + } + + public static String assistantToolUse(String sessionId, List toolCalls) throws JsonProcessingException { + var payload = new AssistantPayload(null, UUID.randomUUID().toString(), false, toolCalls); + var envelope = new Envelope(TYPE_TOOL_USE, sessionId, payload); + return envelope.toJson(); + } + + public static String resultMessage(String sessionId, String text, int toolCallsCount, + long promptTokens, long completionTokens) throws JsonProcessingException { + var payload = new ResultPayload(text, toolCallsCount, promptTokens, completionTokens); + var envelope = new Envelope(TYPE_RESULT, sessionId, payload); + return envelope.toJson(); + } + + public static String controlRequest(String sessionId, String toolName, String toolInput, + String activityDescription) throws JsonProcessingException { + var payload = new ControlRequest( + UUID.randomUUID().toString(), "can_use_tool", + toolName, toolInput, activityDescription); + var envelope = new Envelope(TYPE_CONTROL_REQUEST, sessionId, payload); + return envelope.toJson(); + } + + public static String controlResponse(String sessionId, String requestId, String behavior, + String message) throws JsonProcessingException { + var payload = new ControlResponse(requestId, behavior, message, null); + var envelope = new Envelope(TYPE_CONTROL_RESPONSE, sessionId, payload); + return envelope.toJson(); + } + + public static String interrupt(String sessionId) throws JsonProcessingException { + var envelope = new Envelope(TYPE_INTERRUPT, sessionId, null); + return envelope.toJson(); + } + + public static String keepAlive(String sessionId) throws JsonProcessingException { + var envelope = new Envelope(TYPE_KEEP_ALIVE, sessionId, null); + return envelope.toJson(); + } + + public static String systemEvent(String sessionId, String subtype, String model, String message) throws JsonProcessingException { + var payload = new SystemEvent(subtype, model, message, sessionId); + var envelope = new Envelope(TYPE_SYSTEM, sessionId, payload); + return envelope.toJson(); + } + + public static String error(String sessionId, String code, String message) throws JsonProcessingException { + var payload = new ErrorPayload(code, message); + var envelope = new Envelope(TYPE_ERROR, sessionId, payload); + return envelope.toJson(); + } + + // ==================== 解析工具 ==================== + + public static String getType(String json) throws JsonProcessingException { + JsonNode node = MAPPER.readTree(json); + return node.has("type") ? node.get("type").asText() : null; + } + + public static JsonNode getPayload(String json) throws JsonProcessingException { + JsonNode node = MAPPER.readTree(json); + return node.get("payload"); + } + + public static String getSessionId(String json) throws JsonProcessingException { + JsonNode node = MAPPER.readTree(json); + return node.has("session_id") ? node.get("session_id").asText() : null; + } + + public static ObjectMapper mapper() { + return MAPPER; + } +} diff --git a/src/main/java/com/claudecode/server/ServerModeAutoConfiguration.java b/src/main/java/com/claudecode/server/ServerModeAutoConfiguration.java new file mode 100644 index 0000000..a193f33 --- /dev/null +++ b/src/main/java/com/claudecode/server/ServerModeAutoConfiguration.java @@ -0,0 +1,93 @@ +package com.claudecode.server; + +import com.claudecode.tool.ToolContext; +import com.claudecode.tool.ToolRegistry; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Lazy; +import org.springframework.web.socket.config.annotation.EnableWebSocket; +import org.springframework.web.socket.config.annotation.WebSocketConfigurer; +import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry; +import org.springframework.web.socket.server.standard.ServletServerContainerFactoryBean; + +/** + * Server Mode 的自动配置 —— 仅在 claude-code.server-mode=true 时激活。 + *

+ * 由 {@link com.claudecode.ClaudeCodeApplication#main(String[])} 在检测到 --server 参数时 + * 设置 {@code claude-code.server-mode=true},触发此配置类加载。 + *

+ * 注册的 Bean: + *

    + *
  • {@link DirectConnectWebSocketHandler} — WebSocket 消息处理
  • + *
  • {@link DirectConnectServer} — 服务器生命周期管理
  • + *
  • {@link WebSocketConfigurer} — WebSocket 端点注册
  • + *
+ */ +@Configuration +@ConditionalOnProperty(name = "claude-code.server-mode", havingValue = "true") +@EnableWebSocket +public class ServerModeAutoConfiguration implements WebSocketConfigurer { + + private static final Logger log = LoggerFactory.getLogger(ServerModeAutoConfiguration.class); + + @Autowired @Lazy + private DirectConnectWebSocketHandler wsHandler; + + @Bean + public DirectConnectWebSocketHandler directConnectWebSocketHandler( + ChatModel activeChatModel, + ToolRegistry toolRegistry, + ToolContext toolContext, + String systemPrompt) { + + String[] args = getApplicationArgs(); + String authToken = DirectConnectServer.parseAuthToken(args); + int maxSessions = DirectConnectServer.parseMaxSessions(args); + String model = System.getenv("AI_MODEL") != null + ? System.getenv("AI_MODEL") : "claude-sonnet-4-20250514"; + + log.info("Creating DirectConnect WebSocket handler (maxSessions={}, auth={})", + maxSessions, authToken != null && !authToken.isBlank() ? "enabled" : "disabled"); + + return new DirectConnectWebSocketHandler( + activeChatModel, toolRegistry, toolContext, + systemPrompt, authToken, maxSessions, model); + } + + @Bean + public DirectConnectServer directConnectServer(DirectConnectWebSocketHandler handler) { + String[] args = getApplicationArgs(); + int port = DirectConnectServer.parsePort(args); + String authToken = DirectConnectServer.parseAuthToken(args); + int maxSessions = DirectConnectServer.parseMaxSessions(args); + + DirectConnectServer server = new DirectConnectServer(port, authToken, maxSessions, handler); + server.onServerStarted(); + return server; + } + + @Bean + public ServletServerContainerFactoryBean createWebSocketContainer() { + return WebSocketServerConfig.createWebSocketContainer(); + } + + @Override + public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { + registry.addHandler(wsHandler, "/ws") + .setAllowedOrigins("*"); + log.info("WebSocket handler registered at /ws endpoint"); + } + + private String[] getApplicationArgs() { + String argsStr = System.getProperty("claude-code.server-args"); + if (argsStr != null && !argsStr.isBlank()) { + return argsStr.split("\\s+"); + } + return new String[0]; + } +} diff --git a/src/main/java/com/claudecode/server/ServerSession.java b/src/main/java/com/claudecode/server/ServerSession.java new file mode 100644 index 0000000..7877ba7 --- /dev/null +++ b/src/main/java/com/claudecode/server/ServerSession.java @@ -0,0 +1,300 @@ +package com.claudecode.server; + +import com.claudecode.core.AgentLoop; +import com.claudecode.permission.PermissionTypes.PermissionChoice; +import com.fasterxml.jackson.databind.JsonNode; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketSession; + +import java.io.IOException; +import java.util.UUID; +import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * 单个 WebSocket 客户端的会话管理 —— 对应 TS 端 DirectConnectSessionManager 的服务端对应。 + *

+ * 每个 ServerSession 包装一个 AgentLoop 实例, + * 将 WebSocket 消息转化为 AgentLoop 的调用,并将结果回传。 + *

+ * 权限请求通过 WebSocket 转发给客户端: + *

    + *
  1. AgentLoop 触发 onPermissionRequest 回调
  2. + *
  3. ServerSession 发送 control_request 消息给客户端
  4. + *
  5. 客户端回复 control_response 消息
  6. + *
  7. ServerSession 将结果返回给 AgentLoop
  8. + *
+ */ +public class ServerSession { + + private static final Logger log = LoggerFactory.getLogger(ServerSession.class); + + private final String sessionId; + private final AgentLoop agentLoop; + private final WebSocketSession wsSession; + private final AtomicBoolean processing = new AtomicBoolean(false); + + /** 权限请求的异步等待队列:requestId → CompletableFuture */ + private final ConcurrentHashMap> pendingPermissions + = new ConcurrentHashMap<>(); + + /** 会话线程池(处理用户消息,每个会话一个线程) */ + private final ExecutorService sessionExecutor = Executors.newSingleThreadExecutor( + Thread.ofVirtual().name("server-session-", 0).factory() + ); + + /** Keep-alive 调度器 */ + private final ScheduledExecutorService keepAliveScheduler = Executors.newSingleThreadScheduledExecutor( + Thread.ofVirtual().name("keep-alive-", 0).factory() + ); + + private ScheduledFuture keepAliveTask; + + public ServerSession(String sessionId, AgentLoop agentLoop, WebSocketSession wsSession) { + this.sessionId = sessionId; + this.agentLoop = agentLoop; + this.wsSession = wsSession; + + // 注册 AgentLoop 回调:将事件转发到 WebSocket + setupAgentCallbacks(); + startKeepAlive(); + } + + private void setupAgentCallbacks() { + // 助手文本回调 → 发送 assistant 消息 + agentLoop.setOnAssistantMessage(text -> { + try { + sendMessage(ServerMessage.assistantMessage(sessionId, text, false)); + } catch (Exception e) { + log.error("[{}] Failed to send assistant message", sessionId, e); + } + }); + + // 流式输出回调 → 发送流式 assistant 消息(使用 onStreamStart) + agentLoop.setOnStreamStart(() -> { + // 不需要特殊处理,流式 token 通过 streaming consumer 发送 + }); + + // 工具事件回调 → 发送 tool_use 消息 + agentLoop.setOnToolEvent(event -> { + try { + var toolInfo = new ServerMessage.ToolCallInfo( + null, event.toolName(), event.arguments(), event.result(), + switch (event.phase()) { + case START -> "running"; + case END -> "completed"; + case PROGRESS -> "running"; + } + ); + sendMessage(ServerMessage.assistantToolUse(sessionId, java.util.List.of(toolInfo))); + } catch (Exception e) { + log.error("[{}] Failed to send tool event", sessionId, e); + } + }); + + // 权限请求回调 → 转发到 WebSocket 客户端 + agentLoop.setOnPermissionRequest(req -> { + try { + return forwardPermissionRequest(req); + } catch (Exception e) { + log.error("[{}] Permission request failed", sessionId, e); + return PermissionChoice.DENY_ONCE; + } + }); + + // Thinking 内容 → 发送系统消息 + agentLoop.setOnThinkingContent(thinking -> { + try { + sendMessage(ServerMessage.systemEvent(sessionId, "thinking", null, thinking)); + } catch (Exception e) { + log.debug("[{}] Failed to send thinking content", sessionId, e); + } + }); + } + + /** + * 将权限请求转发到 WebSocket 客户端,同步等待响应。 + */ + private PermissionChoice forwardPermissionRequest(AgentLoop.PermissionRequest req) { + String requestId = UUID.randomUUID().toString(); + CompletableFuture future = new CompletableFuture<>(); + pendingPermissions.put(requestId, future); + + try { + // 发送 control_request + String msg = ServerMessage.controlRequest( + sessionId, req.toolName(), req.arguments(), req.activityDescription()); + sendMessage(msg); + + // 等待客户端响应(超时 60 秒) + return future.get(60, TimeUnit.SECONDS); + } catch (TimeoutException e) { + log.warn("[{}] Permission request timed out for {}", sessionId, req.toolName()); + return PermissionChoice.DENY_ONCE; + } catch (Exception e) { + log.error("[{}] Permission request error", sessionId, e); + return PermissionChoice.DENY_ONCE; + } finally { + pendingPermissions.remove(requestId); + } + } + + /** + * 处理客户端发来的消息。 + */ + public void handleMessage(String rawJson) { + try { + String type = ServerMessage.getType(rawJson); + if (type == null) { + sendMessage(ServerMessage.error(sessionId, "invalid_message", "Missing message type")); + return; + } + + switch (type) { + case ServerMessage.TYPE_USER -> handleUserMessage(rawJson); + case ServerMessage.TYPE_CONTROL_RESPONSE -> handleControlResponse(rawJson); + case ServerMessage.TYPE_INTERRUPT -> handleInterrupt(); + case ServerMessage.TYPE_KEEP_ALIVE -> {} // 忽略 + default -> sendMessage(ServerMessage.error(sessionId, "unknown_type", + "Unknown message type: " + type)); + } + } catch (Exception e) { + log.error("[{}] Message handling error", sessionId, e); + try { + sendMessage(ServerMessage.error(sessionId, "internal_error", e.getMessage())); + } catch (Exception ignored) {} + } + } + + private void handleUserMessage(String rawJson) throws Exception { + if (processing.get()) { + sendMessage(ServerMessage.error(sessionId, "busy", + "Session is currently processing a message. Send interrupt first.")); + return; + } + + JsonNode payload = ServerMessage.getPayload(rawJson); + if (payload == null || !payload.has("content")) { + sendMessage(ServerMessage.error(sessionId, "invalid_payload", "Missing content in user message")); + return; + } + + String content = payload.get("content").asText(); + processing.set(true); + + // 在虚拟线程中异步执行 AgentLoop + sessionExecutor.submit(() -> { + try { + // 使用流式模式,将每个 token 实时转发 + String result = agentLoop.runStreaming(content, token -> { + try { + sendMessage(ServerMessage.assistantMessage(sessionId, token, true)); + } catch (Exception e) { + log.error("[{}] Failed to stream token", sessionId, e); + } + }); + + // 发送最终结果 + var tracker = agentLoop.getTokenTracker(); + sendMessage(ServerMessage.resultMessage( + sessionId, result, 0, + tracker.getInputTokens(), tracker.getOutputTokens())); + } catch (Exception e) { + log.error("[{}] Agent loop execution error", sessionId, e); + try { + sendMessage(ServerMessage.error(sessionId, "execution_error", e.getMessage())); + } catch (Exception ignored) {} + } finally { + processing.set(false); + } + }); + } + + private void handleControlResponse(String rawJson) throws Exception { + JsonNode payload = ServerMessage.getPayload(rawJson); + if (payload == null || !payload.has("request_id") || !payload.has("behavior")) { + sendMessage(ServerMessage.error(sessionId, "invalid_payload", + "Missing request_id or behavior in control_response")); + return; + } + + String requestId = payload.get("request_id").asText(); + String behavior = payload.get("behavior").asText(); + + CompletableFuture future = pendingPermissions.get(requestId); + if (future != null) { + PermissionChoice choice = "allow".equals(behavior) + ? PermissionChoice.ALLOW_ONCE + : PermissionChoice.DENY_ONCE; + future.complete(choice); + } else { + log.warn("[{}] Unknown permission request_id: {}", sessionId, requestId); + } + } + + private void handleInterrupt() { + log.info("[{}] Interrupt signal received", sessionId); + agentLoop.cancel(); + // 取消所有挂起的权限请求 + pendingPermissions.values().forEach(f -> f.complete(PermissionChoice.DENY_ONCE)); + pendingPermissions.clear(); + } + + private void startKeepAlive() { + keepAliveTask = keepAliveScheduler.scheduleAtFixedRate(() -> { + try { + if (wsSession.isOpen()) { + sendMessage(ServerMessage.keepAlive(sessionId)); + } + } catch (Exception e) { + log.debug("[{}] Keep-alive failed", sessionId, e); + } + }, 30, 30, TimeUnit.SECONDS); + } + + /** + * 发送 WebSocket 消息(线程安全)。 + */ + private synchronized void sendMessage(String json) throws IOException { + if (wsSession.isOpen()) { + wsSession.sendMessage(new TextMessage(json)); + } + } + + /** + * 发送会话初始化消息。 + */ + public void sendInitMessage(String model) throws IOException { + sendMessage(ServerMessage.systemEvent(sessionId, "init", model, "Session ready")); + } + + /** + * 关闭会话,清理资源。 + */ + public void close() { + log.info("[{}] Closing server session", sessionId); + if (keepAliveTask != null) { + keepAliveTask.cancel(true); + } + keepAliveScheduler.shutdownNow(); + agentLoop.cancel(); + pendingPermissions.values().forEach(f -> f.complete(PermissionChoice.DENY_ONCE)); + pendingPermissions.clear(); + sessionExecutor.shutdownNow(); + } + + public String getSessionId() { + return sessionId; + } + + public boolean isProcessing() { + return processing.get(); + } + + public AgentLoop getAgentLoop() { + return agentLoop; + } +} diff --git a/src/main/java/com/claudecode/server/WebSocketServerConfig.java b/src/main/java/com/claudecode/server/WebSocketServerConfig.java new file mode 100644 index 0000000..d5fdd37 --- /dev/null +++ b/src/main/java/com/claudecode/server/WebSocketServerConfig.java @@ -0,0 +1,49 @@ +package com.claudecode.server; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.web.socket.config.annotation.WebSocketConfigurer; +import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry; +import org.springframework.web.socket.server.standard.ServletServerContainerFactoryBean; + +/** + * Server Mode 的 WebSocket 配置 —— 仅在 --server 模式下激活。 + *

+ * 注册 {@link DirectConnectWebSocketHandler} 到 /ws 端点。 + *

+ * 对应 claude-code 的 Server Mode 功能: + *

    + *
  • WebSocket 端点: ws://localhost:{port}/ws
  • + *
  • 允许所有来源连接(本地开发用,生产环境应限制)
  • + *
  • 最大消息 1MB
  • + *
+ */ +public class WebSocketServerConfig implements WebSocketConfigurer { + + private static final Logger log = LoggerFactory.getLogger(WebSocketServerConfig.class); + + private final DirectConnectWebSocketHandler handler; + + public WebSocketServerConfig(DirectConnectWebSocketHandler handler) { + this.handler = handler; + } + + @Override + public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { + registry.addHandler(handler, "/ws") + .setAllowedOrigins("*"); + + log.info("WebSocket handler registered at /ws"); + } + + /** + * 配置 WebSocket 容器参数。 + */ + public static ServletServerContainerFactoryBean createWebSocketContainer() { + ServletServerContainerFactoryBean container = new ServletServerContainerFactoryBean(); + container.setMaxTextMessageBufferSize(1024 * 1024); // 1MB + container.setMaxBinaryMessageBufferSize(1024 * 1024); + container.setMaxSessionIdleTimeout(300_000L); // 5 分钟空闲超时 + return container; + } +}