From edb87cd2e71f40daf7d3e4c11e83987e530f42d7 Mon Sep 17 00:00:00 2001 From: hgaol Date: Mon, 2 Mar 2026 18:37:04 +0800 Subject: [PATCH 1/6] feat: support semantic search in AI chat and embedding ability --- cmd/wire_gen.go | 22 +- docs/docs.go | 23 + docs/swagger.json | 23 + docs/swagger.yaml | 17 + i18n/en_US.yaml | 18 + i18n/zh_CN.yaml | 18 + internal/base/constant/ai_config.go | 2 + internal/controller/ai_controller.go | 7 + internal/controller/mcp_controller.go | 132 +++++ internal/entity/embedding.go | 59 ++ internal/migrations/migrations.go | 1 + internal/migrations/v32.go | 9 + internal/repo/embedding/embedding_repo.go | 197 +++++++ internal/repo/provider.go | 2 + internal/schema/mcp_schema.go | 62 ++- internal/schema/mcp_tools/mcp_tools.go | 15 + internal/schema/siteinfo_schema.go | 13 +- .../service/embedding/embedding_service.go | 516 ++++++++++++++++++ internal/service/provider.go | 2 + internal/service/siteinfo/siteinfo_service.go | 13 +- ui/src/common/interface.ts | 5 + ui/src/pages/Admin/AiSettings/index.tsx | 175 ++++++ 22 files changed, 1307 insertions(+), 24 deletions(-) create mode 100644 internal/entity/embedding.go create mode 100644 internal/repo/embedding/embedding_repo.go create mode 100644 internal/service/embedding/embedding_service.go diff --git a/cmd/wire_gen.go b/cmd/wire_gen.go index 9fe134ed6..d5aa06c7f 100644 --- a/cmd/wire_gen.go +++ b/cmd/wire_gen.go @@ -34,7 +34,7 @@ import ( "github.com/apache/answer/internal/base/server" "github.com/apache/answer/internal/base/translator" "github.com/apache/answer/internal/controller" - "github.com/apache/answer/internal/controller/template_render" + templaterender "github.com/apache/answer/internal/controller/template_render" "github.com/apache/answer/internal/controller_admin" "github.com/apache/answer/internal/repo/activity" "github.com/apache/answer/internal/repo/activity_common" @@ -49,6 +49,7 @@ import ( "github.com/apache/answer/internal/repo/collection" "github.com/apache/answer/internal/repo/comment" "github.com/apache/answer/internal/repo/config" + "github.com/apache/answer/internal/repo/embedding" "github.com/apache/answer/internal/repo/export" "github.com/apache/answer/internal/repo/file_record" "github.com/apache/answer/internal/repo/limit" @@ -76,17 +77,18 @@ import ( activity_common2 "github.com/apache/answer/internal/service/activity_common" "github.com/apache/answer/internal/service/activityqueue" ai_conversation2 "github.com/apache/answer/internal/service/ai_conversation" - "github.com/apache/answer/internal/service/answer_common" + answercommon "github.com/apache/answer/internal/service/answer_common" "github.com/apache/answer/internal/service/apikey" auth2 "github.com/apache/answer/internal/service/auth" badge2 "github.com/apache/answer/internal/service/badge" collection2 "github.com/apache/answer/internal/service/collection" - "github.com/apache/answer/internal/service/collection_common" + collectioncommon "github.com/apache/answer/internal/service/collection_common" comment2 "github.com/apache/answer/internal/service/comment" "github.com/apache/answer/internal/service/comment_common" config2 "github.com/apache/answer/internal/service/config" "github.com/apache/answer/internal/service/content" "github.com/apache/answer/internal/service/dashboard" + embedding2 "github.com/apache/answer/internal/service/embedding" "github.com/apache/answer/internal/service/eventqueue" export2 "github.com/apache/answer/internal/service/export" "github.com/apache/answer/internal/service/feature_toggle" @@ -94,13 +96,13 @@ import ( "github.com/apache/answer/internal/service/follow" "github.com/apache/answer/internal/service/importer" meta2 "github.com/apache/answer/internal/service/meta" - "github.com/apache/answer/internal/service/meta_common" + metacommon "github.com/apache/answer/internal/service/meta_common" "github.com/apache/answer/internal/service/noticequeue" "github.com/apache/answer/internal/service/notification" - "github.com/apache/answer/internal/service/notification_common" + notificationcommon "github.com/apache/answer/internal/service/notification_common" "github.com/apache/answer/internal/service/object_info" "github.com/apache/answer/internal/service/plugin_common" - "github.com/apache/answer/internal/service/question_common" + questioncommon "github.com/apache/answer/internal/service/question_common" rank2 "github.com/apache/answer/internal/service/rank" reason2 "github.com/apache/answer/internal/service/reason" report2 "github.com/apache/answer/internal/service/report" @@ -116,7 +118,7 @@ import ( tag_common2 "github.com/apache/answer/internal/service/tag_common" "github.com/apache/answer/internal/service/uploader" "github.com/apache/answer/internal/service/user_admin" - "github.com/apache/answer/internal/service/user_common" + usercommon "github.com/apache/answer/internal/service/user_common" user_external_login2 "github.com/apache/answer/internal/service/user_external_login" user_notification_config2 "github.com/apache/answer/internal/service/user_notification_config" "github.com/segmentfault/pacman" @@ -247,7 +249,9 @@ func initApplication(debug bool, serverConf *conf.Server, dbConf *data.Database, reasonService := reason2.NewReasonService(reasonRepo) reasonController := controller.NewReasonController(reasonService) themeController := controller_admin.NewThemeController() - siteInfoService := siteinfo.NewSiteInfoService(siteInfoRepo, siteInfoCommonService, emailService, tagCommonService, configService, questionCommon, fileRecordService) + embeddingRepo := embedding.NewEmbeddingRepo(dataData) + embeddingService := embedding2.NewEmbeddingService(embeddingRepo, searchService, answerService, questionCommon, commentRepo, siteInfoCommonService) + siteInfoService := siteinfo.NewSiteInfoService(siteInfoRepo, siteInfoCommonService, emailService, tagCommonService, configService, questionCommon, fileRecordService, embeddingService) siteInfoController := controller_admin.NewSiteInfoController(siteInfoService) controllerSiteInfoController := controller.NewSiteInfoController(siteInfoCommonService) notificationCommon := notificationcommon.NewNotificationCommon(dataData, notificationRepo, userCommon, activityRepo, followRepo, objService, noticequeueService, userExternalLoginRepo, siteInfoCommonService) @@ -283,7 +287,7 @@ func initApplication(debug bool, serverConf *conf.Server, dbConf *data.Database, apiKeyService := apikey.NewAPIKeyService(apiKeyRepo) adminAPIKeyController := controller_admin.NewAdminAPIKeyController(apiKeyService) featureToggleService := feature_toggle.NewFeatureToggleService(siteInfoRepo) - mcpController := controller.NewMCPController(searchService, siteInfoCommonService, tagCommonService, questionCommon, commentRepo, userCommon, answerRepo, featureToggleService) + mcpController := controller.NewMCPController(searchService, siteInfoCommonService, tagCommonService, questionCommon, commentRepo, userCommon, answerRepo, featureToggleService, embeddingService) aiConversationRepo := ai_conversation.NewAIConversationRepo(dataData) aiConversationService := ai_conversation2.NewAIConversationService(aiConversationRepo, userCommon) aiController := controller.NewAIController(searchService, siteInfoCommonService, tagCommonService, questionCommon, commentRepo, userCommon, answerRepo, mcpController, aiConversationService, featureToggleService) diff --git a/docs/docs.go b/docs/docs.go index 57a23d432..8f65e05da 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -11712,6 +11712,24 @@ const docTemplate = `{ "type": "string", "maxLength": 256 }, + "embedding_crontab": { + "type": "string", + "maxLength": 100 + }, + "embedding_dimensions": { + "type": "integer" + }, + "embedding_level": { + "type": "string", + "enum": [ + "question", + "answer" + ] + }, + "embedding_model": { + "type": "string", + "maxLength": 100 + }, "model": { "type": "string", "maxLength": 100 @@ -11719,6 +11737,11 @@ const docTemplate = `{ "provider": { "type": "string", "maxLength": 50 + }, + "similarity_threshold": { + "type": "number", + "maximum": 1, + "minimum": 0 } } }, diff --git a/docs/swagger.json b/docs/swagger.json index dac2b38fd..71e802dfa 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -11685,6 +11685,24 @@ "type": "string", "maxLength": 256 }, + "embedding_crontab": { + "type": "string", + "maxLength": 100 + }, + "embedding_dimensions": { + "type": "integer" + }, + "embedding_level": { + "type": "string", + "enum": [ + "question", + "answer" + ] + }, + "embedding_model": { + "type": "string", + "maxLength": 100 + }, "model": { "type": "string", "maxLength": 100 @@ -11692,6 +11710,11 @@ "provider": { "type": "string", "maxLength": 50 + }, + "similarity_threshold": { + "type": "number", + "maximum": 1, + "minimum": 0 } } }, diff --git a/docs/swagger.yaml b/docs/swagger.yaml index 7a7adb681..5dc68d6c4 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -2250,12 +2250,29 @@ definitions: api_key: maxLength: 256 type: string + embedding_crontab: + maxLength: 100 + type: string + embedding_dimensions: + type: integer + embedding_level: + enum: + - question + - answer + type: string + embedding_model: + maxLength: 100 + type: string model: maxLength: 100 type: string provider: maxLength: 50 type: string + similarity_threshold: + maximum: 1 + minimum: 0 + type: number type: object schema.SiteAIReq: properties: diff --git a/i18n/en_US.yaml b/i18n/en_US.yaml index 9a0d198b3..ec385e763 100644 --- a/i18n/en_US.yaml +++ b/i18n/en_US.yaml @@ -2355,6 +2355,24 @@ ui: label: Model msg: Model is required add_success: AI settings updated successfully. + embedding_settings: Embedding Settings + embedding_model: + label: Embedding model + text: "The model used to generate vector embeddings for semantic search (e.g. text-embedding-3-small)." + embedding_dimensions: + label: Embedding dimensions + text: "The number of dimensions for the embedding vectors (e.g. 1536 for text-embedding-3-small)." + embedding_level: + label: Embedding level + text: "Choose whether to create embeddings at the question level (question + all answers + comments) or answer level (each answer separately)." + question: Question level + answer: Answer level + embedding_crontab: + label: Embedding schedule + text: "Cron expression for periodic embedding calculation (e.g. '0 */6 * * *' for every 6 hours). Leave empty to disable automatic indexing." + similarity_threshold: + label: Similarity threshold + text: "Minimum cosine similarity score (0-1) for semantic search results. Only results with a score above this threshold will be returned. Default is 0 (no filtering)." conversations: topic: Topic helpful: Helpful diff --git a/i18n/zh_CN.yaml b/i18n/zh_CN.yaml index f16ed9fad..191619dd1 100644 --- a/i18n/zh_CN.yaml +++ b/i18n/zh_CN.yaml @@ -2319,6 +2319,24 @@ ui: label: 模型 msg: 模型是必需的 add_success: AI 设置更新成功。 + embedding_settings: Embedding 设置 + embedding_model: + label: Embedding 模型 + text: "用于生成语义搜索向量 Embedding 的模型(例如 text-embedding-3-small)。" + embedding_dimensions: + label: Embedding 维度 + text: "Embedding 向量的维度数(例如 text-embedding-3-small 为 1536)。" + embedding_level: + label: Embedding 级别 + text: "选择在问题级别(问题 + 所有回答 + 评论)还是回答级别(每个回答单独)创建 Embedding。" + question: 问题级别 + answer: 回答级别 + embedding_crontab: + label: Embedding 计划 + text: "定期计算 Embedding 的 Cron 表达式(例如 '0 */6 * * *' 表示每 6 小时)。留空则禁用自动索引。" + similarity_threshold: + label: 相似度阈值 + text: "语义搜索结果的最低余弦相似度分数(0-1)。只有分数高于此阈值的结果才会被返回。默认值为 0(不过滤)。" conversations: topic: 主题 helpful: 有帮助 diff --git a/internal/base/constant/ai_config.go b/internal/base/constant/ai_config.go index aa733bbaf..a25e47a45 100644 --- a/internal/base/constant/ai_config.go +++ b/internal/base/constant/ai_config.go @@ -33,6 +33,7 @@ const ( - get_tags: 搜索标签信息 - get_tag_detail: 获取特定标签的详细信息 - get_user: 搜索用户信息 +- semantic_search: 通过语义相似度搜索问题和答案。当用户的问题与现有内容概念相关但可能不匹配确切关键词时使用此工具。当 get_questions 关键词搜索返回较差结果时,请使用 semantic_search。 请根据用户的问题智能地使用这些工具来提供准确的答案。如果需要查询系统信息,请先使用相应的工具获取数据。` DefaultAIPromptConfigEnUS = `You are an intelligent assistant that can help users query information in the system. User question: %s @@ -44,6 +45,7 @@ You can use the following tools to query system information: - get_tags: Search for tag information - get_tag_detail: Get detailed information about a specific tag - get_user: Search for user information +- semantic_search: Search questions and answers by semantic meaning. Use this when the user's question relates conceptually to existing content but may not match exact keywords. When get_questions keyword search returns poor results, use semantic_search instead. Please intelligently use these tools based on the user's question to provide accurate answers. If you need to query system information, please use the appropriate tools to get the data first.` ) diff --git a/internal/controller/ai_controller.go b/internal/controller/ai_controller.go index e020ed30e..125cdab22 100644 --- a/internal/controller/ai_controller.go +++ b/internal/controller/ai_controller.go @@ -446,6 +446,7 @@ func (c *AIController) handleAIConversation(ctx *gin.Context, w http.ResponseWri toolCalls, newMessages, finished, aiResponse := c.processAIStream(ctx, w, id, conversationCtx.Model, client, aiReq, messages) messages = newMessages + log.Debugf("Round %d: toolCalls=%v", round+1, toolCalls) if aiResponse != "" { conversationCtx.Messages = append(conversationCtx.Messages, &ai_conversation.ConversationMessage{ Role: "assistant", @@ -497,6 +498,10 @@ func (c *AIController) processAIStream( break } + if len(response.Choices) == 0 { + continue + } + choice := response.Choices[0] if len(choice.Delta.ToolCalls) > 0 { @@ -735,6 +740,8 @@ func (c *AIController) callMCPTool(ctx context.Context, toolName string, argumen result, err = c.mcpController.MCPTagDetailsHandler()(ctx, request) case "get_user": result, err = c.mcpController.MCPUserDetailsHandler()(ctx, request) + case "semantic_search": + result, err = c.mcpController.MCPSemanticSearchHandler()(ctx, request) default: return "", fmt.Errorf("unknown tool: %s", toolName) } diff --git a/internal/controller/mcp_controller.go b/internal/controller/mcp_controller.go index d52f57979..fecdbef60 100644 --- a/internal/controller/mcp_controller.go +++ b/internal/controller/mcp_controller.go @@ -31,6 +31,7 @@ import ( answercommon "github.com/apache/answer/internal/service/answer_common" "github.com/apache/answer/internal/service/comment" "github.com/apache/answer/internal/service/content" + "github.com/apache/answer/internal/service/embedding" "github.com/apache/answer/internal/service/feature_toggle" questioncommon "github.com/apache/answer/internal/service/question_common" "github.com/apache/answer/internal/service/siteinfo_common" @@ -49,6 +50,7 @@ type MCPController struct { userCommon *usercommon.UserCommon answerRepo answercommon.AnswerRepo featureToggleSvc *feature_toggle.FeatureToggleService + embeddingService *embedding.EmbeddingService } // NewMCPController new site info controller. @@ -61,6 +63,7 @@ func NewMCPController( userCommon *usercommon.UserCommon, answerRepo answercommon.AnswerRepo, featureToggleSvc *feature_toggle.FeatureToggleService, + embeddingService *embedding.EmbeddingService, ) *MCPController { return &MCPController{ searchService: searchService, @@ -71,6 +74,7 @@ func NewMCPController( userCommon: userCommon, answerRepo: answerRepo, featureToggleSvc: featureToggleSvc, + embeddingService: embeddingService, } } @@ -349,3 +353,131 @@ func (c *MCPController) MCPUserDetailsHandler() func(ctx context.Context, reques return mcp.NewToolResultText(string(res)), nil } } + +func (c *MCPController) MCPSemanticSearchHandler() func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + if err := c.ensureMCPEnabled(ctx); err != nil { + return nil, err + } + cond := schema.NewMCPSemanticSearchCond(request) + if len(cond.Query) == 0 { + return mcp.NewToolResultText("Query is required for semantic search."), nil + } + + siteGeneral, err := c.siteInfoService.GetSiteGeneral(ctx) + if err != nil { + log.Errorf("get site general info failed: %v", err) + return nil, err + } + + results, err := c.embeddingService.SearchSimilar(ctx, cond.Query, cond.TopK) + if err != nil { + log.Errorf("semantic search failed: %v", err) + return mcp.NewToolResultText("Semantic search is not available. Embedding may not be configured."), nil + } + if len(results) == 0 { + return mcp.NewToolResultText("No semantically similar content found."), nil + } + + resp := make([]*schema.MCPSemanticSearchResp, 0, len(results)) + for _, r := range results { + var meta entity.EmbeddingMetadata + _ = json.Unmarshal([]byte(r.Metadata), &meta) + + item := &schema.MCPSemanticSearchResp{ + ObjectID: r.ObjectID, + ObjectType: r.ObjectType, + Score: r.Score, + } + + // Compose link from metadata + if r.ObjectType == "answer" && meta.AnswerID != "" { + item.Link = fmt.Sprintf("%s/questions/%s/%s", siteGeneral.SiteUrl, meta.QuestionID, meta.AnswerID) + } else { + item.Link = fmt.Sprintf("%s/questions/%s", siteGeneral.SiteUrl, meta.QuestionID) + } + + // Query content from DB using IDs stored in metadata + if r.ObjectType == "question" { + question, qErr := c.questioncommon.Info(ctx, meta.QuestionID, "") + if qErr != nil { + log.Warnf("get question %s for semantic search failed: %v", meta.QuestionID, qErr) + } else { + item.Title = question.Title + item.Content = question.Content + } + + // Fetch answers by ID from metadata + for _, a := range meta.Answers { + answerEntity, exist, aErr := c.answerRepo.GetAnswer(ctx, a.AnswerID) + if aErr != nil || !exist { + continue + } + answerItem := &schema.MCPSemanticSearchAnswer{ + AnswerID: a.AnswerID, + Content: answerEntity.OriginalText, + } + // Fetch comments on this answer from DB + for _, ac := range a.Comments { + cmt, cExist, cErr := c.commentRepo.GetComment(ctx, ac.CommentID) + if cErr == nil && cExist { + answerItem.Comments = append(answerItem.Comments, &schema.MCPSemanticSearchComment{ + CommentID: ac.CommentID, + Content: cmt.OriginalText, + }) + } + } + item.Answers = append(item.Answers, answerItem) + } + + // Fetch question comments from DB + for _, qc := range meta.Comments { + cmt, cExist, cErr := c.commentRepo.GetComment(ctx, qc.CommentID) + if cErr == nil && cExist { + item.Comments = append(item.Comments, &schema.MCPSemanticSearchComment{ + CommentID: qc.CommentID, + Content: cmt.OriginalText, + }) + } + } + } else if r.ObjectType == "answer" { + // Fetch question title for context + question, qErr := c.questioncommon.Info(ctx, meta.QuestionID, "") + if qErr == nil { + item.Title = question.Title + } + + // Fetch answer content from DB + if meta.AnswerID != "" { + answerEntity, exist, aErr := c.answerRepo.GetAnswer(ctx, meta.AnswerID) + if aErr == nil && exist { + item.Content = answerEntity.OriginalText + } + } else if len(meta.Answers) > 0 { + answerEntity, exist, aErr := c.answerRepo.GetAnswer(ctx, meta.Answers[0].AnswerID) + if aErr == nil && exist { + item.Content = answerEntity.OriginalText + } + } + + // Fetch answer comments from DB + if len(meta.Answers) > 0 { + for _, ac := range meta.Answers[0].Comments { + cmt, cExist, cErr := c.commentRepo.GetComment(ctx, ac.CommentID) + if cErr == nil && cExist { + item.Comments = append(item.Comments, &schema.MCPSemanticSearchComment{ + CommentID: ac.CommentID, + Content: cmt.OriginalText, + }) + } + } + } + } + + resp = append(resp, item) + } + + data, _ := json.Marshal(resp) + return mcp.NewToolResultText(string(data)), nil + } +} diff --git a/internal/entity/embedding.go b/internal/entity/embedding.go new file mode 100644 index 000000000..3ea500d92 --- /dev/null +++ b/internal/entity/embedding.go @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package entity + +import "time" + +// Embedding stores vector embeddings for questions or answers. +type Embedding struct { + ID int `xorm:"not null pk autoincr INT(11) id"` + CreatedAt time.Time `xorm:"created not null default CURRENT_TIMESTAMP TIMESTAMP created_at"` + UpdatedAt time.Time `xorm:"updated not null default CURRENT_TIMESTAMP TIMESTAMP updated_at"` + ObjectID string `xorm:"not null BIGINT(20) INDEX object_id unique(object_embedding)"` + ObjectType string `xorm:"not null default '' VARCHAR(20) object_type unique(object_embedding)"` + ContentHash string `xorm:"not null default '' VARCHAR(64) content_hash"` + Metadata string `xorm:"not null MEDIUMTEXT metadata"` + Embedding string `xorm:"not null MEDIUMTEXT embedding"` + Dimensions int `xorm:"not null default 0 INT(11) dimensions"` +} + +// TableName returns the table name +func (Embedding) TableName() string { + return "embedding" +} + +// EmbeddingMetadata holds IDs for URI composition and content retrieval at query time. +type EmbeddingMetadata struct { + QuestionID string `json:"question_id"` + AnswerID string `json:"answer_id,omitempty"` + Answers []EmbeddingMetadataAnswer `json:"answers,omitempty"` + Comments []EmbeddingMetadataComment `json:"comments,omitempty"` +} + +// EmbeddingMetadataAnswer stores answer ID and comment IDs in metadata. +type EmbeddingMetadataAnswer struct { + AnswerID string `json:"answer_id"` + Comments []EmbeddingMetadataComment `json:"comments,omitempty"` +} + +// EmbeddingMetadataComment stores comment ID in metadata. +type EmbeddingMetadataComment struct { + CommentID string `json:"comment_id"` +} diff --git a/internal/migrations/migrations.go b/internal/migrations/migrations.go index 682a7b207..e6722cc08 100644 --- a/internal/migrations/migrations.go +++ b/internal/migrations/migrations.go @@ -108,6 +108,7 @@ var migrations = []Migration{ NewMigration("v1.8.0", "change admin menu", updateAdminMenuSettings, true), NewMigration("v1.8.1", "ai feat", aiFeat, true), NewMigration("v2.0.1", "change avatar type to text", updateAvatarType, false), + NewMigration("v2.0.2", "add embedding table", addEmbeddingTable, false), } func GetMigrations() []Migration { diff --git a/internal/migrations/v32.go b/internal/migrations/v32.go index fc6614b11..438097348 100644 --- a/internal/migrations/v32.go +++ b/internal/migrations/v32.go @@ -24,6 +24,7 @@ import ( "fmt" "github.com/apache/answer/internal/entity" + "github.com/segmentfault/pacman/log" "xorm.io/xorm" ) @@ -35,3 +36,11 @@ func updateAvatarType(ctx context.Context, x *xorm.Engine) error { } return nil } + +func addEmbeddingTable(ctx context.Context, x *xorm.Engine) error { + if err := x.Context(ctx).Sync(new(entity.Embedding)); err != nil { + return fmt.Errorf("sync embedding table failed: %w", err) + } + log.Info("Embedding table migration completed successfully") + return nil +} diff --git a/internal/repo/embedding/embedding_repo.go b/internal/repo/embedding/embedding_repo.go new file mode 100644 index 000000000..67bde1d5d --- /dev/null +++ b/internal/repo/embedding/embedding_repo.go @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package embedding + +import ( + "context" + "encoding/json" + "math" + "sort" + + "github.com/apache/answer/internal/base/data" + "github.com/apache/answer/internal/entity" + "github.com/segmentfault/pacman/log" + "xorm.io/builder" +) + +// EmbeddingRepo defines the interface for embedding data access. +type EmbeddingRepo interface { + Upsert(ctx context.Context, emb *entity.Embedding) error + GetByObjectID(ctx context.Context, objectID, objectType string) (*entity.Embedding, bool, error) + GetAll(ctx context.Context) ([]*entity.Embedding, error) + SearchSimilar(ctx context.Context, queryVector []float32, topK int) ([]SimilarResult, error) + DeleteByObjectID(ctx context.Context, objectID, objectType string) error + Count(ctx context.Context) (int64, error) +} + +// SimilarResult holds a similarity search result. +type SimilarResult struct { + ObjectID string `json:"object_id"` + ObjectType string `json:"object_type"` + Metadata string `json:"metadata"` + Score float64 `json:"score"` +} + +type embeddingRepo struct { + data *data.Data +} + +// NewEmbeddingRepo creates a new EmbeddingRepo. +func NewEmbeddingRepo(data *data.Data) EmbeddingRepo { + return &embeddingRepo{data: data} +} + +// Upsert inserts or updates an embedding by (object_id, object_type). +func (r *embeddingRepo) Upsert(ctx context.Context, emb *entity.Embedding) error { + existing := &entity.Embedding{} + exist, err := r.data.DB.Context(ctx). + Where(builder.Eq{"object_id": emb.ObjectID, "object_type": emb.ObjectType}). + Get(existing) + if err != nil { + log.Errorf("check embedding existence failed: %v", err) + return err + } + + if exist { + emb.ID = existing.ID + _, err = r.data.DB.Context(ctx).ID(existing.ID). + Cols("content_hash", "metadata", "embedding", "dimensions", "updated_at"). + Update(emb) + if err != nil { + log.Errorf("update embedding failed: %v", err) + return err + } + return nil + } + + _, err = r.data.DB.Context(ctx).Insert(emb) + if err != nil { + log.Errorf("insert embedding failed: %v", err) + return err + } + return nil +} + +// GetByObjectID returns an embedding by object ID and type. +func (r *embeddingRepo) GetByObjectID(ctx context.Context, objectID, objectType string) (*entity.Embedding, bool, error) { + emb := &entity.Embedding{} + exist, err := r.data.DB.Context(ctx). + Where(builder.Eq{"object_id": objectID, "object_type": objectType}). + Get(emb) + if err != nil { + log.Errorf("get embedding failed: %v", err) + return nil, false, err + } + return emb, exist, nil +} + +// GetAll returns all embeddings. +func (r *embeddingRepo) GetAll(ctx context.Context) ([]*entity.Embedding, error) { + var list []*entity.Embedding + err := r.data.DB.Context(ctx).Find(&list) + if err != nil { + log.Errorf("get all embeddings failed: %v", err) + return nil, err + } + return list, nil +} + +// SearchSimilar performs brute-force cosine similarity search in Go. +func (r *embeddingRepo) SearchSimilar(ctx context.Context, queryVector []float32, topK int) ([]SimilarResult, error) { + allEmbeddings, err := r.GetAll(ctx) + if err != nil { + return nil, err + } + + type scored struct { + emb *entity.Embedding + score float64 + } + results := make([]scored, 0, len(allEmbeddings)) + + for _, emb := range allEmbeddings { + var vec []float32 + if err := json.Unmarshal([]byte(emb.Embedding), &vec); err != nil { + log.Warnf("skip embedding id=%d, unmarshal failed: %v", emb.ID, err) + continue + } + score := cosineSimilarity(queryVector, vec) + results = append(results, scored{emb: emb, score: score}) + } + + sort.Slice(results, func(i, j int) bool { + return results[i].score > results[j].score + }) + + if topK > len(results) { + topK = len(results) + } + + out := make([]SimilarResult, 0, topK) + for i := 0; i < topK; i++ { + out = append(out, SimilarResult{ + ObjectID: results[i].emb.ObjectID, + ObjectType: results[i].emb.ObjectType, + Metadata: results[i].emb.Metadata, + Score: results[i].score, + }) + } + return out, nil +} + +// DeleteByObjectID deletes an embedding by object ID and type. +func (r *embeddingRepo) DeleteByObjectID(ctx context.Context, objectID, objectType string) error { + _, err := r.data.DB.Context(ctx). + Where(builder.Eq{"object_id": objectID, "object_type": objectType}). + Delete(&entity.Embedding{}) + if err != nil { + log.Errorf("delete embedding failed: %v", err) + return err + } + return nil +} + +// Count returns the total number of embeddings. +func (r *embeddingRepo) Count(ctx context.Context) (int64, error) { + count, err := r.data.DB.Context(ctx).Count(&entity.Embedding{}) + if err != nil { + log.Errorf("count embeddings failed: %v", err) + return 0, err + } + return count, nil +} + +// cosineSimilarity computes cosine similarity between two vectors. +func cosineSimilarity(a, b []float32) float64 { + if len(a) != len(b) || len(a) == 0 { + return 0 + } + var dotProduct, normA, normB float64 + for i := range a { + dotProduct += float64(a[i]) * float64(b[i]) + normA += float64(a[i]) * float64(a[i]) + normB += float64(b[i]) * float64(b[i]) + } + denom := math.Sqrt(normA) * math.Sqrt(normB) + if denom == 0 { + return 0 + } + return dotProduct / denom +} diff --git a/internal/repo/provider.go b/internal/repo/provider.go index 510a94aaa..2a9717d00 100644 --- a/internal/repo/provider.go +++ b/internal/repo/provider.go @@ -34,6 +34,7 @@ import ( "github.com/apache/answer/internal/repo/collection" "github.com/apache/answer/internal/repo/comment" "github.com/apache/answer/internal/repo/config" + "github.com/apache/answer/internal/repo/embedding" "github.com/apache/answer/internal/repo/export" "github.com/apache/answer/internal/repo/file_record" "github.com/apache/answer/internal/repo/limit" @@ -113,4 +114,5 @@ var ProviderSetRepo = wire.NewSet( file_record.NewFileRecordRepo, api_key.NewAPIKeyRepo, ai_conversation.NewAIConversationRepo, + embedding.NewEmbeddingRepo, ) diff --git a/internal/schema/mcp_schema.go b/internal/schema/mcp_schema.go index bead21c9d..9afee72ec 100644 --- a/internal/schema/mcp_schema.go +++ b/internal/schema/mcp_schema.go @@ -27,15 +27,17 @@ import ( ) const ( - MCPSearchCondKeyword = "keyword" - MCPSearchCondUsername = "username" - MCPSearchCondScore = "score" - MCPSearchCondTag = "tag" - MCPSearchCondPage = "page" - MCPSearchCondPageSize = "page_size" - MCPSearchCondTagName = "tag_name" - MCPSearchCondQuestionID = "question_id" - MCPSearchCondObjectID = "object_id" + MCPSearchCondKeyword = "keyword" + MCPSearchCondUsername = "username" + MCPSearchCondScore = "score" + MCPSearchCondTag = "tag" + MCPSearchCondPage = "page" + MCPSearchCondPageSize = "page_size" + MCPSearchCondTagName = "tag_name" + MCPSearchCondQuestionID = "question_id" + MCPSearchCondObjectID = "object_id" + MCPSearchCondSemanticQuery = "query" + MCPSearchCondTopK = "top_k" ) type MCPSearchCond struct { @@ -98,6 +100,48 @@ type MCPSearchCommentInfoResp struct { Link string `json:"link"` } +// MCPSemanticSearchCond is the condition for semantic search. +type MCPSemanticSearchCond struct { + Query string `json:"query"` + TopK int `json:"top_k"` +} + +// MCPSemanticSearchResp is a single semantic search result. +type MCPSemanticSearchResp struct { + ObjectID string `json:"object_id"` + ObjectType string `json:"object_type"` + Title string `json:"title"` + Content string `json:"content"` + Score float64 `json:"score"` + Link string `json:"link"` + Answers []*MCPSemanticSearchAnswer `json:"answers,omitempty"` + Comments []*MCPSemanticSearchComment `json:"comments,omitempty"` +} + +// MCPSemanticSearchAnswer is an answer in a semantic search result. +type MCPSemanticSearchAnswer struct { + AnswerID string `json:"answer_id"` + Content string `json:"content"` + Comments []*MCPSemanticSearchComment `json:"comments,omitempty"` +} + +// MCPSemanticSearchComment is a comment in a semantic search result. +type MCPSemanticSearchComment struct { + CommentID string `json:"comment_id"` + Content string `json:"content"` +} + +func NewMCPSemanticSearchCond(request mcp.CallToolRequest) *MCPSemanticSearchCond { + cond := &MCPSemanticSearchCond{TopK: 5} + if query, ok := getRequestValue(request, MCPSearchCondSemanticQuery); ok { + cond.Query = query + } + if topK, ok := getRequestNumber(request, MCPSearchCondTopK); ok && topK > 0 { + cond.TopK = topK + } + return cond +} + func NewMCPSearchCond(request mcp.CallToolRequest) *MCPSearchCond { cond := &MCPSearchCond{} if keyword, ok := getRequestValue(request, MCPSearchCondKeyword); ok { diff --git a/internal/schema/mcp_tools/mcp_tools.go b/internal/schema/mcp_tools/mcp_tools.go index 949a738c7..3ae6b3bea 100644 --- a/internal/schema/mcp_tools/mcp_tools.go +++ b/internal/schema/mcp_tools/mcp_tools.go @@ -32,6 +32,7 @@ var ( NewTagsTool(), NewTagDetailTool(), NewUserTool(), + NewSemanticSearchTool(), } ) @@ -103,3 +104,17 @@ func NewUserTool() mcp.Tool { ) return listFilesTool } + +func NewSemanticSearchTool() mcp.Tool { + tool := mcp.NewTool("semantic_search", + mcp.WithDescription("Search questions and answers by semantic meaning. Use this when the user's question relates conceptually to existing content but may not match exact keywords. Returns the most semantically similar content."), + mcp.WithString(schema.MCPSearchCondSemanticQuery, + mcp.Required(), + mcp.Description("The search query text to find semantically similar questions and answers"), + ), + mcp.WithNumber(schema.MCPSearchCondTopK, + mcp.Description("Maximum number of results to return (default 5)"), + ), + ) + return tool +} diff --git a/internal/schema/siteinfo_schema.go b/internal/schema/siteinfo_schema.go index bdf2308d3..84bc54e30 100644 --- a/internal/schema/siteinfo_schema.go +++ b/internal/schema/siteinfo_schema.go @@ -281,10 +281,15 @@ func (s *SiteAIResp) GetProvider() *SiteAIProvider { } type SiteAIProvider struct { - Provider string `validate:"omitempty,lte=50" form:"provider" json:"provider"` - APIHost string `validate:"omitempty,lte=512" form:"api_host" json:"api_host"` - APIKey string `validate:"omitempty,lte=256" form:"api_key" json:"api_key"` - Model string `validate:"omitempty,lte=100" form:"model" json:"model"` + Provider string `validate:"omitempty,lte=50" form:"provider" json:"provider"` + APIHost string `validate:"omitempty,lte=512" form:"api_host" json:"api_host"` + APIKey string `validate:"omitempty,lte=256" form:"api_key" json:"api_key"` + Model string `validate:"omitempty,lte=100" form:"model" json:"model"` + EmbeddingModel string `validate:"omitempty,lte=100" form:"embedding_model" json:"embedding_model"` + EmbeddingDimensions int `validate:"omitempty" form:"embedding_dimensions" json:"embedding_dimensions"` + EmbeddingLevel string `validate:"omitempty,oneof=question answer" form:"embedding_level" json:"embedding_level"` + EmbeddingCrontab string `validate:"omitempty,lte=100" form:"embedding_crontab" json:"embedding_crontab"` + SimilarityThreshold float64 `validate:"omitempty,gte=0,lte=1" form:"similarity_threshold" json:"similarity_threshold"` } // SiteAIResp AI configuration response diff --git a/internal/service/embedding/embedding_service.go b/internal/service/embedding/embedding_service.go new file mode 100644 index 000000000..72e008abb --- /dev/null +++ b/internal/service/embedding/embedding_service.go @@ -0,0 +1,516 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package embedding + +import ( + "context" + "crypto/sha256" + "encoding/json" + "fmt" + "strings" + "sync" + + "github.com/apache/answer/internal/base/pager" + "github.com/apache/answer/internal/entity" + embeddingRepo "github.com/apache/answer/internal/repo/embedding" + "github.com/apache/answer/internal/schema" + "github.com/apache/answer/internal/service/comment" + "github.com/apache/answer/internal/service/content" + questioncommon "github.com/apache/answer/internal/service/question_common" + "github.com/apache/answer/internal/service/siteinfo_common" + "github.com/robfig/cron/v3" + "github.com/sashabaranov/go-openai" + "github.com/segmentfault/pacman/log" +) + +const ( + EmbeddingLevelQuestion = "question" + EmbeddingLevelAnswer = "answer" +) + +// EmbeddingService handles embedding generation, text aggregation, and indexing. +type EmbeddingService struct { + embeddingRepo embeddingRepo.EmbeddingRepo + searchService *content.SearchService + answerService *content.AnswerService + questionCommon *questioncommon.QuestionCommon + commentRepo comment.CommentRepo + siteInfoService siteinfo_common.SiteInfoCommonService + + mu sync.Mutex + cronJob *cron.Cron + cronSpec string +} + +// NewEmbeddingService creates a new EmbeddingService. +func NewEmbeddingService( + embeddingRepo embeddingRepo.EmbeddingRepo, + searchService *content.SearchService, + answerService *content.AnswerService, + questionCommon *questioncommon.QuestionCommon, + commentRepo comment.CommentRepo, + siteInfoService siteinfo_common.SiteInfoCommonService, +) *EmbeddingService { + return &EmbeddingService{ + embeddingRepo: embeddingRepo, + searchService: searchService, + answerService: answerService, + questionCommon: questionCommon, + commentRepo: commentRepo, + siteInfoService: siteInfoService, + } +} + +// getAIConfig returns the current AI configuration. +func (s *EmbeddingService) getAIConfig(ctx context.Context) (*schema.SiteAIResp, *schema.SiteAIProvider, error) { + aiConfig, err := s.siteInfoService.GetSiteAI(ctx) + if err != nil { + return nil, nil, fmt.Errorf("get AI config failed: %w", err) + } + if !aiConfig.Enabled { + return nil, nil, fmt.Errorf("AI feature is disabled") + } + provider := aiConfig.GetProvider() + if provider.EmbeddingModel == "" { + return nil, nil, fmt.Errorf("embedding model not configured") + } + return aiConfig, provider, nil +} + +// createEmbeddingClient creates an OpenAI-compatible client for embedding requests. +func (s *EmbeddingService) createEmbeddingClient(provider *schema.SiteAIProvider) *openai.Client { + config := openai.DefaultConfig(provider.APIKey) + config.BaseURL = provider.APIHost + if !strings.HasSuffix(config.BaseURL, "/v1") { + config.BaseURL += "/v1" + } + return openai.NewClientWithConfig(config) +} + +// GenerateEmbedding generates an embedding vector for the given text. +func (s *EmbeddingService) GenerateEmbedding(ctx context.Context, text string) ([]float32, error) { + _, provider, err := s.getAIConfig(ctx) + if err != nil { + return nil, err + } + + client := s.createEmbeddingClient(provider) + resp, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequestStrings{ + Input: []string{text}, + Model: openai.EmbeddingModel(provider.EmbeddingModel), + }) + if err != nil { + return nil, fmt.Errorf("create embeddings failed: %w", err) + } + if len(resp.Data) == 0 { + return nil, fmt.Errorf("no embedding returned") + } + return resp.Data[0].Embedding, nil +} + +// ComputeContentHash computes SHA256 of the text. +func ComputeContentHash(text string) string { + h := sha256.Sum256([]byte(text)) + return fmt.Sprintf("%x", h) +} + +// BuildTextForQuestion aggregates question title + body + all answers + comments into one text. +// Uses SearchService and QuestionCommon to respect the plugin architecture. +func (s *EmbeddingService) BuildTextForQuestion(ctx context.Context, questionID string) (text string, meta *entity.EmbeddingMetadata, err error) { + // Get question detail via service layer + question, err := s.questionCommon.Info(ctx, questionID, "") + if err != nil { + return "", nil, fmt.Errorf("get question info failed: %w", err) + } + + meta = &entity.EmbeddingMetadata{ + QuestionID: questionID, + } + + var parts []string + parts = append(parts, fmt.Sprintf("Question: %s\n%s", question.Title, question.Content)) + + // Get answers via AnswerService + answerInfoList, _, err := s.answerService.SearchList(ctx, &schema.AnswerListReq{ + QuestionID: questionID, + Page: 1, + PageSize: 50, + }) + if err != nil { + log.Warnf("get answers for question %s failed: %v", questionID, err) + } else { + for _, a := range answerInfoList { + parts = append(parts, fmt.Sprintf("Answer: %s", a.Content)) + answerMeta := entity.EmbeddingMetadataAnswer{ + AnswerID: a.ID, + } + + // Get comments on this answer + answerComments, _, err := s.commentRepo.GetCommentPage(ctx, &comment.CommentQuery{ + PageCond: pager.PageCond{Page: 1, PageSize: 50}, + ObjectID: a.ID, + QueryCond: "newest", + }) + if err != nil { + log.Warnf("get comments for answer %s failed: %v", a.ID, err) + } else { + for _, c := range answerComments { + parts = append(parts, fmt.Sprintf("Comment on answer: %s", c.OriginalText)) + answerMeta.Comments = append(answerMeta.Comments, entity.EmbeddingMetadataComment{ + CommentID: c.ID, + }) + } + } + meta.Answers = append(meta.Answers, answerMeta) + } + } + + // Get comments on the question + commentList, _, err := s.commentRepo.GetCommentPage(ctx, &comment.CommentQuery{ + PageCond: pager.PageCond{Page: 1, PageSize: 50}, + ObjectID: questionID, + QueryCond: "newest", + }) + if err != nil { + log.Warnf("get comments for question %s failed: %v", questionID, err) + } else { + for _, c := range commentList { + parts = append(parts, fmt.Sprintf("Comment: %s", c.OriginalText)) + meta.Comments = append(meta.Comments, entity.EmbeddingMetadataComment{ + CommentID: c.ID, + }) + } + } + + return strings.Join(parts, "\n\n"), meta, nil +} + +// BuildTextForAnswer aggregates answer body + parent question title + answer comments into one text. +func (s *EmbeddingService) BuildTextForAnswer(ctx context.Context, answerID, questionID string) (text string, meta *entity.EmbeddingMetadata, err error) { + // Get parent question title + question, err := s.questionCommon.Info(ctx, questionID, "") + if err != nil { + return "", nil, fmt.Errorf("get question info for answer failed: %w", err) + } + + meta = &entity.EmbeddingMetadata{ + QuestionID: questionID, + AnswerID: answerID, + } + + // Get the specific answer's content via AnswerService + answerInfo, err := s.answerService.GetDetail(ctx, answerID) + if err != nil { + return "", nil, fmt.Errorf("get answer failed: %w", err) + } + + var answerText string + if answerInfo != nil { + answerText = answerInfo.Content + } + + var parts []string + parts = append(parts, fmt.Sprintf("Question: %s", question.Title)) + if answerText != "" { + parts = append(parts, fmt.Sprintf("Answer: %s", answerText)) + meta.Answers = append(meta.Answers, entity.EmbeddingMetadataAnswer{ + AnswerID: answerID, + }) + } + + // Get comments on the answer + commentList, _, err := s.commentRepo.GetCommentPage(ctx, &comment.CommentQuery{ + PageCond: pager.PageCond{Page: 1, PageSize: 50}, + ObjectID: answerID, + QueryCond: "newest", + }) + if err != nil { + log.Warnf("get comments for answer %s failed: %v", answerID, err) + } else { + for _, c := range commentList { + parts = append(parts, fmt.Sprintf("Comment: %s", c.OriginalText)) + if len(meta.Answers) > 0 { + meta.Answers[0].Comments = append(meta.Answers[0].Comments, entity.EmbeddingMetadataComment{ + CommentID: c.ID, + }) + } else { + meta.Comments = append(meta.Comments, entity.EmbeddingMetadataComment{ + CommentID: c.ID, + }) + } + } + } + + return strings.Join(parts, "\n\n"), meta, nil +} + +// IndexQuestion indexes a single question embedding. +func (s *EmbeddingService) IndexQuestion(ctx context.Context, questionID string) error { + text, meta, err := s.BuildTextForQuestion(ctx, questionID) + if err != nil { + return err + } + + contentHash := ComputeContentHash(text) + + // Check staleness + existing, exist, _ := s.embeddingRepo.GetByObjectID(ctx, questionID, EmbeddingLevelQuestion) + if exist && existing.ContentHash == contentHash { + return nil // already up to date + } + + vec, err := s.GenerateEmbedding(ctx, text) + if err != nil { + return fmt.Errorf("generate embedding for question %s failed: %w", questionID, err) + } + + metaJSON, _ := json.Marshal(meta) + vecJSON, _ := json.Marshal(vec) + + return s.embeddingRepo.Upsert(ctx, &entity.Embedding{ + ObjectID: questionID, + ObjectType: EmbeddingLevelQuestion, + ContentHash: contentHash, + Metadata: string(metaJSON), + Embedding: string(vecJSON), + Dimensions: len(vec), + }) +} + +// IndexAnswer indexes a single answer embedding. +func (s *EmbeddingService) IndexAnswer(ctx context.Context, answerID, questionID string) error { + text, meta, err := s.BuildTextForAnswer(ctx, answerID, questionID) + if err != nil { + return err + } + + contentHash := ComputeContentHash(text) + + existing, exist, _ := s.embeddingRepo.GetByObjectID(ctx, answerID, EmbeddingLevelAnswer) + if exist && existing.ContentHash == contentHash { + return nil + } + + vec, err := s.GenerateEmbedding(ctx, text) + if err != nil { + return fmt.Errorf("generate embedding for answer %s failed: %w", answerID, err) + } + + metaJSON, _ := json.Marshal(meta) + vecJSON, _ := json.Marshal(vec) + + return s.embeddingRepo.Upsert(ctx, &entity.Embedding{ + ObjectID: answerID, + ObjectType: EmbeddingLevelAnswer, + ContentHash: contentHash, + Metadata: string(metaJSON), + Embedding: string(vecJSON), + Dimensions: len(vec), + }) +} + +// SearchSimilar performs semantic search and returns top-K similar results. +// Results below the configured similarity threshold are filtered out. +func (s *EmbeddingService) SearchSimilar(ctx context.Context, query string, topK int) ([]embeddingRepo.SimilarResult, error) { + vec, err := s.GenerateEmbedding(ctx, query) + if err != nil { + return nil, fmt.Errorf("generate query embedding failed: %w", err) + } + results, err := s.embeddingRepo.SearchSimilar(ctx, vec, topK) + if err != nil { + return nil, err + } + + for _, r := range results { + log.Debugf("semantic search result: object_id=%s object_type=%s score=%.6f", r.ObjectID, r.ObjectType, r.Score) + } + + // Apply similarity threshold from config (default 0 means no filtering) + _, provider, cfgErr := s.getAIConfig(ctx) + if cfgErr == nil && provider.SimilarityThreshold > 0 { + filtered := make([]embeddingRepo.SimilarResult, 0, len(results)) + for _, r := range results { + if r.Score >= provider.SimilarityThreshold { + filtered = append(filtered, r) + } + } + log.Debugf("semantic search: %d/%d results passed threshold %.4f", len(filtered), len(results), provider.SimilarityThreshold) + return filtered, nil + } + + return results, nil +} + +// GetEmbeddingCount returns the total number of stored embeddings. +func (s *EmbeddingService) GetEmbeddingCount(ctx context.Context) (int64, error) { + return s.embeddingRepo.Count(ctx) +} + +// RemoveEmbedding removes an embedding by object ID and type. +func (s *EmbeddingService) RemoveEmbedding(ctx context.Context, objectID, objectType string) error { + return s.embeddingRepo.DeleteByObjectID(ctx, objectID, objectType) +} + +// IndexAll indexes all questions (and optionally answers) based on the configured embedding level. +func (s *EmbeddingService) IndexAll(ctx context.Context) error { + _, provider, err := s.getAIConfig(ctx) + if err != nil { + log.Warnf("embedding indexer: %v", err) + return err + } + + level := provider.EmbeddingLevel + if level == "" { + level = EmbeddingLevelQuestion + } + + log.Debugf("Starting embedding indexer at level: %s", level) + + page := 1 + totalIndexed := 0 + for { + searchResp, err := s.searchService.Search(ctx, &schema.SearchDTO{ + Query: "is:question", + Page: page, + Size: 50, + Order: "newest", + }) + if err != nil { + return fmt.Errorf("search questions for indexing failed: %w", err) + } + if searchResp == nil || len(searchResp.SearchResults) == 0 { + break + } + + for _, result := range searchResp.SearchResults { + if result.Object == nil { + continue + } + qID := result.Object.QuestionID + if level == EmbeddingLevelQuestion { + if err := s.IndexQuestion(ctx, qID); err != nil { + log.Warnf("index question %s failed: %v", qID, err) + continue + } + totalIndexed++ + } else if level == EmbeddingLevelAnswer { + // Index each answer for this question via AnswerService + answerInfoList, _, err := s.answerService.SearchList(ctx, &schema.AnswerListReq{ + QuestionID: qID, + Page: 1, + PageSize: 50, + }) + if err != nil { + log.Warnf("get answers for question %s failed: %v", qID, err) + continue + } + for _, a := range answerInfoList { + if err := s.IndexAnswer(ctx, a.ID, qID); err != nil { + log.Warnf("index answer %s failed: %v", a.ID, err) + continue + } + totalIndexed++ + } + } + } + + if int64((page)*50) >= searchResp.Total { + break + } + page++ + } + + log.Infof("Embedding indexer completed: %d items indexed", totalIndexed) + return nil +} + +// StartScheduler starts a cron job to periodically run IndexAll. +func (s *EmbeddingService) StartScheduler(spec string) error { + s.mu.Lock() + defer s.mu.Unlock() + + // Stop existing cron if running + if s.cronJob != nil { + s.cronJob.Stop() + s.cronJob = nil + s.cronSpec = "" + } + + if spec == "" { + return nil + } + + c := cron.New() + _, err := c.AddFunc(spec, func() { + ctx := context.Background() + log.Infof("embedding cron triggered (spec=%s)", spec) + if err := s.IndexAll(ctx); err != nil { + log.Errorf("embedding cron IndexAll failed: %v", err) + } + }) + if err != nil { + return fmt.Errorf("invalid cron expression %q: %w", spec, err) + } + + c.Start() + s.cronJob = c + s.cronSpec = spec + log.Infof("embedding scheduler started with cron: %s", spec) + return nil +} + +// StopScheduler stops the embedding cron scheduler. +func (s *EmbeddingService) StopScheduler() { + s.mu.Lock() + defer s.mu.Unlock() + if s.cronJob != nil { + s.cronJob.Stop() + s.cronJob = nil + s.cronSpec = "" + log.Infof("embedding scheduler stopped") + } +} + +// ApplyConfig reads the current AI config and starts or stops the scheduler accordingly. +func (s *EmbeddingService) ApplyConfig(ctx context.Context) { + aiConfig, provider, err := s.getAIConfig(ctx) + if err != nil || aiConfig == nil || provider == nil { + s.StopScheduler() + return + } + + if provider.EmbeddingModel == "" || provider.EmbeddingCrontab == "" { + s.StopScheduler() + return + } + + // Only restart if the cron spec changed + s.mu.Lock() + currentSpec := s.cronSpec + s.mu.Unlock() + + if currentSpec == provider.EmbeddingCrontab { + return + } + + if err := s.StartScheduler(provider.EmbeddingCrontab); err != nil { + log.Errorf("failed to start embedding scheduler: %v", err) + } +} diff --git a/internal/service/provider.go b/internal/service/provider.go index 3e43b0ae0..26f1c4309 100644 --- a/internal/service/provider.go +++ b/internal/service/provider.go @@ -36,6 +36,7 @@ import ( "github.com/apache/answer/internal/service/config" "github.com/apache/answer/internal/service/content" "github.com/apache/answer/internal/service/dashboard" + "github.com/apache/answer/internal/service/embedding" "github.com/apache/answer/internal/service/eventqueue" "github.com/apache/answer/internal/service/export" "github.com/apache/answer/internal/service/feature_toggle" @@ -134,4 +135,5 @@ var ProviderSetService = wire.NewSet( apikey.NewAPIKeyService, ai_conversation.NewAIConversationService, feature_toggle.NewFeatureToggleService, + embedding.NewEmbeddingService, ) diff --git a/internal/service/siteinfo/siteinfo_service.go b/internal/service/siteinfo/siteinfo_service.go index 1e25cbaa4..70003984c 100644 --- a/internal/service/siteinfo/siteinfo_service.go +++ b/internal/service/siteinfo/siteinfo_service.go @@ -33,6 +33,7 @@ import ( "github.com/apache/answer/internal/entity" "github.com/apache/answer/internal/schema" "github.com/apache/answer/internal/service/config" + embeddingService "github.com/apache/answer/internal/service/embedding" "github.com/apache/answer/internal/service/export" "github.com/apache/answer/internal/service/file_record" questioncommon "github.com/apache/answer/internal/service/question_common" @@ -53,6 +54,7 @@ type SiteInfoService struct { configService *config.ConfigService questioncommon *questioncommon.QuestionCommon fileRecordService *file_record.FileRecordService + embeddingService *embeddingService.EmbeddingService } func NewSiteInfoService( @@ -63,7 +65,7 @@ func NewSiteInfoService( configService *config.ConfigService, questioncommon *questioncommon.QuestionCommon, fileRecordService *file_record.FileRecordService, - + embeddingSvc *embeddingService.EmbeddingService, ) *SiteInfoService { plugin.RegisterGetSiteURLFunc(func() string { generalSiteInfo, err := siteInfoCommonService.GetSiteGeneral(context.Background()) @@ -82,6 +84,7 @@ func NewSiteInfoService( configService: configService, questioncommon: questioncommon, fileRecordService: fileRecordService, + embeddingService: embeddingSvc, } } @@ -409,7 +412,13 @@ func (s *SiteInfoService) SaveSiteAI(ctx context.Context, req *schema.SiteAIReq) Content: string(content), Status: 1, } - return s.siteInfoRepo.SaveByType(ctx, constant.SiteTypeAI, siteInfo) + if err := s.siteInfoRepo.SaveByType(ctx, constant.SiteTypeAI, siteInfo); err != nil { + return err + } + + // Apply embedding scheduler config (start/stop cron based on settings) + go s.embeddingService.ApplyConfig(ctx) + return nil } func (s *SiteInfoService) maskAIKeys(resp *schema.SiteAIResp) { diff --git a/ui/src/common/interface.ts b/ui/src/common/interface.ts index 308726e80..b01b621f7 100644 --- a/ui/src/common/interface.ts +++ b/ui/src/common/interface.ts @@ -833,6 +833,11 @@ export interface AiConfig { api_host: string; api_key: string; model: string; + embedding_model: string; + embedding_dimensions: number; + embedding_level: string; + embedding_crontab: string; + similarity_threshold: number; }>; } diff --git a/ui/src/pages/Admin/AiSettings/index.tsx b/ui/src/pages/Admin/AiSettings/index.tsx index 2270aa5c5..de284ff00 100644 --- a/ui/src/pages/Admin/AiSettings/index.tsx +++ b/ui/src/pages/Admin/AiSettings/index.tsx @@ -68,6 +68,31 @@ const Index = () => { isInvalid: false, errorMsg: '', }, + embedding_model: { + value: '', + isInvalid: false, + errorMsg: '', + }, + embedding_dimensions: { + value: '', + isInvalid: false, + errorMsg: '', + }, + embedding_level: { + value: 'question', + isInvalid: false, + errorMsg: '', + }, + embedding_crontab: { + value: '', + isInvalid: false, + errorMsg: '', + }, + similarity_threshold: { + value: '0', + isInvalid: false, + errorMsg: '', + }, }); const [apiHostPlaceholder, setApiHostPlaceholder] = useState(''); const [modelsData, setModels] = useState<{ id: string }[]>([]); @@ -146,6 +171,31 @@ const Index = () => { isInvalid: false, errorMsg: '', }, + embedding_model: { + value: findHistoryProvider?.embedding_model || '', + isInvalid: false, + errorMsg: '', + }, + embedding_dimensions: { + value: String(findHistoryProvider?.embedding_dimensions || ''), + isInvalid: false, + errorMsg: '', + }, + embedding_level: { + value: findHistoryProvider?.embedding_level || 'question', + isInvalid: false, + errorMsg: '', + }, + embedding_crontab: { + value: findHistoryProvider?.embedding_crontab || '', + isInvalid: false, + errorMsg: '', + }, + similarity_threshold: { + value: String(findHistoryProvider?.similarity_threshold || '0'), + isInvalid: false, + errorMsg: '', + }, }); const provider = aiProviders?.find((item) => item.name === value); const host = findHistoryProvider?.api_host || provider?.default_api_host; @@ -218,6 +268,13 @@ const Index = () => { api_host: formData.api_host.value, api_key: formData.api_key.value, model: formData.model.value, + embedding_model: formData.embedding_model.value, + embedding_dimensions: + Number(formData.embedding_dimensions.value) || 0, + embedding_level: formData.embedding_level.value, + embedding_crontab: formData.embedding_crontab.value, + similarity_threshold: + Number(formData.similarity_threshold.value) || 0, }; } return v; @@ -295,6 +352,31 @@ const Index = () => { isInvalid: false, errorMsg: '', }, + embedding_model: { + value: currentAiConfig?.embedding_model || '', + isInvalid: false, + errorMsg: '', + }, + embedding_dimensions: { + value: String(currentAiConfig?.embedding_dimensions || ''), + isInvalid: false, + errorMsg: '', + }, + embedding_level: { + value: currentAiConfig?.embedding_level || 'question', + isInvalid: false, + errorMsg: '', + }, + embedding_crontab: { + value: currentAiConfig?.embedding_crontab || '', + isInvalid: false, + errorMsg: '', + }, + similarity_threshold: { + value: String(currentAiConfig?.similarity_threshold || '0'), + isInvalid: false, + errorMsg: '', + }, }); }; @@ -477,6 +559,99 @@ const Index = () => {
{formData.model.errorMsg}
+
+
{t('embedding_settings')}
+ + + {t('embedding_model.label')} + + handleValueChange({ + embedding_model: { + value: e.target.value, + errorMsg: '', + isInvalid: false, + }, + }) + } + /> + + {t('embedding_model.text')} + + + + + {t('embedding_level.label')} + + handleValueChange({ + embedding_level: { + value: e.target.value, + errorMsg: '', + isInvalid: false, + }, + }) + }> + + + + + {t('embedding_level.text')} + + + + + {t('embedding_crontab.label')} + + handleValueChange({ + embedding_crontab: { + value: e.target.value, + errorMsg: '', + isInvalid: false, + }, + }) + } + /> + + {t('embedding_crontab.text')} + + + + + {t('similarity_threshold.label')} + + handleValueChange({ + similarity_threshold: { + value: e.target.value, + errorMsg: '', + isInvalid: false, + }, + }) + } + /> + + {t('similarity_threshold.text')} + + + From 851c7787fe41045bb5289c3c6af0137594e95926 Mon Sep 17 00:00:00 2001 From: hgaol Date: Mon, 2 Mar 2026 19:02:06 +0800 Subject: [PATCH 2/6] fix lint issue --- cmd/wire_gen.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/cmd/wire_gen.go b/cmd/wire_gen.go index d5aa06c7f..60b2f8935 100644 --- a/cmd/wire_gen.go +++ b/cmd/wire_gen.go @@ -34,7 +34,7 @@ import ( "github.com/apache/answer/internal/base/server" "github.com/apache/answer/internal/base/translator" "github.com/apache/answer/internal/controller" - templaterender "github.com/apache/answer/internal/controller/template_render" + "github.com/apache/answer/internal/controller/template_render" "github.com/apache/answer/internal/controller_admin" "github.com/apache/answer/internal/repo/activity" "github.com/apache/answer/internal/repo/activity_common" @@ -77,12 +77,12 @@ import ( activity_common2 "github.com/apache/answer/internal/service/activity_common" "github.com/apache/answer/internal/service/activityqueue" ai_conversation2 "github.com/apache/answer/internal/service/ai_conversation" - answercommon "github.com/apache/answer/internal/service/answer_common" + "github.com/apache/answer/internal/service/answer_common" "github.com/apache/answer/internal/service/apikey" auth2 "github.com/apache/answer/internal/service/auth" badge2 "github.com/apache/answer/internal/service/badge" collection2 "github.com/apache/answer/internal/service/collection" - collectioncommon "github.com/apache/answer/internal/service/collection_common" + "github.com/apache/answer/internal/service/collection_common" comment2 "github.com/apache/answer/internal/service/comment" "github.com/apache/answer/internal/service/comment_common" config2 "github.com/apache/answer/internal/service/config" @@ -96,13 +96,13 @@ import ( "github.com/apache/answer/internal/service/follow" "github.com/apache/answer/internal/service/importer" meta2 "github.com/apache/answer/internal/service/meta" - metacommon "github.com/apache/answer/internal/service/meta_common" + "github.com/apache/answer/internal/service/meta_common" "github.com/apache/answer/internal/service/noticequeue" "github.com/apache/answer/internal/service/notification" - notificationcommon "github.com/apache/answer/internal/service/notification_common" + "github.com/apache/answer/internal/service/notification_common" "github.com/apache/answer/internal/service/object_info" "github.com/apache/answer/internal/service/plugin_common" - questioncommon "github.com/apache/answer/internal/service/question_common" + "github.com/apache/answer/internal/service/question_common" rank2 "github.com/apache/answer/internal/service/rank" reason2 "github.com/apache/answer/internal/service/reason" report2 "github.com/apache/answer/internal/service/report" @@ -118,7 +118,7 @@ import ( tag_common2 "github.com/apache/answer/internal/service/tag_common" "github.com/apache/answer/internal/service/uploader" "github.com/apache/answer/internal/service/user_admin" - usercommon "github.com/apache/answer/internal/service/user_common" + "github.com/apache/answer/internal/service/user_common" user_external_login2 "github.com/apache/answer/internal/service/user_external_login" user_notification_config2 "github.com/apache/answer/internal/service/user_notification_config" "github.com/segmentfault/pacman" From 899638b673fd39d6badcb42e250f8349c9b3cc68 Mon Sep 17 00:00:00 2001 From: hgaol Date: Mon, 2 Mar 2026 19:26:16 +0800 Subject: [PATCH 3/6] update init tables --- internal/migrations/init_data.go | 1 + 1 file changed, 1 insertion(+) diff --git a/internal/migrations/init_data.go b/internal/migrations/init_data.go index 5af41bbfc..e65c9f11b 100644 --- a/internal/migrations/init_data.go +++ b/internal/migrations/init_data.go @@ -79,6 +79,7 @@ var ( &entity.APIKey{}, &entity.AIConversation{}, &entity.AIConversationRecord{}, + &entity.Embedding{}, } roles = []*entity.Role{ From 65282216c397c3851610a35df6d33d1c9842330a Mon Sep 17 00:00:00 2001 From: hgaol Date: Sun, 12 Apr 2026 13:22:00 +0800 Subject: [PATCH 4/6] feat: implement vector search plugin and syncer for question/answer embeddings - Added a new vector search syncer to aggregate questions and answers with comments for vector embedding. - Introduced a new VectorSearch interface and related structures for managing vector storage and similarity search. - Refactored embedding service to delegate semantic search to the new vector search plugin. - Removed embedding-related fields from SiteAIProvider and UI forms as part of the transition to the new vector search architecture. - Updated plugin registration to include vector search capabilities. - Cleaned up embedding service methods and removed unused dependencies. --- cmd/wire_gen.go | 22 +- docs/docs.go | 23 - docs/swagger.json | 23 - docs/swagger.yaml | 17 - i18n/en_US.yaml | 18 - i18n/zh_CN.yaml | 18 - internal/controller/mcp_controller.go | 3 +- internal/entity/embedding.go | 59 --- internal/migrations/init_data.go | 1 - internal/migrations/migrations.go | 1 - internal/migrations/v32.go | 9 - internal/repo/embedding/embedding_repo.go | 197 ------- internal/repo/provider.go | 2 - internal/repo/vector_search_sync/syncer.go | 196 +++++++ internal/schema/siteinfo_schema.go | 13 +- .../service/embedding/embedding_service.go | 496 +----------------- .../plugin_common/plugin_common_service.go | 17 + internal/service/siteinfo/siteinfo_service.go | 6 - plugin/plugin.go | 4 + plugin/vector_search.go | 174 ++++++ ui/src/common/interface.ts | 5 - ui/src/pages/Admin/AiSettings/index.tsx | 175 ------ 22 files changed, 427 insertions(+), 1052 deletions(-) delete mode 100644 internal/entity/embedding.go delete mode 100644 internal/repo/embedding/embedding_repo.go create mode 100644 internal/repo/vector_search_sync/syncer.go create mode 100644 plugin/vector_search.go diff --git a/cmd/wire_gen.go b/cmd/wire_gen.go index 60b2f8935..2c5931699 100644 --- a/cmd/wire_gen.go +++ b/cmd/wire_gen.go @@ -34,7 +34,7 @@ import ( "github.com/apache/answer/internal/base/server" "github.com/apache/answer/internal/base/translator" "github.com/apache/answer/internal/controller" - "github.com/apache/answer/internal/controller/template_render" + templaterender "github.com/apache/answer/internal/controller/template_render" "github.com/apache/answer/internal/controller_admin" "github.com/apache/answer/internal/repo/activity" "github.com/apache/answer/internal/repo/activity_common" @@ -49,7 +49,6 @@ import ( "github.com/apache/answer/internal/repo/collection" "github.com/apache/answer/internal/repo/comment" "github.com/apache/answer/internal/repo/config" - "github.com/apache/answer/internal/repo/embedding" "github.com/apache/answer/internal/repo/export" "github.com/apache/answer/internal/repo/file_record" "github.com/apache/answer/internal/repo/limit" @@ -77,18 +76,18 @@ import ( activity_common2 "github.com/apache/answer/internal/service/activity_common" "github.com/apache/answer/internal/service/activityqueue" ai_conversation2 "github.com/apache/answer/internal/service/ai_conversation" - "github.com/apache/answer/internal/service/answer_common" + answercommon "github.com/apache/answer/internal/service/answer_common" "github.com/apache/answer/internal/service/apikey" auth2 "github.com/apache/answer/internal/service/auth" badge2 "github.com/apache/answer/internal/service/badge" collection2 "github.com/apache/answer/internal/service/collection" - "github.com/apache/answer/internal/service/collection_common" + collectioncommon "github.com/apache/answer/internal/service/collection_common" comment2 "github.com/apache/answer/internal/service/comment" "github.com/apache/answer/internal/service/comment_common" config2 "github.com/apache/answer/internal/service/config" "github.com/apache/answer/internal/service/content" "github.com/apache/answer/internal/service/dashboard" - embedding2 "github.com/apache/answer/internal/service/embedding" + "github.com/apache/answer/internal/service/embedding" "github.com/apache/answer/internal/service/eventqueue" export2 "github.com/apache/answer/internal/service/export" "github.com/apache/answer/internal/service/feature_toggle" @@ -96,13 +95,13 @@ import ( "github.com/apache/answer/internal/service/follow" "github.com/apache/answer/internal/service/importer" meta2 "github.com/apache/answer/internal/service/meta" - "github.com/apache/answer/internal/service/meta_common" + metacommon "github.com/apache/answer/internal/service/meta_common" "github.com/apache/answer/internal/service/noticequeue" "github.com/apache/answer/internal/service/notification" - "github.com/apache/answer/internal/service/notification_common" + notificationcommon "github.com/apache/answer/internal/service/notification_common" "github.com/apache/answer/internal/service/object_info" "github.com/apache/answer/internal/service/plugin_common" - "github.com/apache/answer/internal/service/question_common" + questioncommon "github.com/apache/answer/internal/service/question_common" rank2 "github.com/apache/answer/internal/service/rank" reason2 "github.com/apache/answer/internal/service/reason" report2 "github.com/apache/answer/internal/service/report" @@ -118,7 +117,7 @@ import ( tag_common2 "github.com/apache/answer/internal/service/tag_common" "github.com/apache/answer/internal/service/uploader" "github.com/apache/answer/internal/service/user_admin" - "github.com/apache/answer/internal/service/user_common" + usercommon "github.com/apache/answer/internal/service/user_common" user_external_login2 "github.com/apache/answer/internal/service/user_external_login" user_notification_config2 "github.com/apache/answer/internal/service/user_notification_config" "github.com/segmentfault/pacman" @@ -249,9 +248,8 @@ func initApplication(debug bool, serverConf *conf.Server, dbConf *data.Database, reasonService := reason2.NewReasonService(reasonRepo) reasonController := controller.NewReasonController(reasonService) themeController := controller_admin.NewThemeController() - embeddingRepo := embedding.NewEmbeddingRepo(dataData) - embeddingService := embedding2.NewEmbeddingService(embeddingRepo, searchService, answerService, questionCommon, commentRepo, siteInfoCommonService) - siteInfoService := siteinfo.NewSiteInfoService(siteInfoRepo, siteInfoCommonService, emailService, tagCommonService, configService, questionCommon, fileRecordService, embeddingService) + embeddingService := embedding.NewEmbeddingService() + siteInfoService := siteinfo.NewSiteInfoService(siteInfoRepo, siteInfoCommonService, emailService, tagCommonService, configService, questionCommon, fileRecordService) siteInfoController := controller_admin.NewSiteInfoController(siteInfoService) controllerSiteInfoController := controller.NewSiteInfoController(siteInfoCommonService) notificationCommon := notificationcommon.NewNotificationCommon(dataData, notificationRepo, userCommon, activityRepo, followRepo, objService, noticequeueService, userExternalLoginRepo, siteInfoCommonService) diff --git a/docs/docs.go b/docs/docs.go index 8f65e05da..b50a3cf78 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -11712,24 +11712,6 @@ const docTemplate = `{ "type": "string", "maxLength": 256 }, - "embedding_crontab": { - "type": "string", - "maxLength": 100 - }, - "embedding_dimensions": { - "type": "integer" - }, - "embedding_level": { - "type": "string", - "enum": [ - "question", - "answer" - ] - }, - "embedding_model": { - "type": "string", - "maxLength": 100 - }, "model": { "type": "string", "maxLength": 100 @@ -11738,11 +11720,6 @@ const docTemplate = `{ "type": "string", "maxLength": 50 }, - "similarity_threshold": { - "type": "number", - "maximum": 1, - "minimum": 0 - } } }, "schema.SiteAIReq": { diff --git a/docs/swagger.json b/docs/swagger.json index 71e802dfa..441862208 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -11685,24 +11685,6 @@ "type": "string", "maxLength": 256 }, - "embedding_crontab": { - "type": "string", - "maxLength": 100 - }, - "embedding_dimensions": { - "type": "integer" - }, - "embedding_level": { - "type": "string", - "enum": [ - "question", - "answer" - ] - }, - "embedding_model": { - "type": "string", - "maxLength": 100 - }, "model": { "type": "string", "maxLength": 100 @@ -11711,11 +11693,6 @@ "type": "string", "maxLength": 50 }, - "similarity_threshold": { - "type": "number", - "maximum": 1, - "minimum": 0 - } } }, "schema.SiteAIReq": { diff --git a/docs/swagger.yaml b/docs/swagger.yaml index 5dc68d6c4..7a7adb681 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -2250,29 +2250,12 @@ definitions: api_key: maxLength: 256 type: string - embedding_crontab: - maxLength: 100 - type: string - embedding_dimensions: - type: integer - embedding_level: - enum: - - question - - answer - type: string - embedding_model: - maxLength: 100 - type: string model: maxLength: 100 type: string provider: maxLength: 50 type: string - similarity_threshold: - maximum: 1 - minimum: 0 - type: number type: object schema.SiteAIReq: properties: diff --git a/i18n/en_US.yaml b/i18n/en_US.yaml index ec385e763..9a0d198b3 100644 --- a/i18n/en_US.yaml +++ b/i18n/en_US.yaml @@ -2355,24 +2355,6 @@ ui: label: Model msg: Model is required add_success: AI settings updated successfully. - embedding_settings: Embedding Settings - embedding_model: - label: Embedding model - text: "The model used to generate vector embeddings for semantic search (e.g. text-embedding-3-small)." - embedding_dimensions: - label: Embedding dimensions - text: "The number of dimensions for the embedding vectors (e.g. 1536 for text-embedding-3-small)." - embedding_level: - label: Embedding level - text: "Choose whether to create embeddings at the question level (question + all answers + comments) or answer level (each answer separately)." - question: Question level - answer: Answer level - embedding_crontab: - label: Embedding schedule - text: "Cron expression for periodic embedding calculation (e.g. '0 */6 * * *' for every 6 hours). Leave empty to disable automatic indexing." - similarity_threshold: - label: Similarity threshold - text: "Minimum cosine similarity score (0-1) for semantic search results. Only results with a score above this threshold will be returned. Default is 0 (no filtering)." conversations: topic: Topic helpful: Helpful diff --git a/i18n/zh_CN.yaml b/i18n/zh_CN.yaml index 191619dd1..f16ed9fad 100644 --- a/i18n/zh_CN.yaml +++ b/i18n/zh_CN.yaml @@ -2319,24 +2319,6 @@ ui: label: 模型 msg: 模型是必需的 add_success: AI 设置更新成功。 - embedding_settings: Embedding 设置 - embedding_model: - label: Embedding 模型 - text: "用于生成语义搜索向量 Embedding 的模型(例如 text-embedding-3-small)。" - embedding_dimensions: - label: Embedding 维度 - text: "Embedding 向量的维度数(例如 text-embedding-3-small 为 1536)。" - embedding_level: - label: Embedding 级别 - text: "选择在问题级别(问题 + 所有回答 + 评论)还是回答级别(每个回答单独)创建 Embedding。" - question: 问题级别 - answer: 回答级别 - embedding_crontab: - label: Embedding 计划 - text: "定期计算 Embedding 的 Cron 表达式(例如 '0 */6 * * *' 表示每 6 小时)。留空则禁用自动索引。" - similarity_threshold: - label: 相似度阈值 - text: "语义搜索结果的最低余弦相似度分数(0-1)。只有分数高于此阈值的结果才会被返回。默认值为 0(不过滤)。" conversations: topic: 主题 helpful: 有帮助 diff --git a/internal/controller/mcp_controller.go b/internal/controller/mcp_controller.go index fecdbef60..b40c58cf4 100644 --- a/internal/controller/mcp_controller.go +++ b/internal/controller/mcp_controller.go @@ -37,6 +37,7 @@ import ( "github.com/apache/answer/internal/service/siteinfo_common" tagcommonser "github.com/apache/answer/internal/service/tag_common" usercommon "github.com/apache/answer/internal/service/user_common" + "github.com/apache/answer/plugin" "github.com/mark3labs/mcp-go/mcp" "github.com/segmentfault/pacman/log" ) @@ -381,7 +382,7 @@ func (c *MCPController) MCPSemanticSearchHandler() func(ctx context.Context, req resp := make([]*schema.MCPSemanticSearchResp, 0, len(results)) for _, r := range results { - var meta entity.EmbeddingMetadata + var meta plugin.VectorSearchMetadata _ = json.Unmarshal([]byte(r.Metadata), &meta) item := &schema.MCPSemanticSearchResp{ diff --git a/internal/entity/embedding.go b/internal/entity/embedding.go deleted file mode 100644 index 3ea500d92..000000000 --- a/internal/entity/embedding.go +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package entity - -import "time" - -// Embedding stores vector embeddings for questions or answers. -type Embedding struct { - ID int `xorm:"not null pk autoincr INT(11) id"` - CreatedAt time.Time `xorm:"created not null default CURRENT_TIMESTAMP TIMESTAMP created_at"` - UpdatedAt time.Time `xorm:"updated not null default CURRENT_TIMESTAMP TIMESTAMP updated_at"` - ObjectID string `xorm:"not null BIGINT(20) INDEX object_id unique(object_embedding)"` - ObjectType string `xorm:"not null default '' VARCHAR(20) object_type unique(object_embedding)"` - ContentHash string `xorm:"not null default '' VARCHAR(64) content_hash"` - Metadata string `xorm:"not null MEDIUMTEXT metadata"` - Embedding string `xorm:"not null MEDIUMTEXT embedding"` - Dimensions int `xorm:"not null default 0 INT(11) dimensions"` -} - -// TableName returns the table name -func (Embedding) TableName() string { - return "embedding" -} - -// EmbeddingMetadata holds IDs for URI composition and content retrieval at query time. -type EmbeddingMetadata struct { - QuestionID string `json:"question_id"` - AnswerID string `json:"answer_id,omitempty"` - Answers []EmbeddingMetadataAnswer `json:"answers,omitempty"` - Comments []EmbeddingMetadataComment `json:"comments,omitempty"` -} - -// EmbeddingMetadataAnswer stores answer ID and comment IDs in metadata. -type EmbeddingMetadataAnswer struct { - AnswerID string `json:"answer_id"` - Comments []EmbeddingMetadataComment `json:"comments,omitempty"` -} - -// EmbeddingMetadataComment stores comment ID in metadata. -type EmbeddingMetadataComment struct { - CommentID string `json:"comment_id"` -} diff --git a/internal/migrations/init_data.go b/internal/migrations/init_data.go index e65c9f11b..5af41bbfc 100644 --- a/internal/migrations/init_data.go +++ b/internal/migrations/init_data.go @@ -79,7 +79,6 @@ var ( &entity.APIKey{}, &entity.AIConversation{}, &entity.AIConversationRecord{}, - &entity.Embedding{}, } roles = []*entity.Role{ diff --git a/internal/migrations/migrations.go b/internal/migrations/migrations.go index e6722cc08..682a7b207 100644 --- a/internal/migrations/migrations.go +++ b/internal/migrations/migrations.go @@ -108,7 +108,6 @@ var migrations = []Migration{ NewMigration("v1.8.0", "change admin menu", updateAdminMenuSettings, true), NewMigration("v1.8.1", "ai feat", aiFeat, true), NewMigration("v2.0.1", "change avatar type to text", updateAvatarType, false), - NewMigration("v2.0.2", "add embedding table", addEmbeddingTable, false), } func GetMigrations() []Migration { diff --git a/internal/migrations/v32.go b/internal/migrations/v32.go index 438097348..fc6614b11 100644 --- a/internal/migrations/v32.go +++ b/internal/migrations/v32.go @@ -24,7 +24,6 @@ import ( "fmt" "github.com/apache/answer/internal/entity" - "github.com/segmentfault/pacman/log" "xorm.io/xorm" ) @@ -36,11 +35,3 @@ func updateAvatarType(ctx context.Context, x *xorm.Engine) error { } return nil } - -func addEmbeddingTable(ctx context.Context, x *xorm.Engine) error { - if err := x.Context(ctx).Sync(new(entity.Embedding)); err != nil { - return fmt.Errorf("sync embedding table failed: %w", err) - } - log.Info("Embedding table migration completed successfully") - return nil -} diff --git a/internal/repo/embedding/embedding_repo.go b/internal/repo/embedding/embedding_repo.go deleted file mode 100644 index 67bde1d5d..000000000 --- a/internal/repo/embedding/embedding_repo.go +++ /dev/null @@ -1,197 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package embedding - -import ( - "context" - "encoding/json" - "math" - "sort" - - "github.com/apache/answer/internal/base/data" - "github.com/apache/answer/internal/entity" - "github.com/segmentfault/pacman/log" - "xorm.io/builder" -) - -// EmbeddingRepo defines the interface for embedding data access. -type EmbeddingRepo interface { - Upsert(ctx context.Context, emb *entity.Embedding) error - GetByObjectID(ctx context.Context, objectID, objectType string) (*entity.Embedding, bool, error) - GetAll(ctx context.Context) ([]*entity.Embedding, error) - SearchSimilar(ctx context.Context, queryVector []float32, topK int) ([]SimilarResult, error) - DeleteByObjectID(ctx context.Context, objectID, objectType string) error - Count(ctx context.Context) (int64, error) -} - -// SimilarResult holds a similarity search result. -type SimilarResult struct { - ObjectID string `json:"object_id"` - ObjectType string `json:"object_type"` - Metadata string `json:"metadata"` - Score float64 `json:"score"` -} - -type embeddingRepo struct { - data *data.Data -} - -// NewEmbeddingRepo creates a new EmbeddingRepo. -func NewEmbeddingRepo(data *data.Data) EmbeddingRepo { - return &embeddingRepo{data: data} -} - -// Upsert inserts or updates an embedding by (object_id, object_type). -func (r *embeddingRepo) Upsert(ctx context.Context, emb *entity.Embedding) error { - existing := &entity.Embedding{} - exist, err := r.data.DB.Context(ctx). - Where(builder.Eq{"object_id": emb.ObjectID, "object_type": emb.ObjectType}). - Get(existing) - if err != nil { - log.Errorf("check embedding existence failed: %v", err) - return err - } - - if exist { - emb.ID = existing.ID - _, err = r.data.DB.Context(ctx).ID(existing.ID). - Cols("content_hash", "metadata", "embedding", "dimensions", "updated_at"). - Update(emb) - if err != nil { - log.Errorf("update embedding failed: %v", err) - return err - } - return nil - } - - _, err = r.data.DB.Context(ctx).Insert(emb) - if err != nil { - log.Errorf("insert embedding failed: %v", err) - return err - } - return nil -} - -// GetByObjectID returns an embedding by object ID and type. -func (r *embeddingRepo) GetByObjectID(ctx context.Context, objectID, objectType string) (*entity.Embedding, bool, error) { - emb := &entity.Embedding{} - exist, err := r.data.DB.Context(ctx). - Where(builder.Eq{"object_id": objectID, "object_type": objectType}). - Get(emb) - if err != nil { - log.Errorf("get embedding failed: %v", err) - return nil, false, err - } - return emb, exist, nil -} - -// GetAll returns all embeddings. -func (r *embeddingRepo) GetAll(ctx context.Context) ([]*entity.Embedding, error) { - var list []*entity.Embedding - err := r.data.DB.Context(ctx).Find(&list) - if err != nil { - log.Errorf("get all embeddings failed: %v", err) - return nil, err - } - return list, nil -} - -// SearchSimilar performs brute-force cosine similarity search in Go. -func (r *embeddingRepo) SearchSimilar(ctx context.Context, queryVector []float32, topK int) ([]SimilarResult, error) { - allEmbeddings, err := r.GetAll(ctx) - if err != nil { - return nil, err - } - - type scored struct { - emb *entity.Embedding - score float64 - } - results := make([]scored, 0, len(allEmbeddings)) - - for _, emb := range allEmbeddings { - var vec []float32 - if err := json.Unmarshal([]byte(emb.Embedding), &vec); err != nil { - log.Warnf("skip embedding id=%d, unmarshal failed: %v", emb.ID, err) - continue - } - score := cosineSimilarity(queryVector, vec) - results = append(results, scored{emb: emb, score: score}) - } - - sort.Slice(results, func(i, j int) bool { - return results[i].score > results[j].score - }) - - if topK > len(results) { - topK = len(results) - } - - out := make([]SimilarResult, 0, topK) - for i := 0; i < topK; i++ { - out = append(out, SimilarResult{ - ObjectID: results[i].emb.ObjectID, - ObjectType: results[i].emb.ObjectType, - Metadata: results[i].emb.Metadata, - Score: results[i].score, - }) - } - return out, nil -} - -// DeleteByObjectID deletes an embedding by object ID and type. -func (r *embeddingRepo) DeleteByObjectID(ctx context.Context, objectID, objectType string) error { - _, err := r.data.DB.Context(ctx). - Where(builder.Eq{"object_id": objectID, "object_type": objectType}). - Delete(&entity.Embedding{}) - if err != nil { - log.Errorf("delete embedding failed: %v", err) - return err - } - return nil -} - -// Count returns the total number of embeddings. -func (r *embeddingRepo) Count(ctx context.Context) (int64, error) { - count, err := r.data.DB.Context(ctx).Count(&entity.Embedding{}) - if err != nil { - log.Errorf("count embeddings failed: %v", err) - return 0, err - } - return count, nil -} - -// cosineSimilarity computes cosine similarity between two vectors. -func cosineSimilarity(a, b []float32) float64 { - if len(a) != len(b) || len(a) == 0 { - return 0 - } - var dotProduct, normA, normB float64 - for i := range a { - dotProduct += float64(a[i]) * float64(b[i]) - normA += float64(a[i]) * float64(a[i]) - normB += float64(b[i]) * float64(b[i]) - } - denom := math.Sqrt(normA) * math.Sqrt(normB) - if denom == 0 { - return 0 - } - return dotProduct / denom -} diff --git a/internal/repo/provider.go b/internal/repo/provider.go index 2a9717d00..510a94aaa 100644 --- a/internal/repo/provider.go +++ b/internal/repo/provider.go @@ -34,7 +34,6 @@ import ( "github.com/apache/answer/internal/repo/collection" "github.com/apache/answer/internal/repo/comment" "github.com/apache/answer/internal/repo/config" - "github.com/apache/answer/internal/repo/embedding" "github.com/apache/answer/internal/repo/export" "github.com/apache/answer/internal/repo/file_record" "github.com/apache/answer/internal/repo/limit" @@ -114,5 +113,4 @@ var ProviderSetRepo = wire.NewSet( file_record.NewFileRecordRepo, api_key.NewAPIKeyRepo, ai_conversation.NewAIConversationRepo, - embedding.NewEmbeddingRepo, ) diff --git a/internal/repo/vector_search_sync/syncer.go b/internal/repo/vector_search_sync/syncer.go new file mode 100644 index 000000000..f27b1fd2c --- /dev/null +++ b/internal/repo/vector_search_sync/syncer.go @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vector_search_sync + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/apache/answer/internal/base/data" + "github.com/apache/answer/internal/entity" + "github.com/apache/answer/pkg/uid" + "github.com/apache/answer/plugin" + "github.com/segmentfault/pacman/log" +) + +// NewPluginSyncer creates a new VectorSearchSyncer that reads from the database. +func NewPluginSyncer(data *data.Data) plugin.VectorSearchSyncer { + return &PluginSyncer{data: data} +} + +// PluginSyncer implements plugin.VectorSearchSyncer. +// It aggregates question/answer text with comments for vector embedding. +type PluginSyncer struct { + data *data.Data +} + +// GetQuestionsPage returns a page of questions with aggregated text +// (question title + body + all answers + all comments). +func (p *PluginSyncer) GetQuestionsPage(ctx context.Context, page, pageSize int) ( + []*plugin.VectorSearchContent, error) { + questions := make([]*entity.Question, 0) + startNum := (page - 1) * pageSize + err := p.data.DB.Context(ctx).Limit(pageSize, startNum).Find(&questions) + if err != nil { + return nil, err + } + return p.buildQuestionContents(ctx, questions) +} + +// GetAnswersPage returns a page of answers with aggregated text +// (parent question title + answer body + answer comments). +func (p *PluginSyncer) GetAnswersPage(ctx context.Context, page, pageSize int) ( + []*plugin.VectorSearchContent, error) { + answers := make([]*entity.Answer, 0) + startNum := (page - 1) * pageSize + err := p.data.DB.Context(ctx).Limit(pageSize, startNum).Find(&answers) + if err != nil { + return nil, err + } + return p.buildAnswerContents(ctx, answers) +} + +// buildQuestionContents aggregates each question with its answers and comments. +func (p *PluginSyncer) buildQuestionContents(ctx context.Context, questions []*entity.Question) ( + []*plugin.VectorSearchContent, error) { + result := make([]*plugin.VectorSearchContent, 0, len(questions)) + for _, q := range questions { + meta := plugin.VectorSearchMetadata{ + QuestionID: uid.DeShortID(q.ID), + } + + var parts []string + parts = append(parts, fmt.Sprintf("Question: %s\n%s", q.Title, q.OriginalText)) + + // Get answers for this question + answers := make([]*entity.Answer, 0) + err := p.data.DB.Context(ctx).Where("question_id = ?", q.ID).Find(&answers) + if err != nil { + log.Warnf("get answers for question %s failed: %v", q.ID, err) + } else { + for _, a := range answers { + parts = append(parts, fmt.Sprintf("Answer: %s", a.OriginalText)) + answerMeta := plugin.VectorSearchMetadataAnswer{ + AnswerID: uid.DeShortID(a.ID), + } + + // Get comments on this answer + answerComments := make([]*entity.Comment, 0) + err := p.data.DB.Context(ctx).Where("object_id = ?", a.ID). + OrderBy("created_at ASC").Limit(50).Find(&answerComments) + if err != nil { + log.Warnf("get comments for answer %s failed: %v", a.ID, err) + } else { + for _, c := range answerComments { + parts = append(parts, fmt.Sprintf("Comment on answer: %s", c.OriginalText)) + answerMeta.Comments = append(answerMeta.Comments, plugin.VectorSearchMetadataComment{ + CommentID: uid.DeShortID(c.ID), + }) + } + } + meta.Answers = append(meta.Answers, answerMeta) + } + } + + // Get comments on the question + questionComments := make([]*entity.Comment, 0) + err = p.data.DB.Context(ctx).Where("object_id = ?", q.ID). + OrderBy("created_at ASC").Limit(50).Find(&questionComments) + if err != nil { + log.Warnf("get comments for question %s failed: %v", q.ID, err) + } else { + for _, c := range questionComments { + parts = append(parts, fmt.Sprintf("Comment: %s", c.OriginalText)) + meta.Comments = append(meta.Comments, plugin.VectorSearchMetadataComment{ + CommentID: uid.DeShortID(c.ID), + }) + } + } + + metaJSON, _ := json.Marshal(meta) + result = append(result, &plugin.VectorSearchContent{ + ObjectID: uid.DeShortID(q.ID), + ObjectType: "question", + Title: q.Title, + Content: strings.Join(parts, "\n\n"), + Metadata: string(metaJSON), + }) + } + return result, nil +} + +// buildAnswerContents aggregates each answer with its parent question title and comments. +func (p *PluginSyncer) buildAnswerContents(ctx context.Context, answers []*entity.Answer) ( + []*plugin.VectorSearchContent, error) { + result := make([]*plugin.VectorSearchContent, 0, len(answers)) + for _, a := range answers { + // Get parent question for title + question := &entity.Question{} + exist, err := p.data.DB.Context(ctx).Where("id = ?", a.QuestionID).Get(question) + if err != nil { + log.Errorf("get question %s failed: %v", a.QuestionID, err) + continue + } + if !exist { + continue + } + + meta := plugin.VectorSearchMetadata{ + QuestionID: uid.DeShortID(a.QuestionID), + AnswerID: uid.DeShortID(a.ID), + } + + var parts []string + parts = append(parts, fmt.Sprintf("Question: %s", question.Title)) + parts = append(parts, fmt.Sprintf("Answer: %s", a.OriginalText)) + + answerMeta := plugin.VectorSearchMetadataAnswer{ + AnswerID: uid.DeShortID(a.ID), + } + + // Get comments on this answer + answerComments := make([]*entity.Comment, 0) + err = p.data.DB.Context(ctx).Where("object_id = ?", a.ID). + OrderBy("created_at ASC").Limit(50).Find(&answerComments) + if err != nil { + log.Warnf("get comments for answer %s failed: %v", a.ID, err) + } else { + for _, c := range answerComments { + parts = append(parts, fmt.Sprintf("Comment: %s", c.OriginalText)) + answerMeta.Comments = append(answerMeta.Comments, plugin.VectorSearchMetadataComment{ + CommentID: uid.DeShortID(c.ID), + }) + } + } + meta.Answers = append(meta.Answers, answerMeta) + + metaJSON, _ := json.Marshal(meta) + result = append(result, &plugin.VectorSearchContent{ + ObjectID: uid.DeShortID(a.ID), + ObjectType: "answer", + Title: question.Title, + Content: strings.Join(parts, "\n\n"), + Metadata: string(metaJSON), + }) + } + return result, nil +} diff --git a/internal/schema/siteinfo_schema.go b/internal/schema/siteinfo_schema.go index 84bc54e30..bdf2308d3 100644 --- a/internal/schema/siteinfo_schema.go +++ b/internal/schema/siteinfo_schema.go @@ -281,15 +281,10 @@ func (s *SiteAIResp) GetProvider() *SiteAIProvider { } type SiteAIProvider struct { - Provider string `validate:"omitempty,lte=50" form:"provider" json:"provider"` - APIHost string `validate:"omitempty,lte=512" form:"api_host" json:"api_host"` - APIKey string `validate:"omitempty,lte=256" form:"api_key" json:"api_key"` - Model string `validate:"omitempty,lte=100" form:"model" json:"model"` - EmbeddingModel string `validate:"omitempty,lte=100" form:"embedding_model" json:"embedding_model"` - EmbeddingDimensions int `validate:"omitempty" form:"embedding_dimensions" json:"embedding_dimensions"` - EmbeddingLevel string `validate:"omitempty,oneof=question answer" form:"embedding_level" json:"embedding_level"` - EmbeddingCrontab string `validate:"omitempty,lte=100" form:"embedding_crontab" json:"embedding_crontab"` - SimilarityThreshold float64 `validate:"omitempty,gte=0,lte=1" form:"similarity_threshold" json:"similarity_threshold"` + Provider string `validate:"omitempty,lte=50" form:"provider" json:"provider"` + APIHost string `validate:"omitempty,lte=512" form:"api_host" json:"api_host"` + APIKey string `validate:"omitempty,lte=256" form:"api_key" json:"api_key"` + Model string `validate:"omitempty,lte=100" form:"model" json:"model"` } // SiteAIResp AI configuration response diff --git a/internal/service/embedding/embedding_service.go b/internal/service/embedding/embedding_service.go index 72e008abb..c69d60d8e 100644 --- a/internal/service/embedding/embedding_service.go +++ b/internal/service/embedding/embedding_service.go @@ -21,496 +21,40 @@ package embedding import ( "context" - "crypto/sha256" - "encoding/json" "fmt" - "strings" - "sync" - "github.com/apache/answer/internal/base/pager" - "github.com/apache/answer/internal/entity" - embeddingRepo "github.com/apache/answer/internal/repo/embedding" - "github.com/apache/answer/internal/schema" - "github.com/apache/answer/internal/service/comment" - "github.com/apache/answer/internal/service/content" - questioncommon "github.com/apache/answer/internal/service/question_common" - "github.com/apache/answer/internal/service/siteinfo_common" - "github.com/robfig/cron/v3" - "github.com/sashabaranov/go-openai" - "github.com/segmentfault/pacman/log" + "github.com/apache/answer/plugin" ) -const ( - EmbeddingLevelQuestion = "question" - EmbeddingLevelAnswer = "answer" -) - -// EmbeddingService handles embedding generation, text aggregation, and indexing. -type EmbeddingService struct { - embeddingRepo embeddingRepo.EmbeddingRepo - searchService *content.SearchService - answerService *content.AnswerService - questionCommon *questioncommon.QuestionCommon - commentRepo comment.CommentRepo - siteInfoService siteinfo_common.SiteInfoCommonService - - mu sync.Mutex - cronJob *cron.Cron - cronSpec string -} +// EmbeddingService is a thin facade that delegates semantic search to a VectorSearch plugin. +// If no plugin is enabled, semantic search is unavailable. +type EmbeddingService struct{} // NewEmbeddingService creates a new EmbeddingService. -func NewEmbeddingService( - embeddingRepo embeddingRepo.EmbeddingRepo, - searchService *content.SearchService, - answerService *content.AnswerService, - questionCommon *questioncommon.QuestionCommon, - commentRepo comment.CommentRepo, - siteInfoService siteinfo_common.SiteInfoCommonService, -) *EmbeddingService { - return &EmbeddingService{ - embeddingRepo: embeddingRepo, - searchService: searchService, - answerService: answerService, - questionCommon: questionCommon, - commentRepo: commentRepo, - siteInfoService: siteInfoService, - } -} - -// getAIConfig returns the current AI configuration. -func (s *EmbeddingService) getAIConfig(ctx context.Context) (*schema.SiteAIResp, *schema.SiteAIProvider, error) { - aiConfig, err := s.siteInfoService.GetSiteAI(ctx) - if err != nil { - return nil, nil, fmt.Errorf("get AI config failed: %w", err) - } - if !aiConfig.Enabled { - return nil, nil, fmt.Errorf("AI feature is disabled") - } - provider := aiConfig.GetProvider() - if provider.EmbeddingModel == "" { - return nil, nil, fmt.Errorf("embedding model not configured") - } - return aiConfig, provider, nil -} - -// createEmbeddingClient creates an OpenAI-compatible client for embedding requests. -func (s *EmbeddingService) createEmbeddingClient(provider *schema.SiteAIProvider) *openai.Client { - config := openai.DefaultConfig(provider.APIKey) - config.BaseURL = provider.APIHost - if !strings.HasSuffix(config.BaseURL, "/v1") { - config.BaseURL += "/v1" - } - return openai.NewClientWithConfig(config) -} - -// GenerateEmbedding generates an embedding vector for the given text. -func (s *EmbeddingService) GenerateEmbedding(ctx context.Context, text string) ([]float32, error) { - _, provider, err := s.getAIConfig(ctx) - if err != nil { - return nil, err - } - - client := s.createEmbeddingClient(provider) - resp, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequestStrings{ - Input: []string{text}, - Model: openai.EmbeddingModel(provider.EmbeddingModel), - }) - if err != nil { - return nil, fmt.Errorf("create embeddings failed: %w", err) - } - if len(resp.Data) == 0 { - return nil, fmt.Errorf("no embedding returned") - } - return resp.Data[0].Embedding, nil -} - -// ComputeContentHash computes SHA256 of the text. -func ComputeContentHash(text string) string { - h := sha256.Sum256([]byte(text)) - return fmt.Sprintf("%x", h) -} - -// BuildTextForQuestion aggregates question title + body + all answers + comments into one text. -// Uses SearchService and QuestionCommon to respect the plugin architecture. -func (s *EmbeddingService) BuildTextForQuestion(ctx context.Context, questionID string) (text string, meta *entity.EmbeddingMetadata, err error) { - // Get question detail via service layer - question, err := s.questionCommon.Info(ctx, questionID, "") - if err != nil { - return "", nil, fmt.Errorf("get question info failed: %w", err) - } - - meta = &entity.EmbeddingMetadata{ - QuestionID: questionID, - } - - var parts []string - parts = append(parts, fmt.Sprintf("Question: %s\n%s", question.Title, question.Content)) - - // Get answers via AnswerService - answerInfoList, _, err := s.answerService.SearchList(ctx, &schema.AnswerListReq{ - QuestionID: questionID, - Page: 1, - PageSize: 50, - }) - if err != nil { - log.Warnf("get answers for question %s failed: %v", questionID, err) - } else { - for _, a := range answerInfoList { - parts = append(parts, fmt.Sprintf("Answer: %s", a.Content)) - answerMeta := entity.EmbeddingMetadataAnswer{ - AnswerID: a.ID, - } - - // Get comments on this answer - answerComments, _, err := s.commentRepo.GetCommentPage(ctx, &comment.CommentQuery{ - PageCond: pager.PageCond{Page: 1, PageSize: 50}, - ObjectID: a.ID, - QueryCond: "newest", - }) - if err != nil { - log.Warnf("get comments for answer %s failed: %v", a.ID, err) - } else { - for _, c := range answerComments { - parts = append(parts, fmt.Sprintf("Comment on answer: %s", c.OriginalText)) - answerMeta.Comments = append(answerMeta.Comments, entity.EmbeddingMetadataComment{ - CommentID: c.ID, - }) - } - } - meta.Answers = append(meta.Answers, answerMeta) - } - } - - // Get comments on the question - commentList, _, err := s.commentRepo.GetCommentPage(ctx, &comment.CommentQuery{ - PageCond: pager.PageCond{Page: 1, PageSize: 50}, - ObjectID: questionID, - QueryCond: "newest", - }) - if err != nil { - log.Warnf("get comments for question %s failed: %v", questionID, err) - } else { - for _, c := range commentList { - parts = append(parts, fmt.Sprintf("Comment: %s", c.OriginalText)) - meta.Comments = append(meta.Comments, entity.EmbeddingMetadataComment{ - CommentID: c.ID, - }) - } - } - - return strings.Join(parts, "\n\n"), meta, nil -} - -// BuildTextForAnswer aggregates answer body + parent question title + answer comments into one text. -func (s *EmbeddingService) BuildTextForAnswer(ctx context.Context, answerID, questionID string) (text string, meta *entity.EmbeddingMetadata, err error) { - // Get parent question title - question, err := s.questionCommon.Info(ctx, questionID, "") - if err != nil { - return "", nil, fmt.Errorf("get question info for answer failed: %w", err) - } - - meta = &entity.EmbeddingMetadata{ - QuestionID: questionID, - AnswerID: answerID, - } - - // Get the specific answer's content via AnswerService - answerInfo, err := s.answerService.GetDetail(ctx, answerID) - if err != nil { - return "", nil, fmt.Errorf("get answer failed: %w", err) - } - - var answerText string - if answerInfo != nil { - answerText = answerInfo.Content - } - - var parts []string - parts = append(parts, fmt.Sprintf("Question: %s", question.Title)) - if answerText != "" { - parts = append(parts, fmt.Sprintf("Answer: %s", answerText)) - meta.Answers = append(meta.Answers, entity.EmbeddingMetadataAnswer{ - AnswerID: answerID, - }) - } - - // Get comments on the answer - commentList, _, err := s.commentRepo.GetCommentPage(ctx, &comment.CommentQuery{ - PageCond: pager.PageCond{Page: 1, PageSize: 50}, - ObjectID: answerID, - QueryCond: "newest", - }) - if err != nil { - log.Warnf("get comments for answer %s failed: %v", answerID, err) - } else { - for _, c := range commentList { - parts = append(parts, fmt.Sprintf("Comment: %s", c.OriginalText)) - if len(meta.Answers) > 0 { - meta.Answers[0].Comments = append(meta.Answers[0].Comments, entity.EmbeddingMetadataComment{ - CommentID: c.ID, - }) - } else { - meta.Comments = append(meta.Comments, entity.EmbeddingMetadataComment{ - CommentID: c.ID, - }) - } - } - } - - return strings.Join(parts, "\n\n"), meta, nil -} - -// IndexQuestion indexes a single question embedding. -func (s *EmbeddingService) IndexQuestion(ctx context.Context, questionID string) error { - text, meta, err := s.BuildTextForQuestion(ctx, questionID) - if err != nil { - return err - } - - contentHash := ComputeContentHash(text) - - // Check staleness - existing, exist, _ := s.embeddingRepo.GetByObjectID(ctx, questionID, EmbeddingLevelQuestion) - if exist && existing.ContentHash == contentHash { - return nil // already up to date - } - - vec, err := s.GenerateEmbedding(ctx, text) - if err != nil { - return fmt.Errorf("generate embedding for question %s failed: %w", questionID, err) - } - - metaJSON, _ := json.Marshal(meta) - vecJSON, _ := json.Marshal(vec) - - return s.embeddingRepo.Upsert(ctx, &entity.Embedding{ - ObjectID: questionID, - ObjectType: EmbeddingLevelQuestion, - ContentHash: contentHash, - Metadata: string(metaJSON), - Embedding: string(vecJSON), - Dimensions: len(vec), - }) +func NewEmbeddingService() *EmbeddingService { + return &EmbeddingService{} } -// IndexAnswer indexes a single answer embedding. -func (s *EmbeddingService) IndexAnswer(ctx context.Context, answerID, questionID string) error { - text, meta, err := s.BuildTextForAnswer(ctx, answerID, questionID) - if err != nil { - return err - } - - contentHash := ComputeContentHash(text) +// SearchSimilar delegates to the VectorSearch plugin. +// Returns an error if no plugin is enabled. +func (s *EmbeddingService) SearchSimilar(ctx context.Context, query string, topK int) ([]plugin.VectorSearchResult, error) { + var results []plugin.VectorSearchResult + var searchErr error + found := false - existing, exist, _ := s.embeddingRepo.GetByObjectID(ctx, answerID, EmbeddingLevelAnswer) - if exist && existing.ContentHash == contentHash { + err := plugin.CallVectorSearch(func(vs plugin.VectorSearch) error { + found = true + results, searchErr = vs.SearchSimilar(ctx, query, topK) return nil - } - - vec, err := s.GenerateEmbedding(ctx, text) - if err != nil { - return fmt.Errorf("generate embedding for answer %s failed: %w", answerID, err) - } - - metaJSON, _ := json.Marshal(meta) - vecJSON, _ := json.Marshal(vec) - - return s.embeddingRepo.Upsert(ctx, &entity.Embedding{ - ObjectID: answerID, - ObjectType: EmbeddingLevelAnswer, - ContentHash: contentHash, - Metadata: string(metaJSON), - Embedding: string(vecJSON), - Dimensions: len(vec), }) -} - -// SearchSimilar performs semantic search and returns top-K similar results. -// Results below the configured similarity threshold are filtered out. -func (s *EmbeddingService) SearchSimilar(ctx context.Context, query string, topK int) ([]embeddingRepo.SimilarResult, error) { - vec, err := s.GenerateEmbedding(ctx, query) - if err != nil { - return nil, fmt.Errorf("generate query embedding failed: %w", err) - } - results, err := s.embeddingRepo.SearchSimilar(ctx, vec, topK) if err != nil { - return nil, err + return nil, fmt.Errorf("call vector search plugin failed: %w", err) } - - for _, r := range results { - log.Debugf("semantic search result: object_id=%s object_type=%s score=%.6f", r.ObjectID, r.ObjectType, r.Score) + if !found { + return nil, fmt.Errorf("semantic search is not available: no vector search plugin is enabled") } - - // Apply similarity threshold from config (default 0 means no filtering) - _, provider, cfgErr := s.getAIConfig(ctx) - if cfgErr == nil && provider.SimilarityThreshold > 0 { - filtered := make([]embeddingRepo.SimilarResult, 0, len(results)) - for _, r := range results { - if r.Score >= provider.SimilarityThreshold { - filtered = append(filtered, r) - } - } - log.Debugf("semantic search: %d/%d results passed threshold %.4f", len(filtered), len(results), provider.SimilarityThreshold) - return filtered, nil + if searchErr != nil { + return nil, searchErr } - return results, nil } - -// GetEmbeddingCount returns the total number of stored embeddings. -func (s *EmbeddingService) GetEmbeddingCount(ctx context.Context) (int64, error) { - return s.embeddingRepo.Count(ctx) -} - -// RemoveEmbedding removes an embedding by object ID and type. -func (s *EmbeddingService) RemoveEmbedding(ctx context.Context, objectID, objectType string) error { - return s.embeddingRepo.DeleteByObjectID(ctx, objectID, objectType) -} - -// IndexAll indexes all questions (and optionally answers) based on the configured embedding level. -func (s *EmbeddingService) IndexAll(ctx context.Context) error { - _, provider, err := s.getAIConfig(ctx) - if err != nil { - log.Warnf("embedding indexer: %v", err) - return err - } - - level := provider.EmbeddingLevel - if level == "" { - level = EmbeddingLevelQuestion - } - - log.Debugf("Starting embedding indexer at level: %s", level) - - page := 1 - totalIndexed := 0 - for { - searchResp, err := s.searchService.Search(ctx, &schema.SearchDTO{ - Query: "is:question", - Page: page, - Size: 50, - Order: "newest", - }) - if err != nil { - return fmt.Errorf("search questions for indexing failed: %w", err) - } - if searchResp == nil || len(searchResp.SearchResults) == 0 { - break - } - - for _, result := range searchResp.SearchResults { - if result.Object == nil { - continue - } - qID := result.Object.QuestionID - if level == EmbeddingLevelQuestion { - if err := s.IndexQuestion(ctx, qID); err != nil { - log.Warnf("index question %s failed: %v", qID, err) - continue - } - totalIndexed++ - } else if level == EmbeddingLevelAnswer { - // Index each answer for this question via AnswerService - answerInfoList, _, err := s.answerService.SearchList(ctx, &schema.AnswerListReq{ - QuestionID: qID, - Page: 1, - PageSize: 50, - }) - if err != nil { - log.Warnf("get answers for question %s failed: %v", qID, err) - continue - } - for _, a := range answerInfoList { - if err := s.IndexAnswer(ctx, a.ID, qID); err != nil { - log.Warnf("index answer %s failed: %v", a.ID, err) - continue - } - totalIndexed++ - } - } - } - - if int64((page)*50) >= searchResp.Total { - break - } - page++ - } - - log.Infof("Embedding indexer completed: %d items indexed", totalIndexed) - return nil -} - -// StartScheduler starts a cron job to periodically run IndexAll. -func (s *EmbeddingService) StartScheduler(spec string) error { - s.mu.Lock() - defer s.mu.Unlock() - - // Stop existing cron if running - if s.cronJob != nil { - s.cronJob.Stop() - s.cronJob = nil - s.cronSpec = "" - } - - if spec == "" { - return nil - } - - c := cron.New() - _, err := c.AddFunc(spec, func() { - ctx := context.Background() - log.Infof("embedding cron triggered (spec=%s)", spec) - if err := s.IndexAll(ctx); err != nil { - log.Errorf("embedding cron IndexAll failed: %v", err) - } - }) - if err != nil { - return fmt.Errorf("invalid cron expression %q: %w", spec, err) - } - - c.Start() - s.cronJob = c - s.cronSpec = spec - log.Infof("embedding scheduler started with cron: %s", spec) - return nil -} - -// StopScheduler stops the embedding cron scheduler. -func (s *EmbeddingService) StopScheduler() { - s.mu.Lock() - defer s.mu.Unlock() - if s.cronJob != nil { - s.cronJob.Stop() - s.cronJob = nil - s.cronSpec = "" - log.Infof("embedding scheduler stopped") - } -} - -// ApplyConfig reads the current AI config and starts or stops the scheduler accordingly. -func (s *EmbeddingService) ApplyConfig(ctx context.Context) { - aiConfig, provider, err := s.getAIConfig(ctx) - if err != nil || aiConfig == nil || provider == nil { - s.StopScheduler() - return - } - - if provider.EmbeddingModel == "" || provider.EmbeddingCrontab == "" { - s.StopScheduler() - return - } - - // Only restart if the cron spec changed - s.mu.Lock() - currentSpec := s.cronSpec - s.mu.Unlock() - - if currentSpec == provider.EmbeddingCrontab { - return - } - - if err := s.StartScheduler(provider.EmbeddingCrontab); err != nil { - log.Errorf("failed to start embedding scheduler: %v", err) - } -} diff --git a/internal/service/plugin_common/plugin_common_service.go b/internal/service/plugin_common/plugin_common_service.go index eb46b5ac2..b3674a230 100644 --- a/internal/service/plugin_common/plugin_common_service.go +++ b/internal/service/plugin_common/plugin_common_service.go @@ -25,6 +25,7 @@ import ( "github.com/apache/answer/internal/base/data" "github.com/apache/answer/internal/repo/search_sync" + "github.com/apache/answer/internal/repo/vector_search_sync" "github.com/segmentfault/pacman/errors" "github.com/segmentfault/pacman/log" @@ -103,6 +104,12 @@ func (ps *PluginCommonService) UpdatePluginConfig(ctx context.Context, req *sche } return nil }) + _ = plugin.CallVectorSearch(func(vs plugin.VectorSearch) error { + if vs.Info().SlugName == req.PluginSlugName { + vs.RegisterSyncer(ctx, vector_search_sync.NewPluginSyncer(ps.data)) + } + return nil + }) _ = plugin.CallImporter(func(importer plugin.Importer) error { importer.RegisterImporterFunc(ctx, ps.importerService.NewImporterFunc()) return nil @@ -176,6 +183,16 @@ func (ps *PluginCommonService) initPluginData() { }) } + // register syncers for search and vector search plugins on startup + _ = plugin.CallSearch(func(search plugin.Search) error { + search.RegisterSyncer(context.Background(), search_sync.NewPluginSyncer(ps.data)) + return nil + }) + _ = plugin.CallVectorSearch(func(vs plugin.VectorSearch) error { + vs.RegisterSyncer(context.Background(), vector_search_sync.NewPluginSyncer(ps.data)) + return nil + }) + // init plugin user config plugin.RegisterGetPluginUserConfigFunc(func(userID, pluginSlugName string) []byte { pluginUserConfig, exist, err := ps.pluginUserConfigRepo.GetPluginUserConfig(context.Background(), userID, pluginSlugName) diff --git a/internal/service/siteinfo/siteinfo_service.go b/internal/service/siteinfo/siteinfo_service.go index 70003984c..c29051b43 100644 --- a/internal/service/siteinfo/siteinfo_service.go +++ b/internal/service/siteinfo/siteinfo_service.go @@ -33,7 +33,6 @@ import ( "github.com/apache/answer/internal/entity" "github.com/apache/answer/internal/schema" "github.com/apache/answer/internal/service/config" - embeddingService "github.com/apache/answer/internal/service/embedding" "github.com/apache/answer/internal/service/export" "github.com/apache/answer/internal/service/file_record" questioncommon "github.com/apache/answer/internal/service/question_common" @@ -54,7 +53,6 @@ type SiteInfoService struct { configService *config.ConfigService questioncommon *questioncommon.QuestionCommon fileRecordService *file_record.FileRecordService - embeddingService *embeddingService.EmbeddingService } func NewSiteInfoService( @@ -65,7 +63,6 @@ func NewSiteInfoService( configService *config.ConfigService, questioncommon *questioncommon.QuestionCommon, fileRecordService *file_record.FileRecordService, - embeddingSvc *embeddingService.EmbeddingService, ) *SiteInfoService { plugin.RegisterGetSiteURLFunc(func() string { generalSiteInfo, err := siteInfoCommonService.GetSiteGeneral(context.Background()) @@ -84,7 +81,6 @@ func NewSiteInfoService( configService: configService, questioncommon: questioncommon, fileRecordService: fileRecordService, - embeddingService: embeddingSvc, } } @@ -416,8 +412,6 @@ func (s *SiteInfoService) SaveSiteAI(ctx context.Context, req *schema.SiteAIReq) return err } - // Apply embedding scheduler config (start/stop cron based on settings) - go s.embeddingService.ApplyConfig(ctx) return nil } diff --git a/plugin/plugin.go b/plugin/plugin.go index 8778b1625..3a657fc40 100644 --- a/plugin/plugin.go +++ b/plugin/plugin.go @@ -130,6 +130,10 @@ func Register(p Base) { if _, ok := p.(Sidebar); ok { registerSidebar(p.(Sidebar)) } + + if _, ok := p.(VectorSearch); ok { + registerVectorSearch(p.(VectorSearch)) + } } type Stack[T Base] struct { diff --git a/plugin/vector_search.go b/plugin/vector_search.go new file mode 100644 index 000000000..a01701522 --- /dev/null +++ b/plugin/vector_search.go @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package plugin + +import ( + "context" + "fmt" + "strings" + + "github.com/sashabaranov/go-openai" + "github.com/segmentfault/pacman/log" +) + +// VectorSearchResult holds a single similarity search result returned by a VectorSearch plugin. +type VectorSearchResult struct { + // ObjectID is the unique identifier of the matched object (question ID or answer ID). + ObjectID string `json:"object_id"` + // ObjectType is "question" or "answer". + ObjectType string `json:"object_type"` + // Metadata is a JSON string containing VectorSearchMetadata for link composition and content retrieval. + Metadata string `json:"metadata"` + // Score is the cosine similarity score (0-1). + Score float64 `json:"score"` +} + +// VectorSearchContent is the document structure passed to plugins for indexing. +type VectorSearchContent struct { + // ObjectID is the unique identifier (question ID or answer ID). + ObjectID string `json:"objectID"` + // ObjectType is "question" or "answer". + ObjectType string `json:"objectType"` + // Title is the question title. + Title string `json:"title"` + // Content is the aggregated text to be embedded (question body + answers + comments). + Content string `json:"content"` + // Metadata is a JSON string containing VectorSearchMetadata. + Metadata string `json:"metadata"` +} + +// VectorSearchDesc describes the vector search engine for display purposes. +type VectorSearchDesc struct { + // Icon is an SVG icon for display. Optional. + Icon string `json:"icon"` + // Link is the URL of the vector search engine. Optional. + Link string `json:"link"` +} + +// VectorSearchMetadata holds IDs for URI composition and content retrieval at query time. +// Shared between plugins and the core MCP controller. +type VectorSearchMetadata struct { + QuestionID string `json:"question_id"` + AnswerID string `json:"answer_id,omitempty"` + Answers []VectorSearchMetadataAnswer `json:"answers,omitempty"` + Comments []VectorSearchMetadataComment `json:"comments,omitempty"` +} + +// VectorSearchMetadataAnswer stores answer ID and its comment IDs in metadata. +type VectorSearchMetadataAnswer struct { + AnswerID string `json:"answer_id"` + Comments []VectorSearchMetadataComment `json:"comments,omitempty"` +} + +// VectorSearchMetadataComment stores a comment ID in metadata. +type VectorSearchMetadataComment struct { + CommentID string `json:"comment_id"` +} + +// VectorSearch is the plugin interface for vector/semantic search engines. +// Plugins implementing this interface manage their own vector storage, embedding computation, +// data synchronization schedule, and similarity search. +type VectorSearch interface { + Base + + // Description returns metadata about the vector search engine. + Description() VectorSearchDesc + + // RegisterSyncer is called by the core to provide a data syncer. + // The plugin should store the syncer and use it to bulk-sync content + // (typically in a background goroutine). + RegisterSyncer(ctx context.Context, syncer VectorSearchSyncer) + + // SearchSimilar performs a semantic similarity search. + // The plugin is responsible for embedding the query text and searching its vector store. + // Returns up to topK results sorted by similarity score (descending). + SearchSimilar(ctx context.Context, query string, topK int) ([]VectorSearchResult, error) + + // UpdateContent upserts a single document in the vector store. + // Called by the core on incremental content changes. + UpdateContent(ctx context.Context, content *VectorSearchContent) error + + // DeleteContent removes a document from the vector store by object ID. + DeleteContent(ctx context.Context, objectID string) error +} + +// VectorSearchSyncer is implemented by the core and provided to plugins via RegisterSyncer. +// Plugins call these methods to pull all content for bulk indexing. +type VectorSearchSyncer interface { + // GetQuestionsPage returns a page of questions with aggregated text (title + body + answers + comments). + GetQuestionsPage(ctx context.Context, page, pageSize int) ([]*VectorSearchContent, error) + // GetAnswersPage returns a page of answers with aggregated text (answer body + parent question title + comments). + GetAnswersPage(ctx context.Context, page, pageSize int) ([]*VectorSearchContent, error) +} + +var ( + // CallVectorSearch is a function that calls all registered VectorSearch plugins. + CallVectorSearch, + registerVectorSearch = MakePlugin[VectorSearch](false) +) + +// GenerateEmbedding is a base utility function that generates an embedding vector +// using an OpenAI-compatible API. Plugins that don't have a built-in vectorizer +// (most vector databases) can call this function with their own credentials. +// Plugins with built-in vectorizers (e.g., Weaviate) can skip this and use their own. +// +// Parameters: +// - ctx: context for cancellation +// - apiHost: the API base URL (e.g. "https://api.openai.com"); "/v1" is appended if missing +// - apiKey: the API key for authentication +// - model: the embedding model name (e.g. "text-embedding-3-small") +// - text: the text to embed +// +// Returns the embedding vector as []float32, or an error. +func GenerateEmbedding(ctx context.Context, apiHost, apiKey, model, text string) ([]float32, error) { + if model == "" { + return nil, fmt.Errorf("embedding model is not configured") + } + if text == "" { + return nil, fmt.Errorf("text is empty") + } + + config := openai.DefaultConfig(apiKey) + config.BaseURL = apiHost + if !strings.HasSuffix(config.BaseURL, "/v1") { + config.BaseURL += "/v1" + } + + log.Debugf("embedding: requesting model=%s baseURL=%s textLen=%d", model, config.BaseURL, len(text)) + + client := openai.NewClientWithConfig(config) + + resp, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequestStrings{ + Input: []string{text}, + Model: openai.EmbeddingModel(model), + }) + if err != nil { + log.Errorf("embedding: request failed model=%s baseURL=%s err=%v", model, config.BaseURL, err) + return nil, fmt.Errorf("create embeddings failed: %w", err) + } + if len(resp.Data) == 0 { + log.Errorf("embedding: no data returned model=%s baseURL=%s", model, config.BaseURL) + return nil, fmt.Errorf("no embedding returned") + } + + log.Debugf("embedding: success model=%s dimensions=%d usage={prompt=%d,total=%d}", + model, len(resp.Data[0].Embedding), resp.Usage.PromptTokens, resp.Usage.TotalTokens) + return resp.Data[0].Embedding, nil +} diff --git a/ui/src/common/interface.ts b/ui/src/common/interface.ts index b01b621f7..308726e80 100644 --- a/ui/src/common/interface.ts +++ b/ui/src/common/interface.ts @@ -833,11 +833,6 @@ export interface AiConfig { api_host: string; api_key: string; model: string; - embedding_model: string; - embedding_dimensions: number; - embedding_level: string; - embedding_crontab: string; - similarity_threshold: number; }>; } diff --git a/ui/src/pages/Admin/AiSettings/index.tsx b/ui/src/pages/Admin/AiSettings/index.tsx index de284ff00..2270aa5c5 100644 --- a/ui/src/pages/Admin/AiSettings/index.tsx +++ b/ui/src/pages/Admin/AiSettings/index.tsx @@ -68,31 +68,6 @@ const Index = () => { isInvalid: false, errorMsg: '', }, - embedding_model: { - value: '', - isInvalid: false, - errorMsg: '', - }, - embedding_dimensions: { - value: '', - isInvalid: false, - errorMsg: '', - }, - embedding_level: { - value: 'question', - isInvalid: false, - errorMsg: '', - }, - embedding_crontab: { - value: '', - isInvalid: false, - errorMsg: '', - }, - similarity_threshold: { - value: '0', - isInvalid: false, - errorMsg: '', - }, }); const [apiHostPlaceholder, setApiHostPlaceholder] = useState(''); const [modelsData, setModels] = useState<{ id: string }[]>([]); @@ -171,31 +146,6 @@ const Index = () => { isInvalid: false, errorMsg: '', }, - embedding_model: { - value: findHistoryProvider?.embedding_model || '', - isInvalid: false, - errorMsg: '', - }, - embedding_dimensions: { - value: String(findHistoryProvider?.embedding_dimensions || ''), - isInvalid: false, - errorMsg: '', - }, - embedding_level: { - value: findHistoryProvider?.embedding_level || 'question', - isInvalid: false, - errorMsg: '', - }, - embedding_crontab: { - value: findHistoryProvider?.embedding_crontab || '', - isInvalid: false, - errorMsg: '', - }, - similarity_threshold: { - value: String(findHistoryProvider?.similarity_threshold || '0'), - isInvalid: false, - errorMsg: '', - }, }); const provider = aiProviders?.find((item) => item.name === value); const host = findHistoryProvider?.api_host || provider?.default_api_host; @@ -268,13 +218,6 @@ const Index = () => { api_host: formData.api_host.value, api_key: formData.api_key.value, model: formData.model.value, - embedding_model: formData.embedding_model.value, - embedding_dimensions: - Number(formData.embedding_dimensions.value) || 0, - embedding_level: formData.embedding_level.value, - embedding_crontab: formData.embedding_crontab.value, - similarity_threshold: - Number(formData.similarity_threshold.value) || 0, }; } return v; @@ -352,31 +295,6 @@ const Index = () => { isInvalid: false, errorMsg: '', }, - embedding_model: { - value: currentAiConfig?.embedding_model || '', - isInvalid: false, - errorMsg: '', - }, - embedding_dimensions: { - value: String(currentAiConfig?.embedding_dimensions || ''), - isInvalid: false, - errorMsg: '', - }, - embedding_level: { - value: currentAiConfig?.embedding_level || 'question', - isInvalid: false, - errorMsg: '', - }, - embedding_crontab: { - value: currentAiConfig?.embedding_crontab || '', - isInvalid: false, - errorMsg: '', - }, - similarity_threshold: { - value: String(currentAiConfig?.similarity_threshold || '0'), - isInvalid: false, - errorMsg: '', - }, }); }; @@ -559,99 +477,6 @@ const Index = () => {
{formData.model.errorMsg}
-
-
{t('embedding_settings')}
- - - {t('embedding_model.label')} - - handleValueChange({ - embedding_model: { - value: e.target.value, - errorMsg: '', - isInvalid: false, - }, - }) - } - /> - - {t('embedding_model.text')} - - - - - {t('embedding_level.label')} - - handleValueChange({ - embedding_level: { - value: e.target.value, - errorMsg: '', - isInvalid: false, - }, - }) - }> - - - - - {t('embedding_level.text')} - - - - - {t('embedding_crontab.label')} - - handleValueChange({ - embedding_crontab: { - value: e.target.value, - errorMsg: '', - isInvalid: false, - }, - }) - } - /> - - {t('embedding_crontab.text')} - - - - - {t('similarity_threshold.label')} - - handleValueChange({ - similarity_threshold: { - value: e.target.value, - errorMsg: '', - isInvalid: false, - }, - }) - } - /> - - {t('similarity_threshold.text')} - - - From 3c1aa87e54200305ee544be5f82c3faaae0f01f9 Mon Sep 17 00:00:00 2001 From: hgaol Date: Sun, 12 Apr 2026 13:31:52 +0800 Subject: [PATCH 5/6] fix lint --- cmd/wire_gen.go | 2 +- docs/docs.go | 2 +- docs/swagger.json | 2 +- plugin/vector_search.go | 12 ++++++------ 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/cmd/wire_gen.go b/cmd/wire_gen.go index 2c5931699..df3b0a1b8 100644 --- a/cmd/wire_gen.go +++ b/cmd/wire_gen.go @@ -248,7 +248,6 @@ func initApplication(debug bool, serverConf *conf.Server, dbConf *data.Database, reasonService := reason2.NewReasonService(reasonRepo) reasonController := controller.NewReasonController(reasonService) themeController := controller_admin.NewThemeController() - embeddingService := embedding.NewEmbeddingService() siteInfoService := siteinfo.NewSiteInfoService(siteInfoRepo, siteInfoCommonService, emailService, tagCommonService, configService, questionCommon, fileRecordService) siteInfoController := controller_admin.NewSiteInfoController(siteInfoService) controllerSiteInfoController := controller.NewSiteInfoController(siteInfoCommonService) @@ -285,6 +284,7 @@ func initApplication(debug bool, serverConf *conf.Server, dbConf *data.Database, apiKeyService := apikey.NewAPIKeyService(apiKeyRepo) adminAPIKeyController := controller_admin.NewAdminAPIKeyController(apiKeyService) featureToggleService := feature_toggle.NewFeatureToggleService(siteInfoRepo) + embeddingService := embedding.NewEmbeddingService() mcpController := controller.NewMCPController(searchService, siteInfoCommonService, tagCommonService, questionCommon, commentRepo, userCommon, answerRepo, featureToggleService, embeddingService) aiConversationRepo := ai_conversation.NewAIConversationRepo(dataData) aiConversationService := ai_conversation2.NewAIConversationService(aiConversationRepo, userCommon) diff --git a/docs/docs.go b/docs/docs.go index b50a3cf78..57a23d432 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -11719,7 +11719,7 @@ const docTemplate = `{ "provider": { "type": "string", "maxLength": 50 - }, + } } }, "schema.SiteAIReq": { diff --git a/docs/swagger.json b/docs/swagger.json index 441862208..dac2b38fd 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -11692,7 +11692,7 @@ "provider": { "type": "string", "maxLength": 50 - }, + } } }, "schema.SiteAIReq": { diff --git a/plugin/vector_search.go b/plugin/vector_search.go index a01701522..134247d6c 100644 --- a/plugin/vector_search.go +++ b/plugin/vector_search.go @@ -65,16 +65,16 @@ type VectorSearchDesc struct { // VectorSearchMetadata holds IDs for URI composition and content retrieval at query time. // Shared between plugins and the core MCP controller. type VectorSearchMetadata struct { - QuestionID string `json:"question_id"` - AnswerID string `json:"answer_id,omitempty"` - Answers []VectorSearchMetadataAnswer `json:"answers,omitempty"` - Comments []VectorSearchMetadataComment `json:"comments,omitempty"` + QuestionID string `json:"question_id"` + AnswerID string `json:"answer_id,omitempty"` + Answers []VectorSearchMetadataAnswer `json:"answers,omitempty"` + Comments []VectorSearchMetadataComment `json:"comments,omitempty"` } // VectorSearchMetadataAnswer stores answer ID and its comment IDs in metadata. type VectorSearchMetadataAnswer struct { - AnswerID string `json:"answer_id"` - Comments []VectorSearchMetadataComment `json:"comments,omitempty"` + AnswerID string `json:"answer_id"` + Comments []VectorSearchMetadataComment `json:"comments,omitempty"` } // VectorSearchMetadataComment stores a comment ID in metadata. From ebd7c7f0f4b104210fd85bccdfcd497a5ffbac88 Mon Sep 17 00:00:00 2001 From: hgaol Date: Sun, 12 Apr 2026 13:39:07 +0800 Subject: [PATCH 6/6] fix lint --- cmd/wire_gen.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/cmd/wire_gen.go b/cmd/wire_gen.go index df3b0a1b8..3ead33fb1 100644 --- a/cmd/wire_gen.go +++ b/cmd/wire_gen.go @@ -34,7 +34,7 @@ import ( "github.com/apache/answer/internal/base/server" "github.com/apache/answer/internal/base/translator" "github.com/apache/answer/internal/controller" - templaterender "github.com/apache/answer/internal/controller/template_render" + "github.com/apache/answer/internal/controller/template_render" "github.com/apache/answer/internal/controller_admin" "github.com/apache/answer/internal/repo/activity" "github.com/apache/answer/internal/repo/activity_common" @@ -76,12 +76,12 @@ import ( activity_common2 "github.com/apache/answer/internal/service/activity_common" "github.com/apache/answer/internal/service/activityqueue" ai_conversation2 "github.com/apache/answer/internal/service/ai_conversation" - answercommon "github.com/apache/answer/internal/service/answer_common" + "github.com/apache/answer/internal/service/answer_common" "github.com/apache/answer/internal/service/apikey" auth2 "github.com/apache/answer/internal/service/auth" badge2 "github.com/apache/answer/internal/service/badge" collection2 "github.com/apache/answer/internal/service/collection" - collectioncommon "github.com/apache/answer/internal/service/collection_common" + "github.com/apache/answer/internal/service/collection_common" comment2 "github.com/apache/answer/internal/service/comment" "github.com/apache/answer/internal/service/comment_common" config2 "github.com/apache/answer/internal/service/config" @@ -95,13 +95,13 @@ import ( "github.com/apache/answer/internal/service/follow" "github.com/apache/answer/internal/service/importer" meta2 "github.com/apache/answer/internal/service/meta" - metacommon "github.com/apache/answer/internal/service/meta_common" + "github.com/apache/answer/internal/service/meta_common" "github.com/apache/answer/internal/service/noticequeue" "github.com/apache/answer/internal/service/notification" - notificationcommon "github.com/apache/answer/internal/service/notification_common" + "github.com/apache/answer/internal/service/notification_common" "github.com/apache/answer/internal/service/object_info" "github.com/apache/answer/internal/service/plugin_common" - questioncommon "github.com/apache/answer/internal/service/question_common" + "github.com/apache/answer/internal/service/question_common" rank2 "github.com/apache/answer/internal/service/rank" reason2 "github.com/apache/answer/internal/service/reason" report2 "github.com/apache/answer/internal/service/report" @@ -117,7 +117,7 @@ import ( tag_common2 "github.com/apache/answer/internal/service/tag_common" "github.com/apache/answer/internal/service/uploader" "github.com/apache/answer/internal/service/user_admin" - usercommon "github.com/apache/answer/internal/service/user_common" + "github.com/apache/answer/internal/service/user_common" user_external_login2 "github.com/apache/answer/internal/service/user_external_login" user_notification_config2 "github.com/apache/answer/internal/service/user_notification_config" "github.com/segmentfault/pacman"