diff --git a/poetry.lock b/poetry.lock index 0df9b1ddc..70be43c96 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5022,6 +5022,26 @@ platformdirs = ">=3.9.1,<5" docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2,!=7.3)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8) ; platform_python_implementation == \"PyPy\" or platform_python_implementation == \"GraalVM\" or platform_python_implementation == \"CPython\" and sys_platform == \"win32\" and python_version >= \"3.13\"", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10) ; platform_python_implementation == \"CPython\""] +[[package]] +name = "volcengine-python-sdk" +version = "4.0.5" +description = "Volcengine SDK for Python" +optional = false +python-versions = "*" +groups = ["main"] +files = [ + {file = "volcengine-python-sdk-4.0.5.tar.gz", hash = "sha256:303e19437d0517e4043a97edb8a1e496782e73afbc13d09a1cb6756538d3b732"}, +] + +[package.dependencies] +certifi = ">=2017.4.17" +python-dateutil = ">=2.1" +six = ">=1.10" +urllib3 = ">=1.23" + +[package.extras] +ark = ["anyio (>=3.5.0,<5)", "cached-property ; python_version < \"3.8\"", "cryptography (>=42.0.0)", "httpx (>=0.23.0,<1)", "pydantic (>=1.9.0,<3)"] + [[package]] name = "watchfiles" version = "1.1.0" @@ -5515,4 +5535,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.1" python-versions = "^3.10" -content-hash = "b56786e7bcdae03bbe0f031e2fe0bee284fef76534a3d3bfb24d6f70a44f4f0d" +content-hash = "558e8a7a9293537be027d0a84beea695809de12b869e3739a25d4c59535a5e1b" diff --git a/pyproject.toml b/pyproject.toml index e67acea0d..e6a0bbf4d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ sentence-transformers = "^4.1.0" sqlalchemy = "^2.0.41" redis = "^6.2.0" schedule = "^1.2.2" +volcengine-python-sdk = "^4.0.4" [tool.poetry.group.dev] optional = false diff --git a/src/memos/configs/embedder.py b/src/memos/configs/embedder.py index f143e5725..cd14a2571 100644 --- a/src/memos/configs/embedder.py +++ b/src/memos/configs/embedder.py @@ -18,6 +18,14 @@ class OllamaEmbedderConfig(BaseEmbedderConfig): api_base: str = Field(default="http://localhost:11434", description="Base URL for Ollama API") +class ArkEmbedderConfig(BaseEmbedderConfig): + api_key: str = Field(..., description="Ark API key") + api_base: str = Field( + default="https://ark.cn-beijing.volces.com/api/v3/", description="Base URL for Ark API" + ) + chunk_size: int = Field(default=1, description="Chunk size for Ark API") + + class SenTranEmbedderConfig(BaseEmbedderConfig): """Configuration class for Sentence Transformer embeddings.""" @@ -36,6 +44,7 @@ class EmbedderConfigFactory(BaseConfig): backend_to_class: ClassVar[dict[str, Any]] = { "ollama": OllamaEmbedderConfig, "sentence_transformer": SenTranEmbedderConfig, + "ark": ArkEmbedderConfig, } @field_validator("backend") diff --git a/src/memos/embedders/ark.py b/src/memos/embedders/ark.py new file mode 100644 index 000000000..cc8fba809 --- /dev/null +++ b/src/memos/embedders/ark.py @@ -0,0 +1,67 @@ +from volcenginesdkarkruntime import Ark +from volcenginesdkarkruntime.types.multimodal_embedding import ( + EmbeddingInputParam, + MultimodalEmbeddingContentPartTextParam, + MultimodalEmbeddingResponse, +) + +from memos.configs.embedder import ArkEmbedderConfig +from memos.embedders.base import BaseEmbedder +from memos.log import get_logger + + +logger = get_logger(__name__) + + +class ArkEmbedder(BaseEmbedder): + """Ark Embedder class.""" + + def __init__(self, config: ArkEmbedderConfig): + self.config = config + + if self.config.embedding_dims is not None: + logger.warning( + "Ark does not support specifying embedding dimensions. " + "The embedding dimensions is determined by the model." + "`embedding_dims` will be set to None." + ) + self.config.embedding_dims = None + + # Default model if not specified + if not self.config.model_name_or_path: + self.config.model_name_or_path = "doubao-embedding-vision-250615" + + # Initialize ark client + self.client = Ark(api_key=self.config.api_key, base_url=self.config.api_base) + + def embed(self, texts: list[str]) -> list[list[float]]: + """ + Generate embeddings for the given texts. + + Args: + texts: List of texts to embed. + + Returns: + List of embeddings, each represented as a list of floats. + """ + texts_input = [ + MultimodalEmbeddingContentPartTextParam(text=text, type="text") for text in texts + ] + return self.multimodal_embeddings(texts_input, chunk_size=self.config.chunk_size) + + def multimodal_embeddings( + self, inputs: list[EmbeddingInputParam], chunk_size: int | None = None + ) -> list[list[float]]: + chunk_size_ = chunk_size or self.config.chunk_size + embeddings: list[list[float]] = [] + + for i in range(0, len(inputs), chunk_size_): + response: MultimodalEmbeddingResponse = self.client.multimodal_embeddings.create( + model=self.config.model_name_or_path, + input=inputs[i : i + chunk_size_], + ) + + data = [response.data] if isinstance(response.data, dict) else response.data + embeddings.extend(r["embedding"] for r in data) + + return embeddings diff --git a/src/memos/embedders/factory.py b/src/memos/embedders/factory.py index 977bf95df..c0cfc26c3 100644 --- a/src/memos/embedders/factory.py +++ b/src/memos/embedders/factory.py @@ -1,6 +1,7 @@ from typing import Any, ClassVar from memos.configs.embedder import EmbedderConfigFactory +from memos.embedders.ark import ArkEmbedder from memos.embedders.base import BaseEmbedder from memos.embedders.ollama import OllamaEmbedder from memos.embedders.sentence_transformer import SenTranEmbedder @@ -12,6 +13,7 @@ class EmbedderFactory(BaseEmbedder): backend_to_class: ClassVar[dict[str, Any]] = { "ollama": OllamaEmbedder, "sentence_transformer": SenTranEmbedder, + "ark": ArkEmbedder, } @classmethod diff --git a/src/memos/memories/textual/general.py b/src/memos/memories/textual/general.py index 75d4e1add..7c23d417f 100644 --- a/src/memos/memories/textual/general.py +++ b/src/memos/memories/textual/general.py @@ -7,7 +7,7 @@ from tenacity import retry, retry_if_exception_type, stop_after_attempt from memos.configs.memory import GeneralTextMemoryConfig -from memos.embedders.factory import EmbedderFactory, OllamaEmbedder +from memos.embedders.factory import ArkEmbedder, EmbedderFactory, OllamaEmbedder from memos.llms.factory import LLMFactory, OllamaLLM, OpenAILLM from memos.log import get_logger from memos.memories.textual.base import BaseTextMemory @@ -28,7 +28,7 @@ def __init__(self, config: GeneralTextMemoryConfig): self.config: GeneralTextMemoryConfig = config self.extractor_llm: OpenAILLM | OllamaLLM = LLMFactory.from_config(config.extractor_llm) self.vector_db: QdrantVecDB = VecDBFactory.from_config(config.vector_db) - self.embedder: OllamaEmbedder = EmbedderFactory.from_config(config.embedder) + self.embedder: OllamaEmbedder | ArkEmbedder = EmbedderFactory.from_config(config.embedder) @retry( stop=stop_after_attempt(3), diff --git a/tests/embedders/test_ark.py b/tests/embedders/test_ark.py new file mode 100644 index 000000000..08c23aae6 --- /dev/null +++ b/tests/embedders/test_ark.py @@ -0,0 +1,64 @@ +import unittest + +from unittest.mock import patch + +from memos.configs.embedder import EmbedderConfigFactory +from memos.embedders.factory import ArkEmbedder, EmbedderFactory + + +class TestEmbedderFactory(unittest.TestCase): + @patch.object(ArkEmbedder, "embed") + def test_embed_single_text(self, mock_embed): + """Test embedding a single text.""" + mock_embed.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]] + + config = EmbedderConfigFactory.model_validate( + { + "backend": "ark", + "config": { + "model_name_or_path": "doubao-embedding-vision-250615", + "embedding_dims": 2048, + "api_key": "your-api-key", + "api_base": "https://ark.cn-beijing.volces.com/api/v3", + }, + } + ) + embedder = EmbedderFactory.from_config(config) + text = "This is a sample text for embedding generation." + result = embedder.embed([text]) + + mock_embed.assert_called_once_with([text]) + self.assertEqual(len(result[0]), 6) + + @patch.object(ArkEmbedder, "embed") + def test_embed_batch_text(self, mock_embed): + """Test embedding multiple texts at once.""" + mock_embed.return_value = [ + [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + [0.6, 0.5, 0.4, 0.3, 0.2, 0.1], + [0.3, 0.4, 0.5, 0.6, 0.1, 0.2], + ] + + config = EmbedderConfigFactory.model_validate( + { + "backend": "ark", + "config": { + "model_name_or_path": "doubao-embedding-vision-250615", + "embedding_dims": 2048, + "api_key": "your-api-key", + "api_base": "https://ark.cn-beijing.volces.com/api/v3", + }, + } + ) + embedder = EmbedderFactory.from_config(config) + texts = [ + "First sample text for batch embedding.", + "Second sample text for batch embedding.", + "Third sample text for batch embedding.", + ] + + result = embedder.embed(texts) + + mock_embed.assert_called_once_with(texts) + self.assertEqual(len(result), 3) + self.assertEqual(len(result[0]), 6)