feat: Server Mode with WebSocket direct connect (Phase 3A)

- DirectConnectServer: lifecycle management, auth token, port config
- DirectConnectWebSocketHandler: WebSocket handler with session management
- ServerSession: per-connection AgentLoop with permission forwarding
- ServerMessage: 7-type JSON protocol (user/assistant/result/control/interrupt/keep_alive/error)
- ServerModeAutoConfiguration: conditional Spring config with @EnableWebSocket
- WebSocketServerConfig: endpoint registration and container settings
- ClaudeCodeApplication: --server flag detection, web-application-type override
- ClaudeCodeRunner: skip TUI in server mode, block on shutdown signal
- pom.xml: added spring-boot-starter-websocket dependency

Server mode: --server [--server-port=12321] [--server-token=xxx]
Env vars: CLAUDE_CODE_SERVER_MODE, CLAUDE_CODE_SERVER_PORT, CLAUDE_CODE_SERVER_TOKEN

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
pull/1/head
abel533 1 month ago
parent a926b31e34
commit b98675f643
  1. 6
      pom.xml
  2. 31
      src/main/java/com/claudecode/ClaudeCodeApplication.java
  3. 34
      src/main/java/com/claudecode/cli/ClaudeCodeRunner.java
  4. 226
      src/main/java/com/claudecode/server/DirectConnectServer.java
  5. 180
      src/main/java/com/claudecode/server/DirectConnectWebSocketHandler.java
  6. 220
      src/main/java/com/claudecode/server/ServerMessage.java
  7. 93
      src/main/java/com/claudecode/server/ServerModeAutoConfiguration.java
  8. 300
      src/main/java/com/claudecode/server/ServerSession.java
  9. 49
      src/main/java/com/claudecode/server/WebSocketServerConfig.java

@ -44,6 +44,12 @@
<artifactId>spring-boot-starter-web</artifactId> <artifactId>spring-boot-starter-web</artifactId>
</dependency> </dependency>
<!-- Spring Boot WebSocket(Server Mode 使用) -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-websocket</artifactId>
</dependency>
<!-- Spring AI Anthropic(Claude 模型调用) --> <!-- Spring AI Anthropic(Claude 模型调用) -->
<dependency> <dependency>
<groupId>org.springframework.ai</groupId> <groupId>org.springframework.ai</groupId>

@ -1,18 +1,45 @@
package com.claudecode; package com.claudecode;
import com.claudecode.server.DirectConnectServer;
import org.springframework.boot.SpringApplication; import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.boot.autoconfigure.SpringBootApplication;
import java.util.*;
/** /**
* Claude Code Java 版主入口 * Claude Code Java 版主入口
* <p> * <p>
* 对应 claude-code/src/entrypoints/cli.tsx * 对应 claude-code/src/entrypoints/cli.tsx
* Spring Boot 应用启动但关闭 Web 服务器 CLI 模式 * <p>
* 支持两种启动模式
* <ul>
* <li>CLI 模式默认 关闭 Web 服务器启动 TUI 交互</li>
* <li>Server 模式--server 启动 WebSocket 服务器 TUI</li>
* </ul>
*/ */
@SpringBootApplication @SpringBootApplication
public class ClaudeCodeApplication { public class ClaudeCodeApplication {
public static void main(String[] args) { public static void main(String[] args) {
SpringApplication.run(ClaudeCodeApplication.class, args); SpringApplication app = new SpringApplication(ClaudeCodeApplication.class);
if (DirectConnectServer.isServerMode(args)) {
// Server Mode: 启用 Web 服务器 + WebSocket
int port = DirectConnectServer.parsePort(args);
Map<String, Object> serverProps = new HashMap<>();
serverProps.put("spring.main.web-application-type", "servlet");
serverProps.put("server.port", port);
serverProps.put("claude-code.server-mode", "true");
app.setDefaultProperties(serverProps);
// 将原始参数保存到系统属性,供 ServerModeAutoConfiguration 解析
System.setProperty("claude-code.server-args", String.join(" ", args));
System.out.println("Starting in Server Mode on port " + port + "...");
}
// CLI 模式使用 application.yml 中的 web-application-type: none
app.run(args);
} }
} }

