Skip to content

Commit ac01d89

Browse files
authored
Add initial EasyList class (#193)
1 parent c737d47 commit ac01d89

3 files changed

Lines changed: 805 additions & 7 deletions

File tree

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
11
from .based_base import BasedBase
22
from .collection_base import CollectionBase
3+
from .easy_list import EasyList
34
from .model_base import ModelBase
45
from .new_base import NewBase
56
from .obj_base import ObjBase
67

7-
__all__ = [
8-
BasedBase,
9-
CollectionBase,
10-
ObjBase,
11-
ModelBase,
12-
NewBase,
13-
]
8+
__all__ = [BasedBase, CollectionBase, ObjBase, ModelBase, NewBase, EasyList]
Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
# SPDX-FileCopyrightText: 2025 EasyScience contributors <core@easyscience.software>
2+
# SPDX-License-Identifier: BSD-3-Clause
3+
# © 2021-2025 Contributors to the EasyScience project <https://github.com/easyScience/EasyScience
4+
5+
from __future__ import annotations
6+
7+
import copy
8+
import warnings
9+
from collections.abc import MutableSequence
10+
from typing import Any
11+
from typing import Callable
12+
from typing import Dict
13+
from typing import Iterable
14+
from typing import List
15+
from typing import Optional
16+
from typing import Type
17+
from typing import TypeVar
18+
from typing import overload
19+
20+
from easyscience.io.serializer_base import SerializerBase
21+
22+
from .new_base import NewBase
23+
24+
ProtectedType_ = TypeVar('ProtectedType', bound=NewBase)
25+
26+
27+
class EasyList(NewBase, MutableSequence[ProtectedType_]):
28+
# If we were to inherit from List instead of MutableSequence,
29+
# we would have to overwrite "extend", "remove", "__iadd__", "count", "append", "__iter__" and "clear"
30+
def __init__(
31+
self,
32+
*args: ProtectedType_ | list[ProtectedType_],
33+
protected_types: list[Type[NewBase]] | Type[NewBase] | None = None,
34+
unique_name: Optional[str] = None,
35+
display_name: Optional[str] = None,
36+
**kwargs: Any,
37+
):
38+
"""
39+
Initialize the EasyList.
40+
:param args: Initial items to add to the list
41+
:param protected_types: Types that are allowed in the list. Can be a single NewBase subclass or a list of them.
42+
If None, defaults to [NewBase].
43+
:param unique_name: Optional unique name for the list
44+
:param display_name: Optional display name for the list
45+
"""
46+
super().__init__(unique_name=unique_name, display_name=display_name)
47+
if protected_types is None:
48+
self._protected_types = [NewBase]
49+
elif isinstance(protected_types, type) and issubclass(protected_types, NewBase):
50+
self._protected_types = [protected_types]
51+
elif isinstance(protected_types, Iterable) and all(issubclass(t, NewBase) for t in protected_types):
52+
self._protected_types = list(protected_types)
53+
else:
54+
raise TypeError('protected_types must be a NewBase subclass or an iterable of NewBase subclasses')
55+
self._data: List[ProtectedType_] = []
56+
57+
# Add initial items
58+
for item in args:
59+
if isinstance(item, list):
60+
for sub_item in item:
61+
self.append(sub_item)
62+
else:
63+
self.append(item)
64+
65+
# For deserialization, the dict can't contain an *args, so we check for 'data' in kwargs
66+
if 'data' in kwargs:
67+
data = kwargs.pop('data')
68+
for item in data:
69+
self.append(item)
70+
71+
# MutableSequence abstract methods
72+
73+
# Use @overload to provide precise type hints for different __getitem__ argument types
74+
@overload
75+
def __getitem__(self, idx: int) -> ProtectedType_: ...
76+
@overload
77+
def __getitem__(self, idx: slice) -> 'EasyList[ProtectedType_]': ...
78+
@overload
79+
def __getitem__(self, idx: str) -> ProtectedType_: ...
80+
def __getitem__(self, idx: int | slice | str) -> ProtectedType_ | 'EasyList[ProtectedType_]':
81+
"""
82+
Get an item by index, slice, or unique_name.
83+
84+
:param idx: Index, slice, or unique_name of the item
85+
:return: The item or a new EasyList for slices
86+
"""
87+
if isinstance(idx, int):
88+
return self._data[idx]
89+
elif isinstance(idx, slice):
90+
return self.__class__(self._data[idx], protected_types=self._protected_types)
91+
elif isinstance(idx, str):
92+
element = next((r for r in self._data if self._get_key(r) == idx), None)
93+
if element is not None:
94+
return element
95+
raise KeyError(f'No item with unique name "{idx}" found')
96+
else:
97+
raise TypeError('Index must be an int, slice, or str')
98+
99+
@overload
100+
def __setitem__(self, idx: int, value: ProtectedType_) -> None: ...
101+
@overload
102+
def __setitem__(self, idx: slice, value: Iterable[ProtectedType_]) -> None: ...
103+
104+
def __setitem__(self, idx: int | slice, value: ProtectedType_ | Iterable[ProtectedType_]) -> None:
105+
"""
106+
Set an item at an index.
107+
108+
:param idx: Index to set
109+
:param value: New value
110+
"""
111+
if isinstance(idx, int):
112+
if not isinstance(value, tuple(self._protected_types)):
113+
raise TypeError(f'Items must be one of {self._protected_types}, got {type(value)}')
114+
if value is not self._data[idx] and value in self:
115+
warnings.warn(f'Item with unique name "{self._get_key(value)}" already in EasyList, it will be ignored')
116+
return
117+
self._data[idx] = value
118+
elif isinstance(idx, slice):
119+
if not isinstance(value, Iterable):
120+
raise TypeError('Value must be an iterable for slice assignment')
121+
replaced = self._data[idx]
122+
new_values = list(value)
123+
if len(new_values) != len(replaced):
124+
raise ValueError('Length of new values must match the length of the slice being replaced')
125+
for i, v in enumerate(new_values):
126+
if not isinstance(v, tuple(self._protected_types)):
127+
raise TypeError(f'Items must be one of {self._protected_types}, got {type(v)}')
128+
if v in self and replaced[i] is not v:
129+
warnings.warn(f'Item with unique name "{v.unique_name}" already in EasyList, it will be ignored')
130+
new_values[i] = replaced[i] # Keep the original value if the new one is a duplicate
131+
self._data[idx] = new_values
132+
else:
133+
raise TypeError('Index must be an int or slice')
134+
135+
def __delitem__(self, idx: int | slice | str) -> None:
136+
"""
137+
Delete an item by index, slice, or name.
138+
139+
:param idx: Index, slice, or name of item to delete
140+
"""
141+
if isinstance(idx, (int, slice)):
142+
del self._data[idx]
143+
elif isinstance(idx, str):
144+
for i, item in enumerate(self._data):
145+
if self._get_key(item) == idx:
146+
del self._data[i]
147+
return
148+
raise KeyError(f'No item with unique name "{idx}" found')
149+
else:
150+
raise TypeError('Index must be an int, slice, or str')
151+
152+
def __len__(self) -> int:
153+
"""Return the number of items in the collection."""
154+
return len(self._data)
155+
156+
def insert(self, index: int, value: ProtectedType_) -> None:
157+
"""
158+
Insert an item at an index.
159+
160+
:param index: Index to insert at
161+
:param value: Item to insert
162+
"""
163+
if not isinstance(index, int):
164+
raise TypeError('Index must be an integer')
165+
elif not isinstance(value, tuple(self._protected_types)):
166+
raise TypeError(f'Items must be one of {self._protected_types}, got {type(value)}')
167+
if value in self:
168+
warnings.warn(f'Item with unique name "{self._get_key(value)}" already in EasyList, it will be ignored')
169+
return
170+
self._data.insert(index, value)
171+
172+
def _get_key(self, obj) -> str:
173+
"""
174+
Get the unique name of an object.
175+
Can be overridden to use a different attribute as the key.
176+
:param object: Object to get the key for
177+
:return: The key of the object
178+
:rtype: str
179+
"""
180+
return obj.unique_name
181+
182+
# Overwriting methods
183+
184+
def __repr__(self) -> str:
185+
return f'{self.__class__.__name__} of length {len(self)} of type(s) {self._protected_types}'
186+
187+
def __contains__(self, item: ProtectedType_ | str) -> bool:
188+
if isinstance(item, str):
189+
return any(self._get_key(r) == item for r in self._data)
190+
return item in self._data
191+
192+
def __reversed__(self):
193+
return self._data.__reversed__()
194+
195+
def sort(self, key: Callable[[ProtectedType_], Any] = None, reverse: bool = False) -> None:
196+
"""
197+
Sort the collection according to the given key function.
198+
199+
:param key: Mapping function to sort by
200+
:param reverse: Whether to reverse the sort
201+
"""
202+
self._data.sort(reverse=reverse, key=key)
203+
204+
def index(self, value: ProtectedType_ | str, start: int = 0, stop: int = None) -> int:
205+
if stop is None:
206+
stop = len(self._data)
207+
if isinstance(value, str):
208+
for i in range(start, min(stop, len(self._data))):
209+
if self._get_key(self._data[i]) == value:
210+
return i
211+
raise ValueError(f'{value} is not in EasyList')
212+
return self._data.index(value, start, stop)
213+
214+
def pop(self, index: int | str = -1) -> ProtectedType_:
215+
"""
216+
Remove and return an item at the given index or unique_name.
217+
218+
:param index: Index or unique_name of the item to remove
219+
:return: The removed item
220+
"""
221+
if isinstance(index, int):
222+
return self._data.pop(index)
223+
elif isinstance(index, str):
224+
for i, item in enumerate(self._data):
225+
if self._get_key(item) == index:
226+
return self._data.pop(i)
227+
raise KeyError(f'No item with unique name "{index}" found')
228+
else:
229+
raise TypeError('Index must be an int or str')
230+
231+
# Serialization support
232+
233+
def to_dict(self) -> dict:
234+
"""
235+
Convert the EasyList to a dictionary for serialization.
236+
237+
:return: Dictionary representation of the EasyList
238+
"""
239+
dict_repr = super().to_dict()
240+
if self._protected_types != [NewBase]:
241+
dict_repr['protected_types'] = [
242+
{'@module': cls_.__module__, '@class': cls_.__name__} for cls_ in self._protected_types
243+
] # noqa: E501
244+
dict_repr['data'] = [item.to_dict() for item in self._data]
245+
return dict_repr
246+
247+
@classmethod
248+
def from_dict(cls, obj_dict: Dict[str, Any]) -> NewBase:
249+
"""
250+
Re-create an EasyScience object from a full encoded dictionary.
251+
252+
:param obj_dict: dictionary containing the serialized contents (from `SerializerDict`) of an EasyScience object
253+
:return: Reformed EasyScience object
254+
"""
255+
if not SerializerBase._is_serialized_easyscience_object(obj_dict):
256+
raise ValueError('Input must be a dictionary representing an EasyScience EasyList object.')
257+
temp_dict = copy.deepcopy(obj_dict) # Make a copy to avoid mutating the input
258+
if temp_dict['@class'] == cls.__name__:
259+
if 'protected_types' in temp_dict:
260+
protected_types = temp_dict.pop('protected_types')
261+
for i, type_dict in enumerate(protected_types):
262+
if '@module' in type_dict and '@class' in type_dict:
263+
modname = type_dict['@module']
264+
classname = type_dict['@class']
265+
mod = __import__(modname, globals(), locals(), [classname], 0)
266+
if hasattr(mod, classname):
267+
cls_ = getattr(mod, classname)
268+
protected_types[i] = cls_
269+
else:
270+
raise ImportError(f'Could not import class {classname} from module {modname}')
271+
else:
272+
raise ValueError(
273+
'Each protected type must be a serialized EasyScience class with @module and @class keys'
274+
) # noqa: E501
275+
else:
276+
protected_types = None
277+
kwargs = SerializerBase.deserialize_dict(temp_dict)
278+
data = kwargs.pop('data', [])
279+
return cls(data, protected_types=protected_types, **kwargs)
280+
else:
281+
raise ValueError(f'Class name in dictionary does not match the expected class: {cls.__name__}.')

0 commit comments

Comments
 (0)