diff --git a/pyrit/backend/services/converter_service.py b/pyrit/backend/services/converter_service.py index a0579239cc..e75ef62c05 100644 --- a/pyrit/backend/services/converter_service.py +++ b/pyrit/backend/services/converter_service.py @@ -88,8 +88,8 @@ async def list_converters_async(self) -> ConverterInstanceListResponse: ConverterInstanceListResponse containing all registered converters. """ items = [ - self._build_instance_from_object(converter_id=name, converter_obj=obj) - for name, obj in self._registry.get_all_instances().items() + self._build_instance_from_object(converter_id=entry.name, converter_obj=entry.instance) + for entry in self._registry.get_all_instances() ] return ConverterInstanceListResponse(items=items) diff --git a/pyrit/backend/services/target_service.py b/pyrit/backend/services/target_service.py index 41d7164970..7959a80408 100644 --- a/pyrit/backend/services/target_service.py +++ b/pyrit/backend/services/target_service.py @@ -108,8 +108,8 @@ async def list_targets_async( TargetListResponse containing paginated targets. """ items = [ - self._build_instance_from_object(target_registry_name=name, target_obj=obj) - for name, obj in self._registry.get_all_instances().items() + self._build_instance_from_object(target_registry_name=entry.name, target_obj=entry.instance) + for entry in self._registry.get_all_instances() ] page, has_more = self._paginate(items, cursor, limit) next_cursor = page[-1].target_registry_name if has_more and page else None diff --git a/pyrit/registry/__init__.py b/pyrit/registry/__init__.py index 790b9284d8..76c162d342 100644 --- a/pyrit/registry/__init__.py +++ b/pyrit/registry/__init__.py @@ -19,6 +19,7 @@ ) from pyrit.registry.instance_registries import ( BaseInstanceRegistry, + RegistryEntry, ScorerRegistry, TargetRegistry, ) @@ -32,6 +33,7 @@ "discover_subclasses_in_loaded_modules", "InitializerMetadata", "InitializerRegistry", + "RegistryEntry", "RegistryProtocol", "ScenarioMetadata", "ScenarioRegistry", diff --git a/pyrit/registry/instance_registries/__init__.py b/pyrit/registry/instance_registries/__init__.py index 761735e261..d635813936 100644 --- a/pyrit/registry/instance_registries/__init__.py +++ b/pyrit/registry/instance_registries/__init__.py @@ -13,6 +13,7 @@ from pyrit.registry.instance_registries.base_instance_registry import ( BaseInstanceRegistry, + RegistryEntry, ) from pyrit.registry.instance_registries.converter_registry import ( ConverterRegistry, @@ -27,6 +28,7 @@ __all__ = [ # Base class "BaseInstanceRegistry", + "RegistryEntry", # Concrete registries "ConverterRegistry", "ScorerRegistry", diff --git a/pyrit/registry/instance_registries/base_instance_registry.py b/pyrit/registry/instance_registries/base_instance_registry.py index 558ef70157..38602ae120 100644 --- a/pyrit/registry/instance_registries/base_instance_registry.py +++ b/pyrit/registry/instance_registries/base_instance_registry.py @@ -16,7 +16,8 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union from pyrit.identifiers import ComponentIdentifier from pyrit.registry.base import RegistryProtocol @@ -28,6 +29,25 @@ MetadataT = TypeVar("MetadataT", bound=ComponentIdentifier) +@dataclass +class RegistryEntry(Generic[T]): + """ + A wrapper around a registered instance, holding its name, tags, and the instance itself. + + Tags are always stored as ``dict[str, str]``. When callers pass a plain + ``list[str]``, each string is normalised to a key with an empty-string value. + + Attributes: + name: The registry name for this entry. + instance: The registered object. + tags: Key-value tags for categorisation and filtering. + """ + + name: str + instance: T + tags: dict[str, str] = field(default_factory=dict) + + class BaseInstanceRegistry(ABC, RegistryProtocol[MetadataT], Generic[T, MetadataT]): """ Abstract base class for registries that store pre-configured instances. @@ -71,10 +91,28 @@ def reset_instance(cls) -> None: if cls in cls._instances: del cls._instances[cls] + @staticmethod + def _normalize_tags(tags: Optional[Union[dict[str, str], list[str]]] = None) -> dict[str, str]: + """ + Normalize tags into a ``dict[str, str]``. + + Args: + tags: Tags as a dict, a list of string keys (values default to ``""``), + or ``None`` (returns empty dict). + + Returns: + A ``dict[str, str]`` of normalised tags. + """ + if tags is None: + return {} + if isinstance(tags, list): + return dict.fromkeys(tags, "") + return dict(tags) + def __init__(self) -> None: """Initialize the instance registry.""" - # Maps registry names to registered items - self._registry_items: dict[str, T] = {} + # Maps registry names to registry entries + self._registry_items: dict[str, RegistryEntry[T]] = {} self._metadata_cache: Optional[list[MetadataT]] = None def register( @@ -82,6 +120,7 @@ def register( instance: T, *, name: str, + tags: Optional[Union[dict[str, str], list[str]]] = None, ) -> None: """ Register an instance. @@ -89,8 +128,11 @@ def register( Args: instance: The pre-configured instance to register. name: The registry name for this instance. + tags: Optional tags for categorisation. Accepts a ``dict[str, str]`` + or a ``list[str]`` (each string becomes a key with value ``""``). """ - self._registry_items[name] = instance + normalized = self._normalize_tags(tags) + self._registry_items[name] = RegistryEntry(name=name, instance=instance, tags=normalized) self._metadata_cache = None def get(self, name: str) -> Optional[T]: @@ -103,6 +145,21 @@ def get(self, name: str) -> Optional[T]: Returns: The instance, or None if not found. """ + entry = self._registry_items.get(name) + if entry is None: + return None + return entry.instance + + def get_entry(self, name: str) -> Optional[RegistryEntry[T]]: + """ + Get a full registry entry by name, including tags. + + Args: + name: The registry name of the entry. + + Returns: + The RegistryEntry, or None if not found. + """ return self._registry_items.get(name) def get_names(self) -> list[str]: @@ -114,14 +171,38 @@ def get_names(self) -> list[str]: """ return sorted(self._registry_items.keys()) - def get_all_instances(self) -> dict[str, T]: + def get_all_instances(self) -> list[RegistryEntry[T]]: """ - Get all registered instances as a name -> instance mapping. + Get all registered entries sorted by name. + + Returns: + List of RegistryEntry objects sorted by name. + """ + return [self._registry_items[name] for name in sorted(self._registry_items.keys())] + + def get_by_tag( + self, + *, + tag: str, + value: Optional[str] = None, + ) -> list[RegistryEntry[T]]: + """ + Get all entries that have a given tag, optionally matching a specific value. + + Args: + tag: The tag key to match. + value: If provided, only entries whose tag value equals this are returned. + If ``None``, any entry that has the tag key is returned regardless of value. Returns: - Dict mapping registry names to their instances. + List of matching RegistryEntry objects sorted by name. """ - return dict(self._registry_items) + results: list[RegistryEntry[T]] = [] + for name in sorted(self._registry_items.keys()): + entry = self._registry_items[name] + if tag in entry.tags and (value is None or entry.tags[tag] == value): + results.append(entry) + return results def list_metadata( self, @@ -152,8 +233,8 @@ def list_metadata( if self._metadata_cache is None: items = [] for name in sorted(self._registry_items.keys()): - instance = self._registry_items[name] - items.append(self._build_metadata(name, instance)) + entry = self._registry_items[name] + items.append(self._build_metadata(name, entry.instance)) self._metadata_cache = items if not include_filters and not exclude_filters: diff --git a/pyrit/registry/instance_registries/converter_registry.py b/pyrit/registry/instance_registries/converter_registry.py index 2ddbfa269a..987c94fa0f 100644 --- a/pyrit/registry/instance_registries/converter_registry.py +++ b/pyrit/registry/instance_registries/converter_registry.py @@ -12,7 +12,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Union from pyrit.identifiers import ComponentIdentifier from pyrit.registry.instance_registries.base_instance_registry import ( @@ -49,6 +49,7 @@ def register_instance( converter: PromptConverter, *, name: Optional[str] = None, + tags: Optional[Union[dict[str, str], list[str]]] = None, ) -> None: """ Register a converter instance. @@ -57,11 +58,13 @@ def register_instance( converter: The pre-configured converter instance (not a class). name: Optional custom registry name. If not provided, derived from the converter's unique identifier. + tags: Optional tags for categorisation. Accepts a ``dict[str, str]`` + or a ``list[str]`` (each string becomes a key with value ``""``). """ if name is None: name = converter.get_identifier().unique_name - self.register(converter, name=name) + self.register(converter, name=name, tags=tags) logger.debug(f"Registered converter instance: {name} ({converter.__class__.__name__})") def get_instance_by_name(self, name: str) -> Optional[PromptConverter]: diff --git a/pyrit/registry/instance_registries/scorer_registry.py b/pyrit/registry/instance_registries/scorer_registry.py index c83747d3cb..e1fb7e1e9c 100644 --- a/pyrit/registry/instance_registries/scorer_registry.py +++ b/pyrit/registry/instance_registries/scorer_registry.py @@ -10,7 +10,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Union from pyrit.identifiers import ComponentIdentifier from pyrit.registry.instance_registries.base_instance_registry import ( @@ -50,6 +50,7 @@ def register_instance( scorer: Scorer, *, name: Optional[str] = None, + tags: Optional[Union[dict[str, str], list[str]]] = None, ) -> None: """ Register a scorer instance. @@ -61,11 +62,13 @@ def register_instance( scorer: The pre-configured scorer instance (not a class). name: Optional custom registry name. If not provided, derived from the scorer's unique identifier. + tags: Optional tags for categorisation. Accepts a ``dict[str, str]`` + or a ``list[str]`` (each string becomes a key with value ``""``). """ if name is None: name = scorer.get_identifier().unique_name - self.register(scorer, name=name) + self.register(scorer, name=name, tags=tags) logger.debug(f"Registered scorer instance: {name} ({scorer.__class__.__name__})") def get_instance_by_name(self, name: str) -> Optional[Scorer]: diff --git a/pyrit/registry/instance_registries/target_registry.py b/pyrit/registry/instance_registries/target_registry.py index d5f71f4805..867a2b6995 100644 --- a/pyrit/registry/instance_registries/target_registry.py +++ b/pyrit/registry/instance_registries/target_registry.py @@ -10,7 +10,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Union from pyrit.identifiers import ComponentIdentifier from pyrit.registry.instance_registries.base_instance_registry import ( @@ -50,6 +50,7 @@ def register_instance( target: PromptTarget, *, name: Optional[str] = None, + tags: Optional[Union[dict[str, str], list[str]]] = None, ) -> None: """ Register a target instance. @@ -62,11 +63,13 @@ def register_instance( name: Optional custom registry name. If not provided, derived from class name with identifier hash appended (e.g., OpenAIChatTarget -> openai_chat_abc123). + tags: Optional tags for categorisation. Accepts a ``dict[str, str]`` + or a ``list[str]`` (each string becomes a key with value ``""``). """ if name is None: name = target.get_identifier().unique_name - self.register(target, name=name) + self.register(target, name=name, tags=tags) logger.debug(f"Registered target instance: {name} ({target.__class__.__name__})") def get_instance_by_name(self, name: str) -> Optional[PromptTarget]: diff --git a/tests/unit/registry/test_base_instance_registry.py b/tests/unit/registry/test_base_instance_registry.py index d2e1ad8318..91ef0b6868 100644 --- a/tests/unit/registry/test_base_instance_registry.py +++ b/tests/unit/registry/test_base_instance_registry.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. from pyrit.identifiers import ComponentIdentifier -from pyrit.registry.instance_registries.base_instance_registry import BaseInstanceRegistry +from pyrit.registry.instance_registries.base_instance_registry import BaseInstanceRegistry, RegistryEntry class ConcreteTestRegistry(BaseInstanceRegistry[str, ComponentIdentifier]): @@ -126,6 +126,34 @@ def test_get_nonexistent_returns_none(self): assert result is None +class TestBaseInstanceRegistryGetEntry: + """Tests for get_entry functionality in BaseInstanceRegistry.""" + + def setup_method(self): + """Reset and get a fresh registry for each test.""" + ConcreteTestRegistry.reset_instance() + self.registry = ConcreteTestRegistry.get_registry_singleton() + self.registry.register("test_value", name="test_name", tags={"role": "scorer"}) + + def teardown_method(self): + """Reset the singleton after each test.""" + ConcreteTestRegistry.reset_instance() + + def test_get_entry_returns_registry_entry(self): + """Test that get_entry returns a RegistryEntry with correct fields.""" + entry = self.registry.get_entry("test_name") + assert entry is not None + assert isinstance(entry, RegistryEntry) + assert entry.name == "test_name" + assert entry.instance == "test_value" + assert entry.tags == {"role": "scorer"} + + def test_get_entry_nonexistent_returns_none(self): + """Test that get_entry returns None for a non-existent name.""" + result = self.registry.get_entry("nonexistent") + assert result is None + + class TestBaseInstanceRegistryGetNames: """Tests for get_names functionality in BaseInstanceRegistry.""" @@ -153,6 +181,53 @@ def test_get_names_returns_sorted_list(self): assert names == ["alpha", "beta", "zeta"] +class TestBaseInstanceRegistryGetAllInstances: + """Tests for get_all_instances functionality in BaseInstanceRegistry.""" + + def setup_method(self): + """Reset and get a fresh registry for each test.""" + ConcreteTestRegistry.reset_instance() + self.registry = ConcreteTestRegistry.get_registry_singleton() + + def teardown_method(self): + """Reset the singleton after each test.""" + ConcreteTestRegistry.reset_instance() + + def test_get_all_instances_returns_list_of_registry_entries(self): + """Test that get_all_instances returns a list of RegistryEntry objects.""" + self.registry.register("value1", name="name1") + self.registry.register("value2", name="name2") + + result = self.registry.get_all_instances() + assert isinstance(result, list) + assert len(result) == 2 + assert all(isinstance(entry, RegistryEntry) for entry in result) + + def test_get_all_instances_sorted_by_name(self): + """Test that get_all_instances returns entries sorted by name.""" + self.registry.register("value_z", name="zeta") + self.registry.register("value_a", name="alpha") + self.registry.register("value_b", name="beta") + + result = self.registry.get_all_instances() + assert [e.name for e in result] == ["alpha", "beta", "zeta"] + + def test_get_all_instances_preserves_tags(self): + """Test that get_all_instances preserves tags on entries.""" + self.registry.register("value1", name="name1", tags={"role": "scorer"}) + self.registry.register("value2", name="name2", tags=["fast"]) + + result = self.registry.get_all_instances() + entry_map = {e.name: e for e in result} + assert entry_map["name1"].tags == {"role": "scorer"} + assert entry_map["name2"].tags == {"fast": ""} + + def test_get_all_instances_empty_registry(self): + """Test that get_all_instances returns empty list on empty registry.""" + result = self.registry.get_all_instances() + assert result == [] + + class TestBaseInstanceRegistryListMetadata: """Tests for list_metadata functionality in BaseInstanceRegistry.""" @@ -224,6 +299,118 @@ def test_list_metadata_caching(self): assert len(metadata1) == 3 +class TestBaseInstanceRegistryTags: + """Tests for tag registration and retrieval in BaseInstanceRegistry.""" + + def setup_method(self): + """Reset and get a fresh registry for each test.""" + ConcreteTestRegistry.reset_instance() + self.registry = ConcreteTestRegistry.get_registry_singleton() + + def teardown_method(self): + """Reset the singleton after each test.""" + ConcreteTestRegistry.reset_instance() + + def test_register_with_dict_tags(self): + """Test that dict tags are stored correctly.""" + self.registry.register("value", name="name1", tags={"role": "scorer", "provider": "azure"}) + + entry = self.registry.get_entry("name1") + assert entry is not None + assert entry.tags == {"role": "scorer", "provider": "azure"} + + def test_register_with_list_tags(self): + """Test that list tags are normalized to dict with empty string values.""" + self.registry.register("value", name="name1", tags=["fast", "default"]) + + entry = self.registry.get_entry("name1") + assert entry is not None + assert entry.tags == {"fast": "", "default": ""} + + def test_register_without_tags(self): + """Test that registering without tags defaults to empty dict.""" + self.registry.register("value", name="name1") + + entry = self.registry.get_entry("name1") + assert entry is not None + assert entry.tags == {} + + def test_get_by_tag_key_only(self): + """Test get_by_tag matching by key only (any value).""" + self.registry.register("v1", name="n1", tags={"role": "scorer"}) + self.registry.register("v2", name="n2", tags={"role": "target"}) + self.registry.register("v3", name="n3", tags={"provider": "azure"}) + + results = self.registry.get_by_tag(tag="role") + assert len(results) == 2 + assert {e.name for e in results} == {"n1", "n2"} + + def test_get_by_tag_key_and_value(self): + """Test get_by_tag matching by key and specific value.""" + self.registry.register("v1", name="n1", tags={"role": "scorer"}) + self.registry.register("v2", name="n2", tags={"role": "target"}) + self.registry.register("v3", name="n3", tags={"role": "scorer"}) + + results = self.registry.get_by_tag(tag="role", value="scorer") + assert len(results) == 2 + assert {e.name for e in results} == {"n1", "n3"} + + def test_get_by_tag_no_match(self): + """Test get_by_tag returns empty list when no entries match.""" + self.registry.register("v1", name="n1", tags={"role": "scorer"}) + + results = self.registry.get_by_tag(tag="nonexistent") + assert results == [] + + def test_get_by_tag_value_no_match(self): + """Test get_by_tag returns empty when key exists but value does not match.""" + self.registry.register("v1", name="n1", tags={"role": "scorer"}) + + results = self.registry.get_by_tag(tag="role", value="nonexistent") + assert results == [] + + def test_get_by_tag_returns_sorted_by_name(self): + """Test that get_by_tag results are sorted by name.""" + self.registry.register("v3", name="zeta", tags=["shared"]) + self.registry.register("v1", name="alpha", tags=["shared"]) + self.registry.register("v2", name="beta", tags=["shared"]) + + results = self.registry.get_by_tag(tag="shared") + assert [e.name for e in results] == ["alpha", "beta", "zeta"] + + def test_get_by_tag_with_list_tags(self): + """Test get_by_tag works with list-style tags (normalized to empty string values).""" + self.registry.register("v1", name="n1", tags=["fast", "default"]) + self.registry.register("v2", name="n2", tags=["slow"]) + + results = self.registry.get_by_tag(tag="fast") + assert len(results) == 1 + assert results[0].name == "n1" + + def test_get_by_tag_with_list_tags_value_empty_string(self): + """Test get_by_tag with explicit empty string value matches list-style tags.""" + self.registry.register("v1", name="n1", tags=["fast"]) + + results = self.registry.get_by_tag(tag="fast", value="") + assert len(results) == 1 + assert results[0].name == "n1" + + def test_normalize_tags_none(self): + """Test _normalize_tags returns empty dict for None.""" + assert BaseInstanceRegistry._normalize_tags(None) == {} + + def test_normalize_tags_list(self): + """Test _normalize_tags converts list to dict with empty values.""" + assert BaseInstanceRegistry._normalize_tags(["a", "b"]) == {"a": "", "b": ""} + + def test_normalize_tags_dict(self): + """Test _normalize_tags returns a copy of the dict.""" + original = {"key": "val"} + result = BaseInstanceRegistry._normalize_tags(original) + assert result == {"key": "val"} + assert result is not original + + class TestBaseInstanceRegistryDunderMethods: """Tests for dunder methods (__contains__, __len__, __iter__) in BaseInstanceRegistry.""" diff --git a/tests/unit/registry/test_converter_registry.py b/tests/unit/registry/test_converter_registry.py index c737e53538..b62fedb6da 100644 --- a/tests/unit/registry/test_converter_registry.py +++ b/tests/unit/registry/test_converter_registry.py @@ -339,11 +339,12 @@ def test_get_names_returns_sorted_list(self): assert names == ["alpha_converter", "test_converter", "zeta_converter"] def test_get_all_instances_returns_all(self): - """Test get_all_instances returns dict of all registered instances.""" + """Test get_all_instances returns list of all registered entries.""" image_converter = MockImageConverter() self.registry.register_instance(image_converter, name="image_converter") - all_instances = self.registry.get_all_instances() - assert len(all_instances) == 2 - assert all_instances["test_converter"] is self.converter - assert all_instances["image_converter"] is image_converter + all_entries = self.registry.get_all_instances() + assert len(all_entries) == 2 + entry_map = {e.name: e for e in all_entries} + assert entry_map["test_converter"].instance is self.converter + assert entry_map["image_converter"].instance is image_converter