feat: Phase5A 流式输出支持,逐token实时显示

AgentLoop 重构:
- 新增 runStreaming(input, onToken) 流式模式
- 使用 chatModel.stream(Prompt) -> Flux<ChatResponse>
- 统一 executeLoop() 核心循环支持阻塞/流式两种模式
- 流式分片工具调用按ID去重累积
- 流式失败自动降级到阻塞模式
- AssistantMessage 使用 Builder 模式构建(构造器是protected)
- 新增 onStreamStart 回调(第一个token到达时停止spinner)
- 新增 getChatModel() / replaceHistory() 方法(为后续compact准备)

ReplSession 更新:
- handleInput 使用 runStreaming 替代 run
- 逐token直接输出到终端(out.print + flush)
- spinner在第一个token到达时自动停止

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
pull/1/head
liuzh 1 month ago
parent de8349079f
commit fd262bf98d
  1. 225
      src/main/java/com/claudecode/core/AgentLoop.java
  2. 20
      src/main/java/com/claudecode/repl/ReplSession.java

@ -14,17 +14,24 @@ import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.ToolCallback;
import reactor.core.publisher.Flux;
import java.util.*; import java.util.*;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer; import java.util.function.Consumer;
/** /**
* Agent 循环 对应 claude-code/src/core/query.ts agent loop * Agent 循环 对应 claude-code/src/core/query.ts agent loop
* <p> * <p>
* 支持两种模式
* <ul>
* <li>{@link #run(String)} 阻塞模式等待完整响应后返回</li>
* <li>{@link #runStreaming(String, Consumer)} 流式模式 token 实时输出</li>
* </ul>
* 使用 ChatModel ChatClient的显式循环完整控制每一轮 * 使用 ChatModel ChatClient的显式循环完整控制每一轮
* <ol> * <ol>
* <li>构建 Prompt消息历史 + 系统提示 + 工具定义</li> * <li>构建 Prompt消息历史 + 系统提示 + 工具定义</li>
* <li>调用 ChatModel.call()</li> * <li>调用 ChatModel.call() ChatModel.stream()</li>
* <li>检查工具调用 执行工具 结果回传</li> * <li>检查工具调用 执行工具 结果回传</li>
* <li>循环直到无工具调用或达到最大迭代</li> * <li>循环直到无工具调用或达到最大迭代</li>
* </ol> * </ol>
@ -49,9 +56,12 @@ public class AgentLoop {
/** 工具调用事件回调:在每次工具调用前/后通知 UI */ /** 工具调用事件回调:在每次工具调用前/后通知 UI */
private Consumer<ToolEvent> onToolEvent; private Consumer<ToolEvent> onToolEvent;
/** 助手文本回调:在每次助手回复时通知 UI */ /** 助手文本回调:在每次助手回复时通知 UI(仅阻塞模式使用) */
private Consumer<String> onAssistantMessage; private Consumer<String> onAssistantMessage;
/** 流式输出开始回调:通知 UI 停止 spinner */
private Runnable onStreamStart;
public AgentLoop(ChatModel chatModel, ToolRegistry toolRegistry, public AgentLoop(ChatModel chatModel, ToolRegistry toolRegistry,
ToolContext toolContext, String systemPrompt) { ToolContext toolContext, String systemPrompt) {
this(chatModel, toolRegistry, toolContext, systemPrompt, new TokenTracker()); this(chatModel, toolRegistry, toolContext, systemPrompt, new TokenTracker());
@ -64,7 +74,6 @@ public class AgentLoop {
this.toolContext = toolContext; this.toolContext = toolContext;
this.systemPrompt = systemPrompt; this.systemPrompt = systemPrompt;
this.tokenTracker = tokenTracker; this.tokenTracker = tokenTracker;
// 添加系统提示词到消息历史
this.messageHistory.add(new SystemMessage(systemPrompt)); this.messageHistory.add(new SystemMessage(systemPrompt));
} }
@ -76,15 +85,39 @@ public class AgentLoop {
this.onAssistantMessage = onAssistantMessage; this.onAssistantMessage = onAssistantMessage;
} }
public void setOnStreamStart(Runnable onStreamStart) {
this.onStreamStart = onStreamStart;
}
// ==================== 阻塞模式 ====================
/** /**
* 执行一轮用户输入的完整 Agent 循环 * 阻塞执行一轮用户输入的完整 Agent 循环
* 等待完整响应后才返回
*/
public String run(String userInput) {
messageHistory.add(new UserMessage(userInput));
return executeLoop(false, null);
}
// ==================== 流式模式 ====================
/**
* 流式执行一轮用户输入的完整 Agent 循环
* 文本逐 token 通过 onToken 回调实时输出到终端
* *
* @param userInput 用户输入文本 * @param userInput 用户输入文本
* @return 最终助手回复文本 * @param onToken 每个文本 token 的实时回调用于终端逐字显示
* @return 最终完整的助手回复文本
*/ */
public String run(String userInput) { public String runStreaming(String userInput, Consumer<String> onToken) {
messageHistory.add(new UserMessage(userInput)); messageHistory.add(new UserMessage(userInput));
return executeLoop(true, onToken);
}
// ==================== 核心循环(统一阻塞/流式) ====================
private String executeLoop(boolean streaming, Consumer<String> onToken) {
List<ToolCallback> callbacks = toolRegistry.toCallbacks(toolContext); List<ToolCallback> callbacks = toolRegistry.toCallbacks(toolContext);
ChatOptions options = ToolCallingChatOptions.builder() ChatOptions options = ToolCallingChatOptions.builder()
.toolCallbacks(callbacks) .toolCallbacks(callbacks)
@ -96,78 +129,158 @@ public class AgentLoop {
while (iteration < MAX_ITERATIONS) { while (iteration < MAX_ITERATIONS) {
iteration++; iteration++;
log.debug("Agent 循环 第{}轮", iteration); log.debug("Agent 循环 第{}轮 ({})", iteration, streaming ? "流式" : "阻塞");
Prompt prompt = new Prompt(List.copyOf(messageHistory), options); Prompt prompt = new Prompt(List.copyOf(messageHistory), options);
ChatResponse response = chatModel.call(prompt);
// 调用 AI 并获取结果
IterationResult result;
if (streaming) {
result = streamIteration(prompt, onToken);
} else {
result = blockingIteration(prompt);
}
// 记录 Token 使用量 // 记录 Token 使用量
if (response.getMetadata() != null && response.getMetadata().getUsage() != null) { if (result.promptTokens > 0 || result.completionTokens > 0) {
var usage = response.getMetadata().getUsage(); tokenTracker.recordUsage(result.promptTokens, result.completionTokens);
tokenTracker.recordUsage(
usage.getPromptTokens(),
usage.getCompletionTokens()
);
} }
AssistantMessage assistant = response.getResult().getOutput(); // 将助手消息加入历史
messageHistory.add(assistant); messageHistory.add(result.assistant);
// 提取并通知助手文本 String text = result.assistant.getText();
String text = assistant.getText();
if (text != null && !text.isBlank()) { if (text != null && !text.isBlank()) {
lastAssistantText = text; lastAssistantText = text;
if (onAssistantMessage != null) { // 阻塞模式通知 UI(流式模式已在回调中实时输出)
if (!streaming && onAssistantMessage != null) {
onAssistantMessage.accept(text); onAssistantMessage.accept(text);
} }
} }
// 检查是否有工具调用 // 无工具调用 → 结束
if (!assistant.hasToolCalls()) { if (!result.assistant.hasToolCalls()) {
log.debug("无工具调用,循环结束(共{}轮)", iteration); log.debug("无工具调用,循环结束(共{}轮)", iteration);
break; break;
} }
// 逐个执行工具调用 // 执行工具调用
List<ToolResponseMessage.ToolResponse> toolResponses = new ArrayList<>(); executeToolCalls(result.assistant.getToolCalls(), callbacks);
for (AssistantMessage.ToolCall toolCall : assistant.getToolCalls()) { }
String toolName = toolCall.name();
String toolArgs = toolCall.arguments(); if (iteration >= MAX_ITERATIONS) {
String callId = toolCall.id(); log.warn("Agent 循环已达最大迭代次数 {},强制终止", MAX_ITERATIONS);
lastAssistantText += "\n\n[WARNING: 达到最大循环次数限制]";
}
return lastAssistantText;
}
/** 阻塞模式:调用 chatModel.call() 并解析结果 */
private IterationResult blockingIteration(Prompt prompt) {
ChatResponse response = chatModel.call(prompt);
// 通知 UI 工具调用开始 long promptTokens = 0, completionTokens = 0;
if (onToolEvent != null) { if (response.getMetadata() != null && response.getMetadata().getUsage() != null) {
onToolEvent.accept(new ToolEvent(toolName, ToolEvent.Phase.START, toolArgs, null)); var usage = response.getMetadata().getUsage();
promptTokens = usage.getPromptTokens();
completionTokens = usage.getCompletionTokens();
}
return new IterationResult(response.getResult().getOutput(), promptTokens, completionTokens);
}
/** 流式模式:调用 chatModel.stream() 逐 token 输出,累积完整响应 */
private IterationResult streamIteration(Prompt prompt, Consumer<String> onToken) {
StringBuilder textBuffer = new StringBuilder();
// 工具调用按 ID 去重累积(流式分片可能多次发送同一工具调用)
Map<String, AssistantMessage.ToolCall> toolCallMap = new LinkedHashMap<>();
long[] tokenUsage = {0, 0};
boolean[] firstToken = {true};
try {
Flux<ChatResponse> flux = chatModel.stream(prompt);
flux.doOnNext(chunk -> {
// 记录 token 使用量(通常出现在最后一个 chunk)
if (chunk.getMetadata() != null && chunk.getMetadata().getUsage() != null) {
var usage = chunk.getMetadata().getUsage();
if (usage.getPromptTokens() > 0) tokenUsage[0] = usage.getPromptTokens();
if (usage.getCompletionTokens() > 0) tokenUsage[1] = usage.getCompletionTokens();
} }
// 查找并执行工具 if (chunk.getResult() == null || chunk.getResult().getOutput() == null) return;
String result; AssistantMessage output = chunk.getResult().getOutput();
ToolCallbackAdapter adapter = findCallbackByName(callbacks, toolName);
if (adapter != null) { // 实时输出文本 token
result = adapter.call(toolArgs); String text = output.getText();
} else { if (text != null && !text.isEmpty()) {
result = "Error: Unknown tool '" + toolName + "'"; // 第一个 token 到达时通知 UI(停止 spinner)
log.warn("未知工具: {}", toolName); if (firstToken[0]) {
firstToken[0] = false;
if (onStreamStart != null) onStreamStart.run();
}
textBuffer.append(text);
if (onToken != null) onToken.accept(text);
} }
// 通知 UI 工具调用完成 // 累积工具调用(按 ID 去重)
if (onToolEvent != null) { if (output.hasToolCalls()) {
onToolEvent.accept(new ToolEvent(toolName, ToolEvent.Phase.END, toolArgs, result)); for (var tc : output.getToolCalls()) {
if (tc.id() != null) {
toolCallMap.putIfAbsent(tc.id(), tc);
}
}
} }
}).blockLast();
toolResponses.add(new ToolResponseMessage.ToolResponse(callId, toolName, result)); } catch (Exception e) {
// 流式调用失败 → 降级到阻塞模式
log.warn("流式调用失败,降级到阻塞模式: {}", e.getMessage());
return blockingIteration(prompt);
}
// 使用 Builder 构建 AssistantMessage(构造器是 protected 的)
List<AssistantMessage.ToolCall> toolCalls = new ArrayList<>(toolCallMap.values());
AssistantMessage assistant = AssistantMessage.builder()
.content(textBuffer.toString())
.toolCalls(toolCalls)
.build();
return new IterationResult(assistant, tokenUsage[0], tokenUsage[1]);
}
/** 执行工具调用列表并将结果加入消息历史 */
private void executeToolCalls(List<AssistantMessage.ToolCall> toolCalls,
List<ToolCallback> callbacks) {
List<ToolResponseMessage.ToolResponse> toolResponses = new ArrayList<>();
for (AssistantMessage.ToolCall toolCall : toolCalls) {
String toolName = toolCall.name();
String toolArgs = toolCall.arguments();
String callId = toolCall.id();
if (onToolEvent != null) {
onToolEvent.accept(new ToolEvent(toolName, ToolEvent.Phase.START, toolArgs, null));
} }
// 将工具结果加入消息历史 String result;
messageHistory.add(ToolResponseMessage.builder().responses(toolResponses).build()); ToolCallbackAdapter adapter = findCallbackByName(callbacks, toolName);
} if (adapter != null) {
result = adapter.call(toolArgs);
} else {
result = "Error: Unknown tool '" + toolName + "'";
log.warn("未知工具: {}", toolName);
}
if (iteration >= MAX_ITERATIONS) { if (onToolEvent != null) {
log.warn("Agent 循环已达最大迭代次数 {},强制终止", MAX_ITERATIONS); onToolEvent.accept(new ToolEvent(toolName, ToolEvent.Phase.END, toolArgs, result));
lastAssistantText += "\n\n[WARNING: 达到最大循环次数限制]"; }
toolResponses.add(new ToolResponseMessage.ToolResponse(callId, toolName, result));
} }
return lastAssistantText; messageHistory.add(ToolResponseMessage.builder().responses(toolResponses).build());
} }
/** 从 ToolCallback 列表中查找匹配名称的适配器 */ /** 从 ToolCallback 列表中查找匹配名称的适配器 */
@ -195,12 +308,26 @@ public class AgentLoop {
return systemPrompt; return systemPrompt;
} }
/** 获取 ChatModel(用于上下文压缩等需要直接调用模型的场景) */
public ChatModel getChatModel() {
return chatModel;
}
/** 重置历史(保留系统提示词) */ /** 重置历史(保留系统提示词) */
public void reset() { public void reset() {
messageHistory.clear(); messageHistory.clear();
messageHistory.add(new SystemMessage(systemPrompt)); messageHistory.add(new SystemMessage(systemPrompt));
} }
/** 替换消息历史(用于上下文压缩后替换) */
public void replaceHistory(List<Message> newHistory) {
messageHistory.clear();
messageHistory.addAll(newHistory);
}
/** 单次迭代结果 */
private record IterationResult(AssistantMessage assistant, long promptTokens, long completionTokens) {}
/** 工具事件,用于 UI 展示 */ /** 工具事件,用于 UI 展示 */
public record ToolEvent(String toolName, Phase phase, String arguments, String result) { public record ToolEvent(String toolName, Phase phase, String arguments, String result) {
public enum Phase { START, END } public enum Phase { START, END }

@ -78,8 +78,11 @@ public class ReplSession {
} }
}); });
// 流式输出第一个 token 到达时停止 spinner
agentLoop.setOnStreamStart(() -> spinner.stop());
agentLoop.setOnAssistantMessage(text -> { agentLoop.setOnAssistantMessage(text -> {
// 助手文本在 agent 循环结束后由 REPL 统一渲染 // 阻塞模式回调:流式模式下由 onToken 实时输出,此回调不触发
}); });
} }
@ -241,14 +244,19 @@ public class ReplSession {
return; return;
} }
// Agent 循环 // Agent 循环(流式输出)
try { try {
spinner.start("Thinking..."); spinner.start("Thinking...");
String response = agentLoop.run(input); out.println(); // 换行准备输出区域
spinner.stop();
out.println(); // 流式回调:逐 token 输出到终端
markdownRenderer.render(response); String response = agentLoop.runStreaming(input, token -> {
out.print(token);
out.flush();
});
spinner.stop();
out.println(); // 流式输出结束后换行
out.println(); out.println();
} catch (Exception e) { } catch (Exception e) {
spinner.stop(); spinner.stop();

Loading…
Cancel
Save