-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Add missing channel_wise parameter to RandScaleIntensityFixedMean #8741
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
f996a93
6516499
5918a07
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -601,6 +601,7 @@ def __init__( | |||||||||||||||||||||||||||||||||||||||||||||||||||||
| factors: Sequence[float] | float = 0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| fixed_mean: bool = True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| preserve_range: bool = False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| channel_wise: bool = False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dtype: DtypeLike = np.float32, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -611,8 +612,8 @@ def __init__( | |||||||||||||||||||||||||||||||||||||||||||||||||||||
| fixed_mean: subtract the mean intensity before scaling with `factor`, then add the same value after scaling | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| to ensure that the output has the same mean as the input. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| channel_wise: if True, scale on each channel separately. `preserve_range` and `fixed_mean` are also applied | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| channel of the image if True. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| channel of the image if True. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dtype: output data type, if None, same as input image. defaults to float32. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -626,29 +627,51 @@ def __init__( | |||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.factor = self.factors[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.fixed_mean = fixed_mean | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.preserve_range = preserve_range | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.channel_wise = channel_wise | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.dtype = dtype | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.scaler = ScaleIntensityFixedMean( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| factor=self.factor, fixed_mean=self.fixed_mean, preserve_range=self.preserve_range, dtype=self.dtype | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| factor=self.factor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| fixed_mean=self.fixed_mean, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| preserve_range=self.preserve_range, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| channel_wise=self.channel_wise, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dtype=self.dtype, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def randomize(self, data: Any | None = None) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| super().randomize(None) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not self._do_transform: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.channel_wise: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.factor = [self.R.uniform(low=self.factors[0], high=self.factors[1]) for _ in range(data.shape[0])] # type: ignore | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
This might be slightly better. |
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Apply the transform to `img`. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| img = convert_to_tensor(img, track_meta=get_track_meta()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if randomize: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.randomize() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.randomize(img) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not self._do_transform: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return convert_data_type(img, dtype=self.dtype)[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.channel_wise: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for i, d in enumerate(img): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out_channel = ScaleIntensityFixedMean( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| factor=self.factor[i], # type: ignore | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| fixed_mean=self.fixed_mean, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| preserve_range=self.preserve_range, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dtype=self.dtype, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| )(d[None])[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out.append(out_channel) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ret: NdarrayOrTensor = torch.stack(out) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ret = convert_to_dst_type(ret, dst=img, dtype=self.dtype or img.dtype)[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return ret | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+661
to
+673
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
I think this is a little more readable, the type ignore for |
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return self.scaler(img, self.factor) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -669,6 +669,7 @@ def __init__( | |
| factors: Sequence[float] | float, | ||
| fixed_mean: bool = True, | ||
| preserve_range: bool = False, | ||
| channel_wise: bool = False, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same issue here, despite the docstring argument being incorrectly included in this order. |
||
| prob: float = 0.1, | ||
| dtype: DtypeLike = np.float32, | ||
| allow_missing_keys: bool = False, | ||
|
|
@@ -683,8 +684,8 @@ def __init__( | |
| fixed_mean: subtract the mean intensity before scaling with `factor`, then add the same value after scaling | ||
| to ensure that the output has the same mean as the input. | ||
| channel_wise: if True, scale on each channel separately. `preserve_range` and `fixed_mean` are also applied | ||
| on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the | ||
| channel of the image if True. | ||
| on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the | ||
| channel of the image if True. | ||
| dtype: output data type, if None, same as input image. defaults to float32. | ||
| allow_missing_keys: don't raise exception if key is missing. | ||
|
|
||
|
|
@@ -694,7 +695,12 @@ def __init__( | |
| self.fixed_mean = fixed_mean | ||
| self.preserve_range = preserve_range | ||
| self.scaler = RandScaleIntensityFixedMean( | ||
| factors=factors, fixed_mean=self.fixed_mean, preserve_range=preserve_range, dtype=dtype, prob=1.0 | ||
| factors=factors, | ||
| fixed_mean=self.fixed_mean, | ||
| preserve_range=preserve_range, | ||
| channel_wise=channel_wise, | ||
| dtype=dtype, | ||
| prob=1.0, | ||
| ) | ||
|
|
||
| def set_random_state( | ||
|
|
@@ -712,8 +718,15 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N | |
| d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) | ||
| return d | ||
|
|
||
| # expect all the specified keys have same spatial shape and share same random factors | ||
| first_key: Hashable = self.first_key(d) | ||
| if first_key == (): | ||
| for key in self.key_iterator(d): | ||
| d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) | ||
| return d | ||
|
|
||
| # all the keys share the same random scale factor | ||
| self.scaler.randomize(None) | ||
| self.scaler.randomize(d[first_key]) | ||
| for key in self.key_iterator(d): | ||
| d[key] = self.scaler(d[key], randomize=False) | ||
| return d | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add new arguments to the end of the argument list so that positional arguments aren't affected. I realise the argument was meant to be here given the docstring ordering, but if users call this constructor with positional arguments this will break compatibility.