diff --git a/pyrit/common/yaml_loadable.py b/pyrit/common/yaml_loadable.py index a31dcf189..4f6aa8890 100644 --- a/pyrit/common/yaml_loadable.py +++ b/pyrit/common/yaml_loadable.py @@ -32,14 +32,17 @@ def from_yaml_file(cls: type[T], file: Union[Path | str]) -> T: FileNotFoundError: If the input YAML file path does not exist. ValueError: If the YAML file is invalid. """ - file = verify_and_resolve_path(file) - try: - yaml_data = yaml.safe_load(file.read_text("utf-8")) - except yaml.YAMLError as exc: - raise ValueError(f"Invalid YAML file '{file}': {exc}") from exc - - # If this class provides a from_dict factory, use it; - # otherwise, just instantiate directly with **yaml_data - if hasattr(cls, "from_dict") and callable(getattr(cls, "from_dict")): # noqa: B009 - return cls.from_dict(yaml_data) # type: ignore[attr-defined, no-any-return] + file = verify_and_resolve_path(file) + try: + yaml_data = yaml.safe_load(file.read_text("utf-8")) + except yaml.YAMLError as exc: + raise ValueError(f"Invalid YAML file '{file}': {exc}") from exc + + if yaml_data is None: + raise ValueError(f"YAML file '{file}' is empty.") + + # If this class provides a from_dict factory, use it; + # otherwise, just instantiate directly with **yaml_data + if hasattr(cls, "from_dict") and callable(getattr(cls, "from_dict")): # noqa: B009 + return cls.from_dict(yaml_data) # type: ignore[attr-defined, no-any-return] return cls(**yaml_data) diff --git a/tests/unit/models/test_seed.py b/tests/unit/models/test_seed.py index 32414e0f7..a37286480 100644 --- a/tests/unit/models/test_seed.py +++ b/tests/unit/models/test_seed.py @@ -183,6 +183,14 @@ def test_seed_dataset_initialization_with_yaml_objective(): assert len(dataset.seeds) == 3 +def test_seed_dataset_from_empty_yaml_file_raises_value_error(tmp_path): + empty_file = tmp_path / "empty.prompt" + empty_file.write_text("", encoding="utf-8") + + with pytest.raises(ValueError, match="is empty"): + SeedDataset.from_yaml_file(empty_file) + + def test_seed_dataset_get_values(): dataset = SeedDataset.from_yaml_file( pathlib.Path(DATASETS_PATH) / "seed_datasets" / "local" / "airt" / "illegal.prompt" diff --git a/tests/unit/setup/test_configuration_loader.py b/tests/unit/setup/test_configuration_loader.py index b4e737bf5..a1588023e 100644 --- a/tests/unit/setup/test_configuration_loader.py +++ b/tests/unit/setup/test_configuration_loader.py @@ -190,6 +190,14 @@ def test_from_yaml_file(self): finally: pathlib.Path(yaml_path).unlink() + def test_from_empty_yaml_file_raises_value_error(self, tmp_path): + """Test that an empty YAML file raises a clear ValueError.""" + yaml_path = tmp_path / "empty.yaml" + yaml_path.write_text("", encoding="utf-8") + + with pytest.raises(ValueError, match="is empty"): + ConfigurationLoader.from_yaml_file(yaml_path) + def test_get_default_config_path(self): """Test get_default_config_path returns expected path.""" default_path = ConfigurationLoader.get_default_config_path()