Skip to content

Commit 167d84a

Browse files
committed
Merge media encoders
Signed-off-by: Samuel Monson <[email protected]>
1 parent 687f702 commit 167d84a

File tree

3 files changed

+17
-40
lines changed

3 files changed

+17
-40
lines changed

src/guidellm/benchmark/schemas/generative/entrypoints.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,7 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any:
207207
Field(
208208
default_factory=lambda: [ # type: ignore [arg-type]
209209
"generative_column_mapper",
210-
"encode_audio",
211-
"encode_image",
212-
"encode_video",
210+
"encode_media",
213211
],
214212
description="List of dataset preprocessors to apply in order",
215213
)
Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .encoders import AudioEncoder, ImageEncoder, PreprocessEncoder, VideoEncoder
1+
from .encoders import MediaEncoder
22
from .mappers import GenerativeColumnMapper
33
from .preprocessor import (
44
DataDependentPreprocessor,
@@ -7,12 +7,9 @@
77
)
88

99
__all__ = [
10-
"AudioEncoder",
1110
"DataDependentPreprocessor",
1211
"DatasetPreprocessor",
1312
"GenerativeColumnMapper",
14-
"ImageEncoder",
15-
"PreprocessEncoder",
13+
"MediaEncoder",
1614
"PreprocessorRegistry",
17-
"VideoEncoder",
1815
]

src/guidellm/data/preprocessors/encoders.py

Lines changed: 14 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,22 @@
77
PreprocessorRegistry,
88
)
99

10-
__all__ = ["AudioEncoder", "ImageEncoder", "PreprocessEncoder", "VideoEncoder"]
10+
__all__ = ["MediaEncoder"]
1111

1212

13-
class PreprocessEncoder(DatasetPreprocessor):
13+
@PreprocessorRegistry.register("encode_media")
14+
class MediaEncoder(DatasetPreprocessor):
15+
def __init__(self, encode_kwargs: dict[str, Any] | None = None) -> None:
16+
self.encode_audio_kwargs = (
17+
encode_kwargs.get("audio", {}) if encode_kwargs else {}
18+
)
19+
self.encode_image_kwargs = (
20+
encode_kwargs.get("image", {}) if encode_kwargs else {}
21+
)
22+
self.encode_video_kwargs = (
23+
encode_kwargs.get("video", {}) if encode_kwargs else {}
24+
)
25+
1426
@staticmethod
1527
def encode_audio(*args, **kwargs):
1628
from guidellm.extras.audio import encode_audio
@@ -29,14 +41,6 @@ def encode_video(*args, **kwargs):
2941

3042
return encode_video(*args, **kwargs)
3143

32-
33-
@PreprocessorRegistry.register("encode_audio")
34-
class AudioEncoder(PreprocessEncoder):
35-
def __init__(self, encode_kwargs: dict[str, Any] | None = None) -> None:
36-
self.encode_audio_kwargs = (
37-
encode_kwargs.get("audio", {}) if encode_kwargs else {}
38-
)
39-
4044
def __call__(self, columns: dict[str, list[Any]]) -> dict[str, list[Any]]:
4145
if columns.get("audio_column"):
4246
encoded_audio = []
@@ -49,17 +53,6 @@ def __call__(self, columns: dict[str, list[Any]]) -> dict[str, list[Any]]:
4953
)
5054
columns["audio_column"] = encoded_audio
5155

52-
return columns
53-
54-
55-
@PreprocessorRegistry.register("encode_image")
56-
class ImageEncoder(PreprocessEncoder):
57-
def __init__(self, encode_kwargs: dict[str, Any] | None = None) -> None:
58-
self.encode_image_kwargs = (
59-
encode_kwargs.get("image", {}) if encode_kwargs else {}
60-
)
61-
62-
def __call__(self, columns: dict[str, list[Any]]) -> dict[str, list[Any]]:
6356
if columns.get("image_column"):
6457
encoded_images = []
6558
for image in columns["image_column"]:
@@ -71,17 +64,6 @@ def __call__(self, columns: dict[str, list[Any]]) -> dict[str, list[Any]]:
7164
)
7265
columns["image_column"] = encoded_images
7366

74-
return columns
75-
76-
77-
@PreprocessorRegistry.register("encode_video")
78-
class VideoEncoder(PreprocessEncoder):
79-
def __init__(self, encode_kwargs: dict[str, Any] | None = None) -> None:
80-
self.encode_video_kwargs = (
81-
encode_kwargs.get("video", {}) if encode_kwargs else {}
82-
)
83-
84-
def __call__(self, columns: dict[str, list[Any]]) -> dict[str, list[Any]]:
8567
if columns.get("video_column"):
8668
encoded_videos = []
8769
for video in columns["video_column"]:

0 commit comments

Comments
 (0)