|
|
|
|
@ -14,17 +14,24 @@ import org.springframework.ai.chat.prompt.ChatOptions; |
|
|
|
|
import org.springframework.ai.chat.prompt.Prompt; |
|
|
|
|
import org.springframework.ai.model.tool.ToolCallingChatOptions; |
|
|
|
|
import org.springframework.ai.tool.ToolCallback; |
|
|
|
|
import reactor.core.publisher.Flux; |
|
|
|
|
|
|
|
|
|
import java.util.*; |
|
|
|
|
import java.util.concurrent.atomic.AtomicReference; |
|
|
|
|
import java.util.function.Consumer; |
|
|
|
|
|
|
|
|
|
/** |
|
|
|
|
* Agent 循环 —— 对应 claude-code/src/core/query.ts 的 agent loop。 |
|
|
|
|
* <p> |
|
|
|
|
* 支持两种模式: |
|
|
|
|
* <ul> |
|
|
|
|
* <li>{@link #run(String)} —— 阻塞模式,等待完整响应后返回</li> |
|
|
|
|
* <li>{@link #runStreaming(String, Consumer)} —— 流式模式,逐 token 实时输出</li> |
|
|
|
|
* </ul> |
|
|
|
|
* 使用 ChatModel(非 ChatClient)的显式循环,完整控制每一轮: |
|
|
|
|
* <ol> |
|
|
|
|
* <li>构建 Prompt(消息历史 + 系统提示 + 工具定义)</li> |
|
|
|
|
* <li>调用 ChatModel.call()</li> |
|
|
|
|
* <li>调用 ChatModel.call() 或 ChatModel.stream()</li> |
|
|
|
|
* <li>检查工具调用 → 执行工具 → 结果回传</li> |
|
|
|
|
* <li>循环直到无工具调用或达到最大迭代</li> |
|
|
|
|
* </ol> |
|
|
|
|
@ -49,9 +56,12 @@ public class AgentLoop { |
|
|
|
|
/** 工具调用事件回调:在每次工具调用前/后通知 UI */ |
|
|
|
|
private Consumer<ToolEvent> onToolEvent; |
|
|
|
|
|
|
|
|
|
/** 助手文本回调:在每次助手回复时通知 UI */ |
|
|
|
|
/** 助手文本回调:在每次助手回复时通知 UI(仅阻塞模式使用) */ |
|
|
|
|
private Consumer<String> onAssistantMessage; |
|
|
|
|
|
|
|
|
|
/** 流式输出开始回调:通知 UI 停止 spinner */ |
|
|
|
|
private Runnable onStreamStart; |
|
|
|
|
|
|
|
|
|
public AgentLoop(ChatModel chatModel, ToolRegistry toolRegistry, |
|
|
|
|
ToolContext toolContext, String systemPrompt) { |
|
|
|
|
this(chatModel, toolRegistry, toolContext, systemPrompt, new TokenTracker()); |
|
|
|
|
@ -64,7 +74,6 @@ public class AgentLoop { |
|
|
|
|
this.toolContext = toolContext; |
|
|
|
|
this.systemPrompt = systemPrompt; |
|
|
|
|
this.tokenTracker = tokenTracker; |
|
|
|
|
// 添加系统提示词到消息历史
|
|
|
|
|
this.messageHistory.add(new SystemMessage(systemPrompt)); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
@ -76,15 +85,39 @@ public class AgentLoop { |
|
|
|
|
this.onAssistantMessage = onAssistantMessage; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
public void setOnStreamStart(Runnable onStreamStart) { |
|
|
|
|
this.onStreamStart = onStreamStart; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// ==================== 阻塞模式 ====================
|
|
|
|
|
|
|
|
|
|
/** |
|
|
|
|
* 阻塞执行一轮用户输入的完整 Agent 循环。 |
|
|
|
|
* 等待完整响应后才返回。 |
|
|
|
|
*/ |
|
|
|
|
public String run(String userInput) { |
|
|
|
|
messageHistory.add(new UserMessage(userInput)); |
|
|
|
|
return executeLoop(false, null); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// ==================== 流式模式 ====================
|
|
|
|
|
|
|
|
|
|
/** |
|
|
|
|
* 执行一轮用户输入的完整 Agent 循环。 |
|
|
|
|
* 流式执行一轮用户输入的完整 Agent 循环。 |
|
|
|
|
* 文本逐 token 通过 onToken 回调实时输出到终端。 |
|
|
|
|
* |
|
|
|
|
* @param userInput 用户输入文本 |
|
|
|
|
* @return 最终助手回复文本 |
|
|
|
|
* @param onToken 每个文本 token 的实时回调(用于终端逐字显示) |
|
|
|
|
* @return 最终完整的助手回复文本 |
|
|
|
|
*/ |
|
|
|
|
public String run(String userInput) { |
|
|
|
|
public String runStreaming(String userInput, Consumer<String> onToken) { |
|
|
|
|
messageHistory.add(new UserMessage(userInput)); |
|
|
|
|
return executeLoop(true, onToken); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// ==================== 核心循环(统一阻塞/流式) ====================
|
|
|
|
|
|
|
|
|
|
private String executeLoop(boolean streaming, Consumer<String> onToken) { |
|
|
|
|
List<ToolCallback> callbacks = toolRegistry.toCallbacks(toolContext); |
|
|
|
|
ChatOptions options = ToolCallingChatOptions.builder() |
|
|
|
|
.toolCallbacks(callbacks) |
|
|
|
|
@ -96,51 +129,141 @@ public class AgentLoop { |
|
|
|
|
|
|
|
|
|
while (iteration < MAX_ITERATIONS) { |
|
|
|
|
iteration++; |
|
|
|
|
log.debug("Agent 循环 第{}轮", iteration); |
|
|
|
|
log.debug("Agent 循环 第{}轮 ({})", iteration, streaming ? "流式" : "阻塞"); |
|
|
|
|
|
|
|
|
|
Prompt prompt = new Prompt(List.copyOf(messageHistory), options); |
|
|
|
|
ChatResponse response = chatModel.call(prompt); |
|
|
|
|
|
|
|
|
|
// 调用 AI 并获取结果
|
|
|
|
|
IterationResult result; |
|
|
|
|
if (streaming) { |
|
|
|
|
result = streamIteration(prompt, onToken); |
|
|
|
|
} else { |
|
|
|
|
result = blockingIteration(prompt); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// 记录 Token 使用量
|
|
|
|
|
if (response.getMetadata() != null && response.getMetadata().getUsage() != null) { |
|
|
|
|
var usage = response.getMetadata().getUsage(); |
|
|
|
|
tokenTracker.recordUsage( |
|
|
|
|
usage.getPromptTokens(), |
|
|
|
|
usage.getCompletionTokens() |
|
|
|
|
); |
|
|
|
|
if (result.promptTokens > 0 || result.completionTokens > 0) { |
|
|
|
|
tokenTracker.recordUsage(result.promptTokens, result.completionTokens); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
AssistantMessage assistant = response.getResult().getOutput(); |
|
|
|
|
messageHistory.add(assistant); |
|
|
|
|
// 将助手消息加入历史
|
|
|
|
|
messageHistory.add(result.assistant); |
|
|
|
|
|
|
|
|
|
// 提取并通知助手文本
|
|
|
|
|
String text = assistant.getText(); |
|
|
|
|
String text = result.assistant.getText(); |
|
|
|
|
if (text != null && !text.isBlank()) { |
|
|
|
|
lastAssistantText = text; |
|
|
|
|
if (onAssistantMessage != null) { |
|
|
|
|
// 阻塞模式通知 UI(流式模式已在回调中实时输出)
|
|
|
|
|
if (!streaming && onAssistantMessage != null) { |
|
|
|
|
onAssistantMessage.accept(text); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// 检查是否有工具调用
|
|
|
|
|
if (!assistant.hasToolCalls()) { |
|
|
|
|
// 无工具调用 → 结束
|
|
|
|
|
if (!result.assistant.hasToolCalls()) { |
|
|
|
|
log.debug("无工具调用,循环结束(共{}轮)", iteration); |
|
|
|
|
break; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// 逐个执行工具调用
|
|
|
|
|
// 执行工具调用
|
|
|
|
|
executeToolCalls(result.assistant.getToolCalls(), callbacks); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if (iteration >= MAX_ITERATIONS) { |
|
|
|
|
log.warn("Agent 循环已达最大迭代次数 {},强制终止", MAX_ITERATIONS); |
|
|
|
|
lastAssistantText += "\n\n[WARNING: 达到最大循环次数限制]"; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
return lastAssistantText; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/** 阻塞模式:调用 chatModel.call() 并解析结果 */ |
|
|
|
|
private IterationResult blockingIteration(Prompt prompt) { |
|
|
|
|
ChatResponse response = chatModel.call(prompt); |
|
|
|
|
|
|
|
|
|
long promptTokens = 0, completionTokens = 0; |
|
|
|
|
if (response.getMetadata() != null && response.getMetadata().getUsage() != null) { |
|
|
|
|
var usage = response.getMetadata().getUsage(); |
|
|
|
|
promptTokens = usage.getPromptTokens(); |
|
|
|
|
completionTokens = usage.getCompletionTokens(); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
return new IterationResult(response.getResult().getOutput(), promptTokens, completionTokens); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/** 流式模式:调用 chatModel.stream() 逐 token 输出,累积完整响应 */ |
|
|
|
|
private IterationResult streamIteration(Prompt prompt, Consumer<String> onToken) { |
|
|
|
|
StringBuilder textBuffer = new StringBuilder(); |
|
|
|
|
// 工具调用按 ID 去重累积(流式分片可能多次发送同一工具调用)
|
|
|
|
|
Map<String, AssistantMessage.ToolCall> toolCallMap = new LinkedHashMap<>(); |
|
|
|
|
long[] tokenUsage = {0, 0}; |
|
|
|
|
boolean[] firstToken = {true}; |
|
|
|
|
|
|
|
|
|
try { |
|
|
|
|
Flux<ChatResponse> flux = chatModel.stream(prompt); |
|
|
|
|
|
|
|
|
|
flux.doOnNext(chunk -> { |
|
|
|
|
// 记录 token 使用量(通常出现在最后一个 chunk)
|
|
|
|
|
if (chunk.getMetadata() != null && chunk.getMetadata().getUsage() != null) { |
|
|
|
|
var usage = chunk.getMetadata().getUsage(); |
|
|
|
|
if (usage.getPromptTokens() > 0) tokenUsage[0] = usage.getPromptTokens(); |
|
|
|
|
if (usage.getCompletionTokens() > 0) tokenUsage[1] = usage.getCompletionTokens(); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if (chunk.getResult() == null || chunk.getResult().getOutput() == null) return; |
|
|
|
|
AssistantMessage output = chunk.getResult().getOutput(); |
|
|
|
|
|
|
|
|
|
// 实时输出文本 token
|
|
|
|
|
String text = output.getText(); |
|
|
|
|
if (text != null && !text.isEmpty()) { |
|
|
|
|
// 第一个 token 到达时通知 UI(停止 spinner)
|
|
|
|
|
if (firstToken[0]) { |
|
|
|
|
firstToken[0] = false; |
|
|
|
|
if (onStreamStart != null) onStreamStart.run(); |
|
|
|
|
} |
|
|
|
|
textBuffer.append(text); |
|
|
|
|
if (onToken != null) onToken.accept(text); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// 累积工具调用(按 ID 去重)
|
|
|
|
|
if (output.hasToolCalls()) { |
|
|
|
|
for (var tc : output.getToolCalls()) { |
|
|
|
|
if (tc.id() != null) { |
|
|
|
|
toolCallMap.putIfAbsent(tc.id(), tc); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
}).blockLast(); |
|
|
|
|
|
|
|
|
|
} catch (Exception e) { |
|
|
|
|
// 流式调用失败 → 降级到阻塞模式
|
|
|
|
|
log.warn("流式调用失败,降级到阻塞模式: {}", e.getMessage()); |
|
|
|
|
return blockingIteration(prompt); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// 使用 Builder 构建 AssistantMessage(构造器是 protected 的)
|
|
|
|
|
List<AssistantMessage.ToolCall> toolCalls = new ArrayList<>(toolCallMap.values()); |
|
|
|
|
AssistantMessage assistant = AssistantMessage.builder() |
|
|
|
|
.content(textBuffer.toString()) |
|
|
|
|
.toolCalls(toolCalls) |
|
|
|
|
.build(); |
|
|
|
|
|
|
|
|
|
return new IterationResult(assistant, tokenUsage[0], tokenUsage[1]); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/** 执行工具调用列表并将结果加入消息历史 */ |
|
|
|
|
private void executeToolCalls(List<AssistantMessage.ToolCall> toolCalls, |
|
|
|
|
List<ToolCallback> callbacks) { |
|
|
|
|
List<ToolResponseMessage.ToolResponse> toolResponses = new ArrayList<>(); |
|
|
|
|
for (AssistantMessage.ToolCall toolCall : assistant.getToolCalls()) { |
|
|
|
|
|
|
|
|
|
for (AssistantMessage.ToolCall toolCall : toolCalls) { |
|
|
|
|
String toolName = toolCall.name(); |
|
|
|
|
String toolArgs = toolCall.arguments(); |
|
|
|
|
String callId = toolCall.id(); |
|
|
|
|
|
|
|
|
|
// 通知 UI 工具调用开始
|
|
|
|
|
if (onToolEvent != null) { |
|
|
|
|
onToolEvent.accept(new ToolEvent(toolName, ToolEvent.Phase.START, toolArgs, null)); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// 查找并执行工具
|
|
|
|
|
String result; |
|
|
|
|
ToolCallbackAdapter adapter = findCallbackByName(callbacks, toolName); |
|
|
|
|
if (adapter != null) { |
|
|
|
|
@ -150,7 +273,6 @@ public class AgentLoop { |
|
|
|
|
log.warn("未知工具: {}", toolName); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// 通知 UI 工具调用完成
|
|
|
|
|
if (onToolEvent != null) { |
|
|
|
|
onToolEvent.accept(new ToolEvent(toolName, ToolEvent.Phase.END, toolArgs, result)); |
|
|
|
|
} |
|
|
|
|
@ -158,18 +280,9 @@ public class AgentLoop { |
|
|
|
|
toolResponses.add(new ToolResponseMessage.ToolResponse(callId, toolName, result)); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// 将工具结果加入消息历史
|
|
|
|
|
messageHistory.add(ToolResponseMessage.builder().responses(toolResponses).build()); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if (iteration >= MAX_ITERATIONS) { |
|
|
|
|
log.warn("Agent 循环已达最大迭代次数 {},强制终止", MAX_ITERATIONS); |
|
|
|
|
lastAssistantText += "\n\n[WARNING: 达到最大循环次数限制]"; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
return lastAssistantText; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/** 从 ToolCallback 列表中查找匹配名称的适配器 */ |
|
|
|
|
private ToolCallbackAdapter findCallbackByName(List<ToolCallback> callbacks, String name) { |
|
|
|
|
for (ToolCallback cb : callbacks) { |
|
|
|
|
@ -195,12 +308,26 @@ public class AgentLoop { |
|
|
|
|
return systemPrompt; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/** 获取 ChatModel(用于上下文压缩等需要直接调用模型的场景) */ |
|
|
|
|
public ChatModel getChatModel() { |
|
|
|
|
return chatModel; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/** 重置历史(保留系统提示词) */ |
|
|
|
|
public void reset() { |
|
|
|
|
messageHistory.clear(); |
|
|
|
|
messageHistory.add(new SystemMessage(systemPrompt)); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/** 替换消息历史(用于上下文压缩后替换) */ |
|
|
|
|
public void replaceHistory(List<Message> newHistory) { |
|
|
|
|
messageHistory.clear(); |
|
|
|
|
messageHistory.addAll(newHistory); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/** 单次迭代结果 */ |
|
|
|
|
private record IterationResult(AssistantMessage assistant, long promptTokens, long completionTokens) {} |
|
|
|
|
|
|
|
|
|
/** 工具事件,用于 UI 展示 */ |
|
|
|
|
public record ToolEvent(String toolName, Phase phase, String arguments, String result) { |
|
|
|
|
public enum Phase { START, END } |
|
|
|
|
|