fix: add thread safety for concurrent state modifications

- Add stateLock to synchronize all getState/setState read-modify-write
  operations across UI thread and AgentLoop background threads
- Synchronize onInput, onPaste, addMessage, appendToStreamingMessage,
  finishStreamingMessage, completeLastToolCall, setThinking, runAgent
- Extract addMessageInternal() to avoid double-locking
- Harden extractToolSummary with explicit indexOf checks

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
pull/1/head
abel533 1 month ago
parent 75ecaeca58
commit b72fcfea79
  1. 231
      src/main/java/com/claudecode/tui/ClaudeCodeComponent.java

@ -71,6 +71,7 @@ public class ClaudeCodeComponent extends Component<ClaudeCodeComponent.TuiState>
private final Runnable onExit; private final Runnable onExit;
// --- 内部状态 --- // --- 内部状态 ---
private final Object stateLock = new Object(); // 保护 getState/setState 的读-改-写操作
private final List<String> inputHistory = new ArrayList<>(); private final List<String> inputHistory = new ArrayList<>();
private int historyIndex = -1; private int historyIndex = -1;
private String savedInput = ""; private String savedInput = "";
@ -441,78 +442,82 @@ public class ClaudeCodeComponent extends Component<ClaudeCodeComponent.TuiState>
@Override @Override
public void onInput(String input, Key key) { public void onInput(String input, Key key) {
TuiState s = getState(); synchronized (stateLock) {
TuiState s = getState();
// Ctrl+D: 退出
if (key.ctrl() && "d".equals(input)) {
if (onExit != null) onExit.run();
return;
}
// Ctrl+C: 取消当前输入或中断 Agent // Ctrl+D: 退出
if (key.ctrl() && "c".equals(input)) { if (key.ctrl() && "d".equals(input)) {
if (agentRunning.get()) { if (onExit != null) onExit.run();
// TODO: 中断 Agent 运行 return;
addMessage(new SystemMsg("^C (interrupt)", Color.BRIGHT_YELLOW));
} else {
setState(new TuiState("", s.messages, s.scrollOffset, false, ""));
} }
return;
}
// 权限确认模式 // Ctrl+C: 取消当前输入或中断 Agent
if (permissionCallback != null) { if (key.ctrl() && "c".equals(input)) {
handlePermissionInput(input, key, s); if (agentRunning.get()) {
return; // TODO: 中断 Agent 运行
} addMessageInternal(new SystemMsg("^C (interrupt)", Color.BRIGHT_YELLOW), s);
} else {
setState(new TuiState("", s.messages, s.scrollOffset, false, ""));
}
return;
}
// AI 运行中时忽略大部分输入(但允许滚动) // 权限确认模式
if (agentRunning.get()) { if (permissionCallback != null) {
handleScrollInput(key, s); handlePermissionInput(input, key, s);
return; return;
} }
if (key.return_() && key.meta()) { // AI 运行中时忽略大部分输入(但允许滚动)
// Shift+Enter: 多行换行 if (agentRunning.get()) {
setState(new TuiState(s.inputText + "\n", s.messages, 0, false, "")); handleScrollInput(key, s);
} else if (key.return_()) { return;
// Enter: 发送
if (!s.inputText.isEmpty()) {
submitInput(s.inputText, s);
} }
} else if (key.backspace()) {
if (!s.inputText.isEmpty()) { if (key.return_() && key.meta()) {
// Shift+Enter: 多行换行
setState(new TuiState(s.inputText + "\n", s.messages, 0, false, ""));
} else if (key.return_()) {
// Enter: 发送
if (!s.inputText.isEmpty()) {
submitInput(s.inputText, s);
}
} else if (key.backspace()) {
if (!s.inputText.isEmpty()) {
abandonHistoryPreview();
String newText = s.inputText.substring(0, s.inputText.length() - 1);
setState(new TuiState(newText, s.messages, s.scrollOffset, false, ""));
}
} else if (key.upArrow()) {
browseHistoryUp(s);
} else if (key.downArrow()) {
browseHistoryDown(s);
} else if (key.scrollUp()) {
scroll(s, 3);
} else if (key.scrollDown()) {
scroll(s, -3);
} else if (key.pageUp()) {
scroll(s, 10);
} else if (key.pageDown()) {
scroll(s, -10);
} else if (key.escape()) {
// Esc: 清空输入
setState(new TuiState("", s.messages, s.scrollOffset, false, ""));
} else if (!input.isEmpty() && isPrintableInput(input, key)) {
abandonHistoryPreview(); abandonHistoryPreview();
String newText = s.inputText.substring(0, s.inputText.length() - 1); setState(new TuiState(s.inputText + input, s.messages, s.scrollOffset, false, ""));
setState(new TuiState(newText, s.messages, s.scrollOffset, false, ""));
} }
} else if (key.upArrow()) {
browseHistoryUp(s);
} else if (key.downArrow()) {
browseHistoryDown(s);
} else if (key.scrollUp()) {
scroll(s, 3);
} else if (key.scrollDown()) {
scroll(s, -3);
} else if (key.pageUp()) {
scroll(s, 10);
} else if (key.pageDown()) {
scroll(s, -10);
} else if (key.escape()) {
// Esc: 清空输入
setState(new TuiState("", s.messages, s.scrollOffset, false, ""));
} else if (!input.isEmpty() && isPrintableInput(input, key)) {
abandonHistoryPreview();
setState(new TuiState(s.inputText + input, s.messages, s.scrollOffset, false, ""));
} }
} }
@Override @Override
public void onPaste(String text) { public void onPaste(String text) {
if (agentRunning.get() || text == null || text.isEmpty()) return; synchronized (stateLock) {
TuiState s = getState(); if (agentRunning.get() || text == null || text.isEmpty()) return;
abandonHistoryPreview(); TuiState s = getState();
setState(new TuiState(s.inputText + text, s.messages, s.scrollOffset, false, "")); abandonHistoryPreview();
setState(new TuiState(s.inputText + text, s.messages, s.scrollOffset, false, ""));
}
} }
/** 处理权限确认输入 */ /** 处理权限确认输入 */
@ -605,8 +610,10 @@ public class ClaudeCodeComponent extends Component<ClaudeCodeComponent.TuiState>
addMessage(new SystemMsg("Error: " + e.getMessage(), Color.BRIGHT_RED)); addMessage(new SystemMsg("Error: " + e.getMessage(), Color.BRIGHT_RED));
} finally { } finally {
agentRunning.set(false); agentRunning.set(false);
TuiState cs = getState(); synchronized (stateLock) {
setState(new TuiState(cs.inputText, cs.messages, 0, false, "")); TuiState cs = getState();
setState(new TuiState(cs.inputText, cs.messages, 0, false, ""));
}
} }
}); });
} }
@ -615,7 +622,13 @@ public class ClaudeCodeComponent extends Component<ClaudeCodeComponent.TuiState>
/** 添加一条消息 */ /** 添加一条消息 */
public void addMessage(UIMessage msg) { public void addMessage(UIMessage msg) {
TuiState s = getState(); synchronized (stateLock) {
addMessageInternal(msg, getState());
}
}
/** 内部添加消息(调用方需持有 stateLock) */
private void addMessageInternal(UIMessage msg, TuiState s) {
List<UIMessage> newMsgs = new ArrayList<>(s.messages); List<UIMessage> newMsgs = new ArrayList<>(s.messages);
newMsgs.add(msg); newMsgs.add(msg);
setState(new TuiState(s.inputText, Collections.unmodifiableList(newMsgs), setState(new TuiState(s.inputText, Collections.unmodifiableList(newMsgs),
@ -624,46 +637,52 @@ public class ClaudeCodeComponent extends Component<ClaudeCodeComponent.TuiState>
/** 追加 token 到当前流式助手消息 */ /** 追加 token 到当前流式助手消息 */
private void appendToStreamingMessage(String token) { private void appendToStreamingMessage(String token) {
TuiState s = getState(); synchronized (stateLock) {
List<UIMessage> msgs = new ArrayList<>(s.messages); TuiState s = getState();
List<UIMessage> msgs = new ArrayList<>(s.messages);
// 查找最后一个 streaming AssistantMsg // 查找最后一个 streaming AssistantMsg
if (!msgs.isEmpty() && msgs.getLast() instanceof AssistantMsg am && am.streaming()) { if (!msgs.isEmpty() && msgs.getLast() instanceof AssistantMsg am && am.streaming()) {
msgs.set(msgs.size() - 1, am.appendText(token)); msgs.set(msgs.size() - 1, am.appendText(token));
} else { } else {
msgs.add(new AssistantMsg(token, true)); msgs.add(new AssistantMsg(token, true));
} }
setState(new TuiState(s.inputText, Collections.unmodifiableList(msgs), setState(new TuiState(s.inputText, Collections.unmodifiableList(msgs),
0, s.thinking, s.thinkingText)); 0, s.thinking, s.thinkingText));
}
} }
/** 完成当前流式消息(公开给 JinkReplSession 使用) */ /** 完成当前流式消息(公开给 JinkReplSession 使用) */
public void finishStreamingMessage() { public void finishStreamingMessage() {
TuiState s = getState(); synchronized (stateLock) {
List<UIMessage> msgs = new ArrayList<>(s.messages); TuiState s = getState();
List<UIMessage> msgs = new ArrayList<>(s.messages);
if (!msgs.isEmpty() && msgs.getLast() instanceof AssistantMsg am && am.streaming()) {
msgs.set(msgs.size() - 1, am.finish()); if (!msgs.isEmpty() && msgs.getLast() instanceof AssistantMsg am && am.streaming()) {
setState(new TuiState(s.inputText, Collections.unmodifiableList(msgs), msgs.set(msgs.size() - 1, am.finish());
0, s.thinking, s.thinkingText)); setState(new TuiState(s.inputText, Collections.unmodifiableList(msgs),
0, s.thinking, s.thinkingText));
}
} }
} }
/** 更新最后一个工具调用消息的结果 */ /** 更新最后一个工具调用消息的结果 */
public void completeLastToolCall(String result) { public void completeLastToolCall(String result) {
TuiState s = getState(); synchronized (stateLock) {
List<UIMessage> msgs = new ArrayList<>(s.messages); TuiState s = getState();
List<UIMessage> msgs = new ArrayList<>(s.messages);
for (int i = msgs.size() - 1; i >= 0; i--) {
if (msgs.get(i) instanceof ToolCallMsg tcm && tcm.running()) { for (int i = msgs.size() - 1; i >= 0; i--) {
msgs.set(i, tcm.complete(result)); if (msgs.get(i) instanceof ToolCallMsg tcm && tcm.running()) {
break; msgs.set(i, tcm.complete(result));
break;
}
} }
}
setState(new TuiState(s.inputText, Collections.unmodifiableList(msgs), setState(new TuiState(s.inputText, Collections.unmodifiableList(msgs),
s.scrollOffset, s.thinking, s.thinkingText)); s.scrollOffset, s.thinking, s.thinkingText));
}
} }
/** 设置权限确认回调 */ /** 设置权限确认回调 */
@ -673,8 +692,10 @@ public class ClaudeCodeComponent extends Component<ClaudeCodeComponent.TuiState>
/** 设置 thinking 状态 */ /** 设置 thinking 状态 */
public void setThinking(boolean thinking, String text) { public void setThinking(boolean thinking, String text) {
TuiState s = getState(); synchronized (stateLock) {
setState(new TuiState(s.inputText, s.messages, s.scrollOffset, thinking, text)); TuiState s = getState();
setState(new TuiState(s.inputText, s.messages, s.scrollOffset, thinking, text));
}
} }
/** 设置首次用户输入回调 */ /** 设置首次用户输入回调 */
@ -747,20 +768,20 @@ public class ClaudeCodeComponent extends Component<ClaudeCodeComponent.TuiState>
String[] keys = {"command", "file_path", "pattern", "query", "url"}; String[] keys = {"command", "file_path", "pattern", "query", "url"};
for (String key : keys) { for (String key : keys) {
String search = "\"" + key + "\""; String search = "\"" + key + "\"";
if (args.contains(search)) { int start = args.indexOf(search);
int start = args.indexOf(search); if (start < 0) continue;
int valStart = args.indexOf("\"", start + search.length()) + 1; int colonPos = args.indexOf("\"", start + search.length());
int valEnd = args.indexOf("\"", valStart); if (colonPos < 0) continue;
if (valStart > 0 && valEnd > valStart) { int valStart = colonPos + 1;
String val = args.substring(valStart, Math.min(valEnd, valStart + 60)); int valEnd = args.indexOf("\"", valStart);
return switch (key) { if (valEnd < 0 || valEnd <= valStart) continue;
case "command" -> "$ " + val; String val = args.substring(valStart, Math.min(valEnd, valStart + 60));
case "pattern" -> "pattern: " + val; return switch (key) {
case "query" -> "\"" + val + "\""; case "command" -> "$ " + val;
default -> val; case "pattern" -> "pattern: " + val;
}; case "query" -> "\"" + val + "\"";
} default -> val;
} };
} }
} catch (Exception ignored) {} } catch (Exception ignored) {}
return null; return null;

Loading…
Cancel
Save