@ -1,6 +1,7 @@
package com.claudecode.cli; package com.claudecode.cli;
import com.claudecode.repl.ReplSession; import com.claudecode.repl.ReplSession;
import com.claudecode.server.DirectConnectServer;
import com.claudecode.tui.JinkReplSession; import com.claudecode.tui.JinkReplSession;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -10,7 +11,12 @@ import org.springframework.stereotype.Component;
/** /**
* 启动编排器 对应 claude-code/src/main.tsx 的初始化逻辑 * 启动编排器 对应 claude-code/src/main.tsx 的初始化逻辑
* <p> * <p>
* 优先使用 jink TUI 模式失败时降级到传统 JLine REPL * 支持三种模式
* <ul>
* <li>Server 模式 (--server) 不启动 TUIWebSocket 服务器由 Spring 自动配置</li>
* <li>Jink TUI 模式默认 全屏终端 UI</li>
* <li>Legacy REPL 模式降级或 CLAUDE_CODE_TUI=legacy</li>
* </ul>
*/ */
@Component @Component
public class ClaudeCodeRunner implements CommandLineRunner { public class ClaudeCodeRunner implements CommandLineRunner {
@ -27,6 +33,14 @@ public class ClaudeCodeRunner implements CommandLineRunner {
@Override @Override
public void run(String... args) { public void run(String... args) {
// Server Mode: 不启动 TUI,WebSocket 服务器已由 ServerModeAutoConfiguration 启动
if (DirectConnectServer.isServerMode(args)) {
log.info("Server Mode active — TUI disabled, WebSocket server running");
// 阻塞主线程,等待 Ctrl+C 或 SIGTERM
waitForShutdown();
return;
}
log.info("Claude Code (Java) starting..."); log.info("Claude Code (Java) starting...");
// 检查是否强制使用旧模式 // 检查是否强制使用旧模式
@ -45,4 +59,22 @@ public class ClaudeCodeRunner implements CommandLineRunner {
replSession.start(); replSession.start();
} }
} }
/**
* Server Mode 下阻塞主线程直到收到 shutdown 信号
*/
private void waitForShutdown() {
Thread shutdownHook = new Thread(() -> {
log.info("Shutdown signal received");
});
Runtime.getRuntime().addShutdownHook(shutdownHook);
try {
// 阻塞直到中断
Thread.currentThread().join();
} catch (InterruptedException e) {
log.info("Server mode interrupted");
Thread.currentThread().interrupt();
}
}
} }

@ -0,0 +1,226 @@
package com.claudecode.server;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import java.util.UUID;
/**
* Server Mode 生命周期管理 对应 claude-code/src/server 的服务端核心
* <p>
* 管理服务端状态认证 Token 生成会话限制等
* <p>
* 启动流程
* <ol>
* <li>ClaudeCodeApplication 检测 --server 参数</li>
* <li>Spring Boot 启用 WebSocket 模式</li>
* <li>DirectConnectServer 初始化生成 auth token</li>
* <li>WebSocket 端点开始接受连接</li>
* </ol>
*
* 客户端连接方式
* <ul>
* <li>WebSocket: ws://localhost:{port}/ws?token={auth_token}</li>
* <li>HTTP Header: Authorization: Bearer {auth_token}</li>
* </ul>
*/
public class DirectConnectServer {
private static final Logger log = LoggerFactory.getLogger(DirectConnectServer.class);
/** 默认端口 */
public static final int DEFAULT_PORT = 12321;
/** 默认最大会话数 */
public static final int DEFAULT_MAX_SESSIONS = 5;
private final int port;
private final String authToken;
private final int maxSessions;
private final DirectConnectWebSocketHandler handler;
private volatile boolean running = false;
public DirectConnectServer(int port, String authToken, int maxSessions,
DirectConnectWebSocketHandler handler) {
this.port = port;
this.authToken = authToken;
this.maxSessions = maxSessions;
this.handler = handler;
}
/**
* 生成随机认证 Token
*/
public static String generateAuthToken() {
return UUID.randomUUID().toString().replace("-", "");
}
/**
* 标记服务器已启动打印连接信息
*/
public void onServerStarted() {
running = true;
printConnectionInfo();
}
/**
* 打印连接信息到控制台
*/
public void printConnectionInfo() {
String separator = "═".repeat(60);
System.out.println();
System.out.println("╔" + separator + "╗");
System.out.println("║ Claude Code Java — Server Mode ║");
System.out.println("╠" + separator + "╣");
System.out.printf("║ WebSocket: ws://localhost:%d/ws%-24s║%n", port, "");
System.out.printf("║ Port: %-48s║%n", port);
if (authToken != null && !authToken.isBlank()) {
System.out.printf("║ Token: %-48s║%n", authToken);
} else {
System.out.printf("║ Auth: %-48s║%n", "disabled (no token)");
}
System.out.printf("║ Max Sess: %-48s║%n", maxSessions);
System.out.println("╠" + separator + "╣");
System.out.println("║ Connect with: ║");
if (authToken != null && !authToken.isBlank()) {
System.out.printf("║ ws://localhost:%d/ws?token=%s ║%n", port, authToken.substring(0, Math.min(12, authToken.length())) + "...");
} else {
System.out.printf("║ ws://localhost:%d/ws%-38s║%n", port, "");
}
System.out.println("╚" + separator + "╝");
System.out.println();
}
/**
* 停止服务器关闭所有会话
*/
public void shutdown() {
if (!running) return;
running = false;
log.info("Shutting down server...");
handler.closeAllSessions();
log.info("Server stopped. All sessions closed.");
}
// ==================== 静态工具方法 ====================
/**
* 检查命令行参数是否包含 --server 标志
*/
public static boolean isServerMode(String[] args) {
if (args == null) return false;
for (String arg : args) {
if ("--server".equals(arg) || arg.startsWith("--server=")) {
return true;
}
}
// 也支持环境变量
String envMode = System.getenv("CLAUDE_CODE_SERVER_MODE");
return "true".equalsIgnoreCase(envMode) || "1".equals(envMode);
}
/**
* 从命令行参数解析端口号
*/
public static int parsePort(String[] args) {
if (args != null) {
for (String arg : args) {
if (arg.startsWith("--server-port=")) {
try {
return Integer.parseInt(arg.substring("--server-port=".length()));
} catch (NumberFormatException e) {
log.warn("Invalid port number: {}", arg);
}
}
if (arg.startsWith("--server=")) {
try {
return Integer.parseInt(arg.substring("--server=".length()));
} catch (NumberFormatException e) {
// --server=true 等非数字参数忽略
}
}
}
}
// 环境变量
String envPort = System.getenv("CLAUDE_CODE_SERVER_PORT");
if (envPort != null) {
try {
return Integer.parseInt(envPort);
} catch (NumberFormatException e) {
log.warn("Invalid CLAUDE_CODE_SERVER_PORT: {}", envPort);
}
}
return DEFAULT_PORT;
}
/**
* 从命令行参数或环境变量获取认证 Token
* 如果未指定生成随机 Token
*/
public static String parseAuthToken(String[] args) {
if (args != null) {
for (String arg : args) {
if (arg.startsWith("--server-token=")) {
return arg.substring("--server-token=".length());
}
}
}
String envToken = System.getenv("CLAUDE_CODE_SERVER_TOKEN");
if (envToken != null && !envToken.isBlank()) {
return envToken;
}
// 未指定则生成随机 token
return generateAuthToken();
}
/**
* 从命令行参数解析最大会话数
*/
public static int parseMaxSessions(String[] args) {
if (args != null) {
for (String arg : args) {
if (arg.startsWith("--max-sessions=")) {
try {
return Integer.parseInt(arg.substring("--max-sessions=".length()));
} catch (NumberFormatException e) {
log.warn("Invalid max sessions: {}", arg);
}
}
}
}
String envMax = System.getenv("CLAUDE_CODE_MAX_SESSIONS");
if (envMax != null) {
try {
return Integer.parseInt(envMax);
} catch (NumberFormatException e) {
log.warn("Invalid CLAUDE_CODE_MAX_SESSIONS: {}", envMax);
}
}
return DEFAULT_MAX_SESSIONS;
}
// ==================== Getters ====================
public int getPort() {
return port;
}
public String getAuthToken() {
return authToken;
}
public int getMaxSessions() {
return maxSessions;
}
public boolean isRunning() {
return running;
}
public int getActiveSessionCount() {
return handler.getActiveSessionCount();
}
}

