diff --git a/internal/ast/ast.go b/internal/ast/ast.go index 4936cd1e71..b52c641b48 100644 --- a/internal/ast/ast.go +++ b/internal/ast/ast.go @@ -605,15 +605,13 @@ func (n *Node) Type() *Node { return n.AsCommonJSExport().Type case KindBinaryExpression: return n.AsBinaryExpression().Type - case KindEnumMember, KindBindingElement: - return nil default: funcLike := n.FunctionLikeData() if funcLike != nil { return funcLike.Type } } - panic("Unhandled case in Node.Type: " + n.Kind.String()) + return nil } func (n *Node) Initializer() *Node { diff --git a/internal/ls/definition.go b/internal/ls/definition.go index 7cc4be72b3..8fffa2d628 100644 --- a/internal/ls/definition.go +++ b/internal/ls/definition.go @@ -119,7 +119,7 @@ func getDeclarationNameForKeyword(node *ast.Node) *ast.Node { if decl := core.FirstOrNil(node.Parent.AsVariableDeclarationList().Declarations.Nodes); decl != nil && decl.Name() != nil { return decl.Name() } - } else if node.Parent.Name() != nil && node.Pos() < node.Parent.Name().Pos() { + } else if node.Parent.DeclarationData() != nil && node.Parent.Name() != nil && node.Pos() < node.Parent.Name().Pos() { return node.Parent.Name() } } diff --git a/internal/ls/findallreferences.go b/internal/ls/findallreferences.go index fead25582a..fb464b3f93 100644 --- a/internal/ls/findallreferences.go +++ b/internal/ls/findallreferences.go @@ -128,14 +128,10 @@ func newNodeEntryWithKind(node *ast.Node, kind entryKind) *referenceEntry { func newNodeEntry(node *ast.Node) *referenceEntry { // creates nodeEntry with `kind == entryKindNode` - n := node - if node != nil && node.Name() != nil { - n = node.Name() - } return &referenceEntry{ kind: entryKindNode, - node: node, - context: getContextNodeForNodeEntry(n), + node: core.OrElse(node.Name(), node), + context: getContextNodeForNodeEntry(node), } } @@ -410,22 +406,53 @@ func (l *LanguageService) ProvideReferences(params *lsproto.ReferenceParams) []* symbolsAndEntries := l.getReferencedSymbolsForNode(position, node, program, program.GetSourceFiles(), options, nil) - return core.FlatMap(symbolsAndEntries, l.convertSymbolAndEntryToLocation) + return core.FlatMap(symbolsAndEntries, l.convertSymbolAndEntriesToLocations) +} + +func (l *LanguageService) ProvideImplementations(params *lsproto.ImplementationParams) []*lsproto.Location { + program, sourceFile := l.getProgramAndFile(params.TextDocument.Uri) + position := int(l.converters.LineAndCharacterToPosition(sourceFile, params.Position)) + node := astnav.GetTouchingPropertyName(sourceFile, position) + + var seenNodes collections.Set[*ast.Node] + var entries []*referenceEntry + queue := l.getImplementationReferenceEntries(program, node, position) + for len(queue) != 0 { + entry := queue[0] + queue = queue[1:] + if !seenNodes.Has(entry.node) { + seenNodes.Add(entry.node) + entries = append(entries, entry) + queue = append(queue, l.getImplementationReferenceEntries(program, entry.node, entry.node.Pos())...) + } + } + + return l.convertEntriesToLocations(entries) +} + +func (l *LanguageService) getImplementationReferenceEntries(program *compiler.Program, node *ast.Node, position int) []*referenceEntry { + options := refOptions{use: referenceUseReferences, implementations: true} + symbolsAndEntries := l.getReferencedSymbolsForNode(position, node, program, program.GetSourceFiles(), options, nil) + return core.FlatMap(symbolsAndEntries, func(s *SymbolAndEntries) []*referenceEntry { return s.references }) } // == functions for conversions == -func (l *LanguageService) convertSymbolAndEntryToLocation(s *SymbolAndEntries) []*lsproto.Location { - var locations []*lsproto.Location - for _, ref := range s.references { - if ref.textRange == nil { - sourceFile := ast.GetSourceFileOfNode(ref.node) - ref.textRange = l.getRangeOfNode(ref.node, sourceFile, nil /*endNode*/) - ref.fileName = sourceFile.FileName() +func (l *LanguageService) convertSymbolAndEntriesToLocations(s *SymbolAndEntries) []*lsproto.Location { + return l.convertEntriesToLocations(s.references) +} + +func (l *LanguageService) convertEntriesToLocations(entries []*referenceEntry) []*lsproto.Location { + locations := make([]*lsproto.Location, len(entries)) + for i, entry := range entries { + if entry.textRange == nil { + sourceFile := ast.GetSourceFileOfNode(entry.node) + entry.textRange = l.getRangeOfNode(entry.node, sourceFile, nil /*endNode*/) + entry.fileName = sourceFile.FileName() + } + locations[i] = &lsproto.Location{ + Uri: FileNameToDocumentURI(entry.fileName), + Range: *entry.textRange, } - locations = append(locations, &lsproto.Location{ - Uri: FileNameToDocumentURI(ref.fileName), - Range: *ref.textRange, - }) } return locations } diff --git a/internal/lsp/server.go b/internal/lsp/server.go index ba2dcf28f7..51cdd08bd9 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -488,6 +488,8 @@ func (s *Server) handleRequestOrNotification(ctx context.Context, req *lsproto.R return s.handleCompletion(ctx, req) case *lsproto.ReferenceParams: return s.handleReferences(ctx, req) + case *lsproto.ImplementationParams: + return s.handleImplementations(ctx, req) case *lsproto.SignatureHelpParams: return s.handleSignatureHelp(ctx, req) case *lsproto.DocumentFormattingParams: @@ -560,6 +562,9 @@ func (s *Server) handleInitialize(req *lsproto.RequestMessage) { ReferencesProvider: &lsproto.BooleanOrReferenceOptions{ Boolean: ptrTo(true), }, + ImplementationProvider: &lsproto.BooleanOrImplementationOptionsOrImplementationRegistrationOptions{ + Boolean: ptrTo(true), + }, DiagnosticProvider: &lsproto.DiagnosticOptionsOrDiagnosticRegistrationOptions{ DiagnosticOptions: &lsproto.DiagnosticOptions{ InterFileDependencies: true, @@ -725,6 +730,17 @@ func (s *Server) handleReferences(ctx context.Context, req *lsproto.RequestMessa return nil } +func (s *Server) handleImplementations(ctx context.Context, req *lsproto.RequestMessage) error { + // goToImplementation + params := req.Params.(*lsproto.ImplementationParams) + project := s.projectService.EnsureDefaultProjectForURI(params.TextDocument.Uri) + languageService, done := project.GetLanguageServiceForRequest(ctx) + defer done() + locations := languageService.ProvideImplementations(params) + s.sendResult(req.ID, locations) + return nil +} + func (s *Server) handleCompletion(ctx context.Context, req *lsproto.RequestMessage) error { params := req.Params.(*lsproto.CompletionParams) project := s.projectService.EnsureDefaultProjectForURI(params.TextDocument.Uri)