Skip to content

Commit cd451c4

Browse files
authored
Fix contextual type for array completions (#2309)
1 parent ad341f4 commit cd451c4

File tree

3 files changed

+88
-41
lines changed

3 files changed

+88
-41
lines changed

internal/checker/services.go

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -671,16 +671,34 @@ func (c *Checker) GetContextualDeclarationsForObjectLiteralElement(objectLiteral
671671
return result
672672
}
673673

674-
// GetContextualTypeForArrayElement returns the contextual type for an element at the given index
674+
// GetContextualTypeForArrayLiteralAtPosition returns the contextual type for an element at the given position
675675
// in an array with the given contextual type.
676-
func (c *Checker) GetContextualTypeForArrayElement(contextualArrayType *Type, elementIndex int) *Type {
676+
func (c *Checker) GetContextualTypeForArrayLiteralAtPosition(contextualArrayType *Type, arrayLiteral *ast.Node, position int) *Type {
677677
if contextualArrayType == nil {
678678
return nil
679679
}
680-
// Pass -1 for length, firstSpreadIndex, and lastSpreadIndex since we don't have
681-
// access to the actual array literal. This falls back to getting the iterated type
682-
// or checking numeric properties, which is appropriate for completion contexts.
683-
return c.getContextualTypeForElementExpression(contextualArrayType, elementIndex, -1, -1, -1)
680+
firstSpreadIndex, lastSpreadIndex := -1, -1
681+
elementIndex := 0
682+
elements := arrayLiteral.Elements()
683+
for i, elem := range elements {
684+
if elem.Pos() < position {
685+
elementIndex++
686+
}
687+
if ast.IsSpreadElement(elem) {
688+
if firstSpreadIndex == -1 {
689+
firstSpreadIndex = i
690+
}
691+
lastSpreadIndex = i
692+
}
693+
}
694+
// The array may be incomplete, so we don't know its final length.
695+
return c.getContextualTypeForElementExpression(
696+
contextualArrayType,
697+
elementIndex,
698+
-1, /*length*/
699+
firstSpreadIndex,
700+
lastSpreadIndex,
701+
)
684702
}
685703

686704
var knownGenericTypeNames = map[string]struct{}{
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package fourslash_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/microsoft/typescript-go/internal/fourslash"
7+
. "github.com/microsoft/typescript-go/internal/fourslash/tests/util"
8+
"github.com/microsoft/typescript-go/internal/testutil"
9+
)
10+
11+
func TestArgumentCompletions(t *testing.T) {
12+
t.Parallel()
13+
14+
defer testutil.RecoverAndFail(t, "Panic on fourslash test")
15+
const content = `
16+
function foo(a: "a", b: "b") {}
17+
foo("a", /*1*/);
18+
19+
20+
const t3 = ['x', 'y', 'z'] as const;
21+
const x: [string, string, string, 'a' | 'b'] = [...t3, /*2*/];
22+
`
23+
f, done := fourslash.NewFourslash(t, nil /*capabilities*/, content)
24+
defer done()
25+
f.VerifyCompletions(t, "1", &fourslash.CompletionsExpectedList{
26+
ItemDefaults: &fourslash.CompletionsExpectedItemDefaults{
27+
CommitCharacters: &DefaultCommitCharacters,
28+
},
29+
Items: &fourslash.CompletionsExpectedItems{
30+
Includes: []fourslash.CompletionsExpectedItem{`"b"`},
31+
},
32+
})
33+
f.VerifyCompletions(t, "2", &fourslash.CompletionsExpectedList{
34+
ItemDefaults: &fourslash.CompletionsExpectedItemDefaults{
35+
CommitCharacters: &DefaultCommitCharacters,
36+
},
37+
Items: &fourslash.CompletionsExpectedItems{
38+
Includes: []fourslash.CompletionsExpectedItem{`"b"`},
39+
},
40+
})
41+
}

internal/ls/completions.go

Lines changed: 23 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2977,27 +2977,7 @@ func getContextualType(previousToken *ast.Node, position int, file *ast.SourceFi
29772977
contextualArrayType := typeChecker.GetContextualType(parent, checker.ContextFlagsNone)
29782978
if contextualArrayType != nil {
29792979
// Get the type for the first element (index 0)
2980-
return typeChecker.GetContextualTypeForArrayElement(contextualArrayType, 0)
2981-
}
2982-
}
2983-
return nil
2984-
case ast.KindCommaToken:
2985-
// When completing after `,` in an array literal (e.g., `[x, /*here*/]`),
2986-
// we should provide contextual type for the element after the comma
2987-
if ast.IsArrayLiteralExpression(parent) {
2988-
contextualArrayType := typeChecker.GetContextualType(parent, checker.ContextFlagsNone)
2989-
if contextualArrayType != nil {
2990-
// Count how many elements come before the cursor position
2991-
arrayLiteral := parent.AsArrayLiteralExpression()
2992-
elementIndex := 0
2993-
for _, elem := range arrayLiteral.Elements.Nodes {
2994-
if elem.Pos() < position {
2995-
elementIndex++
2996-
} else {
2997-
break
2998-
}
2999-
}
3000-
return typeChecker.GetContextualTypeForArrayElement(contextualArrayType, elementIndex)
2980+
return typeChecker.GetContextualTypeForArrayLiteralAtPosition(contextualArrayType, parent, position)
30012981
}
30022982
}
30032983
return nil
@@ -3023,22 +3003,30 @@ func getContextualType(previousToken *ast.Node, position int, file *ast.SourceFi
30233003
if ast.IsConditionalExpression(parent) {
30243004
return getContextualTypeForConditionalExpression(parent, position, file, typeChecker)
30253005
}
3026-
// Fall through to default for other colon contexts (object literals, etc.)
3027-
fallthrough
3028-
default:
3029-
argInfo := getArgumentInfoForCompletions(previousToken, position, file, typeChecker)
3030-
if argInfo != nil {
3031-
return typeChecker.GetContextualTypeForArgumentAtIndex(argInfo.invocation, argInfo.argumentIndex)
3032-
} else if isEqualityOperatorKind(previousToken.Kind) && ast.IsBinaryExpression(parent) && isEqualityOperatorKind(parent.AsBinaryExpression().OperatorToken.Kind) {
3033-
// completion at `x ===/**/`
3034-
return typeChecker.GetTypeAtLocation(parent.AsBinaryExpression().Left)
3035-
} else {
3036-
contextualType := typeChecker.GetContextualType(previousToken, checker.ContextFlagsCompletions)
3037-
if contextualType != nil {
3038-
return contextualType
3006+
case ast.KindCommaToken:
3007+
// When completing after `,` in an array literal (e.g., `[x, /*here*/]`),
3008+
// we should provide contextual type for the element after the comma.
3009+
if ast.IsArrayLiteralExpression(parent) {
3010+
contextualArrayType := typeChecker.GetContextualType(parent, checker.ContextFlagsNone)
3011+
if contextualArrayType != nil {
3012+
return typeChecker.GetContextualTypeForArrayLiteralAtPosition(contextualArrayType, parent, position)
30393013
}
3040-
return typeChecker.GetContextualType(previousToken, checker.ContextFlagsNone)
3014+
return nil
3015+
}
3016+
}
3017+
// Default case: see if we're in an argument position.
3018+
argInfo := getArgumentInfoForCompletions(previousToken, position, file, typeChecker)
3019+
if argInfo != nil {
3020+
return typeChecker.GetContextualTypeForArgumentAtIndex(argInfo.invocation, argInfo.argumentIndex)
3021+
} else if isEqualityOperatorKind(previousToken.Kind) && ast.IsBinaryExpression(parent) && isEqualityOperatorKind(parent.AsBinaryExpression().OperatorToken.Kind) {
3022+
// completion at `x ===/**/`
3023+
return typeChecker.GetTypeAtLocation(parent.AsBinaryExpression().Left)
3024+
} else {
3025+
contextualType := typeChecker.GetContextualType(previousToken, checker.ContextFlagsCompletions)
3026+
if contextualType != nil {
3027+
return contextualType
30413028
}
3029+
return typeChecker.GetContextualType(previousToken, checker.ContextFlagsNone)
30423030
}
30433031
}
30443032

0 commit comments

Comments
 (0)