Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ sentence-transformers = "^4.1.0"
sqlalchemy = "^2.0.41"
redis = "^6.2.0"
schedule = "^1.2.2"
volcengine-python-sdk = "^4.0.4"

[tool.poetry.group.dev]
optional = false
Expand Down
9 changes: 9 additions & 0 deletions src/memos/configs/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ class OllamaEmbedderConfig(BaseEmbedderConfig):
api_base: str = Field(default="http://localhost:11434", description="Base URL for Ollama API")


class ArkEmbedderConfig(BaseEmbedderConfig):
api_key: str = Field(..., description="Ark API key")
api_base: str = Field(
default="https://ark.cn-beijing.volces.com/api/v3/", description="Base URL for Ark API"
)
chunk_size: int = Field(default=1, description="Chunk size for Ark API")


class SenTranEmbedderConfig(BaseEmbedderConfig):
"""Configuration class for Sentence Transformer embeddings."""

Expand All @@ -36,6 +44,7 @@ class EmbedderConfigFactory(BaseConfig):
backend_to_class: ClassVar[dict[str, Any]] = {
"ollama": OllamaEmbedderConfig,
"sentence_transformer": SenTranEmbedderConfig,
"ark": ArkEmbedderConfig,
}

@field_validator("backend")
Expand Down
67 changes: 67 additions & 0 deletions src/memos/embedders/ark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from volcenginesdkarkruntime import Ark
from volcenginesdkarkruntime.types.multimodal_embedding import (
EmbeddingInputParam,
MultimodalEmbeddingContentPartTextParam,
MultimodalEmbeddingResponse,
)

from memos.configs.embedder import ArkEmbedderConfig
from memos.embedders.base import BaseEmbedder
from memos.log import get_logger


logger = get_logger(__name__)


class ArkEmbedder(BaseEmbedder):
"""Ark Embedder class."""

def __init__(self, config: ArkEmbedderConfig):
self.config = config

if self.config.embedding_dims is not None:
logger.warning(
"Ark does not support specifying embedding dimensions. "
"The embedding dimensions is determined by the model."
"`embedding_dims` will be set to None."
)
self.config.embedding_dims = None

# Default model if not specified
if not self.config.model_name_or_path:
self.config.model_name_or_path = "doubao-embedding-vision-250615"

# Initialize ark client
self.client = Ark(api_key=self.config.api_key, base_url=self.config.api_base)

def embed(self, texts: list[str]) -> list[list[float]]:
"""
Generate embeddings for the given texts.

Args:
texts: List of texts to embed.

Returns:
List of embeddings, each represented as a list of floats.
"""
texts_input = [
MultimodalEmbeddingContentPartTextParam(text=text, type="text") for text in texts
]
return self.multimodal_embeddings(texts_input, chunk_size=self.config.chunk_size)

def multimodal_embeddings(
self, inputs: list[EmbeddingInputParam], chunk_size: int | None = None
) -> list[list[float]]:
chunk_size_ = chunk_size or self.config.chunk_size
embeddings: list[list[float]] = []

for i in range(0, len(inputs), chunk_size_):
response: MultimodalEmbeddingResponse = self.client.multimodal_embeddings.create(
model=self.config.model_name_or_path,
input=inputs[i : i + chunk_size_],
)

data = [response.data] if isinstance(response.data, dict) else response.data
embeddings.extend(r["embedding"] for r in data)

return embeddings
2 changes: 2 additions & 0 deletions src/memos/embedders/factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, ClassVar

from memos.configs.embedder import EmbedderConfigFactory
from memos.embedders.ark import ArkEmbedder
from memos.embedders.base import BaseEmbedder
from memos.embedders.ollama import OllamaEmbedder
from memos.embedders.sentence_transformer import SenTranEmbedder
Expand All @@ -12,6 +13,7 @@ class EmbedderFactory(BaseEmbedder):
backend_to_class: ClassVar[dict[str, Any]] = {
"ollama": OllamaEmbedder,
"sentence_transformer": SenTranEmbedder,
"ark": ArkEmbedder,
}

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions src/memos/memories/textual/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tenacity import retry, retry_if_exception_type, stop_after_attempt

from memos.configs.memory import GeneralTextMemoryConfig
from memos.embedders.factory import EmbedderFactory, OllamaEmbedder
from memos.embedders.factory import ArkEmbedder, EmbedderFactory, OllamaEmbedder
from memos.llms.factory import LLMFactory, OllamaLLM, OpenAILLM
from memos.log import get_logger
from memos.memories.textual.base import BaseTextMemory
Expand All @@ -28,7 +28,7 @@ def __init__(self, config: GeneralTextMemoryConfig):
self.config: GeneralTextMemoryConfig = config
self.extractor_llm: OpenAILLM | OllamaLLM = LLMFactory.from_config(config.extractor_llm)
self.vector_db: QdrantVecDB = VecDBFactory.from_config(config.vector_db)
self.embedder: OllamaEmbedder = EmbedderFactory.from_config(config.embedder)
self.embedder: OllamaEmbedder | ArkEmbedder = EmbedderFactory.from_config(config.embedder)

@retry(
stop=stop_after_attempt(3),
Expand Down
64 changes: 64 additions & 0 deletions tests/embedders/test_ark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import unittest

from unittest.mock import patch

from memos.configs.embedder import EmbedderConfigFactory
from memos.embedders.factory import ArkEmbedder, EmbedderFactory


class TestEmbedderFactory(unittest.TestCase):
@patch.object(ArkEmbedder, "embed")
def test_embed_single_text(self, mock_embed):
"""Test embedding a single text."""
mock_embed.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]]

config = EmbedderConfigFactory.model_validate(
{
"backend": "ark",
"config": {
"model_name_or_path": "doubao-embedding-vision-250615",
"embedding_dims": 2048,
"api_key": "your-api-key",
"api_base": "https://ark.cn-beijing.volces.com/api/v3",
},
}
)
embedder = EmbedderFactory.from_config(config)
text = "This is a sample text for embedding generation."
result = embedder.embed([text])

mock_embed.assert_called_once_with([text])
self.assertEqual(len(result[0]), 6)

@patch.object(ArkEmbedder, "embed")
def test_embed_batch_text(self, mock_embed):
"""Test embedding multiple texts at once."""
mock_embed.return_value = [
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
[0.6, 0.5, 0.4, 0.3, 0.2, 0.1],
[0.3, 0.4, 0.5, 0.6, 0.1, 0.2],
]

config = EmbedderConfigFactory.model_validate(
{
"backend": "ark",
"config": {
"model_name_or_path": "doubao-embedding-vision-250615",
"embedding_dims": 2048,
"api_key": "your-api-key",
"api_base": "https://ark.cn-beijing.volces.com/api/v3",
},
}
)
embedder = EmbedderFactory.from_config(config)
texts = [
"First sample text for batch embedding.",
"Second sample text for batch embedding.",
"Third sample text for batch embedding.",
]

result = embedder.embed(texts)

mock_embed.assert_called_once_with(texts)
self.assertEqual(len(result), 3)
self.assertEqual(len(result[0]), 6)