From da26f024982236efcb6016863154306220ef8c81 Mon Sep 17 00:00:00 2001 From: liuzh Date: Wed, 1 Apr 2026 22:32:49 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20P1=E5=AE=8C=E6=88=90=20-=20Hook?= =?UTF-8?q?=E7=B3=BB=E7=BB=9F+Vim=E6=A8=A1=E5=BC=8F+Banner=E5=A2=9E?= =?UTF-8?q?=E5=BC=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 HookManager: 支持 PreToolUse/PostToolUse/PrePrompt/PostResponse 4种钩子 - Hook 优先级排序, PreToolUse 可中止工具执行, PostToolUse 可修改结果 - AgentLoop 集成 Hook 系统到工具调用流程 - ReplSession 支持 Vim 编辑模式 (CLAUDE_CODE_VIM=1 启用) - Banner 显示命令数量和 Vim 模式标识 - 修复重复 isDumb 变量 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../java/com/claudecode/core/AgentLoop.java | 43 +++-- .../java/com/claudecode/core/HookManager.java | 160 ++++++++++++++++++ .../java/com/claudecode/repl/ReplSession.java | 19 ++- 3 files changed, 209 insertions(+), 13 deletions(-) create mode 100644 src/main/java/com/claudecode/core/HookManager.java diff --git a/src/main/java/com/claudecode/core/AgentLoop.java b/src/main/java/com/claudecode/core/AgentLoop.java index d0dceb3..58a833f 100644 --- a/src/main/java/com/claudecode/core/AgentLoop.java +++ b/src/main/java/com/claudecode/core/AgentLoop.java @@ -50,6 +50,7 @@ public class AgentLoop { private final ToolContext toolContext; private final String systemPrompt; private final TokenTracker tokenTracker; + private final HookManager hookManager; /** 消息历史 —— 自行管理,不依赖 Spring AI ChatMemory */ private final List messageHistory = new ArrayList<>(); @@ -81,6 +82,7 @@ public class AgentLoop { this.toolContext = toolContext; this.systemPrompt = systemPrompt; this.tokenTracker = tokenTracker; + this.hookManager = new HookManager(); this.messageHistory.add(new SystemMessage(systemPrompt)); } @@ -279,6 +281,20 @@ public class AgentLoop { String toolArgs = toolCall.arguments(); String callId = toolCall.id(); + // 解析参数用于 Hook 和权限检查 + Map parsedArgs = Map.of(); + try { + parsedArgs = MAPPER.readValue(toolArgs, Map.class); + } catch (Exception ignored) {} + + // PreToolUse Hook + var preHookCtx = new HookManager.HookContext(toolName, parsedArgs); + if (hookManager.execute(HookManager.HookType.PRE_TOOL_USE, preHookCtx) == HookManager.HookResult.ABORT) { + log.info("[{}] PreToolUse Hook 中止了执行", toolName); + toolResponses.add(new ToolResponseMessage.ToolResponse(callId, toolName, "Aborted by hook")); + continue; + } + if (onToolEvent != null) { onToolEvent.accept(new ToolEvent(toolName, ToolEvent.Phase.START, toolArgs, null)); } @@ -289,16 +305,9 @@ public class AgentLoop { // 权限确认:非只读工具需要用户确认 boolean permitted = true; if (!adapter.getTool().isReadOnly() && onPermissionRequest != null) { - try { - Map parsedArgs = MAPPER.readValue(toolArgs, Map.class); - String activity = adapter.getTool().activityDescription(parsedArgs); - PermissionRequest req = new PermissionRequest(toolName, toolArgs, activity); - permitted = onPermissionRequest.apply(req); - } catch (Exception e) { - // JSON 解析失败时仍然请求确认 - PermissionRequest req = new PermissionRequest(toolName, toolArgs, "执行 " + toolName); - permitted = onPermissionRequest.apply(req); - } + String activity = adapter.getTool().activityDescription(parsedArgs); + PermissionRequest req = new PermissionRequest(toolName, toolArgs, activity); + permitted = onPermissionRequest.apply(req); } if (permitted) { @@ -312,6 +321,15 @@ public class AgentLoop { log.warn("未知工具: {}", toolName); } + // PostToolUse Hook + var postHookCtx = new HookManager.HookContext(toolName, parsedArgs); + postHookCtx.setResult(result); + hookManager.execute(HookManager.HookType.POST_TOOL_USE, postHookCtx); + // Hook 可能修改了结果 + if (postHookCtx.getResult() != null) { + result = postHookCtx.getResult(); + } + if (onToolEvent != null) { onToolEvent.accept(new ToolEvent(toolName, ToolEvent.Phase.END, toolArgs, result)); } @@ -357,6 +375,11 @@ public class AgentLoop { return toolContext; } + /** 获取 Hook 管理器 */ + public HookManager getHookManager() { + return hookManager; + } + /** 重置历史(保留系统提示词) */ public void reset() { messageHistory.clear(); diff --git a/src/main/java/com/claudecode/core/HookManager.java b/src/main/java/com/claudecode/core/HookManager.java new file mode 100644 index 0000000..0211380 --- /dev/null +++ b/src/main/java/com/claudecode/core/HookManager.java @@ -0,0 +1,160 @@ +package com.claudecode.core; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CopyOnWriteArrayList; + +/** + * Hook 系统 —— 对应 claude-code/src/hooks/ 模块。 + *

+ * 提供工具调用前后的钩子机制,允许用户通过配置文件 + * 或编程方式注册拦截器,在工具执行的各个阶段介入。 + *

+ * 支持的 Hook 类型: + *

    + *
  • {@link HookType#PRE_TOOL_USE} —— 工具执行前,可修改参数或阻止执行
  • + *
  • {@link HookType#POST_TOOL_USE} —— 工具执行后,可修改结果或触发后续操作
  • + *
  • {@link HookType#PRE_PROMPT} —— 发送 prompt 前,可修改消息内容
  • + *
  • {@link HookType#POST_RESPONSE} —— 收到响应后,可进行后处理
  • + *
+ */ +public class HookManager { + + private static final Logger log = LoggerFactory.getLogger(HookManager.class); + + /** 所有已注册的 Hook 列表(线程安全) */ + private final List hooks = new CopyOnWriteArrayList<>(); + + /** + * 注册一个 Hook。 + * + * @param type Hook 类型 + * @param name Hook 名称(用于日志/调试) + * @param handler Hook 处理器 + */ + public void register(HookType type, String name, HookHandler handler) { + hooks.add(new HookRegistration(type, name, handler, 0)); + log.debug("注册 Hook: {} [{}]", name, type); + } + + /** + * 注册一个带优先级的 Hook(数字越小优先级越高)。 + */ + public void register(HookType type, String name, HookHandler handler, int priority) { + hooks.add(new HookRegistration(type, name, handler, priority)); + log.debug("注册 Hook: {} [{}] priority={}", name, type, priority); + } + + /** + * 执行指定类型的所有 Hook。 + *

+ * Hook 按优先级顺序执行。如果任一 Hook 返回 {@link HookResult#ABORT}, + * 后续 Hook 将不再执行,并返回 ABORT 结果。 + * + * @param type Hook 类型 + * @param context Hook 执行上下文 + * @return 聚合的 Hook 结果 + */ + public HookResult execute(HookType type, HookContext context) { + List matching = hooks.stream() + .filter(h -> h.type() == type) + .sorted((a, b) -> Integer.compare(a.priority(), b.priority())) + .toList(); + + if (matching.isEmpty()) { + return HookResult.CONTINUE; + } + + for (HookRegistration reg : matching) { + try { + log.debug("执行 Hook: {} [{}]", reg.name(), type); + HookResult result = reg.handler().handle(context); + + if (result == HookResult.ABORT) { + log.info("Hook [{}] 中止了操作", reg.name()); + return HookResult.ABORT; + } + } catch (Exception e) { + log.warn("Hook [{}] 执行异常: {}", reg.name(), e.getMessage()); + // Hook 异常不影响主流程 + } + } + + return HookResult.CONTINUE; + } + + /** 移除指定名称的 Hook */ + public void unregister(String name) { + hooks.removeIf(h -> h.name().equals(name)); + } + + /** 获取所有已注册的 Hook */ + public List getHooks() { + return Collections.unmodifiableList(hooks); + } + + /** 清除所有 Hook */ + public void clear() { + hooks.clear(); + } + + // ==================== 内部类型 ==================== + + /** Hook 类型 */ + public enum HookType { + /** 工具执行前 —— 可阻止执行或修改参数 */ + PRE_TOOL_USE, + /** 工具执行后 —— 可修改结果 */ + POST_TOOL_USE, + /** 发送 prompt 前 */ + PRE_PROMPT, + /** 收到响应后 */ + POST_RESPONSE + } + + /** Hook 执行结果 */ + public enum HookResult { + /** 继续执行 */ + CONTINUE, + /** 中止操作 */ + ABORT + } + + /** Hook 处理器接口 */ + @FunctionalInterface + public interface HookHandler { + HookResult handle(HookContext context); + } + + /** Hook 执行上下文 —— 携带当前操作的相关信息 */ + public static class HookContext { + private final String toolName; + private final Map arguments; + private String result; + private final Map metadata; + + public HookContext(String toolName, Map arguments) { + this.toolName = toolName; + this.arguments = arguments != null ? arguments : Map.of(); + this.metadata = new java.util.HashMap<>(); + } + + public String getToolName() { return toolName; } + public Map getArguments() { return arguments; } + public String getResult() { return result; } + public void setResult(String result) { this.result = result; } + + /** 自定义元数据 */ + public void put(String key, Object value) { metadata.put(key, value); } + @SuppressWarnings("unchecked") + public T get(String key) { return (T) metadata.get(key); } + } + + /** Hook 注册记录 */ + public record HookRegistration(HookType type, String name, HookHandler handler, int priority) {} +} diff --git a/src/main/java/com/claudecode/repl/ReplSession.java b/src/main/java/com/claudecode/repl/ReplSession.java index 7d36429..f7219c5 100644 --- a/src/main/java/com/claudecode/repl/ReplSession.java +++ b/src/main/java/com/claudecode/repl/ReplSession.java @@ -168,6 +168,13 @@ public class ReplSession { .option(LineReader.Option.AUTO_LIST, true) .build(); + // Vim 模式支持:通过环境变量 CLAUDE_CODE_VIM=1 或配置启用 + String vimMode = System.getenv("CLAUDE_CODE_VIM"); + if ("1".equals(vimMode) || "true".equalsIgnoreCase(vimMode)) { + reader.setVariable(LineReader.EDITING_MODE, "vi"); + log.info("已启用 Vim 编辑模式"); + } + // 主提示符 String prompt = new AttributedStringBuilder() .style(AttributedStyle.BOLD.foreground(AttributedStyle.CYAN)) @@ -181,8 +188,7 @@ public class ReplSession { this.activeReader = reader; // 非 dumb 终端启用底部状态行 - boolean isDumb2 = "dumb".equals(terminal.getType()); - if (!isDumb2) { + if (!isDumb) { statusLine.enable(providerInfo.model(), agentLoop.getTokenTracker()); } @@ -222,7 +228,7 @@ public class ReplSession { out.println(AnsiStyle.dim(" API URL: ") + AnsiStyle.cyan(providerInfo.baseUrl())); out.println(AnsiStyle.dim(" Work Dir: " + System.getProperty("user.dir"))); - out.println(AnsiStyle.dim(" Tools: " + toolRegistry.size() + " registered")); + out.println(AnsiStyle.dim(" Tools: " + toolRegistry.size() + " | Commands: " + commandRegistry.getCommands().size())); boolean isDumb = "dumb".equals(terminal.getType()); int w = terminal.getWidth(); @@ -231,6 +237,13 @@ public class ReplSession { if (w > 0 && h > 0) { termInfo += " (" + w + "×" + h + ")"; } + + // Vim 模式标识 + String vimMode = System.getenv("CLAUDE_CODE_VIM"); + if ("1".equals(vimMode) || "true".equalsIgnoreCase(vimMode)) { + termInfo += " [vim]"; + } + out.println(AnsiStyle.dim(" Terminal: " + termInfo)); if (isDumb) {