Skip to content

Commit b3466d9

Browse files
committed
fix: database select
1 parent b1322b3 commit b3466d9

File tree

5 files changed

+95
-131
lines changed

5 files changed

+95
-131
lines changed

Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ run-neo4j:
133133
@docker run -d \
134134
--name gograph-neo4j \
135135
-p 7474:7474 -p 7687:7687 \
136+
-v gograph-neo4j-data:/data \
136137
-e NEO4J_AUTH=neo4j/password \
137138
-e NEO4J_PLUGINS='["apoc","graph-data-science"]' \
138139
neo4j:5-community

cmd/gograph/commands/mcp.go

Lines changed: 16 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ import (
1717
"github.com/compozy/gograph/pkg/logger"
1818
mcpconfig "github.com/compozy/gograph/pkg/mcp"
1919
"github.com/joho/godotenv"
20-
"github.com/neo4j/neo4j-go-driver/v5/neo4j"
2120
"github.com/spf13/cobra"
2221
"github.com/spf13/viper"
2322
)
@@ -88,13 +87,8 @@ func runServeMCP(cmd *cobra.Command, _ []string) error {
8887
return err
8988
}
9089

91-
driver, err := initializeNeo4jConnection(ctx)
92-
if err != nil {
93-
return err
94-
}
95-
defer driver.Close(ctx)
96-
97-
server := createMCPServer(config)
90+
server, cleanup := createMCPServer(config)
91+
defer cleanup()
9892

9993
runMCPServerWithGracefulShutdown(ctx, cancel, server)
10094
return nil
@@ -122,28 +116,7 @@ func applyCommandLineFlagOverrides(cmd *cobra.Command, config *mcpconfig.Config)
122116
}
123117
}
124118

