diff --git a/src/main/java/com/claudecode/core/AgentLoop.java b/src/main/java/com/claudecode/core/AgentLoop.java index 5f8275b..b976073 100644 --- a/src/main/java/com/claudecode/core/AgentLoop.java +++ b/src/main/java/com/claudecode/core/AgentLoop.java @@ -14,17 +14,24 @@ import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.tool.ToolCallback; +import reactor.core.publisher.Flux; import java.util.*; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; /** * Agent 循环 —— 对应 claude-code/src/core/query.ts 的 agent loop。 *

+ * 支持两种模式: + *

* 使用 ChatModel(非 ChatClient)的显式循环,完整控制每一轮: *
    *
  1. 构建 Prompt(消息历史 + 系统提示 + 工具定义)
  2. - *
  3. 调用 ChatModel.call()
  4. + *
  5. 调用 ChatModel.call() 或 ChatModel.stream()
  6. *
  7. 检查工具调用 → 执行工具 → 结果回传
  8. *
  9. 循环直到无工具调用或达到最大迭代
  10. *
@@ -49,9 +56,12 @@ public class AgentLoop { /** 工具调用事件回调:在每次工具调用前/后通知 UI */ private Consumer onToolEvent; - /** 助手文本回调:在每次助手回复时通知 UI */ + /** 助手文本回调:在每次助手回复时通知 UI(仅阻塞模式使用) */ private Consumer onAssistantMessage; + /** 流式输出开始回调:通知 UI 停止 spinner */ + private Runnable onStreamStart; + public AgentLoop(ChatModel chatModel, ToolRegistry toolRegistry, ToolContext toolContext, String systemPrompt) { this(chatModel, toolRegistry, toolContext, systemPrompt, new TokenTracker()); @@ -64,7 +74,6 @@ public class AgentLoop { this.toolContext = toolContext; this.systemPrompt = systemPrompt; this.tokenTracker = tokenTracker; - // 添加系统提示词到消息历史 this.messageHistory.add(new SystemMessage(systemPrompt)); } @@ -76,15 +85,39 @@ public class AgentLoop { this.onAssistantMessage = onAssistantMessage; } + public void setOnStreamStart(Runnable onStreamStart) { + this.onStreamStart = onStreamStart; + } + + // ==================== 阻塞模式 ==================== + /** - * 执行一轮用户输入的完整 Agent 循环。 + * 阻塞执行一轮用户输入的完整 Agent 循环。 + * 等待完整响应后才返回。 + */ + public String run(String userInput) { + messageHistory.add(new UserMessage(userInput)); + return executeLoop(false, null); + } + + // ==================== 流式模式 ==================== + + /** + * 流式执行一轮用户输入的完整 Agent 循环。 + * 文本逐 token 通过 onToken 回调实时输出到终端。 * * @param userInput 用户输入文本 - * @return 最终助手回复文本 + * @param onToken 每个文本 token 的实时回调(用于终端逐字显示) + * @return 最终完整的助手回复文本 */ - public String run(String userInput) { + public String runStreaming(String userInput, Consumer onToken) { messageHistory.add(new UserMessage(userInput)); + return executeLoop(true, onToken); + } + // ==================== 核心循环(统一阻塞/流式) ==================== + + private String executeLoop(boolean streaming, Consumer onToken) { List callbacks = toolRegistry.toCallbacks(toolContext); ChatOptions options = ToolCallingChatOptions.builder() .toolCallbacks(callbacks) @@ -96,78 +129,158 @@ public class AgentLoop { while (iteration < MAX_ITERATIONS) { iteration++; - log.debug("Agent 循环 第{}轮", iteration); + log.debug("Agent 循环 第{}轮 ({})", iteration, streaming ? "流式" : "阻塞"); Prompt prompt = new Prompt(List.copyOf(messageHistory), options); - ChatResponse response = chatModel.call(prompt); + + // 调用 AI 并获取结果 + IterationResult result; + if (streaming) { + result = streamIteration(prompt, onToken); + } else { + result = blockingIteration(prompt); + } // 记录 Token 使用量 - if (response.getMetadata() != null && response.getMetadata().getUsage() != null) { - var usage = response.getMetadata().getUsage(); - tokenTracker.recordUsage( - usage.getPromptTokens(), - usage.getCompletionTokens() - ); + if (result.promptTokens > 0 || result.completionTokens > 0) { + tokenTracker.recordUsage(result.promptTokens, result.completionTokens); } - AssistantMessage assistant = response.getResult().getOutput(); - messageHistory.add(assistant); + // 将助手消息加入历史 + messageHistory.add(result.assistant); - // 提取并通知助手文本 - String text = assistant.getText(); + String text = result.assistant.getText(); if (text != null && !text.isBlank()) { lastAssistantText = text; - if (onAssistantMessage != null) { + // 阻塞模式通知 UI(流式模式已在回调中实时输出) + if (!streaming && onAssistantMessage != null) { onAssistantMessage.accept(text); } } - // 检查是否有工具调用 - if (!assistant.hasToolCalls()) { + // 无工具调用 → 结束 + if (!result.assistant.hasToolCalls()) { log.debug("无工具调用,循环结束(共{}轮)", iteration); break; } - // 逐个执行工具调用 - List toolResponses = new ArrayList<>(); - for (AssistantMessage.ToolCall toolCall : assistant.getToolCalls()) { - String toolName = toolCall.name(); - String toolArgs = toolCall.arguments(); - String callId = toolCall.id(); + // 执行工具调用 + executeToolCalls(result.assistant.getToolCalls(), callbacks); + } + + if (iteration >= MAX_ITERATIONS) { + log.warn("Agent 循环已达最大迭代次数 {},强制终止", MAX_ITERATIONS); + lastAssistantText += "\n\n[WARNING: 达到最大循环次数限制]"; + } + + return lastAssistantText; + } + + /** 阻塞模式:调用 chatModel.call() 并解析结果 */ + private IterationResult blockingIteration(Prompt prompt) { + ChatResponse response = chatModel.call(prompt); - // 通知 UI 工具调用开始 - if (onToolEvent != null) { - onToolEvent.accept(new ToolEvent(toolName, ToolEvent.Phase.START, toolArgs, null)); + long promptTokens = 0, completionTokens = 0; + if (response.getMetadata() != null && response.getMetadata().getUsage() != null) { + var usage = response.getMetadata().getUsage(); + promptTokens = usage.getPromptTokens(); + completionTokens = usage.getCompletionTokens(); + } + + return new IterationResult(response.getResult().getOutput(), promptTokens, completionTokens); + } + + /** 流式模式:调用 chatModel.stream() 逐 token 输出,累积完整响应 */ + private IterationResult streamIteration(Prompt prompt, Consumer onToken) { + StringBuilder textBuffer = new StringBuilder(); + // 工具调用按 ID 去重累积(流式分片可能多次发送同一工具调用) + Map toolCallMap = new LinkedHashMap<>(); + long[] tokenUsage = {0, 0}; + boolean[] firstToken = {true}; + + try { + Flux flux = chatModel.stream(prompt); + + flux.doOnNext(chunk -> { + // 记录 token 使用量(通常出现在最后一个 chunk) + if (chunk.getMetadata() != null && chunk.getMetadata().getUsage() != null) { + var usage = chunk.getMetadata().getUsage(); + if (usage.getPromptTokens() > 0) tokenUsage[0] = usage.getPromptTokens(); + if (usage.getCompletionTokens() > 0) tokenUsage[1] = usage.getCompletionTokens(); } - // 查找并执行工具 - String result; - ToolCallbackAdapter adapter = findCallbackByName(callbacks, toolName); - if (adapter != null) { - result = adapter.call(toolArgs); - } else { - result = "Error: Unknown tool '" + toolName + "'"; - log.warn("未知工具: {}", toolName); + if (chunk.getResult() == null || chunk.getResult().getOutput() == null) return; + AssistantMessage output = chunk.getResult().getOutput(); + + // 实时输出文本 token + String text = output.getText(); + if (text != null && !text.isEmpty()) { + // 第一个 token 到达时通知 UI(停止 spinner) + if (firstToken[0]) { + firstToken[0] = false; + if (onStreamStart != null) onStreamStart.run(); + } + textBuffer.append(text); + if (onToken != null) onToken.accept(text); } - // 通知 UI 工具调用完成 - if (onToolEvent != null) { - onToolEvent.accept(new ToolEvent(toolName, ToolEvent.Phase.END, toolArgs, result)); + // 累积工具调用(按 ID 去重) + if (output.hasToolCalls()) { + for (var tc : output.getToolCalls()) { + if (tc.id() != null) { + toolCallMap.putIfAbsent(tc.id(), tc); + } + } } + }).blockLast(); - toolResponses.add(new ToolResponseMessage.ToolResponse(callId, toolName, result)); + } catch (Exception e) { + // 流式调用失败 → 降级到阻塞模式 + log.warn("流式调用失败,降级到阻塞模式: {}", e.getMessage()); + return blockingIteration(prompt); + } + + // 使用 Builder 构建 AssistantMessage(构造器是 protected 的) + List toolCalls = new ArrayList<>(toolCallMap.values()); + AssistantMessage assistant = AssistantMessage.builder() + .content(textBuffer.toString()) + .toolCalls(toolCalls) + .build(); + + return new IterationResult(assistant, tokenUsage[0], tokenUsage[1]); + } + + /** 执行工具调用列表并将结果加入消息历史 */ + private void executeToolCalls(List toolCalls, + List callbacks) { + List toolResponses = new ArrayList<>(); + + for (AssistantMessage.ToolCall toolCall : toolCalls) { + String toolName = toolCall.name(); + String toolArgs = toolCall.arguments(); + String callId = toolCall.id(); + + if (onToolEvent != null) { + onToolEvent.accept(new ToolEvent(toolName, ToolEvent.Phase.START, toolArgs, null)); } - // 将工具结果加入消息历史 - messageHistory.add(ToolResponseMessage.builder().responses(toolResponses).build()); - } + String result; + ToolCallbackAdapter adapter = findCallbackByName(callbacks, toolName); + if (adapter != null) { + result = adapter.call(toolArgs); + } else { + result = "Error: Unknown tool '" + toolName + "'"; + log.warn("未知工具: {}", toolName); + } - if (iteration >= MAX_ITERATIONS) { - log.warn("Agent 循环已达最大迭代次数 {},强制终止", MAX_ITERATIONS); - lastAssistantText += "\n\n[WARNING: 达到最大循环次数限制]"; + if (onToolEvent != null) { + onToolEvent.accept(new ToolEvent(toolName, ToolEvent.Phase.END, toolArgs, result)); + } + + toolResponses.add(new ToolResponseMessage.ToolResponse(callId, toolName, result)); } - return lastAssistantText; + messageHistory.add(ToolResponseMessage.builder().responses(toolResponses).build()); } /** 从 ToolCallback 列表中查找匹配名称的适配器 */ @@ -195,12 +308,26 @@ public class AgentLoop { return systemPrompt; } + /** 获取 ChatModel(用于上下文压缩等需要直接调用模型的场景) */ + public ChatModel getChatModel() { + return chatModel; + } + /** 重置历史(保留系统提示词) */ public void reset() { messageHistory.clear(); messageHistory.add(new SystemMessage(systemPrompt)); } + /** 替换消息历史(用于上下文压缩后替换) */ + public void replaceHistory(List newHistory) { + messageHistory.clear(); + messageHistory.addAll(newHistory); + } + + /** 单次迭代结果 */ + private record IterationResult(AssistantMessage assistant, long promptTokens, long completionTokens) {} + /** 工具事件,用于 UI 展示 */ public record ToolEvent(String toolName, Phase phase, String arguments, String result) { public enum Phase { START, END } diff --git a/src/main/java/com/claudecode/repl/ReplSession.java b/src/main/java/com/claudecode/repl/ReplSession.java index 9fc8cd5..e3d9be0 100644 --- a/src/main/java/com/claudecode/repl/ReplSession.java +++ b/src/main/java/com/claudecode/repl/ReplSession.java @@ -78,8 +78,11 @@ public class ReplSession { } }); + // 流式输出第一个 token 到达时停止 spinner + agentLoop.setOnStreamStart(() -> spinner.stop()); + agentLoop.setOnAssistantMessage(text -> { - // 助手文本在 agent 循环结束后由 REPL 统一渲染 + // 阻塞模式回调:流式模式下由 onToken 实时输出,此回调不触发 }); } @@ -241,14 +244,19 @@ public class ReplSession { return; } - // Agent 循环 + // Agent 循环(流式输出) try { spinner.start("Thinking..."); - String response = agentLoop.run(input); - spinner.stop(); + out.println(); // 换行准备输出区域 - out.println(); - markdownRenderer.render(response); + // 流式回调:逐 token 输出到终端 + String response = agentLoop.runStreaming(input, token -> { + out.print(token); + out.flush(); + }); + + spinner.stop(); + out.println(); // 流式输出结束后换行 out.println(); } catch (Exception e) { spinner.stop();