diff --git a/doc/code/datasets/1_loading_datasets.ipynb b/doc/code/datasets/1_loading_datasets.ipynb index 7bc6aa60a2..b09f5ad1c7 100644 --- a/doc/code/datasets/1_loading_datasets.ipynb +++ b/doc/code/datasets/1_loading_datasets.ipynb @@ -52,74 +52,13 @@ "execution_count": null, "id": "1", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['adv_bench',\n", - " 'aegis_content_safety',\n", - " 'airt_fairness',\n", - " 'airt_fairness_yes_no',\n", - " 'airt_harassment',\n", - " 'airt_harms',\n", - " 'airt_hate',\n", - " 'airt_illegal',\n", - " 'airt_imminent_crisis',\n", - " 'airt_leakage',\n", - " 'airt_malware',\n", - " 'airt_misinformation',\n", - " 'airt_scams',\n", - " 'airt_sexual',\n", - " 'airt_violence',\n", - " 'aya_redteaming',\n", - " 'babelscape_alert',\n", - " 'beaver_tails',\n", - " 'cbt_bench',\n", - " 'ccp_sensitive_prompts',\n", - " 'dark_bench',\n", - " 'equitymedqa',\n", - " 'forbidden_questions',\n", - " 'garak_access_shell_commands',\n", - " 'garak_slur_terms_en',\n", - " 'garak_web_html_js',\n", - " 'harmbench',\n", - " 'harmbench_multimodal',\n", - " 'harmful_qa',\n", - " 'jbb_behaviors',\n", - " 'librai_do_not_answer',\n", - " 'llm_lat_harmful',\n", - " 'medsafetybench',\n", - " 'mental_health_crisis_multiturn_example',\n", - " 'ml_vlsu',\n", - " 'mlcommons_ailuminate',\n", - " 'multilingual_vulnerability',\n", - " 'or_bench_80k',\n", - " 'or_bench_hard',\n", - " 'or_bench_toxic',\n", - " 'pku_safe_rlhf',\n", - " 'promptintel',\n", - " 'psfuzz_steal_system_prompt',\n", - " 'pyrit_example_dataset',\n", - " 'red_team_social_bias',\n", - " 'salad_bench',\n", - " 'simple_safety_tests',\n", - " 'sorry_bench',\n", - " 'sosbench',\n", - " 'tdc23_redteaming',\n", - " 'toxic_chat',\n", - " 'transphobia_awareness',\n", - " 'xstest']" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "from pyrit.datasets import SeedDatasetProvider\n", + "from pyrit.memory import CentralMemory\n", + "from pyrit.setup.initialization import IN_MEMORY, initialize_pyrit_async\n", "\n", - "SeedDatasetProvider.get_all_dataset_names()" + "await SeedDatasetProvider.get_all_dataset_names_async()" ] }, { @@ -137,64 +76,10 @@ "execution_count": null, "id": "3", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\r", - "Loading datasets - this can take a few minutes: 0%| | 0/58 [00:00 SeedDataset: logger.error(f"Failed to load local dataset from {self.file_path}: {e}") raise + async def _parse_metadata(self) -> Optional[SeedDatasetMetadata]: + """ + Extract metadata from a local YAML file and coerce raw values into typed schema fields. + + YAML produces raw Python primitives (str, list) that must be converted to the + enum and set types expected by SeedDatasetMetadata before _match_filter can work. + + Returns: + Optional[SeedDatasetMetadata]: Parsed metadata if available, otherwise None. + + Raises: + Exception: If the dataset file cannot be read. + """ + valid_fields = [f.name for f in fields(SeedDatasetMetadata)] + try: + with open(self.file_path, encoding="utf-8") as f: + dataset = yaml.safe_load(f) + except Exception as e: + logger.error(f"Failed to load local dataset from {self.file_path}: {e}") + raise + + if not isinstance(dataset, dict): + return None + + raw = {k: v for k, v in dataset.items() if k in valid_fields} + if not raw: + return None + + coerced = SeedDatasetMetadata._coerce_metadata_values(raw_metadata=raw) + result = SeedDatasetMetadata(**coerced) + # Validation after coercion: raw values are strings/lists, not sets. + # _validate_singular_fields needs sets to check cardinality. + SeedDatasetMetadata._validate_singular_fields(metadata=result) + return result + def _register_local_datasets() -> None: """ diff --git a/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py b/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py index fc6d46e54d..a1e59f68e1 100644 --- a/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py @@ -19,6 +19,12 @@ class _HarmBenchDataset(_RemoteDatasetLoader): Reference: https://github.com/centerforaisafety/HarmBench """ + # Metadata + harm_categories: list[str] = ["cybercrime", "illegal", "harmful", "chemical_biological", "harassment"] + modalities: list[str] = ["text"] + size: str = "large" # 504 seeds + tags: set[str] = {"default", "safety"} + def __init__( self, *, diff --git a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py index 5cd9212846..944dd513f3 100644 --- a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py +++ b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py @@ -8,6 +8,7 @@ import tempfile from abc import ABC from collections.abc import Callable +from dataclasses import fields from pathlib import Path from typing import Any, Literal, Optional, TextIO, cast @@ -19,6 +20,7 @@ from pyrit.common.path import DB_DATA_PATH from pyrit.common.text_helper import read_txt, write_txt from pyrit.datasets.seed_datasets.seed_dataset_provider import SeedDatasetProvider +from pyrit.datasets.seed_datasets.seed_metadata import SeedDatasetMetadata logger = logging.getLogger(__name__) @@ -285,3 +287,33 @@ def _load_dataset_sync() -> Any: except Exception as e: logger.error(f"Failed to load HuggingFace dataset {dataset_name}: {e}") raise + + async def _parse_metadata(self) -> Optional[SeedDatasetMetadata]: + """ + Extract metadata from class attributes, wrap in sets, and format into SeedDatasetMetadata. + + Class attributes may be singular values (str, enum), lists, or sets. + All are normalized into sets for the unified SeedDatasetMetadata schema. + + Returns: + Optional[SeedDatasetMetadata]: Parsed metadata if available, otherwise None. + """ + valid_fields = [f.name for f in fields(SeedDatasetMetadata)] + + provider_class = type(self) + raw = {} + for key in valid_fields: + value = getattr(provider_class, key, None) + if value is None: + continue + raw[key] = value + + if not raw: + return None + + coerced = SeedDatasetMetadata._coerce_metadata_values(raw_metadata=raw) + # Validation must happen after coercion because raw values are strings/lists, + # not sets. _validate_singular_fields checks set cardinality (len > 1). + result = SeedDatasetMetadata(**coerced) + SeedDatasetMetadata._validate_singular_fields(metadata=result) + return result diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py index 56b61b3996..7ce5eb8c37 100644 --- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py +++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py @@ -5,10 +5,12 @@ import inspect import logging from abc import ABC, abstractmethod +from dataclasses import fields as dc_fields from typing import Any, Optional from tqdm import tqdm +from pyrit.datasets.seed_datasets.seed_metadata import SeedDatasetFilter, SeedDatasetLoadTime, SeedDatasetMetadata from pyrit.models.seeds import SeedDataset logger = logging.getLogger(__name__) @@ -25,9 +27,14 @@ class SeedDatasetProvider(ABC): Subclasses must implement: - fetch_dataset(): Fetch and return the dataset as a SeedDataset - dataset_name property: Human-readable name for the dataset + + All subclasses also have a _metadata property that is optional to make + dataset addition easier, but failing to complete it makes downstream + analysis more difficult. """ _registry: dict[str, type["SeedDatasetProvider"]] = {} + load_time: SeedDatasetLoadTime = SeedDatasetLoadTime.UNINITIALIZED def __init_subclass__(cls, **kwargs: Any) -> None: """ @@ -67,6 +74,19 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: Exception: If the dataset cannot be fetched or processed. """ + async def _parse_metadata(self) -> Optional[SeedDatasetMetadata]: + """ + Parse provider-specific metadata into the shared schema. + + Subclasses can override this to source metadata from class attributes, + prompt files, or any other backing format. The default implementation + returns None, which means metadata is not available for this provider. + + Returns: + Optional[SeedDatasetMetadata]: Parsed metadata for this provider, or None. + """ + return None + @classmethod def get_all_providers(cls) -> dict[str, type["SeedDatasetProvider"]]: """ @@ -78,10 +98,13 @@ def get_all_providers(cls) -> dict[str, type["SeedDatasetProvider"]]: return cls._registry.copy() @classmethod - def get_all_dataset_names(cls) -> list[str]: + async def get_all_dataset_names_async(cls, filters: Optional[SeedDatasetFilter] = None) -> list[str]: """ Get the names of all registered datasets. + Args: + filters (Optional[SeedDatasetFilter]): List of filters to apply. + Returns: List[str]: List of dataset names from all registered providers. @@ -89,7 +112,7 @@ def get_all_dataset_names(cls) -> list[str]: ValueError: If no providers are registered or if providers cannot be instantiated. Example: - >>> names = SeedDatasetProvider.get_all_dataset_names() + >>> names = await SeedDatasetProvider.get_all_dataset_names_async() >>> print(f"Available datasets: {', '.join(names)}") """ dataset_names = set() @@ -97,11 +120,109 @@ def get_all_dataset_names(cls) -> list[str]: try: # Instantiate to get dataset name provider = provider_class() + + # Parser ensures a standard metadata format + metadata = await provider._parse_metadata() + + if filters: + # "all" bypasses metadata filtering and returns every dataset + if filters.has_all_tag: + dataset_names.add(provider.dataset_name) + continue + + # Datasets without metadata are skipped for all other filters + if not metadata: + continue + + # Filters detected but no match -> don't add this dataset + if not cls._match_filter_to_metadata(metadata=metadata, dataset_filter=filters): + continue + dataset_names.add(provider.dataset_name) except Exception as e: raise ValueError(f"Could not get dataset name from {provider_class.__name__}: {e}") from e return sorted(dataset_names) + @classmethod + def _match_filter_to_metadata(cls, metadata: SeedDatasetMetadata, dataset_filter: SeedDatasetFilter) -> bool: + """ + Match a dataset's metadata against filter criteria. + + A dataset matches if ANY criterion in filters.criteria matches (OR across + criteria). Within each criterion, ALL specified fields must match (AND + across fields). Within each field: + - strict_match=False: any overlap suffices (set intersection) + - strict_match=True: all filter values must be present (filter is subset) + + Special tags: + - "all": bypasses all filtering, returns True immediately. + - "default": without strict_match, matches if the dataset has "default" tag. + + Args: + metadata: The dataset's metadata. + dataset_filter: The user-provided filter. + + Returns: + Whether the metadata matches any criterion. + """ + # "all" always bypasses + if dataset_filter.has_all_tag: + return True + + return any( + cls._match_single_criterion(metadata=metadata, criterion=c, strict_match=dataset_filter.strict_match) + for c in dataset_filter.criteria + ) + + @classmethod + def _match_single_criterion( + cls, + *, + metadata: SeedDatasetMetadata, + criterion: SeedDatasetMetadata, + strict_match: bool, + ) -> bool: + """ + Match a single SeedDatasetMetadata criterion against dataset metadata. + + Args: + metadata: The dataset's real metadata. + criterion: A single filter criterion. + strict_match: Whether to require all filter values (AND) vs any overlap (OR). + + Returns: + Whether the metadata satisfies this criterion. + """ + # "default" shortcut (only without strict_match): + # When the filter asks for "default" and the dataset has "default" in its + # tags, match immediately. This lets "default" act as a curated-set marker + # that bypasses other filter axes. With strict_match, "default" is treated + # as a normal tag and must satisfy the full subset check. + if ( + not strict_match + and criterion.tags + and "default" in criterion.tags + and metadata.tags + and "default" in metadata.tags + ): + return True + + for field in dc_fields(SeedDatasetMetadata): + filter_vals = getattr(criterion, field.name) + meta_vals = getattr(metadata, field.name) + + if filter_vals is None or meta_vals is None: + continue + + if strict_match: + if filter_vals - meta_vals: + return False + else: + if not (filter_vals & meta_vals): + return False + + return True + @classmethod async def fetch_datasets_async( cls, @@ -141,7 +262,7 @@ async def fetch_datasets_async( """ # Validate dataset names if specified if dataset_names is not None: - available_names = cls.get_all_dataset_names() + available_names = await cls.get_all_dataset_names_async() invalid_names = [name for name in dataset_names if name not in available_names] if invalid_names: raise ValueError(f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}") diff --git a/pyrit/datasets/seed_datasets/seed_metadata.py b/pyrit/datasets/seed_datasets/seed_metadata.py new file mode 100644 index 0000000000..bf481229da --- /dev/null +++ b/pyrit/datasets/seed_datasets/seed_metadata.py @@ -0,0 +1,254 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging +from dataclasses import dataclass, fields +from enum import Enum +from typing import Any, ClassVar, Literal, Optional + +logger = logging.getLogger(__name__) + + +""" +Contains metadata objects for datasets (i.e. subclasses of SeedDatasetProvider). + +SeedDatasetMetadata is the canonical schema for dataset metadata. All fields are +optional sets. Parsers wrap singular values into single-element sets; filters may +have multiple values per field for OR matching. + +SeedDatasetFilter accepts either flat kwargs (simple use) or a list of +SeedDatasetMetadata criteria (composable use). Internally it always stores +criteria as list[SeedDatasetMetadata]. +""" + +SeedDatasetSizeCategory = Literal["tiny", "small", "medium", "large", "huge"] +# tiny (<10), small (10-99), medium (100-499), large (500-4999), huge (5000+) + +SeedDatasetSourceType = Literal["remote", "local"] + + +class SeedDatasetLoadTime(Enum): + """Approximate time to load a dataset. Used to skip slow datasets in fast runs.""" + + FAST = "fast" + NORMAL = "normal" + SLOW = "slow" + UNINITIALIZED = "uninitialized" + + +@dataclass(frozen=True) +class SeedDatasetMetadata: + """ + Unified schema for dataset metadata and filter criteria. + + All fields are optional sets. When used for real dataset metadata, parsers + wrap singular values into single-element sets. When used as filter criteria, + multiple values per field express "match any of these" (OR within axis). + """ + + # All fields are optional sets to support both real metadata (single-element) + # and filter criteria (multi-element). SINGULAR_FIELDS enforces that parsers + # only produce single-element sets for size and source_type. + tags: Optional[set[str]] = None + size: Optional[set[str]] = None + modalities: Optional[set[str]] = None + source_type: Optional[set[str]] = None + load_time: Optional[set[SeedDatasetLoadTime]] = None + harm_categories: Optional[set[str]] = None + + # Fields that must have at most 1 element in real dataset metadata. + SINGULAR_FIELDS: ClassVar[frozenset[str]] = frozenset({"size", "source_type"}) + + @staticmethod + def _coerce_metadata_values(*, raw_metadata: dict[str, Any]) -> dict[str, Any]: + """ + Convert raw values (from YAML or class attributes) into sets for SeedDatasetMetadata. + + Applies .lower().strip() normalization to all string values. Handles str, + list, set inputs for all fields, plus SeedDatasetLoadTime enum for load_time. + + Args: + raw_metadata: Dictionary of field names to raw values. + + Returns: + Dictionary with all values coerced to sets. + """ + coerced: dict[str, Any] = {} + for key, value in raw_metadata.items(): + if key == "load_time": + if isinstance(value, str): + coerced[key] = {SeedDatasetLoadTime(value.strip().lower())} + elif isinstance(value, SeedDatasetLoadTime): + coerced[key] = {value} + else: + logger.warning( + f"Skipping metadata field '{key}' with unexpected type " + f"{type(value).__name__} (value: {value!r})" + ) + elif isinstance(value, (list, set)): + coerced[key] = {v.strip().lower() if isinstance(v, str) else v for v in value} + elif isinstance(value, str): + coerced[key] = {value.strip().lower()} + else: + logger.warning( + f"Skipping metadata field '{key}' with unexpected type {type(value).__name__} (value: {value!r})" + ) + return coerced + + @staticmethod + def _validate_singular_fields(*, metadata: "SeedDatasetMetadata") -> None: + """ + Validate that singular fields have at most 1 element. + + Call this from parsers when constructing real dataset metadata, NOT when + constructing filter criteria where multiple values are valid. + + Raises: + ValueError: If a singular field has more than 1 element. + """ + for field_name in SeedDatasetMetadata.SINGULAR_FIELDS: + value = getattr(metadata, field_name) + if value is not None and len(value) > 1: + raise ValueError( + f"Metadata field '{field_name}' must have at most 1 value " + f"for real dataset metadata, got {len(value)}: {value}" + ) + + +class SeedDatasetFilter: + """ + Filter for discovering datasets by metadata criteria. + + Supports two construction patterns: + + Simple (flat kwargs — wraps into a single SeedDatasetMetadata criterion):: + + f = SeedDatasetFilter(tags={"safety"}, size={"small", "large"}) + + Composed (explicit criteria list — OR across criteria, AND within each):: + + f = SeedDatasetFilter(criteria=[ + SeedDatasetMetadata(size={"small"}, modalities={"text"}), + SeedDatasetMetadata(size={"large"}, modalities={"image"}), + ]) + + Passing both flat kwargs and criteria raises ValueError. + + Special tags: + - "all": Returns every dataset, ignores all other fields. This tag will + override anything else you pass to the filter object. + - "default": Matches datasets with "default" in their tags. With + strict_match=True, loses its shortcut and is treated as a normal tag. + + Args: + criteria: Explicit list of SeedDatasetMetadata to OR-match against. + strict_match: If True, within-axis matching uses AND (all filter values + must be present) instead of OR (any overlap suffices). + **kwargs: Flat metadata fields (tags, size, modalities, etc.) for simple use. + """ + + def __init__( + self, + *, + criteria: Optional[list[SeedDatasetMetadata]] = None, + strict_match: bool = False, + **kwargs: Any, + ) -> None: + """ + Construct a filter from flat metadata kwargs or an explicit criteria list. + + Simple usage (flat kwargs — wraps into a single SeedDatasetMetadata):: + + f = SeedDatasetFilter(tags={"safety"}, size={"small", "large"}) + + Composed usage (explicit criteria — OR across criteria, AND within each):: + + f = SeedDatasetFilter(criteria=[ + SeedDatasetMetadata(size={"small"}, modalities={"text"}), + SeedDatasetMetadata(size={"large"}, modalities={"image"}), + ]) + + Args: + criteria: Explicit list of SeedDatasetMetadata to OR-match against. + strict_match: If True, within-axis matching uses AND instead of OR. + **kwargs: Flat metadata fields passed to SeedDatasetMetadata. + + Raises: + ValueError: If both criteria and flat kwargs are provided. + """ + if criteria is not None and kwargs: + raise ValueError("Cannot pass both 'criteria' and flat metadata kwargs. Use one or the other.") + + if criteria is not None: + self.criteria = criteria + elif kwargs: + self.criteria = [SeedDatasetMetadata(**kwargs)] + else: + self.criteria = [SeedDatasetMetadata()] + + # Normalize tags: strip whitespace and lowercase so "ALL", " All ", etc. work + def _normalize_criterion(c: SeedDatasetMetadata) -> SeedDatasetMetadata: + normalized = { + f.name: ({t.strip().lower() for t in vals} if f.name == "tags" and vals is not None else vals) + for f, vals in zip(fields(c), (getattr(c, f.name) for f in fields(c)), strict=True) + } + return SeedDatasetMetadata(**normalized) + + self.criteria = [_normalize_criterion(c) for c in self.criteria] + + self.strict_match = strict_match + self._validate() + + def _validate(self) -> None: + """ + Warn about contradictory filter configurations. + + Raises: + ValueError: If strict_match is True and any criterion has multiple + values for a singular field (size, source_type). + """ + # strict_match with multi-valued singular fields is logically impossible. + # A dataset can't be both "small" AND "large" — these are mutually exclusive. + if self.strict_match: + for criterion in self.criteria: + for field_name in SeedDatasetMetadata.SINGULAR_FIELDS: + value = getattr(criterion, field_name) + if value is not None and len(value) > 1: + raise ValueError( + f"strict_match=True with multiple values for '{field_name}' " + f"({value}) is logically impossible — a dataset can only have " + f"one {field_name}. Mutually exclusive fields: " + f"{SeedDatasetMetadata.SINGULAR_FIELDS}. " + f"Use strict_match=False for OR matching, " + f"or split into separate criteria." + ) + + if not self.has_all_tag: + return + + all_criterion = next(c for c in self.criteria if c.tags and "all" in c.tags) + + if all_criterion.tags and len(all_criterion.tags) > 1: + logger.warning( + "Filter has 'all' combined with other tags %s. " + "'all' bypasses all filtering — other tags will be ignored.", + all_criterion.tags - {"all"}, + ) + if self.strict_match: + logger.warning( + "Filter has 'all' with strict_match=True. 'all' bypasses all filtering — strict_match has no effect." + ) + other_fields = [ + f.name for f in fields(all_criterion) if f.name != "tags" and getattr(all_criterion, f.name) is not None + ] + if other_fields: + logger.warning( + "Filter has 'all' combined with other filter fields %s. " + "'all' bypasses all filtering — other fields will be ignored.", + other_fields, + ) + + @property + def has_all_tag(self) -> bool: + """True if any criterion has the 'all' tag.""" + return any(c.tags and "all" in c.tags for c in self.criteria) diff --git a/tests/integration/datasets/test_seed_dataset_provider_integration.py b/tests/integration/datasets/test_seed_dataset_provider_integration.py index a3ede4beab..6ea7a3f7cd 100644 --- a/tests/integration/datasets/test_seed_dataset_provider_integration.py +++ b/tests/integration/datasets/test_seed_dataset_provider_integration.py @@ -2,12 +2,19 @@ # Licensed under the MIT license. import logging +import textwrap +from pathlib import Path +from unittest.mock import patch import pytest from pyrit.datasets import SeedDatasetProvider +from pyrit.datasets.seed_datasets.local.local_dataset_loader import _LocalDatasetLoader from pyrit.datasets.seed_datasets.remote import _VLSUMultimodalDataset -from pyrit.models import SeedDataset +from pyrit.datasets.seed_datasets.seed_metadata import ( + SeedDatasetFilter, +) +from pyrit.models import SeedDataset, SeedPrompt logger = logging.getLogger(__name__) @@ -55,3 +62,658 @@ async def test_fetch_dataset_integration(self, name, provider_cls): except Exception as e: pytest.fail(f"Failed to fetch dataset from {name}: {str(e)}") + + +class TestRemoteFilteringIntegration: + """ + Integration test for remote dataset filtering. + + Uses a mocked remote provider with class-level metadata attributes to + validate the full flow: metadata population, filter matching, and + get_all_dataset_names output. + """ + + def _make_remote_provider_cls( + self, + *, + name: str, + tags: set, + size: str, + modalities: list, + harm_categories: list, + ) -> type: + """Build a minimal concrete SeedDatasetProvider with class-level metadata.""" + from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import _RemoteDatasetLoader + + captured_name = name + + async def _fetch_dataset(self, *, cache=True): + return SeedDataset( + seeds=[SeedPrompt(value="x", data_type="text")], + dataset_name=captured_name, + ) + + attrs = { + "tags": tags, + "size": size, + "modalities": modalities, + "harm_categories": harm_categories, + "should_register": False, + "__module__": __name__, + # Concrete implementations satisfy ABC requirements + "dataset_name": property(lambda self: captured_name), + "fetch_dataset": _fetch_dataset, + "_fetch_from_url": lambda self, **kw: [], + } + + return type(f"_Mock_{name}", (_RemoteDatasetLoader,), attrs) + + def test_filter_matches_correct_remote_provider(self): + """Filter by size returns only providers that match.""" + large_cls = self._make_remote_provider_cls( + name="large_ds", + tags={"default"}, + size="large", + modalities=["text"], + harm_categories=["violence"], + ) + small_cls = self._make_remote_provider_cls( + name="small_ds", + tags={"default"}, + size="small", + modalities=["text"], + harm_categories=["cybercrime"], + ) + + with patch.dict( + SeedDatasetProvider._registry, + {"Large": large_cls, "Small": small_cls}, + clear=True, + ): + names = SeedDatasetProvider.get_all_dataset_names_async( + filters=SeedDatasetFilter(sizes=["large"]), + ) + assert names == ["large_ds"] + + def test_filter_all_tag_returns_everything(self): + """tags={'all'} bypasses filtering and returns every provider.""" + cls1 = self._make_remote_provider_cls( + name="ds_a", + tags={"safety"}, + size="tiny", + modalities=["text"], + harm_categories=[], + ) + cls2 = self._make_remote_provider_cls( + name="ds_b", + tags={"custom"}, + size="huge", + modalities=["image"], + harm_categories=["violence"], + ) + + with patch.dict( + SeedDatasetProvider._registry, + {"A": cls1, "B": cls2}, + clear=True, + ): + names = SeedDatasetProvider.get_all_dataset_names_async( + filters=SeedDatasetFilter(tags={"all"}), + ) + assert sorted(names) == ["ds_a", "ds_b"] + + def test_multi_axis_filter(self): + """Multiple filter axes are ANDed together.""" + cls1 = self._make_remote_provider_cls( + name="text_large", + tags={"default"}, + size="large", + modalities=["text"], + harm_categories=["violence"], + ) + cls2 = self._make_remote_provider_cls( + name="image_large", + tags={"default"}, + size="large", + modalities=["image"], + harm_categories=["violence"], + ) + + with patch.dict( + SeedDatasetProvider._registry, + {"TL": cls1, "IL": cls2}, + clear=True, + ): + names = SeedDatasetProvider.get_all_dataset_names_async( + filters=SeedDatasetFilter( + sizes=["large"], + modalities=["text"], + ), + ) + assert names == ["text_large"] + + +class TestLocalFilteringIntegration: + """ + Integration test for local dataset filtering. + + Creates real YAML prompt files on disk, registers them as local providers, + and validates the full flow through get_all_dataset_names with filters. + """ + + @staticmethod + def _make_local_cls(yaml_path: Path) -> type: + """Build a dynamic local provider class for a YAML file.""" + + def make_init(path: Path): + def init_fn(self): + _LocalDatasetLoader.__init__(self, file_path=path) + + return init_fn + + return type( + f"LocalTest_{yaml_path.stem}", + (_LocalDatasetLoader,), + {"__init__": make_init(yaml_path), "should_register": False, "__module__": __name__}, + ) + + def test_local_filter_by_size(self, tmp_path): + """Local YAML with size metadata is correctly coerced and filtered.""" + large_yaml = tmp_path / "large_ds.prompt" + large_yaml.write_text( + textwrap.dedent("""\ + dataset_name: large_local + size: large + harm_categories: + - violence + seeds: + - value: test + data_type: text + """) + ) + small_yaml = tmp_path / "small_ds.prompt" + small_yaml.write_text( + textwrap.dedent("""\ + dataset_name: small_local + size: small + harm_categories: + - cybercrime + seeds: + - value: test + data_type: text + """) + ) + + large_cls = self._make_local_cls(large_yaml) + small_cls = self._make_local_cls(small_yaml) + + with patch.dict( + SeedDatasetProvider._registry, + {"Large": large_cls, "Small": small_cls}, + clear=True, + ): + names = SeedDatasetProvider.get_all_dataset_names_async( + filters=SeedDatasetFilter(sizes=["large"]), + ) + # dataset_name falls back to file stem when SeedDataset.from_yaml_file + # rejects extra keys like "size" during __init__ pre-loading + assert names == ["large_ds"] + + def test_local_filter_by_tags(self, tmp_path): + """Local YAML tags (list) are coerced to set for intersection.""" + yaml_path = tmp_path / "tagged.prompt" + yaml_path.write_text( + textwrap.dedent("""\ + dataset_name: tagged_local + tags: + - safety + - default + harm_categories: + - violence + seeds: + - value: test + data_type: text + """) + ) + cls = self._make_local_cls(yaml_path) + + with patch.dict( + SeedDatasetProvider._registry, + {"Tagged": cls}, + clear=True, + ): + # dataset_name falls back to file stem ("tagged") when + # SeedDataset.from_yaml_file rejects extra keys like "tags" + matched = SeedDatasetProvider.get_all_dataset_names_async( + filters=SeedDatasetFilter(tags={"safety"}), + ) + assert matched == ["tagged"] + + not_matched = SeedDatasetProvider.get_all_dataset_names_async( + filters=SeedDatasetFilter(tags={"unrelated"}), + ) + assert not_matched == [] + + def test_local_no_metadata_skipped(self, tmp_path): + """Local YAML without metadata fields is skipped when filters are provided.""" + yaml_path = tmp_path / "bare.prompt" + yaml_path.write_text( + textwrap.dedent("""\ + dataset_name: bare_local + seeds: + - value: test + data_type: text + """) + ) + cls = self._make_local_cls(yaml_path) + + with patch.dict( + SeedDatasetProvider._registry, + {"Bare": cls}, + clear=True, + ): + # Without filters, the dataset is included + all_names = SeedDatasetProvider.get_all_dataset_names_async() + assert "bare_local" in all_names + + # With filters, it's skipped (no metadata to match against) + filtered = SeedDatasetProvider.get_all_dataset_names_async( + filters=SeedDatasetFilter(tags={"safety"}), + ) + assert filtered == [] + + +class TestEndToEndLocalDatasetWorkflow: + """ + End-to-end integration test that simulates a user workflow: + 1. Write a temporary local dataset to disk + 2. Register it into the provider registry + 3. Discover it via filtered get_all_dataset_names + 4. Fetch the actual dataset by name + 5. Verify metadata is accessible and correct + """ + + @staticmethod + def _make_local_cls(yaml_path: Path) -> type: + """Build a dynamic local provider class for a YAML file.""" + + def make_init(path: Path): + def init_fn(self): + _LocalDatasetLoader.__init__(self, file_path=path) + + return init_fn + + return type( + f"LocalTest_{yaml_path.stem}", + (_LocalDatasetLoader,), + {"__init__": make_init(yaml_path), "should_register": False, "__module__": __name__}, + ) + + @pytest.mark.asyncio + async def test_user_discovers_and_fetches_filtered_dataset(self, tmp_path): + """ + Simulate a user who wants small text datasets about cybercrime: + 1. Filter returns the matching dataset name + 2. Fetch by that name succeeds and returns correct seeds + 3. Metadata on the provider is accessible and consistent + """ + # --- Step 0: Write two datasets to disk --- + matching_yaml = tmp_path / "cyber_small.prompt" + matching_yaml.write_text( + textwrap.dedent("""\ + dataset_name: cyber_prompts + harm_categories: + - cybercrime + seeds: + - value: Explain phishing techniques + data_type: text + - value: Describe SQL injection + data_type: text + """) + ) + unrelated_yaml = tmp_path / "violence_large.prompt" + unrelated_yaml.write_text( + textwrap.dedent("""\ + dataset_name: violence_prompts + harm_categories: + - violence + seeds: + - value: Describe a violent scenario + data_type: text + """) + ) + + matching_cls = self._make_local_cls(matching_yaml) + unrelated_cls = self._make_local_cls(unrelated_yaml) + + with patch.dict( + SeedDatasetProvider._registry, + {"Cyber": matching_cls, "Violence": unrelated_cls}, + clear=True, + ): + # --- Step 1: User filters by harm_categories --- + names = SeedDatasetProvider.get_all_dataset_names_async( + filters=SeedDatasetFilter(harm_categories=["cybercrime"]), + ) + assert len(names) == 1 + dataset_name = names[0] + + # --- Step 2: User fetches the dataset by name --- + datasets = await SeedDatasetProvider.fetch_datasets_async( + dataset_names=[dataset_name], + ) + assert len(datasets) == 1 + dataset = datasets[0] + assert len(dataset.seeds) == 2 + assert dataset.seeds[0].value == "Explain phishing techniques" + assert dataset.seeds[1].value == "Describe SQL injection" + + # --- Step 3: User inspects metadata --- + provider = matching_cls() + metadata = provider._parse_metadata() + assert metadata is not None + assert metadata.harm_categories == ["cybercrime"] + + @pytest.mark.asyncio + async def test_user_fetches_unfiltered(self, tmp_path): + """ + Without filters, get_all_dataset_names returns everything, + and fetch_datasets_async retrieves all of them. + """ + ds1 = tmp_path / "ds_one.prompt" + ds1.write_text( + textwrap.dedent("""\ + dataset_name: dataset_one + seeds: + - value: prompt one + data_type: text + """) + ) + ds2 = tmp_path / "ds_two.prompt" + ds2.write_text( + textwrap.dedent("""\ + dataset_name: dataset_two + seeds: + - value: prompt two + data_type: text + """) + ) + + cls1 = self._make_local_cls(ds1) + cls2 = self._make_local_cls(ds2) + + with patch.dict( + SeedDatasetProvider._registry, + {"One": cls1, "Two": cls2}, + clear=True, + ): + names = SeedDatasetProvider.get_all_dataset_names_async() + assert len(names) == 2 + + datasets = await SeedDatasetProvider.fetch_datasets_async() + assert len(datasets) == 2 + fetched_names = sorted(d.dataset_name for d in datasets) + assert fetched_names == ["dataset_one", "dataset_two"] + + +class TestAllTagBypassIntegration: + """ + Integration tests for the tags={'all'} bypass pattern. + + The 'all' tag is a special escape hatch that returns every registered + dataset regardless of metadata presence or other filter axes. + """ + + @staticmethod + def _make_local_cls(yaml_path: Path) -> type: + """Build a dynamic local provider class for a YAML file.""" + + def make_init(path: Path): + def init_fn(self): + _LocalDatasetLoader.__init__(self, file_path=path) + + return init_fn + + return type( + f"LocalTest_{yaml_path.stem}", + (_LocalDatasetLoader,), + {"__init__": make_init(yaml_path), "should_register": False, "__module__": __name__}, + ) + + def test_all_tag_includes_datasets_without_metadata(self, tmp_path): + """ + A dataset whose YAML has no metadata fields at all is normally + skipped when filters are present. tags={'all'} overrides that. + """ + bare_yaml = tmp_path / "bare.prompt" + bare_yaml.write_text( + textwrap.dedent("""\ + dataset_name: bare_dataset + seeds: + - value: bare prompt + data_type: text + """) + ) + cls = self._make_local_cls(bare_yaml) + + with patch.dict( + SeedDatasetProvider._registry, + {"Bare": cls}, + clear=True, + ): + # Normal filter skips it + filtered = SeedDatasetProvider.get_all_dataset_names_async( + filters=SeedDatasetFilter(tags={"safety"}), + ) + assert filtered == [] + + # 'all' includes it + all_names = SeedDatasetProvider.get_all_dataset_names_async( + filters=SeedDatasetFilter(tags={"all"}), + ) + assert "bare_dataset" in all_names + + def test_all_tag_ignores_other_filter_axes(self, tmp_path): + """ + tags={'all'} returns everything even when other filter axes + would exclude datasets. + """ + small_yaml = tmp_path / "small.prompt" + small_yaml.write_text( + textwrap.dedent("""\ + dataset_name: small_dataset + size: small + harm_categories: + - cybercrime + seeds: + - value: small prompt + data_type: text + """) + ) + cls = self._make_local_cls(small_yaml) + + with patch.dict( + SeedDatasetProvider._registry, + {"Small": cls}, + clear=True, + ): + # Size filter alone would exclude it + size_filtered = SeedDatasetProvider.get_all_dataset_names_async( + filters=SeedDatasetFilter(sizes=["large"]), + ) + assert size_filtered == [] + + # 'all' tag overrides the size filter + all_names = SeedDatasetProvider.get_all_dataset_names_async( + filters=SeedDatasetFilter(tags={"all"}, sizes=["large"]), + ) + assert "small" in all_names + + def test_all_tag_with_mixed_metadata_and_bare_datasets(self, tmp_path): + """ + With a mix of metadata-rich and metadata-bare datasets, + tags={'all'} returns all of them. + """ + rich_yaml = tmp_path / "rich.prompt" + rich_yaml.write_text( + textwrap.dedent("""\ + dataset_name: rich_dataset + harm_categories: + - violence + tags: + - safety + seeds: + - value: rich prompt + data_type: text + """) + ) + bare_yaml = tmp_path / "bare.prompt" + bare_yaml.write_text( + textwrap.dedent("""\ + dataset_name: bare_dataset + seeds: + - value: bare prompt + data_type: text + """) + ) + + rich_cls = self._make_local_cls(rich_yaml) + bare_cls = self._make_local_cls(bare_yaml) + + with patch.dict( + SeedDatasetProvider._registry, + {"Rich": rich_cls, "Bare": bare_cls}, + clear=True, + ): + all_names = SeedDatasetProvider.get_all_dataset_names_async( + filters=SeedDatasetFilter(tags={"all"}), + ) + assert len(all_names) == 2 + assert "bare_dataset" in all_names + + +class TestHarmbenchMetadataInScenario: + """ + Regression tests verifying that metadata changes don't break the dataset + loading flow for scenarios that use datasets with explicit metadata. + + RedTeamAgent is the primary scenario that uses harmbench — the only remote + dataset with class-level metadata (tags, size, modalities, harm_categories). + These tests verify the full pipeline: metadata parsing → dataset loading → + memory storage → scenario initialization. + """ + + @pytest.mark.asyncio + async def test_harmbench_metadata_parses_correctly(self): + """HarmBench's class-level metadata is correctly parsed into sets.""" + from pyrit.datasets.seed_datasets.remote.harmbench_dataset import _HarmBenchDataset + + loader = _HarmBenchDataset() + metadata = await loader._parse_metadata() + + assert metadata is not None + assert isinstance(metadata.tags, set) + assert "default" in metadata.tags + assert "safety" in metadata.tags + assert metadata.size == {"large"} + assert metadata.modalities == {"text"} + assert isinstance(metadata.harm_categories, set) + assert "cybercrime" in metadata.harm_categories + + @pytest.mark.asyncio + async def test_harmbench_discoverable_via_filter(self): + """HarmBench can be found via tag and harm_category filters.""" + names_by_safety = await SeedDatasetProvider.get_all_dataset_names_async( + filters=SeedDatasetFilter(tags={"safety"}), + ) + assert "harmbench" in names_by_safety + + names_by_harm = await SeedDatasetProvider.get_all_dataset_names_async( + filters=SeedDatasetFilter(harm_categories={"cybercrime"}), + ) + assert "harmbench" in names_by_harm + + @pytest.mark.asyncio + async def test_harmbench_loads_and_stores_in_memory(self): + """HarmBench can be fetched and stored in memory for scenario use.""" + from pyrit.memory import CentralMemory + from pyrit.setup import initialize_pyrit_async + + await initialize_pyrit_async(memory_db_type="InMemory") + + datasets = await SeedDatasetProvider.fetch_datasets_async( + dataset_names=["harmbench"], + ) + assert len(datasets) == 1 + assert datasets[0].dataset_name == "harmbench" + assert len(datasets[0].seeds) > 0 + + memory = CentralMemory.get_memory_instance() + await memory.add_seed_datasets_to_memory_async( + datasets=datasets, + added_by="test", + ) + + # Verify seeds are queryable from memory (this is what scenarios do) + seed_groups = memory.get_seed_groups(dataset_name="harmbench") + assert seed_groups is not None + assert len(list(seed_groups)) > 0 + + @pytest.mark.asyncio + async def test_red_team_agent_initializes_with_harmbench(self): + """ + RedTeamAgent can initialize with harmbench dataset loaded in memory. + + This is the critical regression test: if metadata changes break the + parsing/coercion/filtering pipeline, this test will fail during + scenario initialization when it tries to load seed groups from memory. + """ + from unittest.mock import MagicMock + + from pyrit.executor.attack.core.attack_config import AttackScoringConfig + from pyrit.memory import CentralMemory + from pyrit.prompt_target import TextTarget + from pyrit.scenario.scenarios.foundry.red_team_agent import ( + FoundryStrategy, + RedTeamAgent, + ) + from pyrit.score.true_false.true_false_scorer import TrueFalseScorer + from pyrit.setup import initialize_pyrit_async + + await initialize_pyrit_async(memory_db_type="InMemory") + + # Load harmbench into memory + datasets = await SeedDatasetProvider.fetch_datasets_async( + dataset_names=["harmbench"], + ) + memory = CentralMemory.get_memory_instance() + await memory.add_seed_datasets_to_memory_async( + datasets=datasets, + added_by="test", + ) + + # Mock scorer to avoid Azure dependency + mock_scorer = MagicMock(spec=TrueFalseScorer) + mock_scorer.get_identifier.return_value = {"__type__": "MockScorer"} + + target = TextTarget() + rta = RedTeamAgent( + adversarial_chat=target, + attack_scoring_config=AttackScoringConfig(objective_scorer=mock_scorer), + include_baseline=False, + ) + + # This is the critical call — it loads seed groups from memory + # and builds atomic attacks. If metadata broke the pipeline, + # this would raise ValueError about missing seed_groups. + await rta.initialize_async( + objective_target=target, + max_concurrency=1, + scenario_strategies=[FoundryStrategy.Base64], + ) + + # Verify the scenario got objectives from harmbench + attacks = rta._atomic_attacks + assert len(attacks) > 0 + for attack in attacks: + assert len(attack.objectives) > 0 diff --git a/tests/unit/datasets/test_seed_dataset_metadata.py b/tests/unit/datasets/test_seed_dataset_metadata.py new file mode 100644 index 0000000000..a5a1c01084 --- /dev/null +++ b/tests/unit/datasets/test_seed_dataset_metadata.py @@ -0,0 +1,285 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for metadata components related to SeedDatasetProvider. +""" + +import pytest + +from pyrit.datasets.seed_datasets.seed_metadata import ( + SeedDatasetFilter, + SeedDatasetLoadTime, + SeedDatasetMetadata, +) + + +class TestMetadataLifecycle: + """Test that the metadata object can be created with different subsets of values.""" + + def test_has_no_values(self): + metadata = SeedDatasetMetadata() + assert metadata.tags is None + assert metadata.size is None + assert metadata.modalities is None + assert metadata.source_type is None + assert metadata.load_time is None + assert metadata.harm_categories is None + + def test_has_some_values(self): + metadata = SeedDatasetMetadata(tags={"safety"}, size={"large"}) + assert metadata.tags == {"safety"} + assert metadata.size == {"large"} + assert metadata.modalities is None + + def test_has_all_values(self): + metadata = SeedDatasetMetadata( + tags={"default", "safety"}, + size={"medium"}, + modalities={"text", "image"}, + source_type={"remote"}, + load_time={SeedDatasetLoadTime.FAST}, + harm_categories={"violence", "illegal"}, + ) + assert metadata.tags == {"default", "safety"} + assert metadata.size == {"medium"} + assert len(metadata.modalities) == 2 + assert metadata.source_type == {"remote"} + assert SeedDatasetLoadTime.FAST in metadata.load_time + assert metadata.harm_categories == {"violence", "illegal"} + + +class TestFilterLifecycle: + """Test that the filter object wraps metadata correctly.""" + + def test_has_no_values(self): + f = SeedDatasetFilter() + c = f.criteria[0] + assert c.tags is None + assert c.size is None + + def test_has_some_values(self): + f = SeedDatasetFilter(size={"large"}) + assert f.criteria[0].size == {"large"} + assert f.criteria[0].tags is None + + def test_has_all_values(self): + f = SeedDatasetFilter( + tags={"default"}, + size={"small", "medium"}, + modalities={"text"}, + source_type={"remote"}, + load_time={SeedDatasetLoadTime.FAST}, + harm_categories={"violence"}, + ) + c = f.criteria[0] + assert c.tags == {"default"} + assert len(c.size) == 2 + assert c.modalities == {"text"} + + def test_filter_allows_multiple_sizes(self): + """Filters can have multiple values for singular fields like size.""" + f = SeedDatasetFilter(size={"small", "medium", "large"}) + assert len(f.criteria[0].size) == 3 + + +class TestMetadataProperties: + """Test that the metadata fields populate correctly.""" + + def test_size_value(self): + for size in ["tiny", "small", "medium", "large", "huge"]: + metadata = SeedDatasetMetadata(size={size}) + assert size in metadata.size + + def test_load_time_value(self): + for lt in SeedDatasetLoadTime: + metadata = SeedDatasetMetadata(load_time={lt}) + assert lt in metadata.load_time + + def test_source_value(self): + for source_type in ["remote", "local"]: + metadata = SeedDatasetMetadata(source_type={source_type}) + assert source_type in metadata.source_type + + def test_modality_value(self): + for modality in ["text", "image", "video", "audio"]: + metadata = SeedDatasetMetadata(modalities={modality}) + assert modality in metadata.modalities + + def test_tags_value(self): + metadata = SeedDatasetMetadata(tags={"safety", "default", "custom"}) + assert "safety" in metadata.tags + + def test_harm_categories_value(self): + metadata = SeedDatasetMetadata(harm_categories={"violence", "cybercrime"}) + assert "violence" in metadata.harm_categories + + +class TestMetadataCoercion: + """Test that _coerce_metadata_values normalizes raw values into sets.""" + + def test_tags_list_coerced_to_set(self): + result = SeedDatasetMetadata._coerce_metadata_values(raw_metadata={"tags": ["safety", "default"]}) + assert result["tags"] == {"safety", "default"} + assert isinstance(result["tags"], set) + + def test_tags_string_coerced_to_set(self): + result = SeedDatasetMetadata._coerce_metadata_values(raw_metadata={"tags": "safety"}) + assert result["tags"] == {"safety"} + + def test_tags_normalized_lower_strip(self): + result = SeedDatasetMetadata._coerce_metadata_values(raw_metadata={"tags": [" Safety ", " DEFAULT"]}) + assert result["tags"] == {"safety", "default"} + + def test_size_coerced_to_set(self): + result = SeedDatasetMetadata._coerce_metadata_values(raw_metadata={"size": " Large "}) + assert result["size"] == {"large"} + assert isinstance(result["size"], set) + + def test_source_type_coerced_to_set(self): + result = SeedDatasetMetadata._coerce_metadata_values(raw_metadata={"source_type": " Remote "}) + assert result["source_type"] == {"remote"} + + def test_load_time_coerced_to_enum_set(self): + result = SeedDatasetMetadata._coerce_metadata_values(raw_metadata={"load_time": "fast"}) + assert result["load_time"] == {SeedDatasetLoadTime.FAST} + assert isinstance(result["load_time"], set) + + def test_load_time_normalized_strip_lower(self): + result = SeedDatasetMetadata._coerce_metadata_values(raw_metadata={"load_time": " Slow "}) + assert result["load_time"] == {SeedDatasetLoadTime.SLOW} + + def test_modalities_list_coerced_to_set(self): + result = SeedDatasetMetadata._coerce_metadata_values(raw_metadata={"modalities": ["Text", " IMAGE "]}) + assert result["modalities"] == {"text", "image"} + + def test_modalities_string_coerced_to_set(self): + result = SeedDatasetMetadata._coerce_metadata_values(raw_metadata={"modalities": "text"}) + assert result["modalities"] == {"text"} + + def test_harm_categories_list_coerced_to_set(self): + result = SeedDatasetMetadata._coerce_metadata_values( + raw_metadata={"harm_categories": ["Violence", " Cybercrime "]} + ) + assert result["harm_categories"] == {"violence", "cybercrime"} + + def test_harm_categories_string_coerced_to_set(self): + result = SeedDatasetMetadata._coerce_metadata_values(raw_metadata={"harm_categories": "violence"}) + assert result["harm_categories"] == {"violence"} + + def test_unknown_type_skipped_with_warning(self, caplog): + result = SeedDatasetMetadata._coerce_metadata_values(raw_metadata={"tags": 12345}) + assert "tags" not in result + assert "Skipping metadata field" in caplog.text + + +class TestSingularFieldValidation: + """Test that singular fields (size, source_type) are validated.""" + + def test_singular_size_passes(self): + metadata = SeedDatasetMetadata(size={"large"}) + SeedDatasetMetadata._validate_singular_fields(metadata=metadata) + + def test_singular_source_type_passes(self): + metadata = SeedDatasetMetadata(source_type={"remote"}) + SeedDatasetMetadata._validate_singular_fields(metadata=metadata) + + def test_multiple_sizes_fails(self): + metadata = SeedDatasetMetadata(size={"small", "large"}) + with pytest.raises(ValueError, match="size"): + SeedDatasetMetadata._validate_singular_fields(metadata=metadata) + + def test_multiple_source_types_fails(self): + metadata = SeedDatasetMetadata(source_type={"remote", "local"}) + with pytest.raises(ValueError, match="source_type"): + SeedDatasetMetadata._validate_singular_fields(metadata=metadata) + + def test_none_fields_pass(self): + metadata = SeedDatasetMetadata() + SeedDatasetMetadata._validate_singular_fields(metadata=metadata) + + def test_multi_value_non_singular_fields_pass(self): + """Tags, modalities, harm_categories can have multiple values.""" + metadata = SeedDatasetMetadata( + tags={"safety", "default"}, + modalities={"text", "image"}, + harm_categories={"violence", "cybercrime"}, + ) + SeedDatasetMetadata._validate_singular_fields(metadata=metadata) + + +class TestStrictMatchSingularFieldValidation: + """ + Test that strict_match rejects multi-valued singular fields. + + A dataset can't be both "small" AND "large" — these are mutually exclusive. + strict_match=True with size={"small", "large"} is logically impossible + and should raise ValueError at filter construction time. + """ + + def test_strict_multi_size_raises(self): + """strict_match with size={'small', 'large'} is impossible.""" + with pytest.raises(ValueError, match="logically impossible"): + SeedDatasetFilter(size={"small", "large"}, strict_match=True) + + def test_strict_multi_source_type_raises(self): + """strict_match with source_type={'remote', 'local'} is impossible.""" + with pytest.raises(ValueError, match="logically impossible"): + SeedDatasetFilter(source_type={"remote", "local"}, strict_match=True) + + def test_strict_single_size_ok(self): + """strict_match with single size value is fine.""" + f = SeedDatasetFilter(size={"large"}, strict_match=True) + assert f.criteria[0].size == {"large"} + + def test_nonstrict_multi_size_ok(self): + """Without strict_match, multiple sizes is OR and perfectly valid.""" + f = SeedDatasetFilter(size={"small", "large"}, strict_match=False) + assert len(f.criteria[0].size) == 2 + + def test_strict_multi_tags_ok(self): + """Tags are NOT singular — strict with multiple tags is valid (AND).""" + f = SeedDatasetFilter(tags={"safety", "default"}, strict_match=True) + assert len(f.criteria[0].tags) == 2 + + def test_strict_multi_harm_categories_ok(self): + """harm_categories are NOT singular — strict with multiple is valid.""" + f = SeedDatasetFilter(harm_categories={"violence", "cybercrime"}, strict_match=True) + assert len(f.criteria[0].harm_categories) == 2 + + def test_strict_criteria_list_multi_size_raises(self): + """strict_match validation also applies to criteria=[] construction.""" + with pytest.raises(ValueError, match="logically impossible"): + SeedDatasetFilter( + criteria=[SeedDatasetMetadata(size={"small", "large"})], + strict_match=True, + ) + + +class TestFilterProperties: + """Test that the filter fields populate correctly via flat kwargs.""" + + def test_sizes_values(self): + f = SeedDatasetFilter(size={"small", "large"}) + assert "small" in f.criteria[0].size + assert "large" in f.criteria[0].size + + def test_load_times_values(self): + f = SeedDatasetFilter(load_time={SeedDatasetLoadTime.FAST, SeedDatasetLoadTime.SLOW}) + assert SeedDatasetLoadTime.FAST in f.criteria[0].load_time + + def test_sources_values(self): + f = SeedDatasetFilter(source_type={"local", "remote"}) + assert "local" in f.criteria[0].source_type + + def test_modalities_values(self): + f = SeedDatasetFilter(modalities={"text", "image"}) + assert "text" in f.criteria[0].modalities + + def test_tags_values(self): + f = SeedDatasetFilter(tags={"safety", "default"}) + assert "safety" in f.criteria[0].tags + + def test_harm_categories_values(self): + f = SeedDatasetFilter(harm_categories={"violence", "cybercrime"}) + assert "violence" in f.criteria[0].harm_categories diff --git a/tests/unit/datasets/test_seed_dataset_provider.py b/tests/unit/datasets/test_seed_dataset_provider.py index d61e2291a2..9b116288fe 100644 --- a/tests/unit/datasets/test_seed_dataset_provider.py +++ b/tests/unit/datasets/test_seed_dataset_provider.py @@ -1,13 +1,23 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import textwrap +from dataclasses import fields as dc_fields +from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import pytest +import yaml from pyrit.datasets import SeedDatasetProvider +from pyrit.datasets.seed_datasets.local.local_dataset_loader import _LocalDatasetLoader from pyrit.datasets.seed_datasets.remote.darkbench_dataset import _DarkBenchDataset from pyrit.datasets.seed_datasets.remote.harmbench_dataset import _HarmBenchDataset +from pyrit.datasets.seed_datasets.seed_metadata import ( + SeedDatasetFilter, + SeedDatasetLoadTime, + SeedDatasetMetadata, +) from pyrit.models import SeedDataset, SeedObjective, SeedPrompt @@ -60,29 +70,33 @@ async def fetch_dataset(self): assert "DynamicTestProvider" in providers assert providers["DynamicTestProvider"] == DynamicTestProvider - def test_get_all_dataset_names(self): + @pytest.mark.asyncio + async def test_get_all_dataset_names(self): """Test getting all dataset names.""" # Mock the registry to ensure deterministic results - mock_provider_cls = MagicMock() + mock_provider_cls = MagicMock(__name__="TestProvider") mock_provider_instance = mock_provider_cls.return_value mock_provider_instance.dataset_name = "test_dataset" + mock_provider_instance._parse_metadata = AsyncMock(return_value=None) with patch.dict(SeedDatasetProvider._registry, {"TestProvider": mock_provider_cls}, clear=True): - names = SeedDatasetProvider.get_all_dataset_names() + names = await SeedDatasetProvider.get_all_dataset_names_async() assert names == ["test_dataset"] @pytest.mark.asyncio async def test_fetch_datasets_async(self): """Test fetching all datasets.""" # Mock providers - mock_provider1 = MagicMock() + mock_provider1 = MagicMock(__name__="P1") mock_provider1.return_value.dataset_name = "d1" + mock_provider1.return_value._parse_metadata = AsyncMock(return_value=None) mock_provider1.return_value.fetch_dataset = AsyncMock( return_value=SeedDataset(seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") ) - mock_provider2 = MagicMock() + mock_provider2 = MagicMock(__name__="P2") mock_provider2.return_value.dataset_name = "d2" + mock_provider2.return_value._parse_metadata = AsyncMock(return_value=None) mock_provider2.return_value.fetch_dataset = AsyncMock( return_value=SeedDataset(seeds=[SeedPrompt(value="p2", data_type="text")], dataset_name="d2") ) @@ -94,14 +108,16 @@ async def test_fetch_datasets_async(self): @pytest.mark.asyncio async def test_fetch_datasets_async_with_filter(self): """Test fetching datasets with filter.""" - mock_provider1 = MagicMock() + mock_provider1 = MagicMock(__name__="P1") mock_provider1.return_value.dataset_name = "d1" + mock_provider1.return_value._parse_metadata = AsyncMock(return_value=None) mock_provider1.return_value.fetch_dataset = AsyncMock( return_value=SeedDataset(seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") ) - mock_provider2 = MagicMock() + mock_provider2 = MagicMock(__name__="P2") mock_provider2.return_value.dataset_name = "d2" + mock_provider2.return_value._parse_metadata = AsyncMock(return_value=None) mock_provider2.return_value.fetch_dataset = AsyncMock(side_effect=Exception("Should not be called")) with patch.dict(SeedDatasetProvider._registry, {"P1": mock_provider1, "P2": mock_provider2}, clear=True): @@ -112,14 +128,16 @@ async def test_fetch_datasets_async_with_filter(self): @pytest.mark.asyncio async def test_fetch_datasets_async_invalid_dataset_name(self): """Test that fetch_datasets_async raises ValueError for invalid dataset names.""" - mock_provider1 = MagicMock() + mock_provider1 = MagicMock(__name__="P1") mock_provider1.return_value.dataset_name = "d1" + mock_provider1.return_value._parse_metadata = AsyncMock(return_value=None) mock_provider1.return_value.fetch_dataset = AsyncMock( return_value=SeedDataset(seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") ) - mock_provider2 = MagicMock() + mock_provider2 = MagicMock(__name__="P2") mock_provider2.return_value.dataset_name = "d2" + mock_provider2.return_value._parse_metadata = AsyncMock(return_value=None) mock_provider2.return_value.fetch_dataset = AsyncMock( return_value=SeedDataset(seeds=[SeedPrompt(value="p2", data_type="text")], dataset_name="d2") ) @@ -236,3 +254,621 @@ async def test_fetch_dataset_with_custom_config(self, mock_darkbench_data): assert call_kwargs["dataset_name"] == "custom/darkbench" assert call_kwargs["config"] == "custom_config" assert call_kwargs["split"] == "test" + + +class TestMetadataParsingRemote: + """Test metadata parsing and filter matching for remote providers.""" + + @pytest.mark.asyncio + async def test_parse_metadata_from_class_attrs(self): + """Test _parse_metadata correctly extracts class-level metadata attributes.""" + loader = _HarmBenchDataset() + metadata = await loader._parse_metadata() + assert metadata is not None + assert metadata.tags == {"default", "safety"} + assert metadata.size == {"large"} + assert metadata.modalities == {"text"} + assert metadata.harm_categories == {"cybercrime", "illegal", "harmful", "chemical_biological", "harassment"} + # source_type is not declared as a class attribute on HarmBench; + # load_time inherits the UNINITIALIZED default from SeedDatasetProvider base class + assert metadata.source_type is None + assert metadata.load_time == {SeedDatasetLoadTime.UNINITIALIZED} + + def test_all_tag(self): + """Filter with tags={'all'} matches any metadata.""" + metadata = SeedDatasetMetadata(tags={"safety"}) + filters = SeedDatasetFilter(tags={"all"}) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) + + def test_tags(self): + """Tag filter uses set intersection.""" + metadata = SeedDatasetMetadata(tags={"safety", "default"}) + assert SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, dataset_filter=SeedDatasetFilter(tags={"safety"}) + ) + assert not SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, dataset_filter=SeedDatasetFilter(tags={"unrelated"}) + ) + + def test_sizes(self): + """Size filter checks membership in the sizes list.""" + metadata = SeedDatasetMetadata(size={"large"}) + assert SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, + dataset_filter=SeedDatasetFilter(size={"large", "huge"}), + ) + assert not SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, + dataset_filter=SeedDatasetFilter(size={"small"}), + ) + + def test_modalities(self): + """Modality filter uses set intersection.""" + metadata = SeedDatasetMetadata(modalities={"text", "image"}) + assert SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, + dataset_filter=SeedDatasetFilter(modalities={"text"}), + ) + assert not SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, + dataset_filter=SeedDatasetFilter(modalities={"audio"}), + ) + + def test_sources(self): + """Source filter checks membership.""" + metadata = SeedDatasetMetadata(source_type={"remote"}) + assert SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, + dataset_filter=SeedDatasetFilter(source_type={"remote"}), + ) + assert not SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, + dataset_filter=SeedDatasetFilter(source_type={"local"}), + ) + + def test_ranks(self): + """Load time filter checks membership.""" + metadata = SeedDatasetMetadata(load_time={SeedDatasetLoadTime.FAST}) + assert SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, + dataset_filter=SeedDatasetFilter(load_time={SeedDatasetLoadTime.FAST}), + ) + assert not SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, + dataset_filter=SeedDatasetFilter(load_time={SeedDatasetLoadTime.SLOW}), + ) + + def test_harm_categories(self): + """Harm category filter uses set intersection.""" + metadata = SeedDatasetMetadata(harm_categories={"violence", "cybercrime"}) + assert SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, + dataset_filter=SeedDatasetFilter(harm_categories={"violence"}), + ) + assert not SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, + dataset_filter=SeedDatasetFilter(harm_categories={"unrelated"}), + ) + + def test_empty_filter(self): + """Empty filter (all None) matches any metadata.""" + metadata = SeedDatasetMetadata(tags={"safety"}, size="large") + filters = SeedDatasetFilter() + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) + + @pytest.mark.asyncio + async def test_no_metadata(self): + """Provider without metadata is skipped when filters are applied.""" + mock_provider_cls = MagicMock(__name__="NoProv") + mock_provider_instance = mock_provider_cls.return_value + mock_provider_instance.dataset_name = "no_metadata" + mock_provider_instance._parse_metadata = AsyncMock(return_value=None) + + with patch.dict(SeedDatasetProvider._registry, {"NoProv": mock_provider_cls}, clear=True): + names = await SeedDatasetProvider.get_all_dataset_names_async(filters=SeedDatasetFilter(tags={"safety"})) + assert names == [] + + +class TestStrictMatchFiltering: + """Test strict_match behavior in SeedDatasetFilter.""" + + def test_strict_tags_all_present_matches(self): + """strict_match requires ALL filter tags to be present in metadata.""" + metadata = SeedDatasetMetadata(tags={"safety", "default", "curated"}) + filters = SeedDatasetFilter(tags={"safety", "default"}, strict_match=True) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) + + def test_strict_tags_partial_overlap_fails(self): + """strict_match rejects if metadata is missing any requested tag.""" + metadata = SeedDatasetMetadata(tags={"safety"}) + filters = SeedDatasetFilter(tags={"safety", "default"}, strict_match=True) + assert not SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) + + def test_nonstrict_tags_partial_overlap_passes(self): + """Without strict_match, any tag overlap is sufficient.""" + metadata = SeedDatasetMetadata(tags={"safety"}) + filters = SeedDatasetFilter(tags={"safety", "default"}, strict_match=False) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) + + def test_strict_harm_categories_all_present_matches(self): + """strict_match requires ALL filter harm_categories present in metadata.""" + metadata = SeedDatasetMetadata(harm_categories={"violence", "cybercrime", "illegal"}) + filters = SeedDatasetFilter(harm_categories={"violence", "cybercrime"}, strict_match=True) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) + + def test_strict_harm_categories_partial_fails(self): + """strict_match rejects if metadata is missing any requested harm category.""" + metadata = SeedDatasetMetadata(harm_categories={"violence"}) + filters = SeedDatasetFilter(harm_categories={"violence", "cybercrime"}, strict_match=True) + assert not SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) + + def test_strict_modalities_all_present_matches(self): + """strict_match requires ALL filter modalities present in metadata.""" + metadata = SeedDatasetMetadata(modalities={"text", "image", "audio"}) + filters = SeedDatasetFilter(modalities={"text", "image"}, strict_match=True) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) + + def test_strict_modalities_partial_fails(self): + """strict_match rejects if metadata is missing any requested modality.""" + metadata = SeedDatasetMetadata(modalities={"text"}) + filters = SeedDatasetFilter(modalities={"text", "image"}, strict_match=True) + assert not SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) + + def test_strict_size_unchanged(self): + """strict_match doesn't change size behavior — still membership check.""" + metadata = SeedDatasetMetadata(size={"large"}) + assert SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, + dataset_filter=SeedDatasetFilter(size={"large"}, strict_match=True), + ) + assert not SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, + dataset_filter=SeedDatasetFilter(size={"small"}, strict_match=True), + ) + + def test_strict_cross_axis_and(self): + """strict_match with multiple axes: all must match.""" + metadata = SeedDatasetMetadata( + tags={"safety", "default"}, + size="large", + harm_categories={"violence", "cybercrime"}, + ) + # Both axes satisfied + assert SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, + dataset_filter=SeedDatasetFilter( + tags={"safety"}, + harm_categories={"violence"}, + strict_match=True, + ), + ) + # harm_categories axis fails (missing "illegal") + assert not SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, + dataset_filter=SeedDatasetFilter( + tags={"safety"}, + harm_categories={"violence", "illegal"}, + strict_match=True, + ), + ) + + def test_strict_all_tag_still_bypasses(self): + """tags={'all'} still bypasses everything even with strict_match.""" + metadata = SeedDatasetMetadata(tags={"safety"}) + filters = SeedDatasetFilter(tags={"all"}, strict_match=True) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) + + def test_strict_default_plus_other_tags_requires_both(self): + """With strict_match, 'default' is a normal tag — all must be present.""" + metadata = SeedDatasetMetadata(tags={"default", "safety"}) + # Both present → match + assert SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, + dataset_filter=SeedDatasetFilter(tags={"default", "safety"}, strict_match=True), + ) + # Missing "curated" → reject + assert not SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, + dataset_filter=SeedDatasetFilter(tags={"default", "safety", "curated"}, strict_match=True), + ) + + def test_nonstrict_default_is_shortcut(self): + """Without strict_match, 'default' in filter tags is a shortcut match.""" + # Dataset has "default" tag → matches even without other filter tags present + metadata = SeedDatasetMetadata(tags={"default"}) + assert SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, + dataset_filter=SeedDatasetFilter(tags={"default", "nonexistent"}), + ) + + def test_strict_default_without_tag_on_dataset_fails(self): + """With strict_match, dataset must actually have 'default' in tags.""" + metadata = SeedDatasetMetadata(tags={"default", "safety"}, load_time=SeedDatasetLoadTime.FAST) + # Without strict, "default" shortcut matches because metadata has "default" tag + assert SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, + dataset_filter=SeedDatasetFilter(tags={"default", "curated"}), + ) + # With strict, ALL filter tags must be in metadata — "curated" is missing + assert not SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, + dataset_filter=SeedDatasetFilter(tags={"default", "curated"}, strict_match=True), + ) + + +class TestFilterValidation: + """Test that invalid or contradictory filter configurations are caught early.""" + + def test_all_with_strict_match_warns(self, caplog): + """'all' + strict_match logs a warning since strict has no effect.""" + SeedDatasetFilter(tags={"all"}, strict_match=True) + assert "strict_match has no effect" in caplog.text + + def test_all_with_other_tags_warns(self, caplog): + """'all' combined with other tags logs a warning.""" + SeedDatasetFilter(tags={"all", "safety"}) + assert "other tags will be ignored" in caplog.text + + def test_all_with_other_fields_warns(self, caplog): + """'all' combined with size/modality/etc logs a warning.""" + SeedDatasetFilter(tags={"all"}, size={"large"}) + assert "other fields will be ignored" in caplog.text + + def test_all_alone_no_warning(self, caplog): + """'all' by itself does not warn.""" + SeedDatasetFilter(tags={"all"}) + assert caplog.text == "" + + def test_all_bypasses_match_filter_entirely(self): + """'all' returns True from _match_filter regardless of metadata content.""" + # Metadata with no overlap to any filter field + metadata = SeedDatasetMetadata( + tags={"unrelated"}, + size="tiny", + modalities={"audio"}, + harm_categories={"nothing"}, + ) + # Filter that would normally reject everything about this metadata + filters = SeedDatasetFilter( + tags={"all"}, + size={"huge"}, + modalities={"text"}, + harm_categories={"violence"}, + strict_match=True, + ) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) + + @pytest.mark.asyncio + async def test_all_includes_datasets_without_metadata(self): + """'all' in get_all_dataset_names_async includes providers with no metadata.""" + mock_cls = MagicMock(__name__="BareProv") + mock_cls.return_value.dataset_name = "bare" + mock_cls.return_value._parse_metadata = AsyncMock(return_value=None) + + with patch.dict(SeedDatasetProvider._registry, {"Bare": mock_cls}, clear=True): + # Without 'all', bare datasets are skipped + names = await SeedDatasetProvider.get_all_dataset_names_async( + filters=SeedDatasetFilter(tags={"safety"}), + ) + assert names == [] + + # With 'all', bare datasets are included + names = await SeedDatasetProvider.get_all_dataset_names_async( + filters=SeedDatasetFilter(tags={"all"}), + ) + assert names == ["bare"] + + @pytest.mark.asyncio + async def test_all_skips_match_filter_call(self): + """'all' in get_all_dataset_names_async doesn't call _match_filter at all.""" + mock_cls = MagicMock(__name__="Prov") + mock_cls.return_value.dataset_name = "test" + mock_cls.return_value._parse_metadata = AsyncMock(return_value=None) + + with ( + patch.dict(SeedDatasetProvider._registry, {"P": mock_cls}, clear=True), + patch.object(SeedDatasetProvider, "_match_filter_to_metadata") as mock_match, + ): + await SeedDatasetProvider.get_all_dataset_names_async( + filters=SeedDatasetFilter(tags={"all"}), + ) + mock_match.assert_not_called() + + +class TestMetadataParsingLocal: + """Test metadata parsing and filter matching for local YAML providers.""" + + def _make_loader(self, yaml_path): + """Create a _LocalDatasetLoader bypassing SeedDataset pre-loading.""" + loader = _LocalDatasetLoader.__new__(_LocalDatasetLoader) + loader.file_path = yaml_path + loader._dataset_name = yaml_path.stem + return loader + + def _write_yaml(self, tmp_path, name, content): + """Write a .prompt YAML file and return its path.""" + path = tmp_path / f"{name}.prompt" + path.write_text(content) + return path + + @pytest.mark.asyncio + async def test_parse_metadata_extracts_fields(self, tmp_path): + """Test _parse_metadata correctly extracts metadata fields from YAML.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + harm_categories: + - violence + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = await loader._parse_metadata() + assert metadata is not None + assert metadata.harm_categories == {"violence"} + + @pytest.mark.asyncio + async def test_all_tag(self, tmp_path): + """Filter with tags={'all'} matches regardless of metadata types.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + tags: + - safety + harm_categories: + - violence + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = await loader._parse_metadata() + assert metadata is not None + filters = SeedDatasetFilter(tags={"all"}) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) + + @pytest.mark.asyncio + async def test_tags(self, tmp_path): + """YAML produces tags as list; set intersection in _match_filter expects a set.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + tags: + - safety + - default + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = await loader._parse_metadata() + assert metadata is not None + filters = SeedDatasetFilter(tags={"safety"}) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) + + @pytest.mark.asyncio + async def test_sizes(self, tmp_path): + """YAML produces size as string; _match_filter compares against enum values.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + size: large + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = await loader._parse_metadata() + assert metadata is not None + filters = SeedDatasetFilter(size={"large"}) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) + + @pytest.mark.asyncio + async def test_modalities(self, tmp_path): + """YAML produces modalities as list of strings; _match_filter uses enum values.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + modalities: + - text + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = await loader._parse_metadata() + assert metadata is not None + filters = SeedDatasetFilter(modalities={"text"}) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) + + @pytest.mark.asyncio + async def test_sources(self, tmp_path): + """YAML produces source_type as string; _match_filter compares against enum values.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + source_type: remote + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = await loader._parse_metadata() + assert metadata is not None + filters = SeedDatasetFilter(source_type={"remote"}) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) + + @pytest.mark.asyncio + async def test_ranks(self, tmp_path): + """YAML produces load_time as string; _match_filter compares against enum values.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + load_time: fast + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = await loader._parse_metadata() + assert metadata is not None + filters = SeedDatasetFilter(load_time={SeedDatasetLoadTime.FAST}) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) + + @pytest.mark.asyncio + async def test_harm_categories(self, tmp_path): + """Both YAML and filter use list[str], so intersection works correctly.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + harm_categories: + - violence + - cybercrime + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = await loader._parse_metadata() + assert metadata is not None + filters = SeedDatasetFilter(harm_categories={"violence"}) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) + + @pytest.mark.asyncio + async def test_empty_filter(self, tmp_path): + """Empty filter matches everything.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + harm_categories: + - violence + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = await loader._parse_metadata() + assert metadata is not None + filters = SeedDatasetFilter() + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) + + @pytest.mark.asyncio + async def test_no_metadata(self, tmp_path): + """YAML without any metadata fields returns None from _parse_metadata.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = await loader._parse_metadata() + assert metadata is None + + +class TestLocalDatasetMetadataCollisions: + """ + Regression tests that scan every real .prompt file under seed_datasets/local + to verify _parse_metadata does not crash from field-name collisions between + the YAML schema and SeedDatasetMetadata. + + The previous `source` field collision (URLs parsed as SeedDatasetSourceType) + is the motivating example. + """ + + @staticmethod + def _get_local_prompt_files() -> list: + """Collect all .prompt and .yaml files under the local datasets directory.""" + local_dir = Path(__file__).resolve().parents[3] / "pyrit" / "datasets" / "seed_datasets" / "local" + return sorted(local_dir.glob("**/*.prompt")) + sorted(local_dir.glob("**/*.yaml")) + + @pytest.mark.parametrize("prompt_file", _get_local_prompt_files.__func__(), ids=lambda p: p.stem) + @pytest.mark.asyncio + async def test_parse_metadata_does_not_crash(self, prompt_file): + """_parse_metadata must not raise on any real local dataset file.""" + loader = _LocalDatasetLoader.__new__(_LocalDatasetLoader) + loader.file_path = prompt_file + loader._dataset_name = prompt_file.stem + + metadata = await loader._parse_metadata() + # metadata can be None (no matching fields) or a valid SeedDatasetMetadata + if metadata is not None: + assert isinstance(metadata, SeedDatasetMetadata) + + @pytest.mark.parametrize("prompt_file", _get_local_prompt_files.__func__(), ids=lambda p: p.stem) + def test_no_yaml_key_shadows_metadata_field_with_wrong_type(self, prompt_file): + """ + If a YAML top-level key matches a SeedDatasetMetadata field name, the + coerced value must be the correct type (enum, set, list) — not a raw + string or other primitive that would silently break filtering. + """ + with open(prompt_file, encoding="utf-8") as f: + data = yaml.safe_load(f) + + if not isinstance(data, dict): + return + + metadata_field_names = {fld.name for fld in dc_fields(SeedDatasetMetadata)} + overlapping_keys = metadata_field_names & data.keys() + + if not overlapping_keys: + return + + # Coerce and construct — must not raise + loader = _LocalDatasetLoader.__new__(_LocalDatasetLoader) + loader.file_path = prompt_file + loader._dataset_name = prompt_file.stem + + raw = {k: data[k] for k in overlapping_keys} + coerced = SeedDatasetMetadata._coerce_metadata_values(raw_metadata=raw) + metadata = SeedDatasetMetadata(**coerced) + + # Verify coerced types match expectations + expected_types = { + "tags": (set, type(None)), + "size": (set, type(None)), + "modalities": (set, type(None)), + "source_type": (set, type(None)), + "load_time": (set, type(None)), + "harm_categories": (set, type(None)), + } + for key in overlapping_keys: + value = getattr(metadata, key) + valid_types = expected_types.get(key) + if valid_types: + assert isinstance(value, valid_types), ( + f"Field '{key}' in {prompt_file.name} has type {type(value).__name__}, " + f"expected one of {valid_types}" + ) diff --git a/tests/unit/setup/test_load_default_datasets.py b/tests/unit/setup/test_load_default_datasets.py index 54655f86e5..d701a64450 100644 --- a/tests/unit/setup/test_load_default_datasets.py +++ b/tests/unit/setup/test_load_default_datasets.py @@ -187,7 +187,7 @@ async def test_all_required_datasets_available_in_seed_provider(self) -> None: exists in the SeedDatasetProvider registry. """ # Get all available dataset names from SeedDatasetProvider - available_datasets = set(SeedDatasetProvider.get_all_dataset_names()) + available_datasets = set(await SeedDatasetProvider.get_all_dataset_names_async()) # Get ScenarioRegistry to discover all scenarios registry = ScenarioRegistry.get_registry_singleton()