feat: SessionMemory service with post-sampling hook integration

- SessionMemoryService: threshold-based memory extraction (50K init, 20K update)
- Async extraction via virtual threads with forked agent
- Post-sampling hook in AgentLoop after each model response
- System prompt injection of existing session memory
- /memory session sub-command to view session memory
- AppConfig wiring: bean registration, agent factory, AgentLoop hook

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
pull/1/head
abel533 1 month ago
parent 6088678c4f
commit 6e49c4fdc7
  1. 39
      src/main/java/com/claudecode/command/impl/MemoryCommand.java
  2. 27
      src/main/java/com/claudecode/config/AppConfig.java
  3. 18
      src/main/java/com/claudecode/core/AgentLoop.java
  4. 276
      src/main/java/com/claudecode/core/SessionMemoryService.java

@ -48,6 +48,8 @@ public class MemoryCommand implements SlashCommand {
return handleEdit(); return handleEdit();
} else if (args.equals("user")) { } else if (args.equals("user")) {
return showUserMemory(); return showUserMemory();
} else if (args.equals("session")) {
return showSessionMemory();
} else { } else {
return showProjectMemory(); return showProjectMemory();
} }
@ -154,4 +156,41 @@ public class MemoryCommand implements SlashCommand {
return AnsiStyle.red(" ✗ Failed to open editor: " + e.getMessage()); return AnsiStyle.red(" ✗ Failed to open editor: " + e.getMessage());
} }
} }
/** 显示会话记忆 (SESSION_MEMORY.md) */
private String showSessionMemory() {
Path projectDir = Path.of(System.getProperty("user.dir"));
String sanitized = projectDir.toAbsolutePath().toString()
.replace(":", "_")
.replace("\\", "_")
.replace("/", "_");
Path memoryFile = Path.of(System.getProperty("user.home"))
.resolve(".claude").resolve("projects").resolve(sanitized)
.resolve("memory").resolve("SESSION_MEMORY.md");
StringBuilder sb = new StringBuilder();
sb.append("\n");
sb.append(AnsiStyle.bold(" 🧠 Session Memory\n"));
sb.append(" ").append("─".repeat(50)).append("\n");
sb.append(" ").append(AnsiStyle.dim("Path: " + memoryFile)).append("\n\n");
if (Files.exists(memoryFile)) {
try {
String content = Files.readString(memoryFile, StandardCharsets.UTF_8);
if (content.isBlank()) {
sb.append(AnsiStyle.dim(" (Session memory is empty)\n"));
} else {
content.lines().forEach(line -> sb.append(" ").append(line).append("\n"));
}
} catch (IOException e) {
sb.append(AnsiStyle.red(" ✗ Read failed: " + e.getMessage() + "\n"));
}
} else {
sb.append(AnsiStyle.dim(" (No session memory yet)\n\n"));
sb.append(AnsiStyle.dim(" Session memory is automatically created after extended conversations.\n"));
sb.append(AnsiStyle.dim(" It captures key decisions, code changes, and discoveries.\n"));
}
return sb.toString();
}
} }