@ -0,0 +1,180 @@
package com.claudecode.server;
import com.claudecode.core.AgentLoop;
import com.claudecode.core.TokenTracker;
import com.claudecode.tool.ToolContext;
import com.claudecode.tool.ToolRegistry;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import java.io.IOException;
import java.net.URI;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
/**
* WebSocket Handler Server Mode 的核心连接处理器
* <p>
* 对应 claude-code/src/server/directConnectManager.ts 的服务端实现
* <p>
* 每个 WebSocket 连接创建一个 {@link ServerSession}包含独立的 AgentLoop
* 支持 Bearer Token 认证和多会话管理
*/
public class DirectConnectWebSocketHandler extends TextWebSocketHandler {
private static final Logger log = LoggerFactory.getLogger(DirectConnectWebSocketHandler.class);
private final ChatModel chatModel;
private final ToolRegistry toolRegistry;
private final ToolContext toolContext;
private final String systemPrompt;
private final String authToken;
private final int maxSessions;
private final String model;
/** 活跃会话:WebSocket sessionId → ServerSession */
private final ConcurrentHashMap<String, ServerSession> sessions = new ConcurrentHashMap<>();
public DirectConnectWebSocketHandler(ChatModel chatModel, ToolRegistry toolRegistry,
ToolContext toolContext, String systemPrompt,
String authToken, int maxSessions, String model) {
this.chatModel = chatModel;
this.toolRegistry = toolRegistry;
this.toolContext = toolContext;
this.systemPrompt = systemPrompt;
this.authToken = authToken;
this.maxSessions = maxSessions;
this.model = model;
}
@Override
public void afterConnectionEstablished(WebSocketSession wsSession) throws Exception {
log.info("WebSocket connection attempt from: {}", wsSession.getRemoteAddress());
// Bearer Token 认证
if (authToken != null && !authToken.isBlank()) {
if (!validateAuth(wsSession)) {
wsSession.close(new CloseStatus(4001, "Unauthorized"));
return;
}
}
// 会话数限制
if (maxSessions > 0 && sessions.size() >= maxSessions) {
String errorMsg = ServerMessage.error(null, "max_sessions",
"Maximum sessions (" + maxSessions + ") reached");
wsSession.sendMessage(new TextMessage(errorMsg));
wsSession.close(new CloseStatus(4002, "Max sessions reached"));
return;
}
// 创建新的 AgentLoop(每个连接独立实例)
String sessionId = UUID.randomUUID().toString();
AgentLoop agentLoop = new AgentLoop(chatModel, toolRegistry, toolContext, systemPrompt, new TokenTracker());
// 创建 ServerSession
ServerSession serverSession = new ServerSession(sessionId, agentLoop, wsSession);
sessions.put(wsSession.getId(), serverSession);
// 发送初始化消息
serverSession.sendInitMessage(model);
log.info("Server session created: {} (ws: {}, total: {})",
sessionId, wsSession.getId(), sessions.size());
}
@Override
protected void handleTextMessage(WebSocketSession wsSession, TextMessage message) throws Exception {
ServerSession session = sessions.get(wsSession.getId());
if (session == null) {
log.warn("Message from unknown session: {}", wsSession.getId());
wsSession.close(new CloseStatus(4003, "Unknown session"));
return;
}
session.handleMessage(message.getPayload());
}
@Override
public void afterConnectionClosed(WebSocketSession wsSession, CloseStatus status) throws Exception {
ServerSession session = sessions.remove(wsSession.getId());
if (session != null) {
session.close();
log.info("Session closed: {} (status: {}, remaining: {})",
session.getSessionId(), status, sessions.size());
}
}
@Override
public void handleTransportError(WebSocketSession wsSession, Throwable exception) throws Exception {
log.error("WebSocket transport error for {}: {}", wsSession.getId(), exception.getMessage());
ServerSession session = sessions.get(wsSession.getId());
if (session != null) {
try {
String errorMsg = ServerMessage.error(session.getSessionId(),
"transport_error", exception.getMessage());
wsSession.sendMessage(new TextMessage(errorMsg));
} catch (Exception ignored) {}
}
}
/**
* 验证 WebSocket 连接的 Bearer Token
* <p>
* 支持两种方式
* <ul>
* <li>HTTP Header: {@code Authorization: Bearer <token>}</li>
* <li>Query Parameter: {@code ?token=<token>}</li>
* </ul>
*/
private boolean validateAuth(WebSocketSession wsSession) {
// 方式1: 从 HTTP Header 获取
var headers = wsSession.getHandshakeHeaders();
String authHeader = headers.getFirst("Authorization");
if (authHeader != null && authHeader.startsWith("Bearer ")) {
String token = authHeader.substring(7);
if (authToken.equals(token)) {
return true;
}
}
// 方式2: 从 Query Parameter 获取
URI uri = wsSession.getUri();
if (uri != null && uri.getQuery() != null) {
String query = uri.getQuery();
for (String param : query.split("&")) {
if (param.startsWith("token=")) {
String token = param.substring(6);
if (authToken.equals(token)) {
return true;
}
}
}
}
log.warn("Authentication failed for connection from: {}", wsSession.getRemoteAddress());
return false;
}
/**
* 关闭所有会话
*/
public void closeAllSessions() {
sessions.values().forEach(ServerSession::close);
sessions.clear();
}
public int getActiveSessionCount() {
return sessions.size();
}
public Map<String, ServerSession> getSessions() {
return Map.copyOf(sessions);
}
}

