Skip to content

Commit ecd100f

Browse files
committed
UT for src/guidellm/data/deserializers/file.py
Signed-off-by: guangli.bao <[email protected]>
1 parent fce8858 commit ecd100f

File tree

2 files changed

+366
-0
lines changed

2 files changed

+366
-0
lines changed

.github/actions/run-tox/action.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ runs:
2424
- name: Install dependencies
2525
run: |
2626
pip install tox tox-pdm
27+
pip install numpy==2.1.3, h5py==3.9.0
2728
shell: bash
2829
- name: Run tox
2930
run: |
Lines changed: 365 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,365 @@
1+
import csv
2+
import io
3+
from pathlib import Path
4+
5+
import pandas as pd
6+
import pyarrow as pa
7+
import pyarrow.parquet as pq
8+
import pytest
9+
from datasets import Dataset, DatasetDict
10+
from pyarrow import ipc
11+
12+
from guidellm.data.deserializers.deserializer import DataNotSupportedError
13+
from guidellm.data.deserializers.file import (
14+
ArrowFileDatasetDeserializer,
15+
CSVFileDatasetDeserializer,
16+
DBFileDatasetDeserializer,
17+
HDF5FileDatasetDeserializer,
18+
JSONFileDatasetDeserializer,
19+
ParquetFileDatasetDeserializer,
20+
TarFileDatasetDeserializer,
21+
TextFileDatasetDeserializer,
22+
)
23+
24+
25+
def processor_factory():
26+
return None
27+
28+
29+
###################
30+
# Tests text file deserializer
31+
###################
32+
33+
34+
@pytest.mark.sanity
35+
def test_text_file_deserializer_success(tmp_path):
36+
# Arrange: create a temp text file
37+
file_path = tmp_path / "sample.txt"
38+
file_content = ["hello\n", "world\n"]
39+
file_path.write_text("".join(file_content))
40+
41+
deserializer = TextFileDatasetDeserializer()
42+
43+
dataset = deserializer(
44+
data=file_path,
45+
processor_factory=processor_factory(),
46+
random_seed=123,
47+
)
48+
49+
# Assert
50+
assert isinstance(dataset, Dataset)
51+
assert dataset["text"] == file_content
52+
assert len(dataset) == 2
53+
54+
55+
@pytest.mark.parametrize(
56+
"invalid_data",
57+
[
58+
123, # Not a path
59+
None, # Not a path
60+
{"file": "abc.txt"}, # Wrong type
61+
],
62+
)
63+
@pytest.mark.sanity
64+
def test_text_file_deserializer_invalid_type(invalid_data):
65+
deserializer = TextFileDatasetDeserializer()
66+
67+
with pytest.raises(DataNotSupportedError):
68+
deserializer(
69+
data=invalid_data,
70+
processor_factory=processor_factory(),
71+
random_seed=0,
72+
)
73+
74+
75+
@pytest.mark.sanity
76+
def test_text_file_deserializer_file_not_exists(tmp_path):
77+
deserializer = TextFileDatasetDeserializer()
78+
non_existent_file = tmp_path / "missing.txt"
79+
80+
with pytest.raises(DataNotSupportedError):
81+
deserializer(
82+
data=non_existent_file,
83+
processor_factory=processor_factory(),
84+
random_seed=0,
85+
)
86+
87+
88+
@pytest.mark.sanity
89+
def test_text_file_deserializer_not_a_file(tmp_path):
90+
deserializer = TextFileDatasetDeserializer()
91+
directory = tmp_path / "folder"
92+
directory.mkdir()
93+
94+
with pytest.raises(DataNotSupportedError):
95+
deserializer(
96+
data=directory,
97+
processor_factory=processor_factory(),
98+
random_seed=0,
99+
)
100+
101+
102+
@pytest.mark.sanity
103+
def test_text_file_deserializer_invalid_file_extension(tmp_path):
104+
deserializer = TextFileDatasetDeserializer()
105+
106+
file_path = tmp_path / "data.ttl"
107+
file_path.write_text("hello")
108+
109+
with pytest.raises(DataNotSupportedError):
110+
deserializer(
111+
data=file_path,
112+
processor_factory=processor_factory(),
113+
random_seed=0,
114+
)
115+
116+
117+
###################
118+
# Tests parquet file deserializer
119+
###################
120+
121+
122+
def create_parquet_file(path: Path):
123+
# Arrange: to create a minimal parquet file
124+
table = pa.Table.from_pydict({"text": ["hello", "world"]})
125+
pq.write_table(table, path)
126+
127+
128+
@pytest.mark.sanity
129+
def test_parquet_file_deserializer_success(tmp_path):
130+
file_path = tmp_path / "sample.parquet"
131+
create_parquet_file(file_path)
132+
133+
deserializer = ParquetFileDatasetDeserializer()
134+
135+
dataset = deserializer(
136+
data=file_path,
137+
processor_factory=processor_factory(),
138+
random_seed=42,
139+
)
140+
141+
# Assert
142+
assert isinstance(dataset, DatasetDict)
143+
assert dataset["train"].column_names == ["text"]
144+
assert dataset["train"]["text"] == ["hello", "world"]
145+
assert len(dataset["train"]["text"]) == 2
146+
147+
148+
@pytest.mark.sanity
149+
def test_parquet_file_deserializer_file_not_exists(tmp_path):
150+
deserializer = ParquetFileDatasetDeserializer()
151+
missing_file = tmp_path / "missing.parquet"
152+
153+
with pytest.raises(DataNotSupportedError):
154+
deserializer(
155+
data=missing_file,
156+
processor_factory=processor_factory(),
157+
random_seed=3,
158+
)
159+
160+
161+
###################
162+
# Tests csv file deserializer
163+
###################
164+
165+
166+
def create_csv_file(path: Path):
167+
"""Helper to create a minimal csv file."""
168+
output = io.StringIO()
169+
writer = csv.writer(output)
170+
writer.writerow(["text"])
171+
writer.writerow(["hello world"])
172+
with Path.open("w") as f:
173+
f.write(output.getvalue())
174+
175+
176+
@pytest.mark.sanity
177+
def test_csv_file_deserializer_success(tmp_path):
178+
# Arrange: create a temp csv file
179+
file_path = tmp_path / "sample.csv"
180+
create_csv_file(file_path)
181+
182+
deserializer = CSVFileDatasetDeserializer()
183+
184+
dataset = deserializer(
185+
data=file_path,
186+
processor_factory=processor_factory(),
187+
random_seed=43,
188+
)
189+
190+
# Assert
191+
assert isinstance(dataset, DatasetDict)
192+
assert dataset["train"]["text"] == ["hello world"]
193+
assert len(["train"]) == 1
194+
195+
196+
###################
197+
# Tests json file deserializer
198+
###################
199+
200+
201+
@pytest.mark.sanity
202+
def test_json_file_deserializer_success(tmp_path):
203+
# Arrange: create a temp json file
204+
file_path = tmp_path / "sample.json"
205+
file_content = '{"text": "hello world"}\n'
206+
file_path.write_text("".join(file_content))
207+
208+
deserializer = JSONFileDatasetDeserializer()
209+
210+
dataset = deserializer(
211+
data=file_path,
212+
processor_factory=processor_factory(),
213+
random_seed=123,
214+
)
215+
216+
# Assert
217+
assert isinstance(dataset, DatasetDict)
218+
assert dataset["train"]["text"] == ["hello world"]
219+
assert len(dataset) == 1
220+
221+
222+
###################
223+
# Tests arrow file deserializer
224+
###################
225+
226+
227+
@pytest.mark.sanity
228+
def test_arrow_file_deserializer_success(monkeypatch, tmp_path):
229+
# Arrange: create a temp arrow file
230+
table = pa.Table.from_pydict({"text": ["hello", "world"]})
231+
file_path = tmp_path / "sample.arrow"
232+
233+
with (
234+
pa.OSFile(str(file_path), "wb") as sink,
235+
ipc.RecordBatchFileWriter(sink, table.schema) as writer,
236+
):
237+
writer.write_table(table)
238+
239+
deserializer = ArrowFileDatasetDeserializer()
240+
241+
dataset = deserializer(
242+
data=file_path,
243+
processor_factory=processor_factory(),
244+
random_seed=42,
245+
)
246+
247+
# assert
248+
assert isinstance(dataset, DatasetDict)
249+
assert "train" in dataset
250+
assert isinstance(dataset["train"], Dataset)
251+
assert dataset["train"].num_rows == 2
252+
253+
254+
###################
255+
# Tests HDF5 file deserializer
256+
###################
257+
258+
259+
@pytest.mark.sanity
260+
def test_hdf5_file_deserializer_success(tmp_path):
261+
df_sample = pd.DataFrame({"text": ["hello", "world"]})
262+
file_path = tmp_path / "sample.h5"
263+
df_sample.to_hdf(str(file_path), key="data", mode="w", format="fixed")
264+
265+
deserializer = HDF5FileDatasetDeserializer()
266+
267+
dataset = deserializer(
268+
data=file_path,
269+
processor_factory=processor_factory(),
270+
random_seed=1,
271+
)
272+
273+
# assert
274+
assert isinstance(dataset, Dataset)
275+
assert dataset.num_rows == 2
276+
assert dataset["text"] == ["hello", "world"]
277+
278+
279+
##################
280+
# Tests DB file deserializer
281+
###################
282+
283+
284+
@pytest.mark.skip(reason="issue: #492")
285+
def test_db_file_deserializer_success(monkeypatch, tmp_path):
286+
import sqlite3
287+
288+
def create_sqlite_db(path: Path):
289+
conn = sqlite3.connect(path)
290+
cur = conn.cursor()
291+
cur.execute("CREATE TABLE samples (text TEXT)")
292+
cur.execute("INSERT INTO samples (text) VALUES ('hello')")
293+
cur.execute("INSERT INTO samples (text) VALUES ('world')")
294+
conn.commit()
295+
conn.close()
296+
297+
# Arrange: create a valid .db file
298+
db_path = tmp_path / "sample.db"
299+
create_sqlite_db(db_path)
300+
301+
# arrange: mock Dataset.from_sql return one dataset
302+
mocked_ds = Dataset.from_dict({"text": ["hello", "world"]})
303+
304+
def mock_from_sql(sql, con, **kwargs):
305+
assert sql == "SELECT * FROM samples"
306+
assert con == (str(db_path))
307+
return mocked_ds
308+
309+
monkeypatch.setattr("datasets.Dataset.from_sql", mock_from_sql)
310+
311+
deserializer = DBFileDatasetDeserializer()
312+
313+
dataset = deserializer(
314+
data=db_path,
315+
processor_factory=processor_factory(),
316+
random_seed=1,
317+
)
318+
319+
# Assert: result is of type Dataset
320+
assert isinstance(dataset, Dataset)
321+
assert dataset.num_rows == 2
322+
assert dataset["text"] == ["hello", "world"]
323+
324+
325+
##################
326+
# Tests Tar file deserializer
327+
###################
328+
329+
330+
def create_simple_tar(tar_path: str):
331+
import tarfile
332+
333+
# create tar 文件 in write mode
334+
with tarfile.open(tar_path, "w") as tar:
335+
# write content to be added to the tar file
336+
content = b"hello world\nthis is a tar file\n"
337+
338+
# using BytesIO
339+
data_stream = io.BytesIO(content)
340+
341+
# tarinfo: file description info
342+
info = tarfile.TarInfo(name="sample.txt")
343+
info.size = len(content)
344+
345+
# write file to tar archive
346+
tar.addfile(info, data_stream)
347+
348+
349+
@pytest.mark.sanity
350+
def test_tar_file_deserializer_success(tmp_path):
351+
file_path = tmp_path / "sample.tar"
352+
create_simple_tar(file_path)
353+
354+
deserializer = TarFileDatasetDeserializer()
355+
356+
dataset = deserializer(
357+
data=file_path,
358+
processor_factory=processor_factory(),
359+
random_seed=43,
360+
)
361+
362+
assert isinstance(dataset, DatasetDict)
363+
assert "train" in dataset
364+
assert isinstance(dataset["train"], Dataset)
365+
assert dataset["train"].num_rows == 1

0 commit comments

Comments
 (0)