diff --git a/pyrit/common/csv_helper.py b/pyrit/common/csv_helper.py index 48a9b9dd7..2fb831149 100644 --- a/pyrit/common/csv_helper.py +++ b/pyrit/common/csv_helper.py @@ -24,6 +24,9 @@ def write_csv(file: IO[Any], examples: list[dict[str, str]]) -> None: file: A file-like object opened for writing CSV data. examples (List[Dict[str, str]]): List of dictionaries to write as CSV rows. """ + if not examples: + return + writer = csv.DictWriter(file, fieldnames=examples[0].keys()) writer.writeheader() writer.writerows(examples) diff --git a/tests/unit/datasets/test_remote_dataset_loader.py b/tests/unit/datasets/test_remote_dataset_loader.py index d0052a4c7..a550fa279 100644 --- a/tests/unit/datasets/test_remote_dataset_loader.py +++ b/tests/unit/datasets/test_remote_dataset_loader.py @@ -72,3 +72,12 @@ def test_write_cache_creates_directories(self, tmp_path): loader._write_cache(cache_file=cache_file, examples=data, file_type="json") assert cache_file.exists() + + def test_write_cache_csv_allows_empty_examples(self, tmp_path): + loader = ConcreteRemoteLoader() + cache_file = tmp_path / "empty.csv" + + loader._write_cache(cache_file=cache_file, examples=[], file_type="csv") + + assert cache_file.exists() + assert cache_file.read_text(encoding="utf-8") == ""