@ -0,0 +1,220 @@
package com.claudecode.server;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.List;
import java.util.Map;
import java.util.UUID;
/**
* Server Mode 消息协议 对应 claude-code/src/server/types.ts
* <p>
* WebSocket 上的 JSON 消息7 种类型
* <ul>
* <li><b>user</b> 客户端发送的用户消息</li>
* <li><b>assistant</b> 服务端回复的助手消息</li>
* <li><b>result</b> 轮次结束的最终结果</li>
* <li><b>control_request</b> 权限请求服务端客户端</li>
* <li><b>control_response</b> 权限回复客户端服务端</li>
* <li><b>interrupt</b> 中断信号</li>
* <li><b>keep_alive</b> 心跳</li>
* </ul>
*/
public class ServerMessage {
private static final ObjectMapper MAPPER = new ObjectMapper()
.setSerializationInclusion(JsonInclude.Include.NON_NULL);
// ==================== 消息类型常量 ====================
public static final String TYPE_USER = "user";
public static final String TYPE_ASSISTANT = "assistant";
public static final String TYPE_RESULT = "result";
public static final String TYPE_CONTROL_REQUEST = "control_request";
public static final String TYPE_CONTROL_RESPONSE = "control_response";
public static final String TYPE_INTERRUPT = "interrupt";
public static final String TYPE_KEEP_ALIVE = "keep_alive";
public static final String TYPE_SYSTEM = "system";
public static final String TYPE_ERROR = "error";
public static final String TYPE_TOOL_USE = "tool_use";
// ==================== 通用消息 ====================
@JsonInclude(JsonInclude.Include.NON_NULL)
public record Envelope(
String type,
@JsonProperty("session_id") String sessionId,
Object payload
) {
public String toJson() throws JsonProcessingException {
return MAPPER.writeValueAsString(this);
}
public static Envelope fromJson(String json) throws JsonProcessingException {
return MAPPER.readValue(json, Envelope.class);
}
}
// ==================== 客户端→服务端 ====================
/** 用户消息 */
@JsonInclude(JsonInclude.Include.NON_NULL)
public record UserMessage(
String content,
@JsonProperty("parent_tool_use_id") String parentToolUseId
) {}
/** 权限回复 */
@JsonInclude(JsonInclude.Include.NON_NULL)
public record ControlResponse(
@JsonProperty("request_id") String requestId,
String behavior, // "allow" or "deny"
String message,
@JsonProperty("updated_input") Map<String, Object> updatedInput
) {}
// ==================== 服务端→客户端 ====================
/** 助手消息(流式或完整) */
@JsonInclude(JsonInclude.Include.NON_NULL)
public record AssistantPayload(
String text,
String uuid,
boolean streaming,
@JsonProperty("tool_calls") List<ToolCallInfo> toolCalls
) {}
/** 工具调用信息 */
@JsonInclude(JsonInclude.Include.NON_NULL)
public record ToolCallInfo(
String id,
String name,
@JsonProperty("arguments") String arguments,
String result,
String status // "running", "completed", "error"
) {}
/** 轮次结果 */
@JsonInclude(JsonInclude.Include.NON_NULL)
public record ResultPayload(
String text,
@JsonProperty("tool_calls_count") int toolCallsCount,
@JsonProperty("prompt_tokens") long promptTokens,
@JsonProperty("completion_tokens") long completionTokens
) {}
/** 权限请求 */
@JsonInclude(JsonInclude.Include.NON_NULL)
public record ControlRequest(
@JsonProperty("request_id") String requestId,
String subtype, // "can_use_tool"
@JsonProperty("tool_name") String toolName,
@JsonProperty("tool_input") String toolInput,
@JsonProperty("activity_description") String activityDescription
) {}
/** 系统事件 */
@JsonInclude(JsonInclude.Include.NON_NULL)
public record SystemEvent(
String subtype, // "init", "session_ready", "error"
String model,
String message,
@JsonProperty("session_id") String sessionId
) {}
/** 错误 */
@JsonInclude(JsonInclude.Include.NON_NULL)
public record ErrorPayload(
String code,
String message
) {}
// ==================== 工厂方法 ====================
public static String userMessage(String sessionId, String content) throws JsonProcessingException {
var envelope = new Envelope(TYPE_USER, sessionId, new UserMessage(content, null));
return envelope.toJson();
}
public static String assistantMessage(String sessionId, String text, boolean streaming) throws JsonProcessingException {
var payload = new AssistantPayload(text, UUID.randomUUID().toString(), streaming, null);
var envelope = new Envelope(TYPE_ASSISTANT, sessionId, payload);
return envelope.toJson();
}
public static String assistantToolUse(String sessionId, List<ToolCallInfo> toolCalls) throws JsonProcessingException {
var payload = new AssistantPayload(null, UUID.randomUUID().toString(), false, toolCalls);
var envelope = new Envelope(TYPE_TOOL_USE, sessionId, payload);
return envelope.toJson();
}
public static String resultMessage(String sessionId, String text, int toolCallsCount,
long promptTokens, long completionTokens) throws JsonProcessingException {
var payload = new ResultPayload(text, toolCallsCount, promptTokens, completionTokens);
var envelope = new Envelope(TYPE_RESULT, sessionId, payload);
return envelope.toJson();
}
public static String controlRequest(String sessionId, String toolName, String toolInput,
String activityDescription) throws JsonProcessingException {
var payload = new ControlRequest(
UUID.randomUUID().toString(), "can_use_tool",
toolName, toolInput, activityDescription);
var envelope = new Envelope(TYPE_CONTROL_REQUEST, sessionId, payload);
return envelope.toJson();
}
public static String controlResponse(String sessionId, String requestId, String behavior,
String message) throws JsonProcessingException {
var payload = new ControlResponse(requestId, behavior, message, null);
var envelope = new Envelope(TYPE_CONTROL_RESPONSE, sessionId, payload);
return envelope.toJson();
}
public static String interrupt(String sessionId) throws JsonProcessingException {
var envelope = new Envelope(TYPE_INTERRUPT, sessionId, null);
return envelope.toJson();
}
public static String keepAlive(String sessionId) throws JsonProcessingException {
var envelope = new Envelope(TYPE_KEEP_ALIVE, sessionId, null);
return envelope.toJson();
}
public static String systemEvent(String sessionId, String subtype, String model, String message) throws JsonProcessingException {
var payload = new SystemEvent(subtype, model, message, sessionId);
var envelope = new Envelope(TYPE_SYSTEM, sessionId, payload);
return envelope.toJson();
}
public static String error(String sessionId, String code, String message) throws JsonProcessingException {
var payload = new ErrorPayload(code, message);
var envelope = new Envelope(TYPE_ERROR, sessionId, payload);
return envelope.toJson();
}
// ==================== 解析工具 ====================
public static String getType(String json) throws JsonProcessingException {
JsonNode node = MAPPER.readTree(json);
return node.has("type") ? node.get("type").asText() : null;
}
public static JsonNode getPayload(String json) throws JsonProcessingException {
JsonNode node = MAPPER.readTree(json);
return node.get("payload");
}
public static String getSessionId(String json) throws JsonProcessingException {
JsonNode node = MAPPER.readTree(json);
return node.has("session_id") ? node.get("session_id").asText() : null;
}
public static ObjectMapper mapper() {
return MAPPER;
}
}

