From 12a443c9a92fec85d52f88d196282d40d78d3f40 Mon Sep 17 00:00:00 2001 From: abel533 Date: Sun, 5 Apr 2026 13:14:38 +0800 Subject: [PATCH] refactor: extract AgentToolExecutor from AgentLoop - Extracted tool execution + permission logic into AgentToolExecutor (212 lines) - AgentLoop reduced from 597 to 469 lines (-21%) - Clear separation: AgentLoop owns the chat loop, AgentToolExecutor owns tool dispatch - 87 tests still passing Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../java/com/claudecode/core/AgentLoop.java | 150 +------------ .../claudecode/core/AgentToolExecutor.java | 212 ++++++++++++++++++ 2 files changed, 223 insertions(+), 139 deletions(-) create mode 100644 src/main/java/com/claudecode/core/AgentToolExecutor.java diff --git a/src/main/java/com/claudecode/core/AgentLoop.java b/src/main/java/com/claudecode/core/AgentLoop.java index 71de831..917e6ba 100644 --- a/src/main/java/com/claudecode/core/AgentLoop.java +++ b/src/main/java/com/claudecode/core/AgentLoop.java @@ -5,10 +5,8 @@ import com.claudecode.permission.DenialTracker; import com.claudecode.permission.PermissionRuleEngine; import com.claudecode.permission.PermissionTypes.PermissionChoice; import com.claudecode.permission.PermissionTypes.PermissionDecision; -import com.claudecode.tool.ToolCallbackAdapter; import com.claudecode.tool.ToolContext; import com.claudecode.tool.ToolRegistry; -import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -22,7 +20,6 @@ import org.springframework.ai.tool.ToolCallback; import reactor.core.publisher.Flux; import java.util.*; -import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Function; @@ -69,6 +66,9 @@ public class AgentLoop { /** 拒绝追踪器 */ private final DenialTracker denialTracker = new DenialTracker(); + /** 工具执行器(拆分出的权限+Hook+执行逻辑) */ + private final AgentToolExecutor toolExecutor; + /** 中断标志 —— 用于取消当前运行中的 Agent 循环 */ private volatile boolean cancelled = false; @@ -103,11 +103,13 @@ public class AgentLoop { this.systemPrompt = systemPrompt; this.tokenTracker = tokenTracker; this.hookManager = new HookManager(); + this.toolExecutor = new AgentToolExecutor(hookManager, toolContext, denialTracker); this.messageHistory.add(new SystemMessage(systemPrompt)); } public void setOnToolEvent(Consumer onToolEvent) { this.onToolEvent = onToolEvent; + this.toolExecutor.setOnToolEvent(onToolEvent); } public void setOnAssistantMessage(Consumer onAssistantMessage) { @@ -120,10 +122,12 @@ public class AgentLoop { public void setOnPermissionRequest(Function onPermissionRequest) { this.onPermissionRequest = onPermissionRequest; + this.toolExecutor.setOnPermissionRequest(onPermissionRequest); } public void setPermissionEngine(PermissionRuleEngine engine) { this.permissionEngine = engine; + this.toolExecutor.setPermissionEngine(engine); } public void setAutoCompactManager(AutoCompactManager manager) { @@ -241,8 +245,10 @@ public class AgentLoop { break; } - // 执行工具调用 - executeToolCalls(result.assistant.getToolCalls(), callbacks); + // 执行工具调用(委托给 AgentToolExecutor) + var toolResponseMsg = toolExecutor.executeToolCalls( + result.assistant.getToolCalls(), callbacks, cancelled); + messageHistory.add(toolResponseMsg); // 自动压缩检查(在工具调用后,下次 API 调用前) if (autoCompactManager != null) { @@ -349,140 +355,6 @@ public class AgentLoop { return new IterationResult(assistant, tokenUsage[0], tokenUsage[1]); } - /** 执行工具调用列表并将结果加入消息历史 */ - @SuppressWarnings("unchecked") - private void executeToolCalls(List toolCalls, - List callbacks) { - List toolResponses = new ArrayList<>(); - - for (AssistantMessage.ToolCall toolCall : toolCalls) { - // 检查取消标志 - if (cancelled) { - toolResponses.add(new ToolResponseMessage.ToolResponse( - toolCall.id(), toolCall.name(), "Cancelled by user")); - continue; - } - - String toolName = toolCall.name(); - String toolArgs = toolCall.arguments(); - String callId = toolCall.id(); - - // 解析参数用于 Hook 和权限检查 - Map parsedArgs = Map.of(); - try { - parsedArgs = MAPPER.readValue(toolArgs, Map.class); - } catch (Exception e) { - log.debug("Failed to parse tool arguments for {}: {}", toolName, e.getMessage()); - } - - // PreToolUse Hook - var preHookCtx = new HookManager.HookContext(toolName, parsedArgs); - if (hookManager.execute(HookManager.HookType.PRE_TOOL_USE, preHookCtx) == HookManager.HookResult.ABORT) { - log.info("[{}] PreToolUse Hook aborted execution", toolName); - toolResponses.add(new ToolResponseMessage.ToolResponse(callId, toolName, "Aborted by hook")); - continue; - } - - if (onToolEvent != null) { - onToolEvent.accept(new ToolEvent(toolName, ToolEvent.Phase.START, toolArgs, null)); - } - - String result; - ToolCallbackAdapter adapter = findCallbackByName(callbacks, toolName); - if (adapter != null) { - // 权限检查:优先使用规则引擎,回退到传统回调 - boolean permitted = true; - if (permissionEngine != null) { - PermissionDecision decision = permissionEngine.evaluate( - toolName, parsedArgs, adapter.getTool().isReadOnly()); - if (decision.isAllowed()) { - permitted = true; - denialTracker.recordSuccess(); - } else if (decision.isDenied()) { - permitted = false; - denialTracker.recordDenial(); - log.info("[{}] Denied by rule: {}", toolName, decision.reason()); - } else if (decision.needsAsk() && onPermissionRequest != null) { - // 拒绝追踪:连续拒绝过多时强制回退到手动提示 - if (denialTracker.shouldFallbackToPrompting()) { - log.info("[{}] Denial threshold reached, forcing manual prompt", toolName); - } - String activity = adapter.getTool().activityDescription(parsedArgs); - PermissionRequest req = new PermissionRequest(toolName, toolArgs, activity); - req.setDecision(decision); - PermissionChoice choice = onPermissionRequest.apply(req); - permitted = (choice == PermissionChoice.ALLOW_ONCE || choice == PermissionChoice.ALWAYS_ALLOW); - if (permitted) { - denialTracker.recordSuccess(); - } else { - denialTracker.recordDenial(); - } - // 持久化用户选择 - String command = parsedArgs != null ? (String) parsedArgs.get("command") : null; - permissionEngine.applyChoice(choice, toolName, command); - } else { - permitted = false; - denialTracker.recordDenial(); - } - } else if (!adapter.getTool().isReadOnly() && onPermissionRequest != null) { - // 传统回调模式(向后兼容) - String activity = adapter.getTool().activityDescription(parsedArgs); - PermissionRequest req = new PermissionRequest(toolName, toolArgs, activity); - PermissionChoice choice = onPermissionRequest.apply(req); - permitted = (choice == PermissionChoice.ALLOW_ONCE || choice == PermissionChoice.ALWAYS_ALLOW); - } - - if (permitted) { - // 设置进度回调,将工具输出行转发为 PROGRESS 事件 - final String tn = toolName; - final String ta = toolArgs; - if (onToolEvent != null) { - toolContext.setProgressCallback(line -> - onToolEvent.accept(new ToolEvent(tn, ToolEvent.Phase.PROGRESS, ta, line))); - } - try { - result = adapter.call(toolArgs); - } finally { - toolContext.setProgressCallback(null); - } - } else { - result = "Permission denied: User rejected this operation"; - log.info("[{}] User denied tool execution", toolName); - } - } else { - result = "Error: Unknown tool '" + toolName + "'"; - log.warn("Unknown tool: {}", toolName); - } - - // PostToolUse Hook - var postHookCtx = new HookManager.HookContext(toolName, parsedArgs); - postHookCtx.setResult(result); - hookManager.execute(HookManager.HookType.POST_TOOL_USE, postHookCtx); - // Hook 可能修改了结果 - if (postHookCtx.getResult() != null) { - result = postHookCtx.getResult(); - } - - if (onToolEvent != null) { - onToolEvent.accept(new ToolEvent(toolName, ToolEvent.Phase.END, toolArgs, result)); - } - - toolResponses.add(new ToolResponseMessage.ToolResponse(callId, toolName, result)); - } - - messageHistory.add(ToolResponseMessage.builder().responses(toolResponses).build()); - } - - /** 从 ToolCallback 列表中查找匹配名称的适配器 */ - private ToolCallbackAdapter findCallbackByName(List callbacks, String name) { - for (ToolCallback cb : callbacks) { - if (cb instanceof ToolCallbackAdapter adapter && adapter.getTool().name().equals(name)) { - return adapter; - } - } - return null; - } - /** 获取消息历史(用于上下文压缩等场景) */ public List getMessageHistory() { return Collections.unmodifiableList(messageHistory); diff --git a/src/main/java/com/claudecode/core/AgentToolExecutor.java b/src/main/java/com/claudecode/core/AgentToolExecutor.java new file mode 100644 index 0000000..8b5dfcd --- /dev/null +++ b/src/main/java/com/claudecode/core/AgentToolExecutor.java @@ -0,0 +1,212 @@ +package com.claudecode.core; + +import com.claudecode.permission.DenialTracker; +import com.claudecode.permission.PermissionRuleEngine; +import com.claudecode.permission.PermissionTypes.PermissionChoice; +import com.claudecode.permission.PermissionTypes.PermissionDecision; +import com.claudecode.tool.ToolCallbackAdapter; +import com.claudecode.tool.ToolContext; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.tool.ToolCallback; + +import java.util.*; +import java.util.function.Consumer; +import java.util.function.Function; + +/** + * 工具执行器 —— 从 AgentLoop 拆分出的工具调用执行逻辑。 + *

+ * 职责: + *

    + *
  • 解析工具参数
  • + *
  • PreToolUse / PostToolUse Hook 执行
  • + *
  • 权限检查(规则引擎 + 传统回调)
  • + *
  • 工具调用执行与结果收集
  • + *
+ */ +public class AgentToolExecutor { + + private static final Logger log = LoggerFactory.getLogger(AgentToolExecutor.class); + private static final ObjectMapper MAPPER = new ObjectMapper(); + + private final HookManager hookManager; + private final ToolContext toolContext; + private final DenialTracker denialTracker; + + private PermissionRuleEngine permissionEngine; + private Consumer onToolEvent; + private Function onPermissionRequest; + + public AgentToolExecutor(HookManager hookManager, ToolContext toolContext, DenialTracker denialTracker) { + this.hookManager = hookManager; + this.toolContext = toolContext; + this.denialTracker = denialTracker; + } + + public void setPermissionEngine(PermissionRuleEngine engine) { + this.permissionEngine = engine; + } + + public void setOnToolEvent(Consumer onToolEvent) { + this.onToolEvent = onToolEvent; + } + + public void setOnPermissionRequest(Function onPermissionRequest) { + this.onPermissionRequest = onPermissionRequest; + } + + /** + * 执行工具调用列表并返回 ToolResponseMessage 加入消息历史。 + */ + @SuppressWarnings("unchecked") + public ToolResponseMessage executeToolCalls(List toolCalls, + List callbacks, + boolean cancelled) { + List toolResponses = new ArrayList<>(); + + for (AssistantMessage.ToolCall toolCall : toolCalls) { + if (cancelled) { + toolResponses.add(new ToolResponseMessage.ToolResponse( + toolCall.id(), toolCall.name(), "Cancelled by user")); + continue; + } + + String toolName = toolCall.name(); + String toolArgs = toolCall.arguments(); + String callId = toolCall.id(); + + Map parsedArgs = parseArguments(toolName, toolArgs); + + // PreToolUse Hook + var preHookCtx = new HookManager.HookContext(toolName, parsedArgs); + if (hookManager.execute(HookManager.HookType.PRE_TOOL_USE, preHookCtx) == HookManager.HookResult.ABORT) { + log.info("[{}] PreToolUse Hook aborted execution", toolName); + toolResponses.add(new ToolResponseMessage.ToolResponse(callId, toolName, "Aborted by hook")); + continue; + } + + if (onToolEvent != null) { + onToolEvent.accept(new AgentLoop.ToolEvent(toolName, AgentLoop.ToolEvent.Phase.START, toolArgs, null)); + } + + String result = executeOneTool(toolName, toolArgs, parsedArgs, callbacks); + + // PostToolUse Hook + var postHookCtx = new HookManager.HookContext(toolName, parsedArgs); + postHookCtx.setResult(result); + hookManager.execute(HookManager.HookType.POST_TOOL_USE, postHookCtx); + if (postHookCtx.getResult() != null) { + result = postHookCtx.getResult(); + } + + if (onToolEvent != null) { + onToolEvent.accept(new AgentLoop.ToolEvent(toolName, AgentLoop.ToolEvent.Phase.END, toolArgs, result)); + } + + toolResponses.add(new ToolResponseMessage.ToolResponse(callId, toolName, result)); + } + + return ToolResponseMessage.builder().responses(toolResponses).build(); + } + + /** + * 执行单个工具调用(含权限检查)。 + */ + private String executeOneTool(String toolName, String toolArgs, + Map parsedArgs, + List callbacks) { + ToolCallbackAdapter adapter = findCallbackByName(callbacks, toolName); + if (adapter == null) { + log.warn("Unknown tool: {}", toolName); + return "Error: Unknown tool '" + toolName + "'"; + } + + boolean permitted = checkPermission(toolName, toolArgs, parsedArgs, adapter); + if (!permitted) { + log.info("[{}] User denied tool execution", toolName); + return "Permission denied: User rejected this operation"; + } + + // 设置进度回调 + if (onToolEvent != null) { + toolContext.setProgressCallback(line -> + onToolEvent.accept(new AgentLoop.ToolEvent( + toolName, AgentLoop.ToolEvent.Phase.PROGRESS, toolArgs, line))); + } + try { + return adapter.call(toolArgs); + } finally { + toolContext.setProgressCallback(null); + } + } + + /** + * 权限检查:规则引擎优先,回退到传统回调。 + */ + private boolean checkPermission(String toolName, String toolArgs, + Map parsedArgs, + ToolCallbackAdapter adapter) { + if (permissionEngine != null) { + PermissionDecision decision = permissionEngine.evaluate( + toolName, parsedArgs, adapter.getTool().isReadOnly()); + + if (decision.isAllowed()) { + denialTracker.recordSuccess(); + return true; + } else if (decision.isDenied()) { + denialTracker.recordDenial(); + log.info("[{}] Denied by rule: {}", toolName, decision.reason()); + return false; + } else if (decision.needsAsk() && onPermissionRequest != null) { + if (denialTracker.shouldFallbackToPrompting()) { + log.info("[{}] Denial threshold reached, forcing manual prompt", toolName); + } + String activity = adapter.getTool().activityDescription(parsedArgs); + AgentLoop.PermissionRequest req = new AgentLoop.PermissionRequest(toolName, toolArgs, activity); + req.setDecision(decision); + PermissionChoice choice = onPermissionRequest.apply(req); + boolean allowed = (choice == PermissionChoice.ALLOW_ONCE || choice == PermissionChoice.ALWAYS_ALLOW); + if (allowed) denialTracker.recordSuccess(); else denialTracker.recordDenial(); + String command = parsedArgs != null ? (String) parsedArgs.get("command") : null; + permissionEngine.applyChoice(choice, toolName, command); + return allowed; + } else { + denialTracker.recordDenial(); + return false; + } + } + + // 传统回调模式 + if (!adapter.getTool().isReadOnly() && onPermissionRequest != null) { + String activity = adapter.getTool().activityDescription(parsedArgs); + AgentLoop.PermissionRequest req = new AgentLoop.PermissionRequest(toolName, toolArgs, activity); + PermissionChoice choice = onPermissionRequest.apply(req); + return (choice == PermissionChoice.ALLOW_ONCE || choice == PermissionChoice.ALWAYS_ALLOW); + } + return true; + } + + @SuppressWarnings("unchecked") + private Map parseArguments(String toolName, String toolArgs) { + try { + return MAPPER.readValue(toolArgs, Map.class); + } catch (Exception e) { + log.debug("Failed to parse tool arguments for {}: {}", toolName, e.getMessage()); + return Map.of(); + } + } + + private ToolCallbackAdapter findCallbackByName(List callbacks, String name) { + for (ToolCallback cb : callbacks) { + if (cb instanceof ToolCallbackAdapter adapter && adapter.getTool().name().equals(name)) { + return adapter; + } + } + return null; + } +}