Skip to content

Commit 4e3f77b

Browse files
committed
feat: add singleton for memos and llm
1 parent c7df5ad commit 4e3f77b

File tree

10 files changed

+466
-195
lines changed

10 files changed

+466
-195
lines changed

src/memos/api/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ def get_activation_config() -> Dict[str, Any]:
4646
"config": {
4747
"memory_filename": "activation_memory.pickle",
4848
"extractor_llm": {
49-
"backend": "huggingface",
49+
"backend": "huggingface_singleton",
5050
"config": {
51-
"model_name_or_path": "Qwen/Qwen3-1.7B",
51+
"model_name_or_path": os.getenv("MOS_CHAT_MODEL", "Qwen/Qwen3-1.7B"),
5252
"temperature": 0.8,
5353
"max_tokens": 1024,
5454
"top_p": 0.9,

src/memos/api/product_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class UserRegisterRequest(BaseRequest):
4444
class GetMemoryRequest(BaseRequest):
4545
"""Request model for getting memories."""
4646
user_id: str = Field(..., description="User ID")
47-
memory_type: Literal["text_mem", "act_mem", "param_mem"] = Field(..., description="Memory type")
47+
memory_type: Literal["text_mem", "act_mem", "param_mem", "para_mem"] = Field(..., description="Memory type")
4848
mem_cube_ids: list[str] | None = Field(None, description="Cube IDs")
4949
search_query: str | None = Field(None, description="Search query")
5050

src/memos/api/routers/product_router.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def get_mos_product_instance():
2727
from memos.configs.mem_os import MOSConfig
2828
mos_config = MOSConfig(**default_config)
2929
MOS_PRODUCT_INSTANCE = MOSProduct(default_config=mos_config)
30-
logger.info("MOSProduct instance created successfully")
30+
logger.info("MOSProduct instance created successfully with inheritance architecture")
3131
return MOS_PRODUCT_INSTANCE
3232

3333
get_mos_product_instance()
@@ -265,24 +265,24 @@ async def update_user_config(user_id: str, config_data: dict):
265265
raise HTTPException(status_code=500, detail=str(traceback.format_exc()))
266266

267267

268-
@router.get("/instances/status", summary="Get user instance status", response_model=BaseResponse[dict])
268+
@router.get("/instances/status", summary="Get user configuration status", response_model=BaseResponse[dict])
269269
async def get_instance_status():
270-
"""Get information about active user instances in memory."""
270+
"""Get information about active user configurations in memory."""
271271
try:
272272
mos_product = get_mos_product_instance()
273273
status_info = mos_product.get_user_instance_info()
274274
return BaseResponse(
275-
message="Instance status retrieved successfully",
275+
message="User configuration status retrieved successfully",
276276
data=status_info
277277
)
278278
except Exception as e:
279-
logger.error(f"Failed to get instance status: {traceback.format_exc()}")
279+
logger.error(f"Failed to get user configuration status: {traceback.format_exc()}")
280280
raise HTTPException(status_code=500, detail=str(traceback.format_exc()))
281281

282282

283283
@router.get("/instances/count", summary="Get active user count", response_model=BaseResponse[int])
284284
async def get_active_user_count():
285-
"""Get the number of active user instances in memory."""
285+
"""Get the number of active user configurations in memory."""
286286
try:
287287
mos_product = get_mos_product_instance()
288288
count = mos_product.get_active_user_count()

src/memos/configs/llm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class LLMConfigFactory(BaseConfig):
5454
"openai": OpenAILLMConfig,
5555
"ollama": OllamaLLMConfig,
5656
"huggingface": HFLLMConfig,
57+
"huggingface_singleton": HFLLMConfig, # Add singleton support
5758
}
5859

5960
@field_validator("backend")

src/memos/configs/memory.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ class KVCacheMemoryConfig(BaseActMemoryConfig):
5151
@classmethod
5252
def validate_extractor_llm(cls, extractor_llm: LLMConfigFactory) -> LLMConfigFactory:
5353
"""Validate the extractor_llm field."""
54-
if extractor_llm.backend != "huggingface":
54+
if extractor_llm.backend not in ["huggingface", "huggingface_singleton"]:
5555
raise ConfigurationError(
56-
f"KVCacheMemoryConfig requires extractor_llm backend to be 'huggingface', got '{extractor_llm.backend}'"
56+
f"KVCacheMemoryConfig requires extractor_llm backend to be 'huggingface' or 'huggingface_singleton', got '{extractor_llm.backend}'"
5757
)
5858
return extractor_llm
5959

@@ -83,9 +83,9 @@ class LoRAMemoryConfig(BaseParaMemoryConfig):
8383
@classmethod
8484
def validate_extractor_llm(cls, extractor_llm: LLMConfigFactory) -> LLMConfigFactory:
8585
"""Validate the extractor_llm field."""
86-
if extractor_llm.backend not in ["huggingface"]:
86+
if extractor_llm.backend not in ["huggingface", "huggingface_singleton"]:
8787
raise ConfigurationError(
88-
f"LoRAMemoryConfig requires extractor_llm backend to be 'huggingface', got '{extractor_llm.backend}'"
88+
f"LoRAMemoryConfig requires extractor_llm backend to be 'huggingface' or 'huggingface_singleton', got '{extractor_llm.backend}'"
8989
)
9090
return extractor_llm
9191

src/memos/llms/factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from memos.configs.llm import LLMConfigFactory
44
from memos.llms.base import BaseLLM
55
from memos.llms.hf import HFLLM
6+
from memos.llms.hf_singleton import HFSingletonLLM
67
from memos.llms.ollama import OllamaLLM
78
from memos.llms.openai import OpenAILLM
89

@@ -14,6 +15,7 @@ class LLMFactory(BaseLLM):
1415
"openai": OpenAILLM,
1516
"ollama": OllamaLLM,
1617
"huggingface": HFLLM,
18+
"huggingface_singleton": HFSingletonLLM, # Add singleton version
1719
}
1820

1921
@classmethod

src/memos/llms/hf_singleton.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import threading
2+
from typing import Dict, Optional
3+
from memos.configs.llm import HFLLMConfig
4+
from memos.llms.hf import HFLLM
5+
from memos.log import get_logger
6+
7+
logger = get_logger(__name__)
8+
9+
10+
class HFSingletonLLM(HFLLM):
11+
"""
12+
Singleton version of HFLLM that prevents multiple loading of the same model.
13+
This class inherits from HFLLM and adds singleton behavior.
14+
"""
15+
16+
_instances: Dict[str, 'HFSingletonLLM'] = {}
17+
_lock = threading.Lock()
18+
19+
def __new__(cls, config: HFLLMConfig):
20+
"""
21+
Singleton pattern implementation.
22+
Returns existing instance if config already exists, otherwise creates new one.
23+
"""
24+
config_key = cls._get_config_key(config)
25+
26+
if config_key in cls._instances:
27+
logger.debug(f"Reusing existing HF model: {config.model_name_or_path}")
28+
return cls._instances[config_key]
29+
30+
with cls._lock:
31+
# Double-check pattern to prevent race conditions
32+
if config_key in cls._instances:
33+
logger.debug(f"Reusing existing HF model: {config.model_name_or_path}")
34+
return cls._instances[config_key]
35+
36+
logger.info(f"Creating new HF model: {config.model_name_or_path}")
37+
instance = super().__new__(cls)
38+
cls._instances[config_key] = instance
39+
return instance
40+
41+
def __init__(self, config: HFLLMConfig):
42+
"""
43+
Initialize the singleton HFLLM instance.
44+
Only initializes if this is a new instance.
45+
"""
46+
# Check if already initialized
47+
if hasattr(self, '_initialized'):
48+
return
49+
50+
# Call parent constructor
51+
super().__init__(config)
52+
self._initialized = True
53+
54+
@classmethod
55+
def _get_config_key(cls, config: HFLLMConfig) -> str:
56+
"""
57+
Generate a unique key for the HF model configuration.
58+
59+
Args:
60+
config: The HFLLM configuration
61+
62+
Returns:
63+
A unique string key representing the configuration
64+
"""
65+
# Create a unique key based on model path and key parameters
66+
# str(config.temperature),
67+
# str(config.max_tokens),
68+
# str(config.top_p),
69+
# str(config.top_k),
70+
# str(config.add_generation_prompt),
71+
# str(config.remove_think_prefix),
72+
# str(config.do_sample)
73+
key_parts = [
74+
config.model_name_or_path
75+
]
76+
return "|".join(key_parts)
77+
78+
@classmethod
79+
def get_instance_count(cls) -> int:
80+
"""
81+
Get the number of unique HF model instances currently managed.
82+
83+
Returns:
84+
Number of HF model instances
85+
"""
86+
return len(cls._instances)
87+
88+
@classmethod
89+
def get_instance_info(cls) -> Dict[str, str]:
90+
"""
91+
Get information about all managed HF model instances.
92+
93+
Returns:
94+
Dictionary mapping config keys to model paths
95+
"""
96+
return {key: instance.config.model_name_or_path
97+
for key, instance in cls._instances.items()}
98+
99+
@classmethod
100+
def clear_all(cls) -> None:
101+
"""
102+
Clear all HF model instances from memory.
103+
This should be used carefully as it will force reloading of models.
104+
"""
105+
with cls._lock:
106+
cls._instances.clear()
107+
logger.info("All HF model instances cleared from singleton manager")
108+
109+
110+
# Convenience function to get singleton manager info
111+
def get_hf_singleton_info() -> Dict[str, int]:
112+
"""
113+
Get information about the HF singleton manager.
114+
115+
Returns:
116+
Dictionary with instance count and info
117+
"""
118+
return {
119+
"instance_count": HFSingletonLLM.get_instance_count(),
120+
"instance_info": HFSingletonLLM.get_instance_info()
121+
}

src/memos/mem_os/core.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -477,16 +477,6 @@ def search(
477477
logger.info(
478478
f"🧠 [Memory] Searched memories from {mem_cube_id}:\n{self._str_memories(memories)}\n"
479479
)
480-
if (
481-
(mem_cube_id in install_cube_ids)
482-
and (mem_cube.act_mem is not None)
483-
and self.config.enable_activation_memory
484-
):
485-
memories = mem_cube.act_mem.extract(query)
486-
result["act_mem"].append({"cube_id": mem_cube_id, "memories": [memories]})
487-
logger.info(
488-
f"🧠 [Memory] Searched memories from {mem_cube_id}:\n{self._str_memories(memories)}\n"
489-
)
490480
return result
491481

492482
def add(

0 commit comments

Comments
 (0)