Skip to content

Commit fa6b0a8

Browse files
SteffenDEclaudebenbrandt
authored
Fix session pollution: route notifications by threadId (#123)
* Fix session pollution: route notifications by threadId notify() broadcasted every Codex notification to every registered session handler, so a stale handler from a previous prompt would re-emit the new session's deltas as ACP sessionUpdates on the old sessionId. awaitTurnCompleted() had a similar issue: it never disposed its turn/completed listener and didn't filter by threadId, so concurrent prompts across sessions could resolve to the wrong turn. Filter thread-scoped notifications by params.threadId in notify(), and make awaitTurnCompleted thread-scoped with disposal. Non-thread events (account, mcp, app-list, warnings) still broadcast. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Match turn completions by turn id Add runTurn to handle turn/completed races and ensure concurrent prompts resolve only the matching thread and turn. Update tests to cover out-of-order and early completion notifications. --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
1 parent 7475f5a commit fa6b0a8

8 files changed

Lines changed: 262 additions & 58 deletions

src/CodexAcpClient.ts

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ export class CodexAcpClient {
383383
const effort = modelId.effort as ReasoningEffort | null; //TODO remove unsafe conversion
384384

385385
await this.refreshSkills(cwd, request._meta);
386-
await this.codexClient.turnStart({
386+
return await this.codexClient.runTurn({
387387
outputSchema: null,
388388
threadId: request.sessionId,
389389
input: input,
@@ -395,10 +395,6 @@ export class CodexAcpClient {
395395
effort: effort,
396396
model: modelId.model,
397397
});
398-
399-
// Wait for turn completion
400-
// If turnInterrupt() was called, Codex will send turn/completed event with status "interrupted"
401-
return await this.codexClient.awaitTurnCompleted();
402398
}
403399

404400
async listSkills(params?: SkillsListParams): Promise<SkillsListResponse> {

src/CodexAppServerClient.ts

Lines changed: 100 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ export class CodexAppServerClient {
9494
private mcpServerStartupVersion = 0;
9595
private readonly mcpServerStartupStates = new Map<string, McpServerStartupSnapshot>();
9696
private readonly mcpServerStartupResolvers: Array<McpServerStartupResolver> = [];
97+
private readonly pendingTurnCompletionResolvers = new Map<string, Map<string, (event: TurnCompletedNotification) => void>>();
98+
private readonly turnCompletionCaptures = new Map<string, Set<(event: TurnCompletedNotification) => void>>();
9799

98100
constructor(connection: MessageConnection) {
99101
this.connection = connection;
@@ -108,6 +110,9 @@ export class CodexAppServerClient {
108110
});
109111
this.resolveMcpServerStartupResolvers();
110112
}
113+
if (isTurnCompletedNotification(serverNotification)) {
114+
this.recordTurnCompleted(serverNotification.params);
115+
}
111116
this.notify(serverNotification);
112117
for (const callback of this.codexEventHandlers) {
113118
callback({ eventType: "notification", ...serverNotification });
@@ -155,6 +160,27 @@ export class CodexAppServerClient {
155160
return await this.sendRequest({ method: "turn/start", params: params });
156161
}
157162

163+
async runTurn(params: TurnStartParams): Promise<TurnCompletedNotification> {
164+
const capturedCompletions: Array<TurnCompletedNotification> = [];
165+
const releaseCapture = this.captureTurnCompletions(params.threadId, (event) => {
166+
capturedCompletions.push(event);
167+
});
168+
169+
try {
170+
const turnStarted = await this.turnStart(params);
171+
const earlyCompletion = capturedCompletions.find(event => event.turn.id === turnStarted.turn.id);
172+
releaseCapture();
173+
if (earlyCompletion) {
174+
return earlyCompletion;
175+
}
176+
// Wait for turn completion
177+
// If turnInterrupt() was called, Codex will send turn/completed event with status "interrupted"
178+
return await this.awaitTurnCompleted(params.threadId, turnStarted.turn.id);
179+
} finally {
180+
releaseCapture();
181+
}
182+
}
183+
158184
async turnInterrupt(params: TurnInterruptParams): Promise<TurnInterruptResponse> {
159185
return await this.sendRequest({ method: "turn/interrupt", params: params });
160186
}
@@ -224,11 +250,10 @@ export class CodexAppServerClient {
224250
}
225251

226252
//TODO create type-safe helper
227-
async awaitTurnCompleted(): Promise<TurnCompletedNotification> {
253+
async awaitTurnCompleted(threadId: string, turnId: string): Promise<TurnCompletedNotification> {
228254
return await new Promise((resolve) => {
229-
this.connection.onNotification("turn/completed", (event: TurnCompletedNotification) => {
230-
resolve(event);
231-
});
255+
const threadResolvers = this.getOrCreatePendingTurnCompletionResolvers(threadId);
256+
threadResolvers.set(turnId, resolve);
232257
});
233258
}
234259

@@ -255,11 +280,67 @@ export class CodexAppServerClient {
255280

256281
private notificationHandlers = new Map<string, (event: ServerNotification) => void>();
257282
private notify(notification: ServerNotification) {
283+
const threadId = extractThreadId(notification);
284+
if (threadId !== null) {
285+
const handler = this.notificationHandlers.get(threadId);
286+
if (handler) {
287+
handler(notification);
288+
}
289+
return;
290+
}
258291
for (const notificationHandler of this.notificationHandlers.values()) {
259292
notificationHandler(notification);
260293
}
261294
}
262295

296+
private recordTurnCompleted(event: TurnCompletedNotification): void {
297+
const threadResolvers = this.pendingTurnCompletionResolvers.get(event.threadId);
298+
const resolve = threadResolvers?.get(event.turn.id);
299+
if (resolve) {
300+
threadResolvers!.delete(event.turn.id);
301+
if (threadResolvers!.size === 0) {
302+
this.pendingTurnCompletionResolvers.delete(event.threadId);
303+
}
304+
resolve(event);
305+
return;
306+
}
307+
308+
const captures = this.turnCompletionCaptures.get(event.threadId);
309+
if (!captures) {
310+
return;
311+
}
312+
for (const capture of captures) {
313+
capture(event);
314+
}
315+
}
316+
317+
private getOrCreatePendingTurnCompletionResolvers(threadId: string): Map<string, (event: TurnCompletedNotification) => void> {
318+
const existing = this.pendingTurnCompletionResolvers.get(threadId);
319+
if (existing) {
320+
return existing;
321+
}
322+
const created = new Map<string, (event: TurnCompletedNotification) => void>();
323+
this.pendingTurnCompletionResolvers.set(threadId, created);
324+
return created;
325+
}
326+
327+
private captureTurnCompletions(threadId: string, capture: (event: TurnCompletedNotification) => void): () => void {
328+
const captures = this.turnCompletionCaptures.get(threadId) ?? new Set<(event: TurnCompletedNotification) => void>();
329+
captures.add(capture);
330+
this.turnCompletionCaptures.set(threadId, captures);
331+
let released = false;
332+
return () => {
333+
if (released) {
334+
return;
335+
}
336+
released = true;
337+
captures.delete(capture);
338+
if (captures.size === 0) {
339+
this.turnCompletionCaptures.delete(threadId);
340+
}
341+
};
342+
}
343+
263344
private resolveMcpServerStartupResolvers(): void {
264345
const pendingResolvers: Array<McpServerStartupResolver> = [];
265346
for (const resolver of this.mcpServerStartupResolvers) {
@@ -352,3 +433,18 @@ function isMcpServerStatusUpdatedNotification(notification: ServerNotification):
352433
} {
353434
return notification.method === "mcpServer/startupStatus/updated";
354435
}
436+
437+
function isTurnCompletedNotification(notification: ServerNotification): notification is {
438+
method: "turn/completed";
439+
params: TurnCompletedNotification;
440+
} {
441+
return notification.method === "turn/completed";
442+
}
443+
444+
function extractThreadId(notification: ServerNotification): string | null {
445+
const params = notification.params as { threadId?: unknown } | undefined;
446+
if (params && typeof params.threadId === "string") {
447+
return params.threadId;
448+
}
449+
return null;
450+
}

src/__tests__/CodexACPAgent/CodexAcpClient.test.ts

Lines changed: 124 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import {createTestFixture, createCodexMockTestFixture, createTestSessionState, t
77
import type {ServerNotification} from "../../app-server";
88
import type {SessionState} from "../../CodexAcpServer";
99
import {AgentMode} from "../../AgentMode";
10-
import type {ListMcpServerStatusResponse, Model, SkillsListResponse} from "../../app-server/v2";
10+
import type {ListMcpServerStatusResponse, Model, SkillsListResponse, TurnStartParams} from "../../app-server/v2";
1111
import type {RateLimitsMap} from "../../RateLimitsMap";
1212
import {ModelId} from "../../ModelId";
1313

@@ -407,6 +407,32 @@ describe('ACP server test', { timeout: 40_000 }, () => {
407407
return onServerNotification;
408408
}
409409

410+
function createTurn(id: string, status: "inProgress" | "completed") {
411+
return {
412+
id,
413+
items: [],
414+
status,
415+
error: null,
416+
startedAt: null,
417+
completedAt: null,
418+
durationMs: null,
419+
};
420+
}
421+
422+
function createTurnCompletedNotification(threadId: string, turnId: string): ServerNotification {
423+
return {
424+
method: "turn/completed",
425+
params: {
426+
threadId,
427+
turn: createTurn(turnId, "completed"),
428+
},
429+
};
430+
}
431+
432+
async function flushAsyncWork(): Promise<void> {
433+
await new Promise((resolve) => setTimeout(resolve, 0));
434+
}
435+
410436
it('should map events from dump', async () => {
411437
fixture.getCodexAppServerClient().onServerNotification = loadNotifications();
412438

@@ -462,9 +488,9 @@ describe('ACP server test', { timeout: 40_000 }, () => {
462488

463489
// Trigger notifications after both prompts - should produce only 3 events, not 6
464490
const serverNotifications: ServerNotification[] = [
465-
{ method: "item/agentMessage/delta", params: { threadId: "string", turnId: "string", itemId: "string", delta: "He", }},
466-
{ method: "item/agentMessage/delta", params: { threadId: "string", turnId: "string", itemId: "string", delta: "ll", }},
467-
{ method: "item/agentMessage/delta", params: { threadId: "string", turnId: "string", itemId: "string", delta: "o!", }},
491+
{ method: "item/agentMessage/delta", params: { threadId: "id", turnId: "string", itemId: "string", delta: "He", }},
492+
{ method: "item/agentMessage/delta", params: { threadId: "id", turnId: "string", itemId: "string", delta: "ll", }},
493+
{ method: "item/agentMessage/delta", params: { threadId: "id", turnId: "string", itemId: "string", delta: "o!", }},
468494
];
469495
for (const notification of serverNotifications) {
470496
mockFixture.sendServerNotification(notification);
@@ -486,10 +512,6 @@ describe('ACP server test', { timeout: 40_000 }, () => {
486512
mockFixture.getCodexAppServerClient().turnStart = vi.fn().mockResolvedValue({
487513
turn: { id: "turn-id", items: [], status: "inProgress", error: null }
488514
});
489-
mockFixture.getCodexAppServerClient().awaitTurnCompleted = vi.fn().mockResolvedValue({
490-
threadId: "id",
491-
turn: { id: "turn-id", items: [], status: "completed", error: null }
492-
});
493515

494516
const sessionState1: SessionState = createTestSessionState({
495517
sessionId: "session-1",
@@ -506,15 +528,23 @@ describe('ACP server test', { timeout: 40_000 }, () => {
506528
return sessionId === "session-1" ? sessionState1 : sessionState2;
507529
});
508530

531+
// awaitTurnCompleted is per-turn; resolve the matching thread and turn.
532+
mockFixture.getCodexAppServerClient().awaitTurnCompleted = vi.fn().mockImplementation((threadId: string, turnId: string) => Promise.resolve({
533+
threadId,
534+
turn: createTurn(turnId, "completed")
535+
}));
536+
509537
// Start prompts for two different sessions
510538
await codexAcpAgent.prompt({ sessionId: "session-1", prompt: [{type: "text", text: "Message to session 1"}] });
511539
await codexAcpAgent.prompt({ sessionId: "session-2", prompt: [{type: "text", text: "Message to session 2"}] });
512540

513541
mockFixture.clearAcpConnectionDump();
514542

515-
// Trigger notifications - both session handlers should receive them
543+
// Each notification carries the threadId of the session it belongs to,
544+
// and must only be dispatched to that session.
516545
const serverNotifications: ServerNotification[] = [
517-
{ method: "item/agentMessage/delta", params: { threadId: "string", turnId: "string", itemId: "string", delta: "Hello", }},
546+
{ method: "item/agentMessage/delta", params: { threadId: "session-1", turnId: "string", itemId: "string", delta: "Hello-1", }},
547+
{ method: "item/agentMessage/delta", params: { threadId: "session-2", turnId: "string", itemId: "string", delta: "Hello-2", }},
518548
];
519549
for (const notification of serverNotifications) {
520550
mockFixture.sendServerNotification(notification);
@@ -523,18 +553,100 @@ describe('ACP server test', { timeout: 40_000 }, () => {
523553
// Wait for async handlers to complete
524554
await vi.waitFor(() => {
525555
const dump = mockFixture.getAcpConnectionDump([]);
526-
expect(dump.length).toBeGreaterThan(0);
556+
expect(dump.length).toBeGreaterThanOrEqual(2);
527557
});
528558

529-
// Should have 2 events - one for each session's handler
559+
// Should have exactly 2 events - the session-1 delta only on session-1, and
560+
// the session-2 delta only on session-2 (no cross-session pollution).
530561
await expect(mockFixture.getAcpConnectionDump([])).toMatchFileSnapshot("data/multiple-sessions.json");
531562
});
532563

564+
it('should complete concurrent prompts by matching thread and turn id', async () => {
565+
const mockFixture = createCodexMockTestFixture();
566+
const codexAcpAgent = mockFixture.getCodexAcpAgent();
567+
568+
const turnIds = new Map([
569+
["session-1", "turn-1"],
570+
["session-2", "turn-2"],
571+
]);
572+
const turnStart = vi.fn().mockImplementation((params: TurnStartParams) => Promise.resolve({
573+
turn: createTurn(turnIds.get(params.threadId) ?? "unknown-turn", "inProgress"),
574+
}));
575+
mockFixture.getCodexAppServerClient().turnStart = turnStart;
576+
577+
const sessionState1: SessionState = createTestSessionState({
578+
sessionId: "session-1",
579+
currentModelId: "model-id[effort]",
580+
agentMode: AgentMode.DEFAULT_AGENT_MODE
581+
});
582+
const sessionState2: SessionState = createTestSessionState({
583+
sessionId: "session-2",
584+
currentModelId: "model-id[effort]",
585+
agentMode: AgentMode.DEFAULT_AGENT_MODE
586+
});
587+
vi.spyOn(codexAcpAgent, "getSessionState").mockImplementation((sessionId: string) => {
588+
return sessionId === "session-1" ? sessionState1 : sessionState2;
589+
});
590+
591+
const prompt1 = codexAcpAgent.prompt({ sessionId: "session-1", prompt: [{type: "text", text: "Message to session 1"}] });
592+
const prompt2 = codexAcpAgent.prompt({ sessionId: "session-2", prompt: [{type: "text", text: "Message to session 2"}] });
593+
594+
await vi.waitFor(() => {
595+
expect(turnStart).toHaveBeenCalledTimes(2);
596+
});
597+
598+
let prompt1Settled = false;
599+
void prompt1.then(() => {
600+
prompt1Settled = true;
601+
}, () => {
602+
prompt1Settled = true;
603+
});
604+
605+
mockFixture.sendServerNotification(createTurnCompletedNotification("session-1", "old-turn"));
606+
await flushAsyncWork();
607+
expect(prompt1Settled).toBe(false);
608+
609+
mockFixture.sendServerNotification(createTurnCompletedNotification("session-2", "turn-2"));
610+
await expect(prompt2).resolves.toMatchObject({stopReason: "end_turn"});
611+
expect(prompt1Settled).toBe(false);
612+
613+
mockFixture.sendServerNotification(createTurnCompletedNotification("session-1", "turn-1"));
614+
await expect(prompt1).resolves.toMatchObject({stopReason: "end_turn"});
615+
});
616+
617+
it('should handle a turn completion that arrives before awaitTurnCompleted is called', async () => {
618+
const mockFixture = createCodexMockTestFixture();
619+
const codexAcpAgent = mockFixture.getCodexAcpAgent();
620+
621+
mockFixture.getCodexAppServerClient().turnStart = vi.fn().mockImplementation((params: TurnStartParams) => {
622+
mockFixture.sendServerNotification(createTurnCompletedNotification(params.threadId, "fast-turn"));
623+
return Promise.resolve({
624+
turn: createTurn("fast-turn", "inProgress"),
625+
});
626+
});
627+
628+
vi.spyOn(codexAcpAgent, "getSessionState").mockReturnValue(createTestSessionState({
629+
sessionId: "fast-session",
630+
currentModelId: "model-id[effort]",
631+
agentMode: AgentMode.DEFAULT_AGENT_MODE
632+
}));
633+
634+
await expect(codexAcpAgent.prompt({
635+
sessionId: "fast-session",
636+
prompt: [{type: "text", text: "Fast completion"}],
637+
})).resolves.toMatchObject({stopReason: "end_turn"});
638+
});
639+
533640
it('should send attachments as prompt items', async () => {
534641
const mockFixture = createCodexMockTestFixture();
535642
const codexAcpAgent = mockFixture.getCodexAcpAgent();
536643
const codexAppServerClient = mockFixture.getCodexAppServerClient();
537644

645+
const realTurnStart = codexAppServerClient.turnStart.bind(codexAppServerClient);
646+
vi.spyOn(codexAppServerClient, "turnStart").mockImplementation(async (params) => {
647+
await realTurnStart(params);
648+
return {turn: createTurn("turn-id", "inProgress")};
649+
});
538650
vi.spyOn(codexAppServerClient, "awaitTurnCompleted").mockResolvedValue({
539651
threadId: "session-id",
540652
turn: {

0 commit comments

Comments
 (0)