Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 34 additions & 5 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def __init__(
reset_ops_id: bool = True,
track_meta: bool = False,
weights_only: bool = True,
in_memory: bool = False,
) -> None:
"""
Args:
Expand Down Expand Up @@ -273,6 +274,10 @@ def __init__(
other safe objects. Setting this to `False` is required for loading `MetaTensor`
objects saved with `track_meta=True`, however this creates the possibility of remote
code execution through `torch.load` so be aware of the security implications of doing so.
in_memory: if `True`, keep the pre-processed data in an in-memory dictionary after first access.
This combines the benefits of persistent storage (data survives restarts) with faster RAM access.
When data is accessed, it is first loaded from disk cache and then stored in memory.
Default to `False`.

Raises:
ValueError: When both `track_meta=True` and `weights_only=True`, since this combination
Expand All @@ -299,6 +304,13 @@ def __init__(
)
self.track_meta = track_meta
self.weights_only = weights_only
self.in_memory = in_memory
self._memory_cache: dict[str, Any] = {}

@property
def memory_cache_size(self) -> int:
"""Return the number of items currently stored in the in-memory cache."""
return len(self._memory_cache)

def set_transform_hash(self, hash_xform_func: Callable[..., bytes]):
"""Get hashable transforms, and then hash them. Hashable transforms
Expand Down Expand Up @@ -326,6 +338,7 @@ def set_data(self, data: Sequence):

"""
self.data = data
self._memory_cache = {}
if self.cache_dir is not None and self.cache_dir.exists():
shutil.rmtree(self.cache_dir, ignore_errors=True)
self.cache_dir.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -389,14 +402,24 @@ def _cachecheck(self, item_transformed):

"""
hashfile = None
# compute cache key once for both disk and memory caching
data_item_md5 = self.hash_func(item_transformed).decode("utf-8")
data_item_md5 += self.transform_hash
cache_key = f"{data_item_md5}.pt"

if self.cache_dir is not None:
data_item_md5 = self.hash_func(item_transformed).decode("utf-8")
data_item_md5 += self.transform_hash
hashfile = self.cache_dir / f"{data_item_md5}.pt"
hashfile = self.cache_dir / cache_key

# check in-memory cache first
if self.in_memory and cache_key in self._memory_cache:
return self._memory_cache[cache_key]

if hashfile is not None and hashfile.is_file(): # cache hit
try:
return torch.load(hashfile, weights_only=self.weights_only)
_item_transformed = torch.load(hashfile, weights_only=self.weights_only)
if self.in_memory:
self._memory_cache[cache_key] = _item_transformed
return _item_transformed
except PermissionError as e:
if sys.platform != "win32":
raise e
Expand All @@ -409,15 +432,19 @@ def _cachecheck(self, item_transformed):

_item_transformed = self._pre_transform(deepcopy(item_transformed)) # keep the original hashed
if hashfile is None:
if self.in_memory:
self._memory_cache[cache_key] = _item_transformed
return _item_transformed
Comment on lines 434 to 437
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Memory cache stores different types in pure RAM vs hybrid mode.

Line 436 stores _item_transformed (unconverted), but line 462 stores _item_converted (tensor-converted). This causes type inconsistency between pure RAM mode (cache_dir=None) and hybrid mode.

For consistency, consider storing converted data in both paths:

Proposed fix
         if hashfile is None:
             if self.in_memory:
-                self._memory_cache[cache_key] = _item_transformed
+                _item_converted = convert_to_tensor(_item_transformed, convert_numeric=False, track_meta=self.track_meta)
+                self._memory_cache[cache_key] = _item_converted
             return _item_transformed
🤖 Prompt for AI Agents
In `@monai/data/dataset.py` around lines 434 - 437, The memory cache currently
stores _item_transformed when hashfile is None but stores _item_converted later
for hybrid mode, causing type inconsistency; change the branch that checks if
hashfile is None to store _item_converted (not _item_transformed) in
self._memory_cache when self.in_memory is true and cache_key is available (so
both pure RAM and hybrid modes cache the converted/tensor form consistently),
ensuring _item_converted is created/available before assignment and using the
same cache_key lookup used elsewhere.

# Convert to tensor for disk storage (and memory cache consistency)
_item_converted = convert_to_tensor(_item_transformed, convert_numeric=False, track_meta=self.track_meta)
try:
# NOTE: Writing to a temporary directory and then using a nearly atomic rename operation
# to make the cache more robust to manual killing of parent process
# which may leave partially written cache files in an incomplete state
with tempfile.TemporaryDirectory() as tmpdirname:
temp_hash_file = Path(tmpdirname) / hashfile.name
torch.save(
obj=convert_to_tensor(_item_transformed, convert_numeric=False, track_meta=self.track_meta),
obj=_item_converted,
f=temp_hash_file,
pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),
pickle_protocol=self.pickle_protocol,
Expand All @@ -431,6 +458,8 @@ def _cachecheck(self, item_transformed):
pass
except PermissionError: # project-monai/monai issue #3613
pass
if self.in_memory:
self._memory_cache[cache_key] = _item_converted
return _item_transformed

