@@ -84,6 +84,17 @@ def get_package_versions() -> dict[str, str]:
8484
8585
8686class BenchmarkTest (ABC ):
87+ # Direct mapping of library names to their transform method names
88+ LIBRARY_ATTR_MAP : dict [str , str ] = {
89+ "albucore" : "albucore_transform" ,
90+ "opencv" : "opencv_transform" ,
91+ "numpy" : "numpy_transform" ,
92+ "lut" : "lut_transform" ,
93+ "kornia-rs" : "kornia_transform" ,
94+ "torchvision" : "torchvision_transform" ,
95+ "simsimd" : "simsimd_transform" ,
96+ }
97+
8798 def __init__ (self , num_channels : int ) -> None :
8899 self .num_channels = num_channels
89100 self .img_type = None
@@ -131,30 +142,17 @@ def simsimd(self, img: np.ndarray) -> np.ndarray:
131142 return clip (self .simsimd_transform (img ), img .dtype )
132143
133144 def is_supported_by (self , library : str ) -> bool :
134- library_attr_map = {
135- "albucore" : "albucore_transform" ,
136- "opencv" : "opencv_transform" ,
137- "numpy" : "numpy_transform" ,
138- "lut" : "lut_transform" ,
139- "kornia-rs" : "kornia_transform" ,
140- "torchvision" : "torchvision_transform" ,
141- "simsimd" : "simsimd_transform" ,
142- }
143-
144- # Check if the library is in the map
145- if library in library_attr_map :
146- attrs = library_attr_map [library ]
147- # Ensure attrs is a list for uniform processing
148- if not isinstance (attrs , list ):
149- attrs = [attrs ] # type: ignore[assignment]
150- # Return True if any of the specified attributes exist
151- return any (hasattr (self , attr ) for attr in attrs )
152-
153- # Fallback: checks if the class has an attribute with the library's name
154- return hasattr (self , library )
145+ # Check if the library has a specific mapping
146+ transform_attr = self .LIBRARY_ATTR_MAP .get (library , f"{ library } _transform" )
147+
148+ # Return True if the transform method exists
149+ return hasattr (self , transform_attr )
155150
156151 def run (self , library : str , imgs : list [np .ndarray ]) -> list [np .ndarray ] | None :
157- transform = getattr (self , library )
152+ # Get the appropriate transform method
153+ transform_attr = self .LIBRARY_ATTR_MAP .get (library , f"{ library } _transform" )
154+ transform = getattr (self , transform_attr )
155+
158156 transformed_images = []
159157 for img in imgs :
160158 result = transform (img )
0 commit comments