Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,7 @@ def __init__(
factors: Sequence[float] | float = 0,
fixed_mean: bool = True,
preserve_range: bool = False,
channel_wise: bool = False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add new arguments to the end of the argument list so that positional arguments aren't affected. I realise the argument was meant to be here given the docstring ordering, but if users call this constructor with positional arguments this will break compatibility.

dtype: DtypeLike = np.float32,
) -> None:
"""
Expand All @@ -611,8 +612,8 @@ def __init__(
fixed_mean: subtract the mean intensity before scaling with `factor`, then add the same value after scaling
to ensure that the output has the same mean as the input.
channel_wise: if True, scale on each channel separately. `preserve_range` and `fixed_mean` are also applied
on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the
channel of the image if True.
on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the
channel of the image if True.
dtype: output data type, if None, same as input image. defaults to float32.

"""
Expand All @@ -626,29 +627,51 @@ def __init__(
self.factor = self.factors[0]
self.fixed_mean = fixed_mean
self.preserve_range = preserve_range
self.channel_wise = channel_wise
self.dtype = dtype

self.scaler = ScaleIntensityFixedMean(
factor=self.factor, fixed_mean=self.fixed_mean, preserve_range=self.preserve_range, dtype=self.dtype
factor=self.factor,
fixed_mean=self.fixed_mean,
preserve_range=self.preserve_range,
channel_wise=self.channel_wise,
dtype=self.dtype,
)

def randomize(self, data: Any | None = None) -> None:
super().randomize(None)
if not self._do_transform:
return None
self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1])
if self.channel_wise:
self.factor = [self.R.uniform(low=self.factors[0], high=self.factors[1]) for _ in range(data.shape[0])] # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.factor = [self.R.uniform(low=self.factors[0], high=self.factors[1]) for _ in range(data.shape[0])] # type: ignore
self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1], size=data.shape[:1]) # type: ignore

This might be slightly better.

else:
self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1])

