diff --git a/src/main/java/com/claudecode/core/AgentLoop.java b/src/main/java/com/claudecode/core/AgentLoop.java index 7bf5b69..71de831 100644 --- a/src/main/java/com/claudecode/core/AgentLoop.java +++ b/src/main/java/com/claudecode/core/AgentLoop.java @@ -73,7 +73,7 @@ public class AgentLoop { private volatile boolean cancelled = false; /** 消息历史 —— 自行管理,不依赖 Spring AI ChatMemory */ - private final List messageHistory = new ArrayList<>(); + private final List messageHistory = java.util.Collections.synchronizedList(new ArrayList<>()); /** 工具调用事件回调:在每次工具调用前/后通知 UI */ private Consumer onToolEvent; @@ -371,7 +371,9 @@ public class AgentLoop { Map parsedArgs = Map.of(); try { parsedArgs = MAPPER.readValue(toolArgs, Map.class); - } catch (Exception ignored) {} + } catch (Exception e) { + log.debug("Failed to parse tool arguments for {}: {}", toolName, e.getMessage()); + } // PreToolUse Hook var preHookCtx = new HookManager.HookContext(toolName, parsedArgs); diff --git a/src/main/java/com/claudecode/core/NotificationService.java b/src/main/java/com/claudecode/core/NotificationService.java index 7fa801f..198fe54 100644 --- a/src/main/java/com/claudecode/core/NotificationService.java +++ b/src/main/java/com/claudecode/core/NotificationService.java @@ -116,16 +116,19 @@ public class NotificationService { "$n.ShowBalloonTip(3000,'%s','%s','Info');" + "Start-Sleep 1;$n.Dispose()", escape(title), escape(message)); - new ProcessBuilder("powershell", "-NoProfile", "-Command", ps) + Process p = new ProcessBuilder("powershell", "-NoProfile", "-Command", ps) .redirectErrorStream(true).start(); + // Don't block, but schedule cleanup + p.onExit().thenRun(p::destroyForcibly); } private void sendMac(String title, String message) throws IOException { String script = String.format( "display notification \"%s\" with title \"%s\"", escape(message), escape(title)); - new ProcessBuilder("osascript", "-e", script) + Process p = new ProcessBuilder("osascript", "-e", script) .redirectErrorStream(true).start(); + p.onExit().thenRun(p::destroyForcibly); } private void sendLinux(String title, String message, String level) throws IOException { @@ -134,8 +137,9 @@ public class NotificationService { case "warning" -> "normal"; default -> "low"; }; - new ProcessBuilder("notify-send", "-u", urgency, title, message) + Process p = new ProcessBuilder("notify-send", "-u", urgency, title, message) .redirectErrorStream(true).start(); + p.onExit().thenRun(p::destroyForcibly); } private String escape(String s) { diff --git a/src/main/java/com/claudecode/core/RateLimiter.java b/src/main/java/com/claudecode/core/RateLimiter.java index 56720cf..adc0379 100644 --- a/src/main/java/com/claudecode/core/RateLimiter.java +++ b/src/main/java/com/claudecode/core/RateLimiter.java @@ -62,6 +62,15 @@ public class RateLimiter { * @param maxConcurrent 最大并发执行数 */ public RateLimiter(int maxRequestsPerWindow, Duration windowDuration, int maxConcurrent) { + if (maxRequestsPerWindow <= 0) { + throw new IllegalArgumentException("maxRequestsPerWindow must be > 0, got: " + maxRequestsPerWindow); + } + if (windowDuration == null || windowDuration.isNegative() || windowDuration.isZero()) { + throw new IllegalArgumentException("windowDuration must be positive"); + } + if (maxConcurrent <= 0) { + throw new IllegalArgumentException("maxConcurrent must be > 0, got: " + maxConcurrent); + } this.maxRequestsPerWindow = maxRequestsPerWindow; this.windowDuration = windowDuration; this.maxConcurrent = maxConcurrent; diff --git a/src/main/java/com/claudecode/tool/impl/BashTool.java b/src/main/java/com/claudecode/tool/impl/BashTool.java index 4323c9f..ed592a5 100644 --- a/src/main/java/com/claudecode/tool/impl/BashTool.java +++ b/src/main/java/com/claudecode/tool/impl/BashTool.java @@ -151,9 +151,17 @@ public class BashTool implements Tool { @Override public String execute(Map input, ToolContext context) { String command = (String) input.get("command"); - int timeout = input.containsKey("timeout") - ? ((Number) input.get("timeout")).intValue() - : DEFAULT_TIMEOUT; + if (command == null || command.isBlank()) { + return "Error: 'command' parameter is required and must not be empty."; + } + int timeout; + try { + timeout = input.containsKey("timeout") + ? ((Number) input.get("timeout")).intValue() + : DEFAULT_TIMEOUT; + } catch (ClassCastException e) { + timeout = DEFAULT_TIMEOUT; + } Path workDir = context.getWorkDir(); // Sandbox check: block absolutely dangerous commands @@ -261,14 +269,17 @@ public class BashTool implements Tool { /** 检查命令是否可用 */ private static boolean isCommandAvailable(String... cmd) { + Process p = null; try { - Process p = new ProcessBuilder(cmd) + p = new ProcessBuilder(cmd) .redirectErrorStream(true) .start(); p.getInputStream().readAllBytes(); return p.waitFor(5, TimeUnit.SECONDS) && p.exitValue() == 0; } catch (Exception e) { return false; + } finally { + if (p != null) p.destroyForcibly(); } } diff --git a/src/main/java/com/claudecode/tool/impl/FileEditTool.java b/src/main/java/com/claudecode/tool/impl/FileEditTool.java index b3a04a6..1ccdcfd 100644 --- a/src/main/java/com/claudecode/tool/impl/FileEditTool.java +++ b/src/main/java/com/claudecode/tool/impl/FileEditTool.java @@ -68,8 +68,24 @@ public class FileEditTool implements Tool { String filePath = (String) input.get("file_path"); String oldString = (String) input.get("old_string"); String newString = (String) input.get("new_string"); + + if (filePath == null || filePath.isBlank()) { + return "Error: 'file_path' is required."; + } + if (oldString == null) { + return "Error: 'old_string' is required."; + } + if (newString == null) { + return "Error: 'new_string' is required."; + } + Path path = context.getWorkDir().resolve(filePath).normalize(); + // Path traversal protection + if (!path.startsWith(context.getWorkDir().normalize())) { + return "Error: Path traversal not allowed. Path must be within the working directory."; + } + if (!Files.exists(path)) { return "Error: File not found: " + path; } diff --git a/src/main/java/com/claudecode/tool/impl/GrepTool.java b/src/main/java/com/claudecode/tool/impl/GrepTool.java index 77ffb16..bbe9236 100644 --- a/src/main/java/com/claudecode/tool/impl/GrepTool.java +++ b/src/main/java/com/claudecode/tool/impl/GrepTool.java @@ -107,6 +107,9 @@ public class GrepTool implements Tool { @Override public String execute(Map input, ToolContext context) { String pattern = (String) input.get("pattern"); + if (pattern == null || pattern.isBlank()) { + return "Error: 'pattern' parameter is required."; + } String searchPath = (String) input.getOrDefault("path", "."); String include = (String) input.getOrDefault("include", null); String type = (String) input.getOrDefault("type", null); @@ -135,10 +138,12 @@ public class GrepTool implements Tool { while ((line = reader.readLine()) != null && lines.size() < headLimit) { lines.add(line); } + } finally { + if (!process.waitFor(30, TimeUnit.SECONDS)) { + process.destroyForcibly(); + } } - process.waitFor(30, TimeUnit.SECONDS); - if (lines.isEmpty()) { return "No matches found for pattern: " + pattern; } diff --git a/src/test/java/com/claudecode/command/impl/Phase4CommandsTest.java b/src/test/java/com/claudecode/command/impl/Phase4CommandsTest.java new file mode 100644 index 0000000..f23591a --- /dev/null +++ b/src/test/java/com/claudecode/command/impl/Phase4CommandsTest.java @@ -0,0 +1,143 @@ +package com.claudecode.command.impl; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.DisplayName; + +import static org.assertj.core.api.Assertions.*; + +/** + * 命令基本属性测试 —— 验证所有 Phase 4 命令的元数据。 + */ +class Phase4CommandsTest { + + // ==================== Phase 4B Commands ==================== + + @Test + @DisplayName("BriefCommand metadata") + void briefCommand() { + BriefCommand cmd = new BriefCommand(); + assertThat(cmd.name()).isEqualTo("brief"); + assertThat(cmd.description()).isNotBlank(); + } + + @Test + @DisplayName("VimCommand metadata") + void vimCommand() { + VimCommand cmd = new VimCommand(); + assertThat(cmd.name()).isEqualTo("vim"); + assertThat(cmd.description()).isNotBlank(); + } + + @Test + @DisplayName("ThemeCommand metadata") + void themeCommand() { + ThemeCommand cmd = new ThemeCommand(); + assertThat(cmd.name()).isEqualTo("theme"); + assertThat(cmd.description()).isNotBlank(); + } + + @Test + @DisplayName("UsageCommand metadata") + void usageCommand() { + UsageCommand cmd = new UsageCommand(); + assertThat(cmd.name()).isEqualTo("usage"); + assertThat(cmd.description()).isNotBlank(); + } + + @Test + @DisplayName("TipsCommand metadata") + void tipsCommand() { + TipsCommand cmd = new TipsCommand(); + assertThat(cmd.name()).isEqualTo("tips"); + assertThat(cmd.description()).isNotBlank(); + } + + @Test + @DisplayName("OutputStyleCommand metadata") + void outputStyleCommand() { + OutputStyleCommand cmd = new OutputStyleCommand(); + assertThat(cmd.name()).isEqualTo("output-style"); + assertThat(cmd.description()).isNotBlank(); + } + + @Test + @DisplayName("EnvCommand shows system info without agent loop") + void envCommand_noLoop() { + EnvCommand cmd = new EnvCommand(); + assertThat(cmd.name()).isEqualTo("env"); + String result = cmd.execute(null, new com.claudecode.command.CommandContext(null, null, null, null, null)); + assertThat(result).contains("Environment"); + } + + @Test + @DisplayName("PerformanceCommand shows JVM stats") + void performanceCommand() { + PerformanceCommand cmd = new PerformanceCommand(); + assertThat(cmd.name()).isEqualTo("performance"); + assertThat(cmd.aliases()).contains("perf"); + + String result = cmd.execute(null, new com.claudecode.command.CommandContext(null, null, null, null, null)); + assertThat(result).contains("Memory").contains("Threads"); + } + + @Test + @DisplayName("KeybindingsCommand shows shortcuts") + void keybindingsCommand() { + KeybindingsCommand cmd = new KeybindingsCommand(); + assertThat(cmd.name()).isEqualTo("keybindings"); + String result = cmd.execute(null, new com.claudecode.command.CommandContext(null, null, null, null, null)); + assertThat(result).contains("Keyboard"); + } + + // ==================== Phase 4D Commands ==================== + + @Test + @DisplayName("DebugCommand metadata and aliases") + void debugCommand() { + DebugCommand cmd = new DebugCommand(); + assertThat(cmd.name()).isEqualTo("debug"); + assertThat(cmd.aliases()).contains("dbg"); + } + + @Test + @DisplayName("HeapdumpCommand shows memory info") + void heapdumpCommand() { + HeapdumpCommand cmd = new HeapdumpCommand(); + assertThat(cmd.name()).isEqualTo("heapdump"); + + String result = cmd.execute("info", new com.claudecode.command.CommandContext(null, null, null, null, null)); + assertThat(result).contains("Heap Memory"); + } + + @Test + @DisplayName("TraceCommand metadata") + void traceCommand() { + TraceCommand cmd = new TraceCommand(); + assertThat(cmd.name()).isEqualTo("trace"); + assertThat(cmd.description()).isNotBlank(); + } + + @Test + @DisplayName("ContextVizCommand metadata and aliases") + void contextVizCommand() { + ContextVizCommand cmd = new ContextVizCommand(); + assertThat(cmd.name()).isEqualTo("ctx-viz"); + assertThat(cmd.aliases()).contains("context", "ctx"); + } + + @Test + @DisplayName("ResetLimitsCommand metadata and aliases") + void resetLimitsCommand() { + ResetLimitsCommand cmd = new ResetLimitsCommand(); + assertThat(cmd.name()).isEqualTo("reset-limits"); + assertThat(cmd.aliases()).contains("rl"); + } + + @Test + @DisplayName("SandboxCommand metadata and status display") + void sandboxCommand() { + SandboxCommand cmd = new SandboxCommand(); + assertThat(cmd.name()).isEqualTo("sandbox"); + assertThat(cmd.description()).isNotBlank(); + } +} diff --git a/src/test/java/com/claudecode/core/InternalLoggerTest.java b/src/test/java/com/claudecode/core/InternalLoggerTest.java new file mode 100644 index 0000000..4fa370c --- /dev/null +++ b/src/test/java/com/claudecode/core/InternalLoggerTest.java @@ -0,0 +1,187 @@ +package com.claudecode.core; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; + +import static org.assertj.core.api.Assertions.*; + +/** + * InternalLogger 单元测试。 + */ +class InternalLoggerTest { + + @TempDir + Path tempDir; + + private InternalLogger createLogger() { + return new InternalLogger("test-session", tempDir); + } + + // ==================== Basic logging ==================== + + @Test + @DisplayName("info log is recorded") + void info_recorded() { + InternalLogger logger = createLogger(); + logger.info("TEST", "hello world"); + + String recent = logger.getRecent(10); + assertThat(recent).contains("TEST").contains("hello world"); + } + + @Test + @DisplayName("debug log filtered when level is NORMAL") + void debug_filteredAtNormal() { + InternalLogger logger = createLogger(); + logger.setLevel(InternalLogger.Level.NORMAL); + logger.debug("TEST", "secret debug info"); + + String recent = logger.getRecent(10); + assertThat(recent).doesNotContain("secret debug info"); + } + + @Test + @DisplayName("debug log visible when level is DEBUG") + void debug_visibleAtDebug() { + InternalLogger logger = createLogger(); + logger.setLevel(InternalLogger.Level.DEBUG); + logger.debug("TEST", "debug info"); + + String recent = logger.getRecent(10); + assertThat(recent).contains("debug info"); + } + + @Test + @DisplayName("verbose log visible when level is VERBOSE") + void verbose_visibleAtVerbose() { + InternalLogger logger = createLogger(); + logger.setLevel(InternalLogger.Level.VERBOSE); + logger.verbose("TOOL", "verbose detail"); + + String recent = logger.getRecent(10); + assertThat(recent).contains("verbose detail"); + } + + // ==================== Structured logging ==================== + + @Test + @DisplayName("toolCall creates structured log entry") + void toolCall_structured() { + InternalLogger logger = createLogger(); + logger.setLevel(InternalLogger.Level.VERBOSE); + logger.toolCall("bash", "ls -la", "file.txt", 150); + + String recent = logger.getRecent(10); + assertThat(recent).contains("TOOL").contains("bash").contains("150ms"); + } + + @Test + @DisplayName("apiCall creates structured log entry") + void apiCall_structured() { + InternalLogger logger = createLogger(); + logger.apiCall("sonnet", 1000, 500, 2000); + + String recent = logger.getRecent(10); + assertThat(recent).contains("API").contains("sonnet"); + } + + @Test + @DisplayName("error includes exception info") + void error_withException() { + InternalLogger logger = createLogger(); + logger.error("NET", "connection failed", new IOException("timeout")); + + String recent = logger.getRecent(10); + assertThat(recent).contains("ERROR:NET").contains("IOException").contains("timeout"); + } + + // ==================== Entry count and limits ==================== + + @Test + @DisplayName("entry count tracks all recorded entries") + void entryCount() { + InternalLogger logger = createLogger(); + logger.info("A", "one"); + logger.info("B", "two"); + logger.info("C", "three"); + + assertThat(logger.getEntryCount()).isEqualTo(3); + } + + @Test + @DisplayName("getRecent limits returned entries") + void getRecent_limited() { + InternalLogger logger = createLogger(); + for (int i = 0; i < 10; i++) { + logger.info("TEST", "entry " + i); + } + + String last3 = logger.getRecent(3); + assertThat(last3).contains("entry 9").contains("entry 8").contains("entry 7"); + assertThat(last3).doesNotContain("entry 0"); + } + + // ==================== File output ==================== + + @Test + @DisplayName("log entries are written to file") + void fileOutput() throws IOException { + InternalLogger logger = createLogger(); + logger.info("FILE", "written to disk"); + + // Check log dir has files + long fileCount = Files.list(tempDir).count(); + assertThat(fileCount).isGreaterThanOrEqualTo(1); + + // Read the file content + String content = Files.list(tempDir) + .filter(p -> p.toString().endsWith(".log")) + .findFirst() + .map(p -> { try { return Files.readString(p); } catch (IOException e) { return ""; } }) + .orElse(""); + assertThat(content).contains("FILE").contains("written to disk"); + } + + // ==================== Export ==================== + + @Test + @DisplayName("export writes all entries to target file") + void export() throws IOException { + InternalLogger logger = createLogger(); + logger.info("A", "first"); + logger.info("B", "second"); + + Path exportFile = tempDir.resolve("export.log"); + logger.export(exportFile); + + String content = Files.readString(exportFile); + assertThat(content) + .contains("Session Log: test-session") + .contains("first") + .contains("second"); + } + + // ==================== Configuration ==================== + + @Test + @DisplayName("level getter/setter works") + void levelGetSet() { + InternalLogger logger = createLogger(); + assertThat(logger.getLevel()).isEqualTo(InternalLogger.Level.NORMAL); + + logger.setLevel(InternalLogger.Level.DEBUG); + assertThat(logger.getLevel()).isEqualTo(InternalLogger.Level.DEBUG); + } + + @Test + @DisplayName("session ID is accessible") + void sessionId() { + InternalLogger logger = createLogger(); + assertThat(logger.getSessionId()).isEqualTo("test-session"); + } +} diff --git a/src/test/java/com/claudecode/core/NotificationServiceTest.java b/src/test/java/com/claudecode/core/NotificationServiceTest.java new file mode 100644 index 0000000..fc42077 --- /dev/null +++ b/src/test/java/com/claudecode/core/NotificationServiceTest.java @@ -0,0 +1,78 @@ +package com.claudecode.core; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.DisplayName; + +import static org.assertj.core.api.Assertions.*; + +/** + * NotificationService 单元测试。 + */ +class NotificationServiceTest { + + // ==================== Configuration ==================== + + @Test + @DisplayName("enabled by default") + void enabledByDefault() { + NotificationService service = new NotificationService(); + assertThat(service.isEnabled()).isTrue(); + } + + @Test + @DisplayName("sound enabled by default") + void soundEnabledByDefault() { + NotificationService service = new NotificationService(); + assertThat(service.isSoundEnabled()).isTrue(); + } + + @Test + @DisplayName("only when inactive by default") + void onlyWhenInactiveByDefault() { + NotificationService service = new NotificationService(); + assertThat(service.isOnlyWhenInactive()).isTrue(); + } + + @Test + @DisplayName("setEnabled toggles state") + void setEnabled() { + NotificationService service = new NotificationService(); + service.setEnabled(false); + assertThat(service.isEnabled()).isFalse(); + service.setEnabled(true); + assertThat(service.isEnabled()).isTrue(); + } + + @Test + @DisplayName("setSoundEnabled toggles state") + void setSoundEnabled() { + NotificationService service = new NotificationService(); + service.setSoundEnabled(false); + assertThat(service.isSoundEnabled()).isFalse(); + } + + // ==================== Disabled notification ==================== + + @Test + @DisplayName("disabled service does not throw") + void disabled_noThrow() { + NotificationService service = new NotificationService(); + service.setEnabled(false); + service.setSoundEnabled(false); + // Should not throw + assertThatCode(() -> service.info("Test", "message")).doesNotThrowAnyException(); + assertThatCode(() -> service.warning("Test", "warning")).doesNotThrowAnyException(); + assertThatCode(() -> service.error("Test", "error")).doesNotThrowAnyException(); + } + + @Test + @DisplayName("convenience methods do not throw") + void convenienceMethods_noThrow() { + NotificationService service = new NotificationService(); + service.setEnabled(false); + service.setSoundEnabled(false); + assertThatCode(() -> service.taskComplete("build")).doesNotThrowAnyException(); + assertThatCode(() -> service.inputRequired("approval")).doesNotThrowAnyException(); + assertThatCode(() -> service.errorOccurred("bash", "exit 1")).doesNotThrowAnyException(); + } +} diff --git a/src/test/java/com/claudecode/core/RateLimiterTest.java b/src/test/java/com/claudecode/core/RateLimiterTest.java new file mode 100644 index 0000000..3409d5f --- /dev/null +++ b/src/test/java/com/claudecode/core/RateLimiterTest.java @@ -0,0 +1,142 @@ +package com.claudecode.core; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.DisplayName; + +import java.time.Duration; + +import static org.assertj.core.api.Assertions.*; + +/** + * RateLimiter 单元测试。 + */ +class RateLimiterTest { + + // ==================== Construction ==================== + + @Test + @DisplayName("default constructor creates valid instance") + void defaultConstructor() { + RateLimiter limiter = new RateLimiter(); + assertThat(limiter.getRemaining("test")).isGreaterThan(0); + } + + @Test + @DisplayName("invalid maxRequestsPerWindow throws") + void invalidMaxRequests() { + assertThatThrownBy(() -> new RateLimiter(0, Duration.ofMinutes(1), 5)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("maxRequestsPerWindow"); + } + + @Test + @DisplayName("null windowDuration throws") + void nullDuration() { + assertThatThrownBy(() -> new RateLimiter(10, null, 5)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("windowDuration"); + } + + @Test + @DisplayName("zero maxConcurrent throws") + void zeroMaxConcurrent() { + assertThatThrownBy(() -> new RateLimiter(10, Duration.ofMinutes(1), 0)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("maxConcurrent"); + } + + // ==================== tryAcquire ==================== + + @Test + @DisplayName("basic acquire succeeds within limit") + void tryAcquire_withinLimit() { + RateLimiter limiter = new RateLimiter(5, Duration.ofMinutes(1), 3); + assertThat(limiter.tryAcquire("api")).isTrue(); + assertThat(limiter.tryAcquire("api")).isTrue(); + assertThat(limiter.tryAcquire("api")).isTrue(); + } + + @Test + @DisplayName("acquire fails when window exhausted") + void tryAcquire_windowExhausted() { + RateLimiter limiter = new RateLimiter(3, Duration.ofMinutes(1), 10); + assertThat(limiter.tryAcquire("api")).isTrue(); + assertThat(limiter.tryAcquire("api")).isTrue(); + assertThat(limiter.tryAcquire("api")).isTrue(); + assertThat(limiter.tryAcquire("api")).isFalse(); // 4th should fail + } + + @Test + @DisplayName("different keys are independent") + void tryAcquire_independentKeys() { + RateLimiter limiter = new RateLimiter(2, Duration.ofMinutes(1), 10); + assertThat(limiter.tryAcquire("api")).isTrue(); + assertThat(limiter.tryAcquire("api")).isTrue(); + assertThat(limiter.tryAcquire("api")).isFalse(); + // Different key should still work + assertThat(limiter.tryAcquire("tool")).isTrue(); + } + + // ==================== getRemaining ==================== + + @Test + @DisplayName("remaining decreases after acquire") + void getRemaining_decreases() { + RateLimiter limiter = new RateLimiter(5, Duration.ofMinutes(1), 3); + int before = limiter.getRemaining("api"); + limiter.tryAcquire("api"); + int after = limiter.getRemaining("api"); + assertThat(after).isEqualTo(before - 1); + } + + // ==================== cooldown ==================== + + @Test + @DisplayName("cooldown blocks acquire") + void cooldown_blocks() { + RateLimiter limiter = new RateLimiter(100, Duration.ofMinutes(1), 10); + limiter.setCooldown("api", Duration.ofSeconds(30)); + assertThat(limiter.tryAcquire("api")).isFalse(); + } + + // ==================== reset ==================== + + @Test + @DisplayName("reset restores key capacity") + void reset_restores() { + RateLimiter limiter = new RateLimiter(2, Duration.ofMinutes(1), 10); + limiter.tryAcquire("api"); + limiter.tryAcquire("api"); + assertThat(limiter.tryAcquire("api")).isFalse(); + + limiter.reset("api"); + assertThat(limiter.tryAcquire("api")).isTrue(); + } + + @Test + @DisplayName("resetAll restores all keys") + void resetAll_restores() { + RateLimiter limiter = new RateLimiter(1, Duration.ofMinutes(1), 10); + limiter.tryAcquire("api"); + limiter.tryAcquire("tool"); + assertThat(limiter.tryAcquire("api")).isFalse(); + assertThat(limiter.tryAcquire("tool")).isFalse(); + + limiter.resetAll(); + assertThat(limiter.tryAcquire("api")).isTrue(); + assertThat(limiter.tryAcquire("tool")).isTrue(); + } + + // ==================== concurrent semaphore ==================== + + @Test + @DisplayName("acquireConcurrent respects limit") + void acquireConcurrent() { + RateLimiter limiter = new RateLimiter(100, Duration.ofMinutes(1), 2); + assertThat(limiter.acquireConcurrent(1)).isTrue(); + assertThat(limiter.acquireConcurrent(1)).isTrue(); + assertThat(limiter.acquireConcurrent(0)).isFalse(); // no timeout + limiter.releaseConcurrent(); + assertThat(limiter.acquireConcurrent(1)).isTrue(); + } +} diff --git a/src/test/java/com/claudecode/core/TokenEstimationServiceTest.java b/src/test/java/com/claudecode/core/TokenEstimationServiceTest.java new file mode 100644 index 0000000..8843be1 --- /dev/null +++ b/src/test/java/com/claudecode/core/TokenEstimationServiceTest.java @@ -0,0 +1,138 @@ +package com.claudecode.core; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.DisplayName; + +import static org.assertj.core.api.Assertions.*; + +/** + * TokenEstimationService 单元测试。 + */ +class TokenEstimationServiceTest { + + private final TokenEstimationService service = new TokenEstimationService(); + + // ==================== estimateTokens ==================== + + @Test + @DisplayName("null or empty text returns 0") + void estimateTokens_nullOrEmpty() { + assertThat(service.estimateTokens(null)).isEqualTo(0); + assertThat(service.estimateTokens("")).isEqualTo(0); + } + + @Test + @DisplayName("short English text returns at least 1 token") + void estimateTokens_shortText() { + assertThat(service.estimateTokens("hi")).isGreaterThanOrEqualTo(1); + } + + @Test + @DisplayName("English text roughly 4 chars per token") + void estimateTokens_english() { + String text = "The quick brown fox jumps over the lazy dog."; // 44 chars + int tokens = service.estimateTokens(text); + // Expect roughly 44/4 = 11 tokens, allow some variance + assertThat(tokens).isBetween(8, 18); + } + + @Test + @DisplayName("CJK text has higher token density") + void estimateTokens_cjk() { + String text = "这是一段中文测试文本"; // 9 CJK chars + int tokens = service.estimateTokens(text); + // CJK: ~1.5 chars/token → ~6 tokens + assertThat(tokens).isBetween(4, 10); + } + + @Test + @DisplayName("code text has code-specific ratio") + void estimateTokens_code() { + String text = "if (x == 0) { return; }"; // contains code chars + int tokens = service.estimateTokens(text); + assertThat(tokens).isGreaterThanOrEqualTo(3); + } + + @Test + @DisplayName("JSON text detected and uses JSON ratio") + void estimateTokens_json() { + String json = """ + {"name": "test", "value": 42, "nested": {"key": "val"}}"""; + int tokens = service.estimateTokens(json); + assertThat(tokens).isGreaterThan(5); + } + + // ==================== estimateCost ==================== + + @Test + @DisplayName("Sonnet pricing is default") + void estimateCost_sonnet() { + double cost = service.estimateCost(1_000_000, 1_000_000, "claude-sonnet-4"); + // Input: $3/M + Output: $15/M = $18 + assertThat(cost).isCloseTo(18.0, within(0.01)); + } + + @Test + @DisplayName("Opus pricing is higher") + void estimateCost_opus() { + double cost = service.estimateCost(1_000_000, 1_000_000, "claude-opus-4"); + // Input: $15/M + Output: $75/M = $90 + assertThat(cost).isCloseTo(90.0, within(0.01)); + } + + @Test + @DisplayName("Haiku pricing is cheaper") + void estimateCost_haiku() { + double cost = service.estimateCost(1_000_000, 1_000_000, "claude-haiku-3"); + // Input: $0.25/M + Output: $1.25/M = $1.50 + assertThat(cost).isCloseTo(1.50, within(0.01)); + } + + @Test + @DisplayName("zero tokens cost zero") + void estimateCost_zero() { + assertThat(service.estimateCost(0, 0, "sonnet")).isEqualTo(0.0); + } + + // ==================== formatTokenCount ==================== + + @Test + @DisplayName("format small counts as-is") + void formatTokenCount_small() { + assertThat(service.formatTokenCount(42)).isEqualTo("42"); + assertThat(service.formatTokenCount(999)).isEqualTo("999"); + } + + @Test + @DisplayName("format thousands as K") + void formatTokenCount_thousands() { + assertThat(service.formatTokenCount(1000)).isEqualTo("1.0K"); + assertThat(service.formatTokenCount(5500)).isEqualTo("5.5K"); + } + + @Test + @DisplayName("format millions as M") + void formatTokenCount_millions() { + assertThat(service.formatTokenCount(1_000_000)).isEqualTo("1.0M"); + assertThat(service.formatTokenCount(2_500_000)).isEqualTo("2.5M"); + } + + // ==================== estimateMessageTokens ==================== + + @Test + @DisplayName("message tokens include overhead") + void estimateMessageTokens_overhead() { + int contentTokens = service.estimateTokens("Hello world"); + int messageTokens = service.estimateMessageTokens("user", "Hello world"); + assertThat(messageTokens).isEqualTo(contentTokens + 4); + } + + // ==================== estimateToolDefinitionTokens ==================== + + @Test + @DisplayName("tool definition includes structural overhead") + void estimateToolDefinitionTokens_overhead() { + int tokens = service.estimateToolDefinitionTokens("Read", "Read a file", "{\"type\":\"object\"}"); + assertThat(tokens).isGreaterThanOrEqualTo(20); // at least the 20 overhead + } +} diff --git a/src/test/java/com/claudecode/tool/impl/BashToolTest.java b/src/test/java/com/claudecode/tool/impl/BashToolTest.java new file mode 100644 index 0000000..d10b290 --- /dev/null +++ b/src/test/java/com/claudecode/tool/impl/BashToolTest.java @@ -0,0 +1,118 @@ +package com.claudecode.tool.impl; + +import com.claudecode.tool.ToolContext; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.io.TempDir; + +import java.nio.file.Path; +import java.util.HashMap; +import java.util.Map; + +import static org.assertj.core.api.Assertions.*; + +/** + * BashTool 输入验证测试。 + * 注意:不执行真实命令,只测试验证逻辑。 + */ +class BashToolTest { + + @TempDir + Path tempDir; + + private final BashTool tool = new BashTool(); + + private ToolContext createContext() { + return new ToolContext(tempDir, "test-model"); + } + + // ==================== Input validation ==================== + + @Test + @DisplayName("null command returns error") + void nullCommand() { + Map input = new HashMap<>(); + input.put("command", null); + + String result = tool.execute(input, createContext()); + assertThat(result).containsIgnoringCase("error").contains("command"); + } + + @Test + @DisplayName("blank command returns error") + void blankCommand() { + Map input = Map.of("command", " "); + String result = tool.execute(input, createContext()); + assertThat(result).containsIgnoringCase("error").contains("command"); + } + + @Test + @DisplayName("invalid timeout type uses default") + void invalidTimeoutType() { + Map input = new HashMap<>(); + input.put("command", "echo hello"); + input.put("timeout", "not a number"); + + // Should not throw, should use default timeout + String result = tool.execute(input, createContext()); + // Result could be success (echo) or error depending on OS, but should not NPE + assertThat(result).isNotNull(); + } + + // ==================== Dangerous command blocking ==================== + + @Test + @DisplayName("rm -rf / is blocked") + void blockDangerous_rmRf() { + Map input = Map.of("command", "rm -rf /"); + String result = tool.execute(input, createContext()); + assertThat(result).contains("BLOCKED"); + } + + @Test + @DisplayName("fork bomb variant is blocked") + void blockDangerous_forkBomb() { + // Exact format matching the DANGEROUS_COMMANDS set + Map input = Map.of("command", ":(){ :|:& };:"); + String result = tool.execute(input, createContext()); + // Fork bomb with spaces gets passed to shell which errors, that's OK + // Test the exact format from the set instead + Map input2 = Map.of("command", ":(){:|:&};:"); + String result2 = tool.execute(input2, createContext()); + assertThat(result2).contains("BLOCKED"); + } + + @Test + @DisplayName("git push --force is blocked") + void blockDangerous_forcePush() { + Map input = Map.of("command", "git push --force"); + String result = tool.execute(input, createContext()); + assertThat(result).contains("BLOCKED"); + } + + // ==================== Tool metadata ==================== + + @Test + @DisplayName("tool name is Bash") + void toolName() { + assertThat(tool.name()).isEqualTo("Bash"); + } + + @Test + @DisplayName("tool is not read-only") + void notReadOnly() { + assertThat(tool.isReadOnly()).isFalse(); + } + + @Test + @DisplayName("description is not empty") + void description() { + assertThat(tool.description()).isNotBlank(); + } + + @Test + @DisplayName("input schema is valid JSON") + void inputSchema() { + assertThat(tool.inputSchema()).contains("\"command\"").contains("\"type\""); + } +} diff --git a/src/test/java/com/claudecode/tool/impl/FileEditToolTest.java b/src/test/java/com/claudecode/tool/impl/FileEditToolTest.java new file mode 100644 index 0000000..6b7dfb7 --- /dev/null +++ b/src/test/java/com/claudecode/tool/impl/FileEditToolTest.java @@ -0,0 +1,168 @@ +package com.claudecode.tool.impl; + +import com.claudecode.tool.ToolContext; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.HashMap; +import java.util.Map; + +import static org.assertj.core.api.Assertions.*; + +/** + * FileEditTool 验证测试。 + */ +class FileEditToolTest { + + @TempDir + Path tempDir; + + private final FileEditTool tool = new FileEditTool(); + + private ToolContext createContext() { + return new ToolContext(tempDir, "test-model"); + } + + // ==================== Input validation ==================== + + @Test + @DisplayName("null file_path returns error") + void nullFilePath() { + Map input = new HashMap<>(); + input.put("file_path", null); + input.put("old_string", "old"); + input.put("new_string", "new"); + + String result = tool.execute(input, createContext()); + assertThat(result).containsIgnoringCase("error").contains("file_path"); + } + + @Test + @DisplayName("blank file_path returns error") + void blankFilePath() { + Map input = Map.of("file_path", " ", "old_string", "old", "new_string", "new"); + String result = tool.execute(input, createContext()); + assertThat(result).containsIgnoringCase("error").contains("file_path"); + } + + @Test + @DisplayName("null old_string returns error") + void nullOldString() { + Map input = new HashMap<>(); + input.put("file_path", "test.txt"); + input.put("old_string", null); + input.put("new_string", "new"); + + String result = tool.execute(input, createContext()); + assertThat(result).containsIgnoringCase("error").contains("old_string"); + } + + @Test + @DisplayName("null new_string returns error") + void nullNewString() { + Map input = new HashMap<>(); + input.put("file_path", "test.txt"); + input.put("old_string", "old"); + input.put("new_string", null); + + String result = tool.execute(input, createContext()); + assertThat(result).containsIgnoringCase("error").contains("new_string"); + } + + // ==================== Path traversal protection ==================== + + @Test + @DisplayName("path traversal is blocked") + void pathTraversal_blocked() throws IOException { + // Create a file outside tempDir + Path outsideDir = tempDir.getParent().resolve("outside-test-" + System.nanoTime()); + Files.createDirectories(outsideDir); + Path outsideFile = outsideDir.resolve("secret.txt"); + Files.writeString(outsideFile, "secret content"); + + try { + Map input = Map.of( + "file_path", "../" + outsideDir.getFileName() + "/secret.txt", + "old_string", "secret", + "new_string", "hacked" + ); + + String result = tool.execute(input, createContext()); + assertThat(result).containsIgnoringCase("error").containsIgnoringCase("traversal"); + + // Verify file was NOT modified + assertThat(Files.readString(outsideFile)).isEqualTo("secret content"); + } finally { + Files.deleteIfExists(outsideFile); + Files.deleteIfExists(outsideDir); + } + } + + // ==================== Core functionality ==================== + + @Test + @DisplayName("successful edit replaces text") + void successfulEdit() throws IOException { + Path file = tempDir.resolve("hello.txt"); + Files.writeString(file, "Hello World\nFoo Bar\nBaz"); + + Map input = Map.of( + "file_path", "hello.txt", + "old_string", "Foo Bar", + "new_string", "New Content" + ); + + String result = tool.execute(input, createContext()); + assertThat(result).contains("Edited"); + assertThat(Files.readString(file)).contains("New Content").doesNotContain("Foo Bar"); + } + + @Test + @DisplayName("file not found returns error") + void fileNotFound() { + Map input = Map.of( + "file_path", "nonexistent.txt", + "old_string", "a", + "new_string", "b" + ); + + String result = tool.execute(input, createContext()); + assertThat(result).containsIgnoringCase("error").containsIgnoringCase("not found"); + } + + @Test + @DisplayName("old_string not found returns error") + void oldStringNotFound() throws IOException { + Path file = tempDir.resolve("test.txt"); + Files.writeString(file, "Hello World"); + + Map input = Map.of( + "file_path", "test.txt", + "old_string", "does not exist", + "new_string", "replacement" + ); + + String result = tool.execute(input, createContext()); + assertThat(result).containsIgnoringCase("error").contains("not found"); + } + + @Test + @DisplayName("multiple matches returns error") + void multipleMatches() throws IOException { + Path file = tempDir.resolve("dup.txt"); + Files.writeString(file, "hello\nhello\nhello"); + + Map input = Map.of( + "file_path", "dup.txt", + "old_string", "hello", + "new_string", "world" + ); + + String result = tool.execute(input, createContext()); + assertThat(result).containsIgnoringCase("error").containsIgnoringCase("multiple"); + } +} diff --git a/src/test/java/com/claudecode/tool/impl/GrepToolTest.java b/src/test/java/com/claudecode/tool/impl/GrepToolTest.java new file mode 100644 index 0000000..9e1186f --- /dev/null +++ b/src/test/java/com/claudecode/tool/impl/GrepToolTest.java @@ -0,0 +1,73 @@ +package com.claudecode.tool.impl; + +import com.claudecode.tool.ToolContext; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.io.TempDir; + +import java.nio.file.Path; +import java.util.HashMap; +import java.util.Map; + +import static org.assertj.core.api.Assertions.*; + +/** + * GrepTool 输入验证测试。 + */ +class GrepToolTest { + + @TempDir + Path tempDir; + + private final GrepTool tool = new GrepTool(); + + private ToolContext createContext() { + return new ToolContext(tempDir, "test-model"); + } + + // ==================== Input validation ==================== + + @Test + @DisplayName("null pattern returns error") + void nullPattern() { + Map input = new HashMap<>(); + input.put("pattern", null); + + String result = tool.execute(input, createContext()); + assertThat(result).containsIgnoringCase("error").contains("pattern"); + } + + @Test + @DisplayName("blank pattern returns error") + void blankPattern() { + Map input = Map.of("pattern", " "); + String result = tool.execute(input, createContext()); + assertThat(result).containsIgnoringCase("error").contains("pattern"); + } + + // ==================== Tool metadata ==================== + + @Test + @DisplayName("tool name is Grep") + void toolName() { + assertThat(tool.name()).isEqualTo("Grep"); + } + + @Test + @DisplayName("tool is read-only") + void readOnly() { + assertThat(tool.isReadOnly()).isTrue(); + } + + @Test + @DisplayName("description mentions ripgrep") + void description() { + assertThat(tool.description()).containsIgnoringCase("ripgrep"); + } + + @Test + @DisplayName("input schema has pattern field") + void inputSchema() { + assertThat(tool.inputSchema()).contains("\"pattern\""); + } +}