Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyrit/backend/services/converter_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ async def list_converters_async(self) -> ConverterInstanceListResponse:
ConverterInstanceListResponse containing all registered converters.
"""
items = [
self._build_instance_from_object(converter_id=name, converter_obj=obj)
for name, obj in self._registry.get_all_instances().items()
self._build_instance_from_object(converter_id=entry.name, converter_obj=entry.instance)
for entry in self._registry.get_all_instances()
]
return ConverterInstanceListResponse(items=items)

Expand Down
4 changes: 2 additions & 2 deletions pyrit/backend/services/target_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ async def list_targets_async(
TargetListResponse containing paginated targets.
"""
items = [
self._build_instance_from_object(target_registry_name=name, target_obj=obj)
for name, obj in self._registry.get_all_instances().items()
self._build_instance_from_object(target_registry_name=entry.name, target_obj=entry.instance)
for entry in self._registry.get_all_instances()
]
page, has_more = self._paginate(items, cursor, limit)
next_cursor = page[-1].target_registry_name if has_more and page else None
Expand Down
2 changes: 2 additions & 0 deletions pyrit/registry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from pyrit.registry.instance_registries import (
BaseInstanceRegistry,
RegistryEntry,
ScorerRegistry,
TargetRegistry,
)
Expand All @@ -32,6 +33,7 @@
"discover_subclasses_in_loaded_modules",
"InitializerMetadata",
"InitializerRegistry",
"RegistryEntry",
"RegistryProtocol",
"ScenarioMetadata",
"ScenarioRegistry",
Expand Down
2 changes: 2 additions & 0 deletions pyrit/registry/instance_registries/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from pyrit.registry.instance_registries.base_instance_registry import (
BaseInstanceRegistry,
RegistryEntry,
)
from pyrit.registry.instance_registries.converter_registry import (
ConverterRegistry,
Expand All @@ -27,6 +28,7 @@
__all__ = [
# Base class
"BaseInstanceRegistry",
"RegistryEntry",
# Concrete registries
"ConverterRegistry",
"ScorerRegistry",
Expand Down
101 changes: 91 additions & 10 deletions pyrit/registry/instance_registries/base_instance_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union

from pyrit.identifiers import ComponentIdentifier
from pyrit.registry.base import RegistryProtocol
Expand All @@ -28,6 +29,25 @@
MetadataT = TypeVar("MetadataT", bound=ComponentIdentifier)


@dataclass
class RegistryEntry(Generic[T]):
"""
A wrapper around a registered instance, holding its name, tags, and the instance itself.

Tags are always stored as ``dict[str, str]``. When callers pass a plain
``list[str]``, each string is normalised to a key with an empty-string value.

Attributes:
name: The registry name for this entry.
instance: The registered object.
tags: Key-value tags for categorisation and filtering.
"""

name: str
instance: T
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clarifying question: if we know this is always a registered instance, would it be better to use a specific type or protocols instead of T? For example, I imagine that for registered instances we might want to have some protocol (RegisteredInstanceProtocol)? etc. that lets us expose helper methods to RegistryEntry. A generic type might be the simplest solution since this is just meant to hold tags, but I'm curious if there's a tradeoff

tags: dict[str, str] = field(default_factory=dict)


class BaseInstanceRegistry(ABC, RegistryProtocol[MetadataT], Generic[T, MetadataT]):
"""
Abstract base class for registries that store pre-configured instances.
Expand Down Expand Up @@ -71,26 +91,48 @@ def reset_instance(cls) -> None:
if cls in cls._instances:
del cls._instances[cls]

@staticmethod
def _normalize_tags(tags: Optional[Union[dict[str, str], list[str]]] = None) -> dict[str, str]:
"""
Normalize tags into a ``dict[str, str]``.

Args:
tags: Tags as a dict, a list of string keys (values default to ``""``),
or ``None`` (returns empty dict).

Returns:
A ``dict[str, str]`` of normalised tags.
"""
if tags is None:
return {}
if isinstance(tags, list):
return dict.fromkeys(tags, "")
return dict(tags)

def __init__(self) -> None:
"""Initialize the instance registry."""
# Maps registry names to registered items
self._registry_items: dict[str, T] = {}
# Maps registry names to registry entries
self._registry_items: dict[str, RegistryEntry[T]] = {}
self._metadata_cache: Optional[list[MetadataT]] = None

def register(
self,
instance: T,
*,
name: str,
tags: Optional[Union[dict[str, str], list[str]]] = None,
) -> None:
"""
Register an instance.

Args:
instance: The pre-configured instance to register.
name: The registry name for this instance.
tags: Optional tags for categorisation. Accepts a ``dict[str, str]``
or a ``list[str]`` (each string becomes a key with value ``""``).
"""
self._registry_items[name] = instance
normalized = self._normalize_tags(tags)
self._registry_items[name] = RegistryEntry(name=name, instance=instance, tags=normalized)
self._metadata_cache = None

def get(self, name: str) -> Optional[T]:
Expand All @@ -103,6 +145,21 @@ def get(self, name: str) -> Optional[T]:
Returns:
The instance, or None if not found.
"""
entry = self._registry_items.get(name)
if entry is None:
return None
return entry.instance

def get_entry(self, name: str) -> Optional[RegistryEntry[T]]:
"""
Get a full registry entry by name, including tags.

Args:
name: The registry name of the entry.

Returns:
The RegistryEntry, or None if not found.
"""
return self._registry_items.get(name)

def get_names(self) -> list[str]:
Expand All @@ -114,14 +171,38 @@ def get_names(self) -> list[str]:
"""
return sorted(self._registry_items.keys())

def get_all_instances(self) -> dict[str, T]:
def get_all_instances(self) -> list[RegistryEntry[T]]:
"""
Get all registered instances as a name -> instance mapping.
Get all registered entries sorted by name.

Returns:
List of RegistryEntry objects sorted by name.
"""
return [self._registry_items[name] for name in sorted(self._registry_items.keys())]

def get_by_tag(
self,
*,
tag: str,
value: Optional[str] = None,
) -> list[RegistryEntry[T]]:
"""
Get all entries that have a given tag, optionally matching a specific value.

Args:
tag: The tag key to match.
value: If provided, only entries whose tag value equals this are returned.
If ``None``, any entry that has the tag key is returned regardless of value.

Returns:
Dict mapping registry names to their instances.
List of matching RegistryEntry objects sorted by name.
"""
return dict(self._registry_items)
results: list[RegistryEntry[T]] = []
for name in sorted(self._registry_items.keys()):
entry = self._registry_items[name]
if tag in entry.tags and (value is None or entry.tags[tag] == value):
results.append(entry)
return results

def list_metadata(
self,
Expand Down Expand Up @@ -152,8 +233,8 @@ def list_metadata(
if self._metadata_cache is None:
items = []
for name in sorted(self._registry_items.keys()):
instance = self._registry_items[name]
items.append(self._build_metadata(name, instance))
entry = self._registry_items[name]
items.append(self._build_metadata(name, entry.instance))
self._metadata_cache = items

if not include_filters and not exclude_filters:
Expand Down
7 changes: 5 additions & 2 deletions pyrit/registry/instance_registries/converter_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union

from pyrit.identifiers import ComponentIdentifier
from pyrit.registry.instance_registries.base_instance_registry import (
Expand Down Expand Up @@ -49,6 +49,7 @@ def register_instance(
converter: PromptConverter,
*,
name: Optional[str] = None,
tags: Optional[Union[dict[str, str], list[str]]] = None,
) -> None:
"""
Register a converter instance.
Expand All @@ -57,11 +58,13 @@ def register_instance(
converter: The pre-configured converter instance (not a class).
name: Optional custom registry name. If not provided,
derived from the converter's unique identifier.
tags: Optional tags for categorisation. Accepts a ``dict[str, str]``
or a ``list[str]`` (each string becomes a key with value ``""``).
"""
if name is None:
name = converter.get_identifier().unique_name

self.register(converter, name=name)
self.register(converter, name=name, tags=tags)
logger.debug(f"Registered converter instance: {name} ({converter.__class__.__name__})")

def get_instance_by_name(self, name: str) -> Optional[PromptConverter]:
Expand Down
7 changes: 5 additions & 2 deletions pyrit/registry/instance_registries/scorer_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union

from pyrit.identifiers import ComponentIdentifier
from pyrit.registry.instance_registries.base_instance_registry import (
Expand Down Expand Up @@ -50,6 +50,7 @@ def register_instance(
scorer: Scorer,
*,
name: Optional[str] = None,
tags: Optional[Union[dict[str, str], list[str]]] = None,
) -> None:
"""
Register a scorer instance.
Expand All @@ -61,11 +62,13 @@ def register_instance(
scorer: The pre-configured scorer instance (not a class).
name: Optional custom registry name. If not provided,
derived from the scorer's unique identifier.
tags: Optional tags for categorisation. Accepts a ``dict[str, str]``
or a ``list[str]`` (each string becomes a key with value ``""``).
"""
if name is None:
name = scorer.get_identifier().unique_name

self.register(scorer, name=name)
self.register(scorer, name=name, tags=tags)
logger.debug(f"Registered scorer instance: {name} ({scorer.__class__.__name__})")

def get_instance_by_name(self, name: str) -> Optional[Scorer]:
Expand Down
7 changes: 5 additions & 2 deletions pyrit/registry/instance_registries/target_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union

from pyrit.identifiers import ComponentIdentifier
from pyrit.registry.instance_registries.base_instance_registry import (
Expand Down Expand Up @@ -50,6 +50,7 @@ def register_instance(
target: PromptTarget,
*,
name: Optional[str] = None,
tags: Optional[Union[dict[str, str], list[str]]] = None,
) -> None:
"""
Register a target instance.
Expand All @@ -62,11 +63,13 @@ def register_instance(
name: Optional custom registry name. If not provided,
derived from class name with identifier hash appended
(e.g., OpenAIChatTarget -> openai_chat_abc123).
tags: Optional tags for categorisation. Accepts a ``dict[str, str]``
or a ``list[str]`` (each string becomes a key with value ``""``).
"""
if name is None:
name = target.get_identifier().unique_name

self.register(target, name=name)
self.register(target, name=name, tags=tags)
logger.debug(f"Registered target instance: {name} ({target.__class__.__name__})")

def get_instance_by_name(self, name: str) -> Optional[PromptTarget]:
Expand Down
Loading
Loading