@ -0,0 +1,93 @@
package com.claudecode.server;
import com.claudecode.tool.ToolContext;
import com.claudecode.tool.ToolRegistry;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Lazy;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
import org.springframework.web.socket.server.standard.ServletServerContainerFactoryBean;
/**
* Server Mode 的自动配置 仅在 claude-code.server-mode=true 时激活
* <p>
* {@link com.claudecode.ClaudeCodeApplication#main(String[])} 在检测到 --server 参数时
* 设置 {@code claude-code.server-mode=true}触发此配置类加载
* <p>
* 注册的 Bean:
* <ul>
* <li>{@link DirectConnectWebSocketHandler} WebSocket 消息处理</li>
* <li>{@link DirectConnectServer} 服务器生命周期管理</li>
* <li>{@link WebSocketConfigurer} WebSocket 端点注册</li>
* </ul>
*/
@Configuration
@ConditionalOnProperty(name = "claude-code.server-mode", havingValue = "true")
@EnableWebSocket
public class ServerModeAutoConfiguration implements WebSocketConfigurer {
private static final Logger log = LoggerFactory.getLogger(ServerModeAutoConfiguration.class);
@Autowired @Lazy
private DirectConnectWebSocketHandler wsHandler;
@Bean
public DirectConnectWebSocketHandler directConnectWebSocketHandler(
ChatModel activeChatModel,
ToolRegistry toolRegistry,
ToolContext toolContext,
String systemPrompt) {
String[] args = getApplicationArgs();
String authToken = DirectConnectServer.parseAuthToken(args);
int maxSessions = DirectConnectServer.parseMaxSessions(args);
String model = System.getenv("AI_MODEL") != null
? System.getenv("AI_MODEL") : "claude-sonnet-4-20250514";
log.info("Creating DirectConnect WebSocket handler (maxSessions={}, auth={})",
maxSessions, authToken != null && !authToken.isBlank() ? "enabled" : "disabled");
return new DirectConnectWebSocketHandler(
activeChatModel, toolRegistry, toolContext,
systemPrompt, authToken, maxSessions, model);
}
@Bean
public DirectConnectServer directConnectServer(DirectConnectWebSocketHandler handler) {
String[] args = getApplicationArgs();
int port = DirectConnectServer.parsePort(args);
String authToken = DirectConnectServer.parseAuthToken(args);
int maxSessions = DirectConnectServer.parseMaxSessions(args);
DirectConnectServer server = new DirectConnectServer(port, authToken, maxSessions, handler);
server.onServerStarted();
return server;
}
@Bean
public ServletServerContainerFactoryBean createWebSocketContainer() {
return WebSocketServerConfig.createWebSocketContainer();
}
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
registry.addHandler(wsHandler, "/ws")
.setAllowedOrigins("*");
log.info("WebSocket handler registered at /ws endpoint");
}
private String[] getApplicationArgs() {
String argsStr = System.getProperty("claude-code.server-args");
if (argsStr != null && !argsStr.isBlank()) {
return argsStr.split("\\s+");
}
return new String[0];
}
}

