Skip to content

Commit 29b84f0

Browse files
committed
refactor: encapsulate model ID and token limits in session
1 parent 432241f commit 29b84f0

File tree

4 files changed

+72
-32
lines changed

4 files changed

+72
-32
lines changed

server/aws-lsp-codewhisperer/src/language-server/agenticChat/agenticChatController.test.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3059,7 +3059,7 @@ ${' '.repeat(8)}}
30593059
// Create a session and set initial model
30603060
chatController.onTabAdd({ tabId: mockTabId })
30613061
const session = chatSessionManagementService.getSession(mockTabId).data!
3062-
session.modelId = initialModelId
3062+
session.setModel(initialModelId, cachedModels)
30633063

30643064
// Get initial token limits (default 200K)
30653065
const initialLimits = session.tokenLimits
@@ -3137,7 +3137,7 @@ ${' '.repeat(8)}}
31373137
getCachedModelsStub.returns(cachedData)
31383138

31393139
const session = chatSessionManagementService.getSession(mockTabId).data!
3140-
session.modelId = 'model1'
3140+
session.setModel('model1', cachedData.models)
31413141

31423142
const result = await chatController.onListAvailableModels({ tabId: mockTabId })
31433143

@@ -3324,7 +3324,7 @@ ${' '.repeat(8)}}
33243324

33253325
it('should use defaultModelId from cache when session has no modelId', async () => {
33263326
const session = chatSessionManagementService.getSession(mockTabId).data!
3327-
session.modelId = undefined
3327+
session.setModel(undefined, undefined)
33283328

33293329
const result = await chatController.onListAvailableModels({ tabId: mockTabId })
33303330

@@ -3341,7 +3341,7 @@ ${' '.repeat(8)}}
33413341
})
33423342

33433343
const session = chatSessionManagementService.getSession(mockTabId).data!
3344-
session.modelId = undefined
3344+
session.setModel(undefined, undefined)
33453345

33463346
const result = await chatController.onListAvailableModels({ tabId: mockTabId })
33473347

