refactor: extract AgentToolExecutor from AgentLoop

- Extracted tool execution + permission logic into AgentToolExecutor (212 lines)
- AgentLoop reduced from 597 to 469 lines (-21%)
- Clear separation: AgentLoop owns the chat loop, AgentToolExecutor owns tool dispatch
- 87 tests still passing

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
pull/1/head
abel533 1 month ago
parent 80f43480c1
commit 12a443c9a9
  1. 150
      src/main/java/com/claudecode/core/AgentLoop.java
  2. 212
      src/main/java/com/claudecode/core/AgentToolExecutor.java

@ -5,10 +5,8 @@ import com.claudecode.permission.DenialTracker;
import com.claudecode.permission.PermissionRuleEngine; import com.claudecode.permission.PermissionRuleEngine;
import com.claudecode.permission.PermissionTypes.PermissionChoice; import com.claudecode.permission.PermissionTypes.PermissionChoice;
import com.claudecode.permission.PermissionTypes.PermissionDecision; import com.claudecode.permission.PermissionTypes.PermissionDecision;
import com.claudecode.tool.ToolCallbackAdapter;
import com.claudecode.tool.ToolContext; import com.claudecode.tool.ToolContext;
import com.claudecode.tool.ToolRegistry; import com.claudecode.tool.ToolRegistry;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -22,7 +20,6 @@ import org.springframework.ai.tool.ToolCallback;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import java.util.*; import java.util.*;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Function; import java.util.function.Function;
@ -69,6 +66,9 @@ public class AgentLoop {
/** 拒绝追踪器 */ /** 拒绝追踪器 */
private final DenialTracker denialTracker = new DenialTracker(); private final DenialTracker denialTracker = new DenialTracker();
/** 工具执行器(拆分出的权限+Hook+执行逻辑) */
private final AgentToolExecutor toolExecutor;
/** 中断标志 —— 用于取消当前运行中的 Agent 循环 */ /** 中断标志 —— 用于取消当前运行中的 Agent 循环 */
private volatile boolean cancelled = false; private volatile boolean cancelled = false;
@ -103,11 +103,13 @@ public class AgentLoop {
this.systemPrompt = systemPrompt; this.systemPrompt = systemPrompt;
this.tokenTracker = tokenTracker; this.tokenTracker = tokenTracker;
this.hookManager = new HookManager(); this.hookManager = new HookManager();
this.toolExecutor = new AgentToolExecutor(hookManager, toolContext, denialTracker);
this.messageHistory.add(new SystemMessage(systemPrompt)); this.messageHistory.add(new SystemMessage(systemPrompt));
} }
public void setOnToolEvent(Consumer<ToolEvent> onToolEvent) { public void setOnToolEvent(Consumer<ToolEvent> onToolEvent) {
this.onToolEvent = onToolEvent; this.onToolEvent = onToolEvent;
this.toolExecutor.setOnToolEvent(onToolEvent);
} }
public void setOnAssistantMessage(Consumer<String> onAssistantMessage) { public void setOnAssistantMessage(Consumer<String> onAssistantMessage) {
@ -120,10 +122,12 @@ public class AgentLoop {
public void setOnPermissionRequest(Function<PermissionRequest, PermissionChoice> onPermissionRequest) { public void setOnPermissionRequest(Function<PermissionRequest, PermissionChoice> onPermissionRequest) {
this.onPermissionRequest = onPermissionRequest; this.onPermissionRequest = onPermissionRequest;
this.toolExecutor.setOnPermissionRequest(onPermissionRequest);
} }
public void setPermissionEngine(PermissionRuleEngine engine) { public void setPermissionEngine(PermissionRuleEngine engine) {
this.permissionEngine = engine; this.permissionEngine = engine;
this.toolExecutor.setPermissionEngine(engine);
} }
public void setAutoCompactManager(AutoCompactManager manager) { public void setAutoCompactManager(AutoCompactManager manager) {
@ -241,8 +245,10 @@ public class AgentLoop {
break; break;
} }
// 执行工具调用 // 执行工具调用(委托给 AgentToolExecutor)
executeToolCalls(result.assistant.getToolCalls(), callbacks); var toolResponseMsg = toolExecutor.executeToolCalls(
result.assistant.getToolCalls(), callbacks, cancelled);
messageHistory.add(toolResponseMsg);
// 自动压缩检查(在工具调用后,下次 API 调用前) // 自动压缩检查(在工具调用后,下次 API 调用前)
if (autoCompactManager != null) { if (autoCompactManager != null) {
@ -349,140 +355,6 @@ public class AgentLoop {
return new IterationResult(assistant, tokenUsage[0], tokenUsage[1]); return new IterationResult(assistant, tokenUsage[0], tokenUsage[1]);
} }
/** 执行工具调用列表并将结果加入消息历史 */
@SuppressWarnings("unchecked")
private void executeToolCalls(List<AssistantMessage.ToolCall> toolCalls,
List<ToolCallback> callbacks) {
List<ToolResponseMessage.ToolResponse> toolResponses = new ArrayList<>();
for (AssistantMessage.ToolCall toolCall : toolCalls) {
// 检查取消标志
if (cancelled) {
toolResponses.add(new ToolResponseMessage.ToolResponse(
toolCall.id(), toolCall.name(), "Cancelled by user"));
continue;
}
String toolName = toolCall.name();
String toolArgs = toolCall.arguments();
String callId = toolCall.id();
// 解析参数用于 Hook 和权限检查
Map<String, Object> parsedArgs = Map.of();
try {
parsedArgs = MAPPER.readValue(toolArgs, Map.class);
} catch (Exception e) {
log.debug("Failed to parse tool arguments for {}: {}", toolName, e.getMessage());
}
// PreToolUse Hook
var preHookCtx = new HookManager.HookContext(toolName, parsedArgs);
if (hookManager.execute(HookManager.HookType.PRE_TOOL_USE, preHookCtx) == HookManager.HookResult.ABORT) {
log.info("[{}] PreToolUse Hook aborted execution", toolName);
toolResponses.add(new ToolResponseMessage.ToolResponse(callId, toolName, "Aborted by hook"));
continue;
}
if (onToolEvent != null) {
onToolEvent.accept(new ToolEvent(toolName, ToolEvent.Phase.START, toolArgs, null));
}
String result;
ToolCallbackAdapter adapter = findCallbackByName(callbacks, toolName);
if (adapter != null) {
// 权限检查:优先使用规则引擎,回退到传统回调
boolean permitted = true;
if (permissionEngine != null) {
PermissionDecision decision = permissionEngine.evaluate(
toolName, parsedArgs, adapter.getTool().isReadOnly());
if (decision.isAllowed()) {
permitted = true;
denialTracker.recordSuccess();
} else if (decision.isDenied()) {
permitted = false;
denialTracker.recordDenial();
log.info("[{}] Denied by rule: {}", toolName, decision.reason());
} else if (decision.needsAsk() && onPermissionRequest != null) {
// 拒绝追踪:连续拒绝过多时强制回退到手动提示
if (denialTracker.shouldFallbackToPrompting()) {
log.info("[{}] Denial threshold reached, forcing manual prompt", toolName);
}
String activity = adapter.getTool().activityDescription(parsedArgs);
PermissionRequest req = new PermissionRequest(toolName, toolArgs, activity);
req.setDecision(decision);
PermissionChoice choice = onPermissionRequest.apply(req);
permitted = (choice == PermissionChoice.ALLOW_ONCE || choice == PermissionChoice.ALWAYS_ALLOW);
if (permitted) {
denialTracker.recordSuccess();
} else {
denialTracker.recordDenial();
}
// 持久化用户选择
String command = parsedArgs != null ? (String) parsedArgs.get("command") : null;
permissionEngine.applyChoice(choice, toolName, command);
} else {
permitted = false;
denialTracker.recordDenial();
}
} else if (!adapter.getTool().isReadOnly() && onPermissionRequest != null) {
// 传统回调模式(向后兼容)
String activity = adapter.getTool().activityDescription(parsedArgs);
PermissionRequest req = new PermissionRequest(toolName, toolArgs, activity);
PermissionChoice choice = onPermissionRequest.apply(req);
permitted = (choice == PermissionChoice.ALLOW_ONCE || choice == PermissionChoice.ALWAYS_ALLOW);
}
if (permitted) {
// 设置进度回调,将工具输出行转发为 PROGRESS 事件
final String tn = toolName;
final String ta = toolArgs;
if (onToolEvent != null) {
toolContext.setProgressCallback(line ->
onToolEvent.accept(new ToolEvent(tn, ToolEvent.Phase.PROGRESS, ta, line)));
}
try {
result = adapter.call(toolArgs);
} finally {
toolContext.setProgressCallback(null);
}
} else {
result = "Permission denied: User rejected this operation";
log.info("[{}] User denied tool execution", toolName);
}
} else {
result = "Error: Unknown tool '" + toolName + "'";
log.warn("Unknown tool: {}", toolName);
}
// PostToolUse Hook
var postHookCtx = new HookManager.HookContext(toolName, parsedArgs);
postHookCtx.setResult(result);
hookManager.execute(HookManager.HookType.POST_TOOL_USE, postHookCtx);
// Hook 可能修改了结果
if (postHookCtx.getResult() != null) {
result = postHookCtx.getResult();
}
if (onToolEvent != null) {
onToolEvent.accept(new ToolEvent(toolName, ToolEvent.Phase.END, toolArgs, result));
}
toolResponses.add(new ToolResponseMessage.ToolResponse(callId, toolName, result));
}
messageHistory.add(ToolResponseMessage.builder().responses(toolResponses).build());
}
/** 从 ToolCallback 列表中查找匹配名称的适配器 */
private ToolCallbackAdapter findCallbackByName(List<ToolCallback> callbacks, String name) {
for (ToolCallback cb : callbacks) {
if (cb instanceof ToolCallbackAdapter adapter && adapter.getTool().name().equals(name)) {
return adapter;
}
}
return null;
}
/** 获取消息历史(用于上下文压缩等场景) */ /** 获取消息历史(用于上下文压缩等场景) */
public List<Message> getMessageHistory() { public List<Message> getMessageHistory() {
return Collections.unmodifiableList(messageHistory); return Collections.unmodifiableList(messageHistory);

@ -0,0 +1,212 @@
package com.claudecode.core;
import com.claudecode.permission.DenialTracker;
import com.claudecode.permission.PermissionRuleEngine;
import com.claudecode.permission.PermissionTypes.PermissionChoice;
import com.claudecode.permission.PermissionTypes.PermissionDecision;
import com.claudecode.tool.ToolCallbackAdapter;
import com.claudecode.tool.ToolContext;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.tool.ToolCallback;
import java.util.*;
import java.util.function.Consumer;
import java.util.function.Function;
/**
* 工具执行器 AgentLoop 拆分出的工具调用执行逻辑
* <p>
* 职责
* <ul>
* <li>解析工具参数</li>
* <li>PreToolUse / PostToolUse Hook 执行</li>
* <li>权限检查规则引擎 + 传统回调</li>
* <li>工具调用执行与结果收集</li>
* </ul>
*/
public class AgentToolExecutor {
private static final Logger log = LoggerFactory.getLogger(AgentToolExecutor.class);
private static final ObjectMapper MAPPER = new ObjectMapper();
private final HookManager hookManager;
private final ToolContext toolContext;
private final DenialTracker denialTracker;
private PermissionRuleEngine permissionEngine;
private Consumer<AgentLoop.ToolEvent> onToolEvent;
private Function<AgentLoop.PermissionRequest, PermissionChoice> onPermissionRequest;
public AgentToolExecutor(HookManager hookManager, ToolContext toolContext, DenialTracker denialTracker) {
this.hookManager = hookManager;
this.toolContext = toolContext;
this.denialTracker = denialTracker;
}
public void setPermissionEngine(PermissionRuleEngine engine) {
this.permissionEngine = engine;
}
public void setOnToolEvent(Consumer<AgentLoop.ToolEvent> onToolEvent) {
this.onToolEvent = onToolEvent;
}
public void setOnPermissionRequest(Function<AgentLoop.PermissionRequest, PermissionChoice> onPermissionRequest) {
this.onPermissionRequest = onPermissionRequest;
}
/**
* 执行工具调用列表并返回 ToolResponseMessage 加入消息历史
*/
@SuppressWarnings("unchecked")
public ToolResponseMessage executeToolCalls(List<AssistantMessage.ToolCall> toolCalls,
List<ToolCallback> callbacks,
boolean cancelled) {
List<ToolResponseMessage.ToolResponse> toolResponses = new ArrayList<>();
for (AssistantMessage.ToolCall toolCall : toolCalls) {
if (cancelled) {
toolResponses.add(new ToolResponseMessage.ToolResponse(
toolCall.id(), toolCall.name(), "Cancelled by user"));
continue;
}
String toolName = toolCall.name();
String toolArgs = toolCall.arguments();
String callId = toolCall.id();
Map<String, Object> parsedArgs = parseArguments(toolName, toolArgs);
// PreToolUse Hook
var preHookCtx = new HookManager.HookContext(toolName, parsedArgs);
if (hookManager.execute(HookManager.HookType.PRE_TOOL_USE, preHookCtx) == HookManager.HookResult.ABORT) {
log.info("[{}] PreToolUse Hook aborted execution", toolName);
toolResponses.add(new ToolResponseMessage.ToolResponse(callId, toolName, "Aborted by hook"));
continue;
}
if (onToolEvent != null) {
onToolEvent.accept(new AgentLoop.ToolEvent(toolName, AgentLoop.ToolEvent.Phase.START, toolArgs, null));
}
String result = executeOneTool(toolName, toolArgs, parsedArgs, callbacks);
// PostToolUse Hook
var postHookCtx = new HookManager.HookContext(toolName, parsedArgs);
postHookCtx.setResult(result);
hookManager.execute(HookManager.HookType.POST_TOOL_USE, postHookCtx);
if (postHookCtx.getResult() != null) {
result = postHookCtx.getResult();
}
if (onToolEvent != null) {
onToolEvent.accept(new AgentLoop.ToolEvent(toolName, AgentLoop.ToolEvent.Phase.END, toolArgs, result));
}
toolResponses.add(new ToolResponseMessage.ToolResponse(callId, toolName, result));
}
return ToolResponseMessage.builder().responses(toolResponses).build();
}
/**
* 执行单个工具调用含权限检查
*/
private String executeOneTool(String toolName, String toolArgs,
Map<String, Object> parsedArgs,
List<ToolCallback> callbacks) {
ToolCallbackAdapter adapter = findCallbackByName(callbacks, toolName);
if (adapter == null) {
log.warn("Unknown tool: {}", toolName);
return "Error: Unknown tool '" + toolName + "'";
}
boolean permitted = checkPermission(toolName, toolArgs, parsedArgs, adapter);
if (!permitted) {
log.info("[{}] User denied tool execution", toolName);
return "Permission denied: User rejected this operation";
}
// 设置进度回调
if (onToolEvent != null) {
toolContext.setProgressCallback(line ->
onToolEvent.accept(new AgentLoop.ToolEvent(
toolName, AgentLoop.ToolEvent.Phase.PROGRESS, toolArgs, line)));
}
try {
return adapter.call(toolArgs);
} finally {
toolContext.setProgressCallback(null);
}
}
/**
* 权限检查规则引擎优先回退到传统回调
*/
private boolean checkPermission(String toolName, String toolArgs,
Map<String, Object> parsedArgs,
ToolCallbackAdapter adapter) {
if (permissionEngine != null) {
PermissionDecision decision = permissionEngine.evaluate(
toolName, parsedArgs, adapter.getTool().isReadOnly());
if (decision.isAllowed()) {
denialTracker.recordSuccess();
return true;
} else if (decision.isDenied()) {
denialTracker.recordDenial();
log.info("[{}] Denied by rule: {}", toolName, decision.reason());
return false;
} else if (decision.needsAsk() && onPermissionRequest != null) {
if (denialTracker.shouldFallbackToPrompting()) {
log.info("[{}] Denial threshold reached, forcing manual prompt", toolName);
}
String activity = adapter.getTool().activityDescription(parsedArgs);
AgentLoop.PermissionRequest req = new AgentLoop.PermissionRequest(toolName, toolArgs, activity);
req.setDecision(decision);
PermissionChoice choice = onPermissionRequest.apply(req);
boolean allowed = (choice == PermissionChoice.ALLOW_ONCE || choice == PermissionChoice.ALWAYS_ALLOW);
if (allowed) denialTracker.recordSuccess(); else denialTracker.recordDenial();
String command = parsedArgs != null ? (String) parsedArgs.get("command") : null;
permissionEngine.applyChoice(choice, toolName, command);
return allowed;
} else {
denialTracker.recordDenial();
return false;
}
}
// 传统回调模式
if (!adapter.getTool().isReadOnly() && onPermissionRequest != null) {
String activity = adapter.getTool().activityDescription(parsedArgs);
AgentLoop.PermissionRequest req = new AgentLoop.PermissionRequest(toolName, toolArgs, activity);
PermissionChoice choice = onPermissionRequest.apply(req);
return (choice == PermissionChoice.ALLOW_ONCE || choice == PermissionChoice.ALWAYS_ALLOW);
}
return true;
}
@SuppressWarnings("unchecked")
private Map<String, Object> parseArguments(String toolName, String toolArgs) {
try {
return MAPPER.readValue(toolArgs, Map.class);
} catch (Exception e) {
log.debug("Failed to parse tool arguments for {}: {}", toolName, e.getMessage());
return Map.of();
}
}
private ToolCallbackAdapter findCallbackByName(List<ToolCallback> callbacks, String name) {
for (ToolCallback cb : callbacks) {
if (cb instanceof ToolCallbackAdapter adapter && adapter.getTool().name().equals(name)) {
return adapter;
}
}
return null;
}
}
Loading…
Cancel
Save