From 6e49c4fdc72d838214d73200d6e4fd70ca963d27 Mon Sep 17 00:00:00 2001 From: abel533 Date: Sun, 5 Apr 2026 09:44:28 +0800 Subject: [PATCH] feat: SessionMemory service with post-sampling hook integration - SessionMemoryService: threshold-based memory extraction (50K init, 20K update) - Async extraction via virtual threads with forked agent - Post-sampling hook in AgentLoop after each model response - System prompt injection of existing session memory - /memory session sub-command to view session memory - AppConfig wiring: bean registration, agent factory, AgentLoop hook Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../command/impl/MemoryCommand.java | 39 +++ .../java/com/claudecode/config/AppConfig.java | 27 +- .../java/com/claudecode/core/AgentLoop.java | 18 ++ .../claudecode/core/SessionMemoryService.java | 276 ++++++++++++++++++ 4 files changed, 356 insertions(+), 4 deletions(-) create mode 100644 src/main/java/com/claudecode/core/SessionMemoryService.java diff --git a/src/main/java/com/claudecode/command/impl/MemoryCommand.java b/src/main/java/com/claudecode/command/impl/MemoryCommand.java index 0ebcf2f..6408740 100644 --- a/src/main/java/com/claudecode/command/impl/MemoryCommand.java +++ b/src/main/java/com/claudecode/command/impl/MemoryCommand.java @@ -48,6 +48,8 @@ public class MemoryCommand implements SlashCommand { return handleEdit(); } else if (args.equals("user")) { return showUserMemory(); + } else if (args.equals("session")) { + return showSessionMemory(); } else { return showProjectMemory(); } @@ -154,4 +156,41 @@ public class MemoryCommand implements SlashCommand { return AnsiStyle.red(" ✗ Failed to open editor: " + e.getMessage()); } } + + /** 显示会话记忆 (SESSION_MEMORY.md) */ + private String showSessionMemory() { + Path projectDir = Path.of(System.getProperty("user.dir")); + String sanitized = projectDir.toAbsolutePath().toString() + .replace(":", "_") + .replace("\\", "_") + .replace("/", "_"); + Path memoryFile = Path.of(System.getProperty("user.home")) + .resolve(".claude").resolve("projects").resolve(sanitized) + .resolve("memory").resolve("SESSION_MEMORY.md"); + + StringBuilder sb = new StringBuilder(); + sb.append("\n"); + sb.append(AnsiStyle.bold(" 🧠 Session Memory\n")); + sb.append(" ").append("─".repeat(50)).append("\n"); + sb.append(" ").append(AnsiStyle.dim("Path: " + memoryFile)).append("\n\n"); + + if (Files.exists(memoryFile)) { + try { + String content = Files.readString(memoryFile, StandardCharsets.UTF_8); + if (content.isBlank()) { + sb.append(AnsiStyle.dim(" (Session memory is empty)\n")); + } else { + content.lines().forEach(line -> sb.append(" ").append(line).append("\n")); + } + } catch (IOException e) { + sb.append(AnsiStyle.red(" ✗ Read failed: " + e.getMessage() + "\n")); + } + } else { + sb.append(AnsiStyle.dim(" (No session memory yet)\n\n")); + sb.append(AnsiStyle.dim(" Session memory is automatically created after extended conversations.\n")); + sb.append(AnsiStyle.dim(" It captures key decisions, code changes, and discoveries.\n")); + } + + return sb.toString(); + } } diff --git a/src/main/java/com/claudecode/config/AppConfig.java b/src/main/java/com/claudecode/config/AppConfig.java index a8756e8..eb7d9e5 100644 --- a/src/main/java/com/claudecode/config/AppConfig.java +++ b/src/main/java/com/claudecode/config/AppConfig.java @@ -7,6 +7,7 @@ import com.claudecode.context.GitContext; import com.claudecode.context.SkillLoader; import com.claudecode.context.SystemPromptBuilder; import com.claudecode.core.AgentLoop; +import com.claudecode.core.SessionMemoryService; import com.claudecode.core.TaskManager; import com.claudecode.core.TokenTracker; import com.claudecode.core.compact.AutoCompactManager; @@ -231,6 +232,12 @@ public class AppConfig { return new AutoCompactManager(activeChatModel, tokenTracker); } + @Bean + public SessionMemoryService sessionMemoryService() { + Path projectDir = Path.of(System.getProperty("user.dir")); + return new SessionMemoryService(projectDir); + } + @Bean public TokenTracker tokenTracker(ProviderInfo info) { TokenTracker tracker = new TokenTracker(); @@ -239,7 +246,7 @@ public class AppConfig { } @Bean - public String systemPrompt(ToolContext toolContext) { + public String systemPrompt(ToolContext toolContext, SessionMemoryService sessionMemoryService) { Path projectDir = Path.of(System.getProperty("user.dir")); ClaudeMdLoader claudeLoader = new ClaudeMdLoader(projectDir); @@ -252,13 +259,20 @@ public class AppConfig { // Inject SkillLoader into ToolContext for SkillTool toolContext.set(SkillTool.SKILL_LOADER_KEY, skillLoader); + // Inject SessionMemoryService into ToolContext + toolContext.set("SESSION_MEMORY_SERVICE", sessionMemoryService); + GitContext gitContext = new GitContext(projectDir).collect(); String gitSummary = gitContext.buildSummary(); + // Load existing session memory + String sessionMemory = sessionMemoryService.getMemoryContent(); + return new SystemPromptBuilder() .claudeMd(claudeMd) .skills(skillsSummary) .git(gitSummary) + .sessionMemory(sessionMemory) .build(); } @@ -266,7 +280,7 @@ public class AppConfig { public AgentLoop agentLoop(ChatModel activeChatModel, ToolRegistry toolRegistry, ToolContext toolContext, String systemPrompt, TokenTracker tokenTracker, PluginManager pluginManager, PermissionRuleEngine permissionRuleEngine, - AutoCompactManager autoCompactManager) { + AutoCompactManager autoCompactManager, SessionMemoryService sessionMemoryService) { AgentLoop mainLoop = new AgentLoop(activeChatModel, toolRegistry, toolContext, systemPrompt, tokenTracker); // 注入权限引擎和自动压缩管理器 @@ -274,11 +288,16 @@ public class AppConfig { mainLoop.setAutoCompactManager(autoCompactManager); // 注册子 Agent 工厂 - toolContext.set(AgentTool.AGENT_FACTORY_KEY, + java.util.function.Function agentFactory = (java.util.function.Function) prompt -> { AgentLoop subLoop = new AgentLoop(activeChatModel, toolRegistry, toolContext, systemPrompt); return subLoop.run(prompt); - }); + }; + toolContext.set(AgentTool.AGENT_FACTORY_KEY, agentFactory); + + // Wire SessionMemoryService with agent factory and agent loop + sessionMemoryService.setAgentFactory(agentFactory); + mainLoop.setSessionMemoryService(sessionMemoryService); // 注册 PluginManager 到 ToolContext toolContext.set("PLUGIN_MANAGER", pluginManager); diff --git a/src/main/java/com/claudecode/core/AgentLoop.java b/src/main/java/com/claudecode/core/AgentLoop.java index 33a6730..7bf5b69 100644 --- a/src/main/java/com/claudecode/core/AgentLoop.java +++ b/src/main/java/com/claudecode/core/AgentLoop.java @@ -63,6 +63,9 @@ public class AgentLoop { /** 自动压缩管理器(可选) */ private AutoCompactManager autoCompactManager; + /** 会话记忆服务(可选) */ + private SessionMemoryService sessionMemoryService; + /** 拒绝追踪器 */ private final DenialTracker denialTracker = new DenialTracker(); @@ -127,6 +130,10 @@ public class AgentLoop { this.autoCompactManager = manager; } + public void setSessionMemoryService(SessionMemoryService service) { + this.sessionMemoryService = service; + } + public AutoCompactManager getAutoCompactManager() { return autoCompactManager; } @@ -244,6 +251,17 @@ public class AgentLoop { this::replaceHistory ); } + + // 会话记忆提取检查(异步,不阻塞主循环) + if (sessionMemoryService != null) { + int toolCallCount = result.assistant.hasToolCalls() + ? result.assistant.getToolCalls().size() : 0; + sessionMemoryService.onPostSampling( + result.promptTokens + result.completionTokens, + toolCallCount, + messageHistory + ); + } } if (iteration >= MAX_ITERATIONS) { diff --git a/src/main/java/com/claudecode/core/SessionMemoryService.java b/src/main/java/com/claudecode/core/SessionMemoryService.java new file mode 100644 index 0000000..7610b04 --- /dev/null +++ b/src/main/java/com/claudecode/core/SessionMemoryService.java @@ -0,0 +1,276 @@ +package com.claudecode.core; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.ai.chat.messages.Message; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; + +/** + * 会话记忆服务 —— 对应 claude-code/src/services/SessionMemory/。 + *

+ * 自动维护 SESSION_MEMORY.md 文件,记录当前会话的关键发现、决策和上下文。 + * 在后台运行,不中断主对话。 + *

+ * 触发条件: + *

    + *
  • 初始化:上下文 token 超过 50,000
  • + *
  • 更新:自上次提取后增长超过 20,000 token 且工具调用 >= 5
  • + *
+ */ +public class SessionMemoryService { + + private static final Logger log = LoggerFactory.getLogger(SessionMemoryService.class); + + /** 初始化阈值:50K tokens */ + private static final long MINIMUM_TOKENS_TO_INIT = 50_000; + /** 更新阈值:自上次提取后增长 20K tokens */ + private static final long MINIMUM_TOKENS_BETWEEN_UPDATE = 20_000; + /** 更新阈值:工具调用次数 */ + private static final int MINIMUM_TOOL_CALLS_BETWEEN_UPDATE = 5; + + private final Path memoryDir; + private final Path memoryFile; + private final AtomicLong lastExtractionTokens = new AtomicLong(0); + private final AtomicInteger toolCallsSinceLastExtraction = new AtomicInteger(0); + private volatile boolean initialized = false; + private volatile boolean extracting = false; + + /** Agent factory for forked extraction agent */ + private Function agentFactory; + + public SessionMemoryService(Path projectDir) { + String sanitized = projectDir.toAbsolutePath().toString() + .replace(":", "_") + .replace("\\", "_") + .replace("/", "_"); + this.memoryDir = Path.of(System.getProperty("user.home")) + .resolve(".claude") + .resolve("projects") + .resolve(sanitized) + .resolve("memory"); + this.memoryFile = memoryDir.resolve("SESSION_MEMORY.md"); + } + + public void setAgentFactory(Function agentFactory) { + this.agentFactory = agentFactory; + } + + /** Cumulative token count for threshold tracking */ + private final AtomicLong cumulativeTokens = new AtomicLong(0); + + /** + * Post-sampling hook: 在每次模型响应后调用。 + * 根据阈值决定是否触发记忆提取。 + * + * @param tokensThisTurn 本次迭代使用的 token 数 + * @param toolCallCount 本次响应中的工具调用数量 + * @param messageHistory 当前消息历史(用于提取上下文) + */ + public void onPostSampling(long tokensThisTurn, int toolCallCount, List messageHistory) { + if (extracting) return; // Already extracting + + long currentTokens = cumulativeTokens.addAndGet(tokensThisTurn); + toolCallsSinceLastExtraction.addAndGet(toolCallCount); + + if (shouldExtractMemory(currentTokens)) { + extractMemoryAsync(currentTokens); + } + } + + /** + * 记录工具调用(用于计数阈值)。 + */ + public void recordToolCall() { + toolCallsSinceLastExtraction.incrementAndGet(); + } + + /** + * 判断是否应该提取记忆。 + */ + boolean shouldExtractMemory(long currentTokens) { + if (!initialized) { + // First extraction: need enough context + return currentTokens >= MINIMUM_TOKENS_TO_INIT; + } + + // Subsequent extractions + long tokenGrowth = currentTokens - lastExtractionTokens.get(); + if (tokenGrowth < MINIMUM_TOKENS_BETWEEN_UPDATE) { + return false; + } + + // Token threshold met + tool call threshold + int toolCalls = toolCallsSinceLastExtraction.get(); + return toolCalls >= MINIMUM_TOOL_CALLS_BETWEEN_UPDATE; + } + + /** + * 异步提取记忆。 + */ + private void extractMemoryAsync(long currentTokens) { + if (agentFactory == null) { + log.debug("SessionMemory: no agent factory, skipping extraction"); + return; + } + + extracting = true; + Thread.ofVirtual().name("session-memory-extraction").start(() -> { + try { + extractMemory(currentTokens); + } catch (Exception e) { + log.debug("SessionMemory extraction failed", e); + } finally { + extracting = false; + } + }); + } + + /** + * 执行记忆提取。 + */ + void extractMemory(long currentTokens) { + log.info("SessionMemory: starting extraction (tokens: {}, initialized: {})", + currentTokens, initialized); + + try { + // Ensure directory exists + Files.createDirectories(memoryDir); + + String existingMemory = ""; + if (Files.exists(memoryFile)) { + existingMemory = Files.readString(memoryFile, StandardCharsets.UTF_8); + } + + // Build extraction prompt + String prompt = initialized + ? buildUpdatePrompt(existingMemory) + : buildInitPrompt(); + + // Run forked agent for extraction + String result = agentFactory.apply(prompt); + + // The agent should have written to the file via FileWrite/FileEdit tools + // But as a fallback, if it returned content, write it + if (result != null && !result.isBlank() && !Files.exists(memoryFile)) { + Files.writeString(memoryFile, result, StandardCharsets.UTF_8); + } + + // Update tracking + lastExtractionTokens.set(currentTokens); + toolCallsSinceLastExtraction.set(0); + initialized = true; + + log.info("SessionMemory: extraction complete"); + } catch (IOException e) { + log.warn("SessionMemory: failed to write memory file", e); + } + } + + /** + * 初始化提取提示词。 + */ + String buildInitPrompt() { + return """ + You are a session memory extractor. Your job is to create a SESSION_MEMORY.md file \ + that captures the key information from this conversation. + + Create the file at: %s + + The file should include these sections: + + # Session Memory + + ## Task Overview + - What is the user working on? + - What are the main goals? + + ## Key Decisions + - Important decisions made during the conversation + - Rationale for each decision + + ## Code Changes + - Files modified and why + - Key patterns or approaches used + + ## Discoveries + - Important findings about the codebase + - Architecture or design insights + + ## Next Steps + - What remains to be done + - Known issues or blockers + + Extract information from the conversation history. Be concise but comprehensive. \ + Focus on information that would be valuable if the conversation were interrupted \ + and needed to be resumed later. + """.formatted(memoryFile); + } + + /** + * 更新提取提示词。 + */ + String buildUpdatePrompt(String existingMemory) { + return """ + You are a session memory extractor. Update the existing SESSION_MEMORY.md file \ + with new information from the recent conversation. + + File location: %s + + Current content: + ``` + %s + ``` + + Update the file with new information. Rules: + - Keep existing information that is still relevant + - Add new decisions, changes, and discoveries + - Update the "Next Steps" section + - Remove outdated information + - Be concise — this file should stay under 200 lines + - Use FileEdit to update specific sections, or FileWrite to rewrite entirely + """.formatted(memoryFile, existingMemory); + } + + /** + * 读取当前记忆内容(用于系统提示注入)。 + */ + public String getMemoryContent() { + if (!Files.exists(memoryFile)) return null; + try { + String content = Files.readString(memoryFile, StandardCharsets.UTF_8); + return content.isBlank() ? null : content; + } catch (IOException e) { + return null; + } + } + + /** + * 获取记忆文件路径。 + */ + public Path getMemoryFile() { + return memoryFile; + } + + /** + * 是否已初始化(至少提取过一次)。 + */ + public boolean isInitialized() { + return initialized; + } + + /** + * 是否正在提取。 + */ + public boolean isExtracting() { + return extracting; + } +}