Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions pyrit/common/yaml_loadable.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,17 @@ def from_yaml_file(cls: type[T], file: Union[Path | str]) -> T:
FileNotFoundError: If the input YAML file path does not exist.
ValueError: If the YAML file is invalid.
"""
file = verify_and_resolve_path(file)
try:
yaml_data = yaml.safe_load(file.read_text("utf-8"))
except yaml.YAMLError as exc:
raise ValueError(f"Invalid YAML file '{file}': {exc}") from exc

# If this class provides a from_dict factory, use it;
# otherwise, just instantiate directly with **yaml_data
if hasattr(cls, "from_dict") and callable(getattr(cls, "from_dict")): # noqa: B009
return cls.from_dict(yaml_data) # type: ignore[attr-defined, no-any-return]
file = verify_and_resolve_path(file)
try:
yaml_data = yaml.safe_load(file.read_text("utf-8"))
except yaml.YAMLError as exc:
raise ValueError(f"Invalid YAML file '{file}': {exc}") from exc

if yaml_data is None:
raise ValueError(f"YAML file '{file}' is empty.")

# If this class provides a from_dict factory, use it;
# otherwise, just instantiate directly with **yaml_data
if hasattr(cls, "from_dict") and callable(getattr(cls, "from_dict")): # noqa: B009
return cls.from_dict(yaml_data) # type: ignore[attr-defined, no-any-return]
return cls(**yaml_data)
8 changes: 8 additions & 0 deletions tests/unit/models/test_seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,14 @@ def test_seed_dataset_initialization_with_yaml_objective():
assert len(dataset.seeds) == 3


def test_seed_dataset_from_empty_yaml_file_raises_value_error(tmp_path):
empty_file = tmp_path / "empty.prompt"
empty_file.write_text("", encoding="utf-8")

with pytest.raises(ValueError, match="is empty"):
SeedDataset.from_yaml_file(empty_file)


def test_seed_dataset_get_values():
dataset = SeedDataset.from_yaml_file(
pathlib.Path(DATASETS_PATH) / "seed_datasets" / "local" / "airt" / "illegal.prompt"
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/setup/test_configuration_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,14 @@ def test_from_yaml_file(self):
finally:
pathlib.Path(yaml_path).unlink()

def test_from_empty_yaml_file_raises_value_error(self, tmp_path):
"""Test that an empty YAML file raises a clear ValueError."""
yaml_path = tmp_path / "empty.yaml"
yaml_path.write_text("", encoding="utf-8")

with pytest.raises(ValueError, match="is empty"):
ConfigurationLoader.from_yaml_file(yaml_path)

def test_get_default_config_path(self):
"""Test get_default_config_path returns expected path."""
default_path = ConfigurationLoader.get_default_config_path()
Expand Down
Loading