@@ -37,16 +37,19 @@ class ExtractDataKeyFromMetaKeyd(MapTransform):
3737 keys: keys to be transferred from meta to data
3838 meta_key: the meta key where all the meta-data is stored
3939 allow_missing_keys: don't raise exception if key is missing
40+ image_only: if True, only extract metadata from MetaTensor images to avoid duplication
4041
4142 Example:
4243 When the fastMRI dataset is loaded, "kspace" is stored in the data dictionary,
4344 but the ground-truth image with the key "reconstruction_rss" is stored in the meta data.
4445 In this case, ExtractDataKeyFromMetaKeyd moves "reconstruction_rss" to data.
46+ For MetaTensor objects, setting image_only=True prevents extracting redundant metadata.
4547 """
4648
47- def __init__ (self , keys : KeysCollection , meta_key : str , allow_missing_keys : bool = False ) -> None :
49+ def __init__ (self , keys : KeysCollection , meta_key : str , allow_missing_keys : bool = False , image_only : bool = False ) -> None :
4850 MapTransform .__init__ (self , keys , allow_missing_keys )
4951 self .meta_key = meta_key
52+ self .image_only = image_only
5053
5154 def __call__ (self , data : Mapping [Hashable , NdarrayOrTensor ]) -> dict [Hashable , Tensor ]:
5255 """
@@ -60,7 +63,12 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, T
6063 d = dict (data )
6164 for key in self .keys :
6265 if key in d [self .meta_key ]:
63- d [key ] = d [self .meta_key ][key ] # type: ignore
66+ extracted_value = d [self .meta_key ][key ]
67+ # When image_only is True, skip if the extracted value is a MetaTensor
68+ # to preserve metadata associations
69+ if self .image_only and isinstance (extracted_value , MetaTensor ):
70+ continue
71+ d [key ] = extracted_value # type: ignore
6472 elif not self .allow_missing_keys :
6573 raise KeyError (
6674 f"Key `{ key } ` of transform `{ self .__class__ .__name__ } ` was missing in the meta data"
0 commit comments