77 PreprocessorRegistry ,
88)
99
10- __all__ = ["AudioEncoder" , "ImageEncoder" , "PreprocessEncoder" , "VideoEncoder " ]
10+ __all__ = ["MediaEncoder " ]
1111
1212
13- class PreprocessEncoder (DatasetPreprocessor ):
13+ @PreprocessorRegistry .register ("encode_media" )
14+ class MediaEncoder (DatasetPreprocessor ):
15+ def __init__ (self , encode_kwargs : dict [str , Any ] | None = None ) -> None :
16+ self .encode_audio_kwargs = (
17+ encode_kwargs .get ("audio" , {}) if encode_kwargs else {}
18+ )
19+ self .encode_image_kwargs = (
20+ encode_kwargs .get ("image" , {}) if encode_kwargs else {}
21+ )
22+ self .encode_video_kwargs = (
23+ encode_kwargs .get ("video" , {}) if encode_kwargs else {}
24+ )
25+
1426 @staticmethod
1527 def encode_audio (* args , ** kwargs ):
1628 from guidellm .extras .audio import encode_audio
@@ -29,14 +41,6 @@ def encode_video(*args, **kwargs):
2941
3042 return encode_video (* args , ** kwargs )
3143
32-
33- @PreprocessorRegistry .register ("encode_audio" )
34- class AudioEncoder (PreprocessEncoder ):
35- def __init__ (self , encode_kwargs : dict [str , Any ] | None = None ) -> None :
36- self .encode_audio_kwargs = (
37- encode_kwargs .get ("audio" , {}) if encode_kwargs else {}
38- )
39-
4044 def __call__ (self , columns : dict [str , list [Any ]]) -> dict [str , list [Any ]]:
4145 if columns .get ("audio_column" ):
4246 encoded_audio = []
@@ -49,17 +53,6 @@ def __call__(self, columns: dict[str, list[Any]]) -> dict[str, list[Any]]:
4953 )
5054 columns ["audio_column" ] = encoded_audio
5155
52- return columns
53-
54-
55- @PreprocessorRegistry .register ("encode_image" )
56- class ImageEncoder (PreprocessEncoder ):
57- def __init__ (self , encode_kwargs : dict [str , Any ] | None = None ) -> None :
58- self .encode_image_kwargs = (
59- encode_kwargs .get ("image" , {}) if encode_kwargs else {}
60- )
61-
62- def __call__ (self , columns : dict [str , list [Any ]]) -> dict [str , list [Any ]]:
6356 if columns .get ("image_column" ):
6457 encoded_images = []
6558 for image in columns ["image_column" ]:
@@ -71,17 +64,6 @@ def __call__(self, columns: dict[str, list[Any]]) -> dict[str, list[Any]]:
7164 )
7265 columns ["image_column" ] = encoded_images
7366
74- return columns
75-
76-
77- @PreprocessorRegistry .register ("encode_video" )
78- class VideoEncoder (PreprocessEncoder ):
79- def __init__ (self , encode_kwargs : dict [str , Any ] | None = None ) -> None :
80- self .encode_video_kwargs = (
81- encode_kwargs .get ("video" , {}) if encode_kwargs else {}
82- )
83-
84- def __call__ (self , columns : dict [str , list [Any ]]) -> dict [str , list [Any ]]:
8567 if columns .get ("video_column" ):
8668 encoded_videos = []
8769 for video in columns ["video_column" ]:
0 commit comments