def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
"""
img = convert_to_tensor(img, track_meta=get_track_meta())
if randomize:
self.randomize()
self.randomize(img)

if not self._do_transform:
return convert_data_type(img, dtype=self.dtype)[0]

if self.channel_wise:
out = []
for i, d in enumerate(img):
out_channel = ScaleIntensityFixedMean(
factor=self.factor[i], # type: ignore
fixed_mean=self.fixed_mean,
preserve_range=self.preserve_range,
dtype=self.dtype,
)(d[None])[0]
out.append(out_channel)
ret: NdarrayOrTensor = torch.stack(out)
ret = convert_to_dst_type(ret, dst=img, dtype=self.dtype or img.dtype)[0]
return ret
Comment on lines +661 to +673
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.channel_wise:
out = []
for i, d in enumerate(img):
out_channel = ScaleIntensityFixedMean(
factor=self.factor[i], # type: ignore
fixed_mean=self.fixed_mean,
preserve_range=self.preserve_range,
dtype=self.dtype,
)(d[None])[0]
out.append(out_channel)
ret: NdarrayOrTensor = torch.stack(out)
ret = convert_to_dst_type(ret, dst=img, dtype=self.dtype or img.dtype)[0]
return ret
if self.channel_wise:
out = []
for i, d in enumerate(img):
scale_trans = ScaleIntensityFixedMean(
factor=float(self.factor[i]),
fixed_mean=self.fixed_mean,
preserve_range=self.preserve_range,
dtype=self.dtype,
)
out.append(scale_trans(d[None]))
ret: NdarrayOrTensor = torch.cat(out)
ret = convert_to_dst_type(ret, dst=img, dtype=self.dtype or img.dtype)[0]
return ret

I think this is a little more readable, the type ignore for factor may still be needed though.


return self.scaler(img, self.factor)


Expand Down
21 changes: 17 additions & 4 deletions monai/transforms/intensity/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,7 @@ def __init__(
factors: Sequence[float] | float,
fixed_mean: bool = True,
preserve_range: bool = False,
channel_wise: bool = False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue here, despite the docstring argument being incorrectly included in this order.

prob: float = 0.1,
dtype: DtypeLike = np.float32,
allow_missing_keys: bool = False,
Expand All @@ -683,8 +684,8 @@ def __init__(
fixed_mean: subtract the mean intensity before scaling with `factor`, then add the same value after scaling
to ensure that the output has the same mean as the input.
channel_wise: if True, scale on each channel separately. `preserve_range` and `fixed_mean` are also applied
on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the
channel of the image if True.
on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the
channel of the image if True.
dtype: output data type, if None, same as input image. defaults to float32.
allow_missing_keys: don't raise exception if key is missing.

Expand All @@ -694,7 +695,12 @@ def __init__(
self.fixed_mean = fixed_mean
self.preserve_range = preserve_range
self.scaler = RandScaleIntensityFixedMean(
factors=factors, fixed_mean=self.fixed_mean, preserve_range=preserve_range, dtype=dtype, prob=1.0
factors=factors,
fixed_mean=self.fixed_mean,
preserve_range=preserve_range,
channel_wise=channel_wise,
dtype=dtype,
prob=1.0,
)

def set_random_state(
Expand All @@ -712,8 +718,15 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())
return d

# expect all the specified keys have same spatial shape and share same random factors
first_key: Hashable = self.first_key(d)
if first_key == ():
for key in self.key_iterator(d):
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())
return d

# all the keys share the same random scale factor
self.scaler.randomize(None)
self.scaler.randomize(d[first_key])
for key in self.key_iterator(d):
d[key] = self.scaler(d[key], randomize=False)
return d
Expand Down
32 changes: 32 additions & 0 deletions tests/transforms/test_rand_scale_intensity_fixed_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,38 @@ def test_value(self, p):
expected = expected + mn
assert_allclose(result, expected, type_test="tensor", atol=1e-7)

@parameterized.expand([[p] for p in TEST_NDARRAYS])
def test_channel_wise(self, p):
scaler = RandScaleIntensityFixedMean(prob=1.0, factors=0.5, channel_wise=True)
scaler.set_random_state(seed=0)
im = p(self.imt)
result = scaler(im)
np.random.seed(0)
# simulate the randomize() of transform
np.random.random()
channel_num = self.imt.shape[0]
factor = [np.random.uniform(low=-0.5, high=0.5) for _ in range(channel_num)]
expected = np.stack(
[
np.asarray((self.imt[i] - self.imt[i].mean()) * (1 + factor[i]) + self.imt[i].mean())
for i in range(channel_num)
]
).astype(np.float32)
assert_allclose(result, p(expected), atol=1e-4, rtol=1e-4, type_test=False)

@parameterized.expand([[p] for p in TEST_NDARRAYS])
def test_channel_wise_preserve_range(self, p):
scaler = RandScaleIntensityFixedMean(
prob=1.0, factors=0.5, channel_wise=True, preserve_range=True, fixed_mean=True
)
scaler.set_random_state(seed=0)
im = p(self.imt)
result = scaler(im)
# verify output is within input range per channel
for c in range(self.imt.shape[0]):
assert float(result[c].min()) >= float(im[c].min()) - 1e-6
assert float(result[c].max()) <= float(im[c].max()) + 1e-6


if __name__ == "__main__":
unittest.main()
20 changes: 20 additions & 0 deletions tests/transforms/test_rand_scale_intensity_fixed_meand.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,26 @@ def test_value(self):
expected = expected + mn
assert_allclose(result[key], p(expected), type_test="tensor", atol=1e-6)

def test_channel_wise(self):
key = "img"
for p in TEST_NDARRAYS:
scaler = RandScaleIntensityFixedMeand(keys=[key], factors=0.5, prob=1.0, channel_wise=True)
scaler.set_random_state(seed=0)
im = p(self.imt)
result = scaler({key: im})
np.random.seed(0)
# simulate the randomize function of transform
np.random.random()
channel_num = self.imt.shape[0]
factor = [np.random.uniform(low=-0.5, high=0.5) for _ in range(channel_num)]
expected = np.stack(
[
np.asarray((self.imt[i] - self.imt[i].mean()) * (1 + factor[i]) + self.imt[i].mean())
for i in range(channel_num)
]
).astype(np.float32)
assert_allclose(result[key], p(expected), atol=1e-4, rtol=1e-4, type_test=False)


if __name__ == "__main__":
unittest.main()
Loading