125-
func initializeNeo4jConnection(_ context.Context) (neo4j.DriverWithContext, error) {
126-
neo4jConfig := &infra.Neo4jConfig{
127-
URI: viper.GetString("neo4j.uri"),
128-
Username: viper.GetString("neo4j.username"),
129-
Password: viper.GetString("neo4j.password"),
130-
Database: viper.GetString("neo4j.database"),
131-
MaxRetries: 3,
132-
BatchSize: 1000,
133-
}
134-
135-
driver, err := neo4j.NewDriverWithContext(
136-
neo4jConfig.URI,
137-
neo4j.BasicAuth(neo4jConfig.Username, neo4jConfig.Password, ""),
138-
)
139-
if err != nil {
140-
return nil, fmt.Errorf("failed to connect to Neo4j: %w", err)
141-
}
142-
143-
return driver, nil
144-
}
145-
146-
func createMCPServer(config *mcpconfig.Config) *mcp.Server {
119+
func createMCPServer(config *mcpconfig.Config) (*mcp.Server, func()) {
147120
logger.Info("Creating MCP server with full service configuration")
148121

149122
// Initialize Neo4j repository
@@ -160,17 +133,7 @@ func createMCPServer(config *mcpconfig.Config) *mcp.Server {
160133
if err != nil {
161134
logger.Error("Failed to create Neo4j repository", "error", err)
162135
// Return server with nil services as fallback, but log the error
163-
return mcp.NewServer(config, nil, nil, nil, nil)
164-
}
165-
166-
// Create Neo4j driver for service adapter
167-
driver, err := neo4j.NewDriverWithContext(
168-
neo4jConfig.URI,
169-
neo4j.BasicAuth(neo4jConfig.Username, neo4jConfig.Password, ""),
170-
)
171-
if err != nil {
172-
logger.Error("Failed to create Neo4j driver", "error", err)
173-
return mcp.NewServer(config, nil, nil, nil, nil)
136+
return mcp.NewServer(config, nil, nil, nil, nil), func() {}
174137
}
175138

176139
// Create core services
@@ -187,8 +150,8 @@ func createMCPServer(config *mcpconfig.Config) *mcp.Server {
187150
graph.DefaultServiceConfig(),
188151
)
189152

190-
// Create service adapter
191-
serviceAdapter := mcp.NewServiceAdapter(driver, graphService, parserService, analyzerService)
153+
// Create service adapter using repository instead of driver
154+
serviceAdapter := mcp.NewServiceAdapter(repository, graphService, parserService, analyzerService)
192155

193156
// Initialize LLM service if OpenAI API key is available
194157
var llmService llm.CypherTranslator
@@ -213,7 +176,16 @@ func createMCPServer(config *mcpconfig.Config) *mcp.Server {
213176
}
214177

215178
// TODO: Initialize context generator and query builder when implemented
216-
return mcp.NewServer(config, serviceAdapter, llmService, nil, nil)
179+
server := mcp.NewServer(config, serviceAdapter, llmService, nil, nil)
180+
181+
// Return cleanup function to close repository
182+
cleanup := func() {
183+
if err := repository.Close(); err != nil {
184+
logger.Error("Failed to close repository", "error", err)
185+
}
186+
}
187+
188+
return server, cleanup
217189
}
218190

219191
func runMCPServerWithGracefulShutdown(ctx context.Context, cancel context.CancelFunc, server *mcp.Server) {

engine/infra/neo4j_repository.go

Lines changed: 56 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"encoding/json"
66
"fmt"
77
"reflect"
8+
"sync"
89
"time"
910

1011
"github.com/compozy/gograph/engine/core"
@@ -25,6 +26,9 @@ type Neo4jConfig struct {
2526
BatchSize int // Batch size for bulk operations
2627
}
2728

29+
// Global mutex to prevent concurrent index creation across all repository instances
30+
var indexCreationMutex sync.Mutex
31+
2832
// Neo4jRepository implements the graph.Repository interface
2933
type Neo4jRepository struct {
3034
driver neo4j.DriverWithContext
@@ -131,7 +135,9 @@ func (r *Neo4jRepository) Close() error {
131135

132136
// CreateNode creates a new node in the graph
133137
func (r *Neo4jRepository) CreateNode(ctx context.Context, node *core.Node) error {
134-
session := r.driver.NewSession(ctx, neo4j.SessionConfig{})
138+
session := r.driver.NewSession(ctx, neo4j.SessionConfig{
139+
DatabaseName: r.config.Database,
140+
})
135141
defer session.Close(ctx)
136142

137143
// Build the query with dynamic properties
@@ -176,7 +182,9 @@ func (r *Neo4jRepository) CreateNodes(ctx context.Context, nodes []core.Node) er
176182
return nil
177183
}
178184

179-
session := r.driver.NewSession(ctx, neo4j.SessionConfig{})
185+
session := r.driver.NewSession(ctx, neo4j.SessionConfig{
186+
DatabaseName: r.config.Database,
187+
})
180188
defer session.Close(ctx)
181189

182190
// Determine batch size (default to 1000 if not configured)
@@ -251,7 +259,9 @@ func (r *Neo4jRepository) CreateNodes(ctx context.Context, nodes []core.Node) er
251259

252260
// GetNode retrieves a node by ID
253261
func (r *Neo4jRepository) GetNode(ctx context.Context, id core.ID) (*core.Node, error) {
254-
session := r.driver.NewSession(ctx, neo4j.SessionConfig{})
262+
session := r.driver.NewSession(ctx, neo4j.SessionConfig{
263+
DatabaseName: r.config.Database,
264+
})
255265
defer session.Close(ctx)
256266

257267
query := `
@@ -290,7 +300,9 @@ func (r *Neo4jRepository) GetNode(ctx context.Context, id core.ID) (*core.Node,
290300

291301
// UpdateNode updates an existing node
292302
func (r *Neo4jRepository) UpdateNode(ctx context.Context, node *core.Node) error {
293-
session := r.driver.NewSession(ctx, neo4j.SessionConfig{})
303+
session := r.driver.NewSession(ctx, neo4j.SessionConfig{
304+
DatabaseName: r.config.Database,
305+
})
294306
defer session.Close(ctx)
295307

296308
query := `
@@ -320,7 +332,9 @@ func (r *Neo4jRepository) UpdateNode(ctx context.Context, node *core.Node) error
320332

321333
// DeleteNode deletes a node by ID
322334
func (r *Neo4jRepository) DeleteNode(ctx context.Context, id core.ID) error {
323-
session := r.driver.NewSession(ctx, neo4j.SessionConfig{})
335+
session := r.driver.NewSession(ctx, neo4j.SessionConfig{
336+
DatabaseName: r.config.Database,
337+
})
324338
defer session.Close(ctx)
325339

326340
query := `
@@ -347,7 +361,9 @@ func (r *Neo4jRepository) DeleteNode(ctx context.Context, id core.ID) error {
347361

348362
// CreateRelationship creates a new relationship
349363
func (r *Neo4jRepository) CreateRelationship(ctx context.Context, rel *core.Relationship) error {
350-
session := r.driver.NewSession(ctx, neo4j.SessionConfig{})
364+
session := r.driver.NewSession(ctx, neo4j.SessionConfig{
365+
DatabaseName: r.config.Database,
366+
})
351367
defer session.Close(ctx)
352368

353369
// Build the relationship properties map (excluding node matching properties)
@@ -395,7 +411,9 @@ func (r *Neo4jRepository) CreateRelationships(ctx context.Context, rels []core.R
395411
return nil
396412
}
397413

398-
session := r.driver.NewSession(ctx, neo4j.SessionConfig{})
414+
session := r.driver.NewSession(ctx, neo4j.SessionConfig{
415+
DatabaseName: r.config.Database,
416+
})
399417
defer session.Close(ctx)
400418

401419
// Determine batch size (default to 1000 if not configured)
@@ -476,7 +494,9 @@ func (r *Neo4jRepository) CreateRelationships(ctx context.Context, rels []core.R
476494

477495
// GetRelationship retrieves a relationship by ID
478496
func (r *Neo4jRepository) GetRelationship(ctx context.Context, id core.ID) (*core.Relationship, error) {
479-
session := r.driver.NewSession(ctx, neo4j.SessionConfig{})
497+
session := r.driver.NewSession(ctx, neo4j.SessionConfig{
498+
DatabaseName: r.config.Database,
499+
})
480500
defer session.Close(ctx)
481501

482502
query := `
@@ -515,7 +535,9 @@ func (r *Neo4jRepository) GetRelationship(ctx context.Context, id core.ID) (*cor
515535

516536
// DeleteRelationship deletes a relationship by ID
517537
func (r *Neo4jRepository) DeleteRelationship(ctx context.Context, id core.ID) error {
518-
session := r.driver.NewSession(ctx, neo4j.SessionConfig{})
538+
session := r.driver.NewSession(ctx, neo4j.SessionConfig{
539+
DatabaseName: r.config.Database,
540+
})
519541
defer session.Close(ctx)
520542

521543
query := `
@@ -587,7 +609,9 @@ func (r *Neo4jRepository) ImportAnalysisResult(ctx context.Context, result *core
587609
}
588610

589611
// Use a single session for the entire import to ensure consistency
590-
session := r.driver.NewSession(ctx, neo4j.SessionConfig{})
612+
session := r.driver.NewSession(ctx, neo4j.SessionConfig{
613+
DatabaseName: r.config.Database,
614+
})
591615
defer session.Close(ctx)
592616

593617
// Execute the entire import in a single write transaction to ensure consistency
@@ -781,7 +805,13 @@ func (r *Neo4jRepository) createRelationshipsInTransaction(
781805

782806
// ensureIndexes creates indexes for better query performance on large codebases
783807
func (r *Neo4jRepository) ensureIndexes(ctx context.Context) error {
784-
session := r.driver.NewSession(ctx, neo4j.SessionConfig{})
808+
// Use global mutex to prevent concurrent index creation across all instances
809+
indexCreationMutex.Lock()
810+
defer indexCreationMutex.Unlock()
811+
812+
session := r.driver.NewSession(ctx, neo4j.SessionConfig{
813+
DatabaseName: r.config.Database,
814+
})
785815
defer session.Close(ctx)
786816

787817
// Create single-property indexes
@@ -946,7 +976,9 @@ func (r *Neo4jRepository) createConstraints(ctx context.Context, session neo4j.S
946976

947977
// createProjectMetadata creates a metadata node for the project
948978
func (r *Neo4jRepository) createProjectMetadata(ctx context.Context, result *core.AnalysisResult) error {
949-
session := r.driver.NewSession(ctx, neo4j.SessionConfig{})
979+
session := r.driver.NewSession(ctx, neo4j.SessionConfig{
980+
DatabaseName: r.config.Database,
981+
})
950982
defer session.Close(ctx)
951983

952984
query := `
@@ -981,7 +1013,9 @@ func (r *Neo4jRepository) createProjectMetadata(ctx context.Context, result *cor
9811013

9821014
// ClearProject removes all nodes and relationships for a specific project
9831015
func (r *Neo4jRepository) ClearProject(ctx context.Context, projectID core.ID) error {
984-
session := r.driver.NewSession(ctx, neo4j.SessionConfig{})
1016+
session := r.driver.NewSession(ctx, neo4j.SessionConfig{
1017+
DatabaseName: r.config.Database,
1018+
})
9851019
defer session.Close(ctx)
9861020

9871021
// Only delete nodes and relationships for the specified project_id
@@ -1014,7 +1048,9 @@ func (r *Neo4jRepository) FindNodesByType(
10141048
nodeType core.NodeType,
10151049
projectID core.ID,
10161050
) ([]core.Node, error) {
1017-
session := r.driver.NewSession(ctx, neo4j.SessionConfig{})
1051+
session := r.driver.NewSession(ctx, neo4j.SessionConfig{
1052+
DatabaseName: r.config.Database,
1053+
})
10181054
defer session.Close(ctx)
10191055

10201056
query := fmt.Sprintf(`
@@ -1062,7 +1098,9 @@ func (r *Neo4jRepository) FindNodesByName(
10621098
name string,
10631099
projectID core.ID,
10641100
) ([]core.Node, error) {
1065-
session := r.driver.NewSession(ctx, neo4j.SessionConfig{})
1101+
session := r.driver.NewSession(ctx, neo4j.SessionConfig{
1102+
DatabaseName: r.config.Database,
1103+
})
10661104
defer session.Close(ctx)
10671105

10681106
query := `
@@ -1111,7 +1149,9 @@ func (r *Neo4jRepository) FindRelationshipsByType(
11111149
relType core.RelationType,
11121150
projectID core.ID,
11131151
) ([]core.Relationship, error) {
1114-
session := r.driver.NewSession(ctx, neo4j.SessionConfig{})
1152+
session := r.driver.NewSession(ctx, neo4j.SessionConfig{
1153+
DatabaseName: r.config.Database,
1154+
})
11151155
defer session.Close(ctx)
11161156

11171157
query := fmt.Sprintf(`

engine/mcp/graph_adapter.go

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,22 @@ package mcp
22

33
import (
44
"context"
5-
"fmt"
65

76
"github.com/compozy/gograph/engine/core"
87
"github.com/compozy/gograph/engine/graph"
9-
"github.com/neo4j/neo4j-go-driver/v5/neo4j"
108
)
119

1210
// GraphAdapter provides graph operations needed by MCP server
1311
type GraphAdapter struct {
14-
driver neo4j.DriverWithContext
15-
service graph.Service
12+
repository graph.Repository
13+
service graph.Service
1614
}
1715

1816
// NewGraphAdapter creates a new graph adapter
19-
func NewGraphAdapter(driver neo4j.DriverWithContext, service graph.Service) *GraphAdapter {
17+
func NewGraphAdapter(repository graph.Repository, service graph.Service) *GraphAdapter {
2018
return &GraphAdapter{
21-
driver: driver,
22-
service: service,
19+
repository: repository,
20+
service: service,
2321
}
2422
}
2523

@@ -29,29 +27,7 @@ func (a *GraphAdapter) ExecuteQuery(
2927
query string,
3028
params map[string]any,
3129
) ([]map[string]any, error) {
32-
session := a.driver.NewSession(ctx, neo4j.SessionConfig{})
33-
defer session.Close(ctx)
34-
35-
result, err := session.Run(ctx, query, params)
36-
if err != nil {
37-
return nil, fmt.Errorf("failed to execute query: %w", err)
38-
}
39-
40-
var results []map[string]any
41-
for result.Next(ctx) {
42-
record := result.Record()
43-
row := make(map[string]any)
44-
for i, key := range record.Keys {
45-
row[key] = record.Values[i]
46-
}
47-
results = append(results, row)
48-
}
49-
50-
if err = result.Err(); err != nil {
51-
return nil, fmt.Errorf("error processing results: %w", err)
52-
}
53-
54-
return results, nil
30+
return a.repository.ExecuteQuery(ctx, query, params)
5531
}
5632

5733
// GetProjectStatistics delegates to the underlying service

0 commit comments

Comments
 (0)