def _transform(self, index: int):
Expand Down
110 changes: 110 additions & 0 deletions tests/data/test_persistentdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os
import tempfile
import unittest
from pathlib import Path

import nibabel as nib
import numpy as np
Expand Down Expand Up @@ -200,6 +201,115 @@ def test_track_meta_and_weights_only(self, track_meta, weights_only, expected_er
im = test_dataset[0]["image"]
self.assertIsInstance(im, expected_type)

def test_in_memory_cache(self):
"""Test in_memory caching feature that combines persistent storage with RAM caching."""
items = [[list(range(i))] for i in range(5)]

with tempfile.TemporaryDirectory() as tempdir:
# First, create the persistent cache
ds1 = PersistentDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, in_memory=False)
# Access all items to populate disk cache
_ = list(ds1)

# Now create a new dataset with in_memory=True
ds2 = PersistentDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, in_memory=True)

# Memory cache should be empty initially
self.assertEqual(ds2.memory_cache_size, 0)

# Access items - they should be loaded from disk and cached in memory
_ = ds2[0]
self.assertEqual(ds2.memory_cache_size, 1)

_ = ds2[1]
self.assertEqual(ds2.memory_cache_size, 2)

# Access all items
_ = list(ds2)
self.assertEqual(ds2.memory_cache_size, 5)

# Accessing same item again should use memory cache (same result)
result1 = ds2[0]
result2 = ds2[0]
self.assertEqual(result1, result2)

# Test set_data clears in-memory cache
ds2.set_data(items[:3])
self.assertEqual(ds2.memory_cache_size, 0)

def test_in_memory_without_cache_dir(self):
"""Test in_memory caching works even without a cache_dir (pure RAM cache)."""
items = [[list(range(i))] for i in range(3)]

ds = PersistentDataset(data=items, transform=_InplaceXform(), cache_dir=None, in_memory=True)

# Memory cache should be empty initially
self.assertEqual(ds.memory_cache_size, 0)

# Access items - they should be cached in memory
_ = ds[0]
self.assertEqual(ds.memory_cache_size, 1)

_ = list(ds)
self.assertEqual(ds.memory_cache_size, 3)

def test_automatic_hybrid_caching(self):
"""
Test that in_memory=True provides automatic hybrid caching:
- ALL samples automatically persist to disk
- ALL samples automatically cache to RAM after first access
- No manual specification of which samples go where (unlike torchdatasets)
- Simulates restart scenario: disk cache survives, RAM cache rebuilds automatically
"""
items = [[list(range(i))] for i in range(5)]

with tempfile.TemporaryDirectory() as tempdir:
# === First "session": populate both disk and RAM cache ===
ds1 = PersistentDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, in_memory=True)

# Access all items - should automatically cache to BOTH disk AND RAM
for i in range(len(items)):
_ = ds1[i]

# Verify: ALL samples are in RAM (automatic, no manual specification)
self.assertEqual(ds1.memory_cache_size, 5)

# Verify: ALL samples are on disk (count .pt files)
cache_files = list(Path(tempdir).glob("*.pt"))
self.assertEqual(len(cache_files), 5)

# === Simulate "restart": new dataset instance, same cache_dir ===
# This is the key benefit over CacheDataset - disk cache survives restart
ds2 = PersistentDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, in_memory=True)

# RAM cache starts empty (simulating fresh process)
self.assertEqual(ds2.memory_cache_size, 0)

# Access all items - should load from disk and automatically cache to RAM
results = [ds2[i] for i in range(len(items))]

# Verify: ALL samples now in RAM again (automatic rebuild from disk)
self.assertEqual(ds2.memory_cache_size, 5)

# Verify: Results are correct (transformed by _InplaceXform)
for i, result in enumerate(results):
if i == 0:
expected = [[1]] # empty list -> append 1
else:
expected = [[np.pi] + list(range(1, i))] # data[0] = 0 + np.pi
self.assertEqual(result, expected)

# === Verify RAM cache provides fast repeated access ===
# Accessing same items again should hit RAM cache (same objects)
for i in range(len(items)):
result1 = ds2[i]
result2 = ds2[i]
# Should return equivalent results
self.assertEqual(result1, result2)

# RAM cache size unchanged (no duplicate entries)
self.assertEqual(ds2.memory_cache_size, 5)


if __name__ == "__main__":
unittest.main()
Loading