@ -7,6 +7,7 @@ import com.claudecode.context.GitContext;
import com.claudecode.context.SkillLoader; import com.claudecode.context.SkillLoader;
import com.claudecode.context.SystemPromptBuilder; import com.claudecode.context.SystemPromptBuilder;
import com.claudecode.core.AgentLoop; import com.claudecode.core.AgentLoop;
import com.claudecode.core.SessionMemoryService;
import com.claudecode.core.TaskManager; import com.claudecode.core.TaskManager;
import com.claudecode.core.TokenTracker; import com.claudecode.core.TokenTracker;
import com.claudecode.core.compact.AutoCompactManager; import com.claudecode.core.compact.AutoCompactManager;
@ -231,6 +232,12 @@ public class AppConfig {
return new AutoCompactManager(activeChatModel, tokenTracker); return new AutoCompactManager(activeChatModel, tokenTracker);
} }
@Bean
public SessionMemoryService sessionMemoryService() {
Path projectDir = Path.of(System.getProperty("user.dir"));
return new SessionMemoryService(projectDir);
}
@Bean @Bean
public TokenTracker tokenTracker(ProviderInfo info) { public TokenTracker tokenTracker(ProviderInfo info) {
TokenTracker tracker = new TokenTracker(); TokenTracker tracker = new TokenTracker();
@ -239,7 +246,7 @@ public class AppConfig {
} }
@Bean @Bean
public String systemPrompt(ToolContext toolContext) { public String systemPrompt(ToolContext toolContext, SessionMemoryService sessionMemoryService) {
Path projectDir = Path.of(System.getProperty("user.dir")); Path projectDir = Path.of(System.getProperty("user.dir"));
ClaudeMdLoader claudeLoader = new ClaudeMdLoader(projectDir); ClaudeMdLoader claudeLoader = new ClaudeMdLoader(projectDir);
@ -252,13 +259,20 @@ public class AppConfig {
// Inject SkillLoader into ToolContext for SkillTool // Inject SkillLoader into ToolContext for SkillTool
toolContext.set(SkillTool.SKILL_LOADER_KEY, skillLoader); toolContext.set(SkillTool.SKILL_LOADER_KEY, skillLoader);
// Inject SessionMemoryService into ToolContext
toolContext.set("SESSION_MEMORY_SERVICE", sessionMemoryService);
GitContext gitContext = new GitContext(projectDir).collect(); GitContext gitContext = new GitContext(projectDir).collect();
String gitSummary = gitContext.buildSummary(); String gitSummary = gitContext.buildSummary();
// Load existing session memory
String sessionMemory = sessionMemoryService.getMemoryContent();
return new SystemPromptBuilder() return new SystemPromptBuilder()
.claudeMd(claudeMd) .claudeMd(claudeMd)
.skills(skillsSummary) .skills(skillsSummary)
.git(gitSummary) .git(gitSummary)
.sessionMemory(sessionMemory)
.build(); .build();
} }
@ -266,7 +280,7 @@ public class AppConfig {
public AgentLoop agentLoop(ChatModel activeChatModel, ToolRegistry toolRegistry, public AgentLoop agentLoop(ChatModel activeChatModel, ToolRegistry toolRegistry,
ToolContext toolContext, String systemPrompt, TokenTracker tokenTracker, ToolContext toolContext, String systemPrompt, TokenTracker tokenTracker,
PluginManager pluginManager, PermissionRuleEngine permissionRuleEngine, PluginManager pluginManager, PermissionRuleEngine permissionRuleEngine,
AutoCompactManager autoCompactManager) { AutoCompactManager autoCompactManager, SessionMemoryService sessionMemoryService) {
AgentLoop mainLoop = new AgentLoop(activeChatModel, toolRegistry, toolContext, systemPrompt, tokenTracker); AgentLoop mainLoop = new AgentLoop(activeChatModel, toolRegistry, toolContext, systemPrompt, tokenTracker);
// 注入权限引擎和自动压缩管理器 // 注入权限引擎和自动压缩管理器
@ -274,11 +288,16 @@ public class AppConfig {
mainLoop.setAutoCompactManager(autoCompactManager); mainLoop.setAutoCompactManager(autoCompactManager);
// 注册子 Agent 工厂 // 注册子 Agent 工厂
toolContext.set(AgentTool.AGENT_FACTORY_KEY, java.util.function.Function<String, String> agentFactory =
(java.util.function.Function<String, String>) prompt -> { (java.util.function.Function<String, String>) prompt -> {
AgentLoop subLoop = new AgentLoop(activeChatModel, toolRegistry, toolContext, systemPrompt); AgentLoop subLoop = new AgentLoop(activeChatModel, toolRegistry, toolContext, systemPrompt);
return subLoop.run(prompt); return subLoop.run(prompt);
}); };
toolContext.set(AgentTool.AGENT_FACTORY_KEY, agentFactory);
// Wire SessionMemoryService with agent factory and agent loop
sessionMemoryService.setAgentFactory(agentFactory);
mainLoop.setSessionMemoryService(sessionMemoryService);
// 注册 PluginManager 到 ToolContext // 注册 PluginManager 到 ToolContext
toolContext.set("PLUGIN_MANAGER", pluginManager); toolContext.set("PLUGIN_MANAGER", pluginManager);

@ -63,6 +63,9 @@ public class AgentLoop {
/** 自动压缩管理器(可选) */ /** 自动压缩管理器(可选) */
private AutoCompactManager autoCompactManager; private AutoCompactManager autoCompactManager;
/** 会话记忆服务(可选) */
private SessionMemoryService sessionMemoryService;
/** 拒绝追踪器 */ /** 拒绝追踪器 */
private final DenialTracker denialTracker = new DenialTracker(); private final DenialTracker denialTracker = new DenialTracker();
@ -127,6 +130,10 @@ public class AgentLoop {
this.autoCompactManager = manager; this.autoCompactManager = manager;
} }
public void setSessionMemoryService(SessionMemoryService service) {
this.sessionMemoryService = service;
}
public AutoCompactManager getAutoCompactManager() { public AutoCompactManager getAutoCompactManager() {
return autoCompactManager; return autoCompactManager;
} }
@ -244,6 +251,17 @@ public class AgentLoop {
this::replaceHistory this::replaceHistory
); );
} }
// 会话记忆提取检查(异步,不阻塞主循环)
if (sessionMemoryService != null) {
int toolCallCount = result.assistant.hasToolCalls()
? result.assistant.getToolCalls().size() : 0;
sessionMemoryService.onPostSampling(
result.promptTokens + result.completionTokens,
toolCallCount,
messageHistory
);
}
} }
if (iteration >= MAX_ITERATIONS) { if (iteration >= MAX_ITERATIONS) {

@ -0,0 +1,276 @@
package com.claudecode.core;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.Message;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
/**
* 会话记忆服务 对应 claude-code/src/services/SessionMemory/
* <p>
* 自动维护 SESSION_MEMORY.md 文件记录当前会话的关键发现决策和上下文
* 在后台运行不中断主对话
* <p>
* 触发条件
* <ul>
* <li>初始化上下文 token 超过 50,000</li>
* <li>更新自上次提取后增长超过 20,000 token 且工具调用 >= 5</li>
* </ul>
*/
public class SessionMemoryService {
private static final Logger log = LoggerFactory.getLogger(SessionMemoryService.class);
/** 初始化阈值:50K tokens */
private static final long MINIMUM_TOKENS_TO_INIT = 50_000;
/** 更新阈值:自上次提取后增长 20K tokens */
private static final long MINIMUM_TOKENS_BETWEEN_UPDATE = 20_000;
/** 更新阈值:工具调用次数 */
private static final int MINIMUM_TOOL_CALLS_BETWEEN_UPDATE = 5;
private final Path memoryDir;
private final Path memoryFile;
private final AtomicLong lastExtractionTokens = new AtomicLong(0);
private final AtomicInteger toolCallsSinceLastExtraction = new AtomicInteger(0);
private volatile boolean initialized = false;
private volatile boolean extracting = false;
/** Agent factory for forked extraction agent */
private Function<String, String> agentFactory;
public SessionMemoryService(Path projectDir) {
String sanitized = projectDir.toAbsolutePath().toString()
.replace(":", "_")
.replace("\\", "_")
.replace("/", "_");
this.memoryDir = Path.of(System.getProperty("user.home"))
.resolve(".claude")
.resolve("projects")
.resolve(sanitized)
.resolve("memory");
this.memoryFile = memoryDir.resolve("SESSION_MEMORY.md");
}
public void setAgentFactory(Function<String, String> agentFactory) {
this.agentFactory = agentFactory;
}
/** Cumulative token count for threshold tracking */
private final AtomicLong cumulativeTokens = new AtomicLong(0);
/**
* Post-sampling hook: 在每次模型响应后调用
* 根据阈值决定是否触发记忆提取
*
* @param tokensThisTurn 本次迭代使用的 token
* @param toolCallCount 本次响应中的工具调用数量
* @param messageHistory 当前消息历史用于提取上下文
*/
public void onPostSampling(long tokensThisTurn, int toolCallCount, List<Message> messageHistory) {
if (extracting) return; // Already extracting
long currentTokens = cumulativeTokens.addAndGet(tokensThisTurn);
toolCallsSinceLastExtraction.addAndGet(toolCallCount);
if (shouldExtractMemory(currentTokens)) {
extractMemoryAsync(currentTokens);
}
}
/**
* 记录工具调用用于计数阈值
*/
public void recordToolCall() {
toolCallsSinceLastExtraction.incrementAndGet();
}
/**
* 判断是否应该提取记忆
*/
boolean shouldExtractMemory(long currentTokens) {
if (!initialized) {
// First extraction: need enough context
return currentTokens >= MINIMUM_TOKENS_TO_INIT;
}
// Subsequent extractions
long tokenGrowth = currentTokens - lastExtractionTokens.get();
if (tokenGrowth < MINIMUM_TOKENS_BETWEEN_UPDATE) {
return false;
}
// Token threshold met + tool call threshold
int toolCalls = toolCallsSinceLastExtraction.get();
return toolCalls >= MINIMUM_TOOL_CALLS_BETWEEN_UPDATE;
}
/**
* 异步提取记忆
*/
private void extractMemoryAsync(long currentTokens) {
if (agentFactory == null) {
log.debug("SessionMemory: no agent factory, skipping extraction");
return;
}
extracting = true;
Thread.ofVirtual().name("session-memory-extraction").start(() -> {
try {
extractMemory(currentTokens);
} catch (Exception e) {
log.debug("SessionMemory extraction failed", e);
} finally {
extracting = false;
}
});
}
/**
* 执行记忆提取
*/
void extractMemory(long currentTokens) {
log.info("SessionMemory: starting extraction (tokens: {}, initialized: {})",
currentTokens, initialized);
try {
// Ensure directory exists
Files.createDirectories(memoryDir);
String existingMemory = "";
if (Files.exists(memoryFile)) {
existingMemory = Files.readString(memoryFile, StandardCharsets.UTF_8);
}
// Build extraction prompt
String prompt = initialized
? buildUpdatePrompt(existingMemory)
: buildInitPrompt();
// Run forked agent for extraction
String result = agentFactory.apply(prompt);
// The agent should have written to the file via FileWrite/FileEdit tools
// But as a fallback, if it returned content, write it
if (result != null && !result.isBlank() && !Files.exists(memoryFile)) {
Files.writeString(memoryFile, result, StandardCharsets.UTF_8);
}
// Update tracking
lastExtractionTokens.set(currentTokens);
toolCallsSinceLastExtraction.set(0);
initialized = true;
log.info("SessionMemory: extraction complete");
} catch (IOException e) {
log.warn("SessionMemory: failed to write memory file", e);
}
}
/**
* 初始化提取提示词
*/
String buildInitPrompt() {
return """
You are a session memory extractor. Your job is to create a SESSION_MEMORY.md file \
that captures the key information from this conversation.
Create the file at: %s
The file should include these sections:
# Session Memory
## Task Overview
- What is the user working on?
- What are the main goals?
## Key Decisions
- Important decisions made during the conversation
- Rationale for each decision
## Code Changes
- Files modified and why
- Key patterns or approaches used
## Discoveries
- Important findings about the codebase
- Architecture or design insights
## Next Steps
- What remains to be done
- Known issues or blockers
Extract information from the conversation history. Be concise but comprehensive. \
Focus on information that would be valuable if the conversation were interrupted \
and needed to be resumed later.
""".formatted(memoryFile);
}
/**
* 更新提取提示词
*/
String buildUpdatePrompt(String existingMemory) {
return """
You are a session memory extractor. Update the existing SESSION_MEMORY.md file \
with new information from the recent conversation.
File location: %s
Current content:
```
%s
```
Update the file with new information. Rules:
- Keep existing information that is still relevant
- Add new decisions, changes, and discoveries
- Update the "Next Steps" section
- Remove outdated information
- Be concise this file should stay under 200 lines
- Use FileEdit to update specific sections, or FileWrite to rewrite entirely
""".formatted(memoryFile, existingMemory);
}
/**
* 读取当前记忆内容用于系统提示注入
*/
public String getMemoryContent() {
if (!Files.exists(memoryFile)) return null;
try {
String content = Files.readString(memoryFile, StandardCharsets.UTF_8);
return content.isBlank() ? null : content;
} catch (IOException e) {
return null;
}
}
/**
* 获取记忆文件路径
*/
public Path getMemoryFile() {
return memoryFile;
}
/**
* 是否已初始化至少提取过一次
*/
public boolean isInitialized() {
return initialized;
}
/**
* 是否正在提取
*/
public boolean isExtracting() {
return extracting;
}
}
Loading…
Cancel
Save