@ -0,0 +1,300 @@
package com.claudecode.server;
import com.claudecode.core.AgentLoop;
import com.claudecode.permission.PermissionTypes.PermissionChoice;
import com.fasterxml.jackson.databind.JsonNode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import java.io.IOException;
import java.util.UUID;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* 单个 WebSocket 客户端的会话管理 对应 TS DirectConnectSessionManager 的服务端对应
* <p>
* 每个 ServerSession 包装一个 AgentLoop 实例
* WebSocket 消息转化为 AgentLoop 的调用并将结果回传
* <p>
* 权限请求通过 WebSocket 转发给客户端
* <ol>
* <li>AgentLoop 触发 onPermissionRequest 回调</li>
* <li>ServerSession 发送 control_request 消息给客户端</li>
* <li>客户端回复 control_response 消息</li>
* <li>ServerSession 将结果返回给 AgentLoop</li>
* </ol>
*/
public class ServerSession {
private static final Logger log = LoggerFactory.getLogger(ServerSession.class);
private final String sessionId;
private final AgentLoop agentLoop;
private final WebSocketSession wsSession;
private final AtomicBoolean processing = new AtomicBoolean(false);
/** 权限请求的异步等待队列:requestId → CompletableFuture<PermissionChoice> */
private final ConcurrentHashMap<String, CompletableFuture<PermissionChoice>> pendingPermissions
= new ConcurrentHashMap<>();
/** 会话线程池(处理用户消息,每个会话一个线程) */
private final ExecutorService sessionExecutor = Executors.newSingleThreadExecutor(
Thread.ofVirtual().name("server-session-", 0).factory()
);
/** Keep-alive 调度器 */
private final ScheduledExecutorService keepAliveScheduler = Executors.newSingleThreadScheduledExecutor(
Thread.ofVirtual().name("keep-alive-", 0).factory()
);
private ScheduledFuture<?> keepAliveTask;
public ServerSession(String sessionId, AgentLoop agentLoop, WebSocketSession wsSession) {
this.sessionId = sessionId;
this.agentLoop = agentLoop;
this.wsSession = wsSession;
// 注册 AgentLoop 回调:将事件转发到 WebSocket
setupAgentCallbacks();
startKeepAlive();
}
private void setupAgentCallbacks() {
// 助手文本回调 → 发送 assistant 消息
agentLoop.setOnAssistantMessage(text -> {
try {
sendMessage(ServerMessage.assistantMessage(sessionId, text, false));
} catch (Exception e) {
log.error("[{}] Failed to send assistant message", sessionId, e);
}
});
// 流式输出回调 → 发送流式 assistant 消息(使用 onStreamStart)
agentLoop.setOnStreamStart(() -> {
// 不需要特殊处理,流式 token 通过 streaming consumer 发送
});
// 工具事件回调 → 发送 tool_use 消息
agentLoop.setOnToolEvent(event -> {
try {
var toolInfo = new ServerMessage.ToolCallInfo(
null, event.toolName(), event.arguments(), event.result(),
switch (event.phase()) {
case START -> "running";
case END -> "completed";
case PROGRESS -> "running";
}
);
sendMessage(ServerMessage.assistantToolUse(sessionId, java.util.List.of(toolInfo)));
} catch (Exception e) {
log.error("[{}] Failed to send tool event", sessionId, e);
}
});
// 权限请求回调 → 转发到 WebSocket 客户端
agentLoop.setOnPermissionRequest(req -> {
try {
return forwardPermissionRequest(req);
} catch (Exception e) {
log.error("[{}] Permission request failed", sessionId, e);
return PermissionChoice.DENY_ONCE;
}
});
// Thinking 内容 → 发送系统消息
agentLoop.setOnThinkingContent(thinking -> {
try {
sendMessage(ServerMessage.systemEvent(sessionId, "thinking", null, thinking));
} catch (Exception e) {
log.debug("[{}] Failed to send thinking content", sessionId, e);
}
});
}
/**
* 将权限请求转发到 WebSocket 客户端同步等待响应
*/
private PermissionChoice forwardPermissionRequest(AgentLoop.PermissionRequest req) {
String requestId = UUID.randomUUID().toString();
CompletableFuture<PermissionChoice> future = new CompletableFuture<>();
pendingPermissions.put(requestId, future);
try {
// 发送 control_request
String msg = ServerMessage.controlRequest(
sessionId, req.toolName(), req.arguments(), req.activityDescription());
sendMessage(msg);
// 等待客户端响应(超时 60 秒)
return future.get(60, TimeUnit.SECONDS);
} catch (TimeoutException e) {
log.warn("[{}] Permission request timed out for {}", sessionId, req.toolName());
return PermissionChoice.DENY_ONCE;
} catch (Exception e) {
log.error("[{}] Permission request error", sessionId, e);
return PermissionChoice.DENY_ONCE;
} finally {
pendingPermissions.remove(requestId);
}
}
/**
* 处理客户端发来的消息
*/
public void handleMessage(String rawJson) {
try {
String type = ServerMessage.getType(rawJson);
if (type == null) {
sendMessage(ServerMessage.error(sessionId, "invalid_message", "Missing message type"));
return;
}
switch (type) {
case ServerMessage.TYPE_USER -> handleUserMessage(rawJson);
case ServerMessage.TYPE_CONTROL_RESPONSE -> handleControlResponse(rawJson);
case ServerMessage.TYPE_INTERRUPT -> handleInterrupt();
case ServerMessage.TYPE_KEEP_ALIVE -> {} // 忽略
default -> sendMessage(ServerMessage.error(sessionId, "unknown_type",
"Unknown message type: " + type));
}
} catch (Exception e) {
log.error("[{}] Message handling error", sessionId, e);
try {
sendMessage(ServerMessage.error(sessionId, "internal_error", e.getMessage()));
} catch (Exception ignored) {}
}
}
private void handleUserMessage(String rawJson) throws Exception {
if (processing.get()) {
sendMessage(ServerMessage.error(sessionId, "busy",
"Session is currently processing a message. Send interrupt first."));
return;
}
JsonNode payload = ServerMessage.getPayload(rawJson);
if (payload == null || !payload.has("content")) {
sendMessage(ServerMessage.error(sessionId, "invalid_payload", "Missing content in user message"));
return;
}
String content = payload.get("content").asText();
processing.set(true);
// 在虚拟线程中异步执行 AgentLoop
sessionExecutor.submit(() -> {
try {
// 使用流式模式,将每个 token 实时转发
String result = agentLoop.runStreaming(content, token -> {
try {
sendMessage(ServerMessage.assistantMessage(sessionId, token, true));
} catch (Exception e) {
log.error("[{}] Failed to stream token", sessionId, e);
}
});
// 发送最终结果
var tracker = agentLoop.getTokenTracker();
sendMessage(ServerMessage.resultMessage(
sessionId, result, 0,
tracker.getInputTokens(), tracker.getOutputTokens()));
} catch (Exception e) {
log.error("[{}] Agent loop execution error", sessionId, e);
try {
sendMessage(ServerMessage.error(sessionId, "execution_error", e.getMessage()));
} catch (Exception ignored) {}
} finally {
processing.set(false);
}
});
}
private void handleControlResponse(String rawJson) throws Exception {
JsonNode payload = ServerMessage.getPayload(rawJson);
if (payload == null || !payload.has("request_id") || !payload.has("behavior")) {
sendMessage(ServerMessage.error(sessionId, "invalid_payload",
"Missing request_id or behavior in control_response"));
return;
}
String requestId = payload.get("request_id").asText();
String behavior = payload.get("behavior").asText();
CompletableFuture<PermissionChoice> future = pendingPermissions.get(requestId);
if (future != null) {
PermissionChoice choice = "allow".equals(behavior)
? PermissionChoice.ALLOW_ONCE
: PermissionChoice.DENY_ONCE;
future.complete(choice);
} else {
log.warn("[{}] Unknown permission request_id: {}", sessionId, requestId);
}
}
private void handleInterrupt() {
log.info("[{}] Interrupt signal received", sessionId);
agentLoop.cancel();
// 取消所有挂起的权限请求
pendingPermissions.values().forEach(f -> f.complete(PermissionChoice.DENY_ONCE));
pendingPermissions.clear();
}
private void startKeepAlive() {
keepAliveTask = keepAliveScheduler.scheduleAtFixedRate(() -> {
try {
if (wsSession.isOpen()) {
sendMessage(ServerMessage.keepAlive(sessionId));
}
} catch (Exception e) {
log.debug("[{}] Keep-alive failed", sessionId, e);
}
}, 30, 30, TimeUnit.SECONDS);
}
/**
* 发送 WebSocket 消息线程安全
*/
private synchronized void sendMessage(String json) throws IOException {
if (wsSession.isOpen()) {
wsSession.sendMessage(new TextMessage(json));
}
}
/**
* 发送会话初始化消息
*/
public void sendInitMessage(String model) throws IOException {
sendMessage(ServerMessage.systemEvent(sessionId, "init", model, "Session ready"));
}
/**
* 关闭会话清理资源
*/
public void close() {
log.info("[{}] Closing server session", sessionId);
if (keepAliveTask != null) {
keepAliveTask.cancel(true);
}
keepAliveScheduler.shutdownNow();
agentLoop.cancel();
pendingPermissions.values().forEach(f -> f.complete(PermissionChoice.DENY_ONCE));
pendingPermissions.clear();
sessionExecutor.shutdownNow();
}
public String getSessionId() {
return sessionId;
}
public boolean isProcessing() {
return processing.get();
}
public AgentLoop getAgentLoop() {
return agentLoop;
}
}

@ -0,0 +1,49 @@
package com.claudecode.server;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
import org.springframework.web.socket.server.standard.ServletServerContainerFactoryBean;
/**
* Server Mode WebSocket 配置 仅在 --server 模式下激活
* <p>
* 注册 {@link DirectConnectWebSocketHandler} /ws 端点
* <p>
* 对应 claude-code Server Mode 功能
* <ul>
* <li>WebSocket 端点: ws://localhost:{port}/ws</li>
* <li>允许所有来源连接本地开发用生产环境应限制</li>
* <li>最大消息 1MB</li>
* </ul>
*/
public class WebSocketServerConfig implements WebSocketConfigurer {
private static final Logger log = LoggerFactory.getLogger(WebSocketServerConfig.class);
private final DirectConnectWebSocketHandler handler;
public WebSocketServerConfig(DirectConnectWebSocketHandler handler) {
this.handler = handler;
}
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
registry.addHandler(handler, "/ws")
.setAllowedOrigins("*");
log.info("WebSocket handler registered at /ws");
}
/**
* 配置 WebSocket 容器参数
*/
public static ServletServerContainerFactoryBean createWebSocketContainer() {
ServletServerContainerFactoryBean container = new ServletServerContainerFactoryBean();
container.setMaxTextMessageBufferSize(1024 * 1024); // 1MB
container.setMaxBinaryMessageBufferSize(1024 * 1024);
container.setMaxSessionIdleTimeout(300_000L); // 5 分钟空闲超时
return container;
}
}
Loading…
Cancel
Save