server/aws-lsp-codewhisperer/src/language-server/agenticChat/agenticChatController.ts

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -780,13 +780,12 @@ export class AgenticChatController implements ChatHandlers {
780780

781781
// Handle error cases by returning default model
782782
if (!success || errorFromAPI) {
783-
// Even in error cases, calculate token limits from the default/fallback model
783+
// Even in error cases, set the model with token limits
784784
if (success && session) {
785-
const fallbackModel = models.find(model => model.id === DEFAULT_MODEL_ID)
786-
const maxInputTokens = TokenLimitsCalculator.extractMaxInputTokens(fallbackModel)
787-
const tokenLimits = TokenLimitsCalculator.calculate(maxInputTokens)
788-
session.setTokenLimits(tokenLimits)
789-
this.#log(`Token limits calculated for fallback model (error case): ${JSON.stringify(tokenLimits)}`)
785+
session.setModel(DEFAULT_MODEL_ID, models)
786+
this.#log(
787+
`Model set for fallback (error case): ${DEFAULT_MODEL_ID}, tokenLimits: ${JSON.stringify(session.tokenLimits)}`
788+
)
790789
}
791790
return {
792791
tabId: params.tabId,
@@ -828,16 +827,10 @@ export class AgenticChatController implements ChatHandlers {
828827
selectedModelId = defaultModelId || getMappedModelId(DEFAULT_MODEL_ID)
829828
}
830829

831-
// Store the selected model in the session
832-
session.modelId = selectedModelId
833-
834-
// Extract maxInputTokens from the selected model and calculate token limits
835-
const selectedModel = models.find(model => model.id === selectedModelId)
836-
const maxInputTokens = TokenLimitsCalculator.extractMaxInputTokens(selectedModel)
837-
const tokenLimits = TokenLimitsCalculator.calculate(maxInputTokens)
838-
session.setTokenLimits(tokenLimits)
830+
// Store the selected model in the session (automatically calculates token limits)
831+
session.setModel(selectedModelId, models)
839832
this.#log(
840-
`Token limits calculated for initial model selection (${selectedModelId}): ${JSON.stringify(tokenLimits)}`
833+
`Model set for initial selection: ${selectedModelId}, tokenLimits: ${JSON.stringify(session.tokenLimits)}`
841834
)
842835

843836
return {
@@ -4678,17 +4671,13 @@ export class AgenticChatController implements ChatHandlers {
46784671
session.pairProgrammingMode = params.optionsValues['pair-programmer-mode'] === 'true'
46794672
const newModelId = params.optionsValues['model-selection']
46804673

4681-
// Recalculate token limits when model changes
4682-
if (newModelId && newModelId !== session.modelId) {
4674+
// Set model (automatically recalculates token limits)
4675+
if (newModelId !== session.modelId) {
46834676
const cachedData = this.#chatHistoryDb.getCachedModels()
4684-
const selectedModel = cachedData?.models?.find(model => model.id === newModelId)
4685-
const maxInputTokens = TokenLimitsCalculator.extractMaxInputTokens(selectedModel)
4686-
const tokenLimits = TokenLimitsCalculator.calculate(maxInputTokens)
4687-
session.setTokenLimits(tokenLimits)
4688-
this.#log(`Token limits calculated for model switch (${newModelId}): ${JSON.stringify(tokenLimits)}`)
4677+
session.setModel(newModelId, cachedData?.models)
4678+
this.#log(`Model set for model switch: ${newModelId}, tokenLimits: ${JSON.stringify(session.tokenLimits)}`)
46894679
}
46904680

4691-
session.modelId = newModelId
46924681
this.#chatHistoryDb.setModelId(session.modelId)
46934682
this.#chatHistoryDb.setPairProgrammingMode(session.pairProgrammingMode)
46944683
}

server/aws-lsp-codewhisperer/src/language-server/chat/chatSessionService.test.ts

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,45 @@ describe('Chat Session Service', () => {
329329
})
330330
})
331331

332+
describe('setModel encapsulation', () => {
333+
let chatSessionService: ChatSessionService
334+
335+
beforeEach(() => {
336+
chatSessionService = new ChatSessionService()
337+
})
338+
339+
it('should initialize with undefined modelId and default token limits', () => {
340+
assert.strictEqual(chatSessionService.modelId, undefined)
341+
assert.strictEqual(chatSessionService.tokenLimits.maxInputTokens, 200_000)
342+
})
343+
344+
it('should set modelId and calculate token limits together', () => {
345+
const models = [
346+
{ id: 'model-1', name: 'Model 1', description: 'Test', tokenLimits: { maxInputTokens: 300_000 } },
347+
]
348+
349+
chatSessionService.setModel('model-1', models)
350+
351+
assert.strictEqual(chatSessionService.modelId, 'model-1')
352+
assert.strictEqual(chatSessionService.tokenLimits.maxInputTokens, 300_000)
353+
assert.strictEqual(chatSessionService.tokenLimits.maxOverallCharacters, Math.floor(300_000 * 3.5))
354+
})
355+
356+
it('should use default token limits when model not found in list', () => {
357+
chatSessionService.setModel('unknown-model', [])
358+
359+
assert.strictEqual(chatSessionService.modelId, 'unknown-model')
360+
assert.strictEqual(chatSessionService.tokenLimits.maxInputTokens, 200_000)
361+
})
362+
363+
it('should use default token limits when models list is undefined', () => {
364+
chatSessionService.setModel('some-model', undefined)
365+
366+
assert.strictEqual(chatSessionService.modelId, 'some-model')
367+
assert.strictEqual(chatSessionService.tokenLimits.maxInputTokens, 200_000)
368+
})
369+
})
370+
332371
describe('IAM client source property', () => {
333372
it('sets source to Origin.IDE when using StreamingClientServiceIAM', async () => {
334373
const codeWhispererStreamingClientIAM = stubInterface<StreamingClientServiceIAM>()

server/aws-lsp-codewhisperer/src/language-server/chat/chatSessionService.ts

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import { QErrorTransformer } from '../agenticChat/retry/errorTransformer'
1818
import { DelayNotification } from '../agenticChat/retry/delayInterceptor'
1919
import { MAX_REQUEST_ATTEMPTS } from '../agenticChat/constants/constants'
2020
import { TokenLimits, TokenLimitsCalculator } from '../agenticChat/utils/tokenLimitsCalculator'
21+
import { Model } from '@aws/language-server-runtimes/protocol'
2122

2223
export type ChatSessionServiceConfig = CodeWhispererStreamingClientConfig
2324
type FileChange = { before?: string; after?: string }
@@ -29,8 +30,8 @@ type DeferredHandler = {
2930
export class ChatSessionService {
3031
public pairProgrammingMode: boolean = true
3132
public contextListSent: boolean = false
32-
public modelId: string | undefined
3333
public isMemoryBankGeneration: boolean = false
34+
#modelId: string | undefined
3435
#lsp?: Features['lsp']
3536
#abortController?: AbortController
3637
#currentPromptId?: string
@@ -145,6 +146,13 @@ export class ChatSessionService {
145146
this.#tokenLimits = TokenLimitsCalculator.calculate()
146147
}
147148

149+
/**
150+
* Gets the model ID for this session
151+
*/
152+
public get modelId(): string | undefined {
153+
return this.#modelId
154+
}
155+
148156
/**
149157
* Gets the token limits for this session
150158
*/
@@ -153,11 +161,15 @@ export class ChatSessionService {
153161
}
154162

155163
/**
156-
* Sets the token limits for this session
157-
* @param limits The token limits to set
164+
* Sets the model for this session, automatically calculating token limits.
165+
* This encapsulates model ID and token limits as a single entity.
166+
* @param modelId The model ID to set
167+
* @param models Optional list of available models to look up token limits from
158168
*/
159-
public setTokenLimits(limits: TokenLimits): void {
160-
this.#tokenLimits = limits
169+
public setModel(modelId: string | undefined, models?: Model[]): void {
170+
this.#modelId = modelId
171+
const maxInputTokens = TokenLimitsCalculator.extractMaxInputTokens(models?.find(m => m.id === modelId))
172+
this.#tokenLimits = TokenLimitsCalculator.calculate(maxInputTokens)
161173
}
162174

163175
public async sendMessage(request: SendMessageCommandInput): Promise<SendMessageCommandOutput> {

0 commit comments

Comments
 (0)