Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ Release Notes
Upcoming Version
----------------

* When passing DataArray bounds to ``add_variables`` with explicit ``coords``, the ``coords`` parameter now defines the variable's coordinates. DataArray bounds are validated against these coords (raises ``ValueError`` on mismatch) and broadcast to missing dimensions. Previously, the ``coords`` parameter was silently ignored for DataArray inputs.

Version 0.6.4
--------------

Expand Down
86 changes: 81 additions & 5 deletions linopy/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import operator
import os
from collections.abc import Callable, Generator, Hashable, Iterable, Sequence
from collections.abc import Callable, Generator, Hashable, Iterable, Mapping, Sequence
from functools import partial, reduce, wraps
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload
Expand Down Expand Up @@ -128,6 +128,66 @@ def get_from_iterable(lst: DimsLike | None, index: int) -> Any | None:
return lst[index] if 0 <= index < len(lst) else None


def _coords_to_mapping(coords: CoordsLike, dims: DimsLike | None = None) -> Mapping:
"""
Normalize coords to a Mapping.

If coords is already a Mapping, return as-is. If it's a sequence,
convert to a dict using dims (if provided) or Index names.
"""
if isinstance(coords, Mapping):
return coords
seq = list(coords)
if dims is not None:
dim_names: list[str] = [dims] if isinstance(dims, str) else list(dims) # type: ignore[arg-type]
return dict(zip(dim_names, seq))
return {c.name: c for c in seq if hasattr(c, "name") and c.name}


def _validate_dataarray_coords(
arr: DataArray,
coords: CoordsLike | None = None,
dims: DimsLike | None = None,
allow_extra_dims: bool = False,
) -> DataArray:
"""
Validate a DataArray's coordinates against expected coords.

For shared dimensions, coordinates must match exactly. Extra
dimensions in the DataArray raise unless allow_extra_dims is True.
Missing dimensions are broadcast via expand_dims.
"""
if coords is None:
return arr

expected = _coords_to_mapping(coords, dims)
if not expected:
return arr

if not allow_extra_dims:
extra = set(arr.dims) - set(expected)
if extra:
raise ValueError(f"DataArray has extra dimensions not in coords: {extra}")

expand = {}
for k, v in expected.items():
if k not in arr.dims:
expand[k] = v
continue
actual = arr.coords[k]
v_idx = v if isinstance(v, pd.Index) else pd.Index(v)
if not actual.to_index().equals(v_idx):
raise ValueError(
f"Coordinates for dimension '{k}' do not match: "
f"expected {v_idx.tolist()}, got {actual.values.tolist()}"
)

if expand:
arr = arr.expand_dims(expand)

return arr


def pandas_to_dataarray(
arr: pd.DataFrame | pd.Series,
coords: CoordsLike | None = None,
Expand Down Expand Up @@ -163,8 +223,7 @@ def pandas_to_dataarray(
]
if coords is not None:
pandas_coords = dict(zip(dims, arr.axes))
if isinstance(coords, Sequence):
coords = dict(zip(dims, coords))
coords = _coords_to_mapping(coords, dims)
shared_dims = set(pandas_coords.keys()) & set(coords.keys())
non_aligned = []
for dim in shared_dims:
Expand Down Expand Up @@ -224,7 +283,7 @@ def numpy_to_dataarray(
dims = [get_from_iterable(dims, i) or f"dim_{i}" for i in range(ndim)]

if isinstance(coords, list) and dims is not None and len(dims):
coords = dict(zip(dims, coords))
coords = _coords_to_mapping(coords, dims)

return DataArray(arr, coords=coords, dims=dims, **kwargs)

Expand All @@ -233,6 +292,7 @@ def as_dataarray(
arr: Any,
coords: CoordsLike | None = None,
dims: DimsLike | None = None,
allow_extra_dims: bool = False,
**kwargs: Any,
) -> DataArray:
"""
Expand All @@ -246,13 +306,27 @@ def as_dataarray(
The coordinates for the DataArray. If None, default coordinates will be used.
dims (Union[list, None]):
The dimensions for the DataArray. If None, the dimensions will be automatically generated.
allow_extra_dims:
If False (default), raise ValueError when a DataArray input has
dimensions not present in coords. Set to True to allow extra
dimensions for broadcasting.
**kwargs:
Additional keyword arguments to be passed to the DataArray constructor.

Returns
-------
DataArray:
The converted DataArray.

Raises
------
ValueError
If arr is a DataArray and coords is provided: raised when
coordinates for shared dimensions do not match, or when the
DataArray has dimensions not covered by coords (unless
allow_extra_dims is True). The DataArray's dimensions may be
a subset of coords — missing dimensions are added via
expand_dims.
"""
if isinstance(arr, pd.Series | pd.DataFrame):
arr = pandas_to_dataarray(arr, coords=coords, dims=dims, **kwargs)
Expand All @@ -264,7 +338,9 @@ def as_dataarray(
arr = DataArray(float(arr), coords=coords, dims=dims, **kwargs)
elif isinstance(arr, int | float | str | bool | list):
arr = DataArray(arr, coords=coords, dims=dims, **kwargs)

elif isinstance(arr, DataArray):
if coords is not None:
arr = _validate_dataarray_coords(arr, coords, dims, allow_extra_dims)
elif not isinstance(arr, DataArray):
supported_types = [
np.number,
Expand Down
8 changes: 6 additions & 2 deletions linopy/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,9 @@ def _multiply_by_linear_expression(
def _multiply_by_constant(
self: GenericExpression, other: ConstantLike
) -> GenericExpression:
multiplier = as_dataarray(other, coords=self.coords, dims=self.coord_dims)
multiplier = as_dataarray(
other, coords=self.coords, dims=self.coord_dims, allow_extra_dims=True
)
coeffs = self.coeffs * multiplier
assert all(coeffs.sizes[d] == s for d, s in self.coeffs.sizes.items())
const = self.const * multiplier
Expand Down Expand Up @@ -1413,7 +1415,9 @@ def __matmul__(
Matrix multiplication with other, similar to xarray dot.
"""
if not isinstance(other, LinearExpression | variables.Variable):
other = as_dataarray(other, coords=self.coords, dims=self.coord_dims)
other = as_dataarray(
other, coords=self.coords, dims=self.coord_dims, allow_extra_dims=True
)

common_dims = list(set(self.coord_dims).intersection(other.dims))
return (self * other).sum(dim=common_dims)
Expand Down
24 changes: 17 additions & 7 deletions linopy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,10 +470,12 @@ def add_variables(
Upper bound of the variable(s). Ignored if `binary` is True.
The default is inf.
coords : list/xarray.Coordinates, optional
The coords of the variable array.
These are directly passed to the DataArray creation of
`lower` and `upper`. For every single combination of
coordinates a optimization variable is added to the model.
The coords of the variable array. For every single
combination of coordinates an optimization variable is
added to the model. Data for `lower`, `upper` and `mask`
is fitted to these coords: shared dimensions must have
matching coordinates, and missing dimensions are broadcast.
A ValueError is raised if the data is not compatible.
The default is None.
name : str, optional
Reference name of the added variables. The default None results in
Expand All @@ -495,7 +497,9 @@ def add_variables(
------
ValueError
If neither lower bound and upper bound have coordinates, nor
`coords` are directly given.
`coords` are directly given. Also raised if `lower` or
`upper` are DataArrays whose coordinates do not match the
provided `coords`.

Returns
-------
Expand Down Expand Up @@ -551,7 +555,10 @@ def add_variables(
self._check_valid_dim_names(data)

if mask is not None:
mask = as_dataarray(mask, coords=data.coords, dims=data.dims).astype(bool)
# TODO: Simplify when removing Future Warning from broadcast_mask
if not isinstance(mask, DataArray):
mask = as_dataarray(mask, coords=data.coords, dims=data.dims)
mask = mask.astype(bool)
mask = broadcast_mask(mask, data.labels)

# Auto-mask based on NaN in bounds (use numpy for speed)
Expand Down Expand Up @@ -746,7 +753,10 @@ def add_constraints(
(data,) = xr.broadcast(data, exclude=[TERM_DIM])

if mask is not None:
mask = as_dataarray(mask, coords=data.coords, dims=data.dims).astype(bool)
# TODO: Simplify when removing Future Warning from broadcast_mask
if not isinstance(mask, DataArray):
mask = as_dataarray(mask, coords=data.coords, dims=data.dims)
mask = mask.astype(bool)
mask = broadcast_mask(mask, data.labels)

# Auto-mask based on null expressions or NaN RHS (use numpy for speed)
Expand Down
4 changes: 3 additions & 1 deletion linopy/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,9 @@ def to_linexpr(
linopy.LinearExpression
Linear expression with the variables and coefficients.
"""
coefficient = as_dataarray(coefficient, coords=self.coords, dims=self.dims)
coefficient = as_dataarray(
coefficient, coords=self.coords, dims=self.dims, allow_extra_dims=True
)
ds = Dataset({"coeffs": coefficient, "vars": self.labels}).expand_dims(
TERM_DIM, -1
)
Expand Down
80 changes: 80 additions & 0 deletions test/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,86 @@ def test_as_dataarray_with_dataarray_default_dims_coords() -> None:
assert list(da_out.coords["dim2"].values) == list(da_in.coords["dim2"].values)


def test_as_dataarray_with_dataarray_coord_mismatch() -> None:
da_in = DataArray([1, 2, 3], dims=["x"], coords={"x": [10, 20, 30]})
with pytest.raises(ValueError, match="do not match"):
as_dataarray(da_in, coords={"x": [10, 20, 40]})


def test_as_dataarray_with_dataarray_extra_dims() -> None:
da_in = DataArray([[1, 2], [3, 4]], dims=["x", "y"])
with pytest.raises(ValueError, match="extra dimensions"):
as_dataarray(da_in, coords={"x": [0, 1]})


def test_as_dataarray_with_dataarray_extra_dims_allowed() -> None:
da_in = DataArray(
[[1, 2], [3, 4]],
dims=["x", "y"],
coords={"x": [0, 1], "y": [0, 1]},
)
da_out = as_dataarray(da_in, coords={"x": [0, 1]}, allow_extra_dims=True)
assert da_out.dims == da_in.dims
assert da_out.shape == da_in.shape


def test_as_dataarray_with_dataarray_broadcast() -> None:
da_in = DataArray([1, 2], dims=["x"], coords={"x": ["a", "b"]})
da_out = as_dataarray(
da_in, coords={"x": ["a", "b"], "y": [1, 2, 3]}, dims=["x", "y"]
)
assert set(da_out.dims) == {"x", "y"}
assert da_out.sizes["y"] == 3


def test_as_dataarray_with_dataarray_no_coords() -> None:
da_in = DataArray([1, 2, 3], dims=["x"], coords={"x": [10, 20, 30]})
da_out = as_dataarray(da_in)
assert_equal(da_out, da_in)


def test_as_dataarray_with_dataarray_sequence_coords() -> None:
da_in = DataArray([1, 2], dims=["x"], coords={"x": [0, 1]})
idx = pd.RangeIndex(2, name="x")
da_out = as_dataarray(da_in, coords=[idx], dims=["x"])
assert list(da_out.coords["x"].values) == [0, 1]


def test_as_dataarray_with_dataarray_sequence_coords_mismatch() -> None:
da_in = DataArray([1, 2], dims=["x"], coords={"x": [0, 1]})
idx = pd.RangeIndex(3, name="x")
with pytest.raises(ValueError, match="do not match"):
as_dataarray(da_in, coords=[idx], dims=["x"])


def test_add_variables_with_dataarray_bounds_and_coords() -> None:
model = Model()
time = pd.RangeIndex(5, name="time")
lower = DataArray([0, 0, 0, 0, 0], dims=["time"], coords={"time": range(5)})
var = model.add_variables(lower=lower, coords=[time], name="x")
assert var.shape == (5,)
assert list(var.data.coords["time"].values) == list(range(5))


def test_add_variables_with_dataarray_bounds_coord_mismatch() -> None:
model = Model()
time = pd.RangeIndex(5, name="time")
lower = DataArray([0, 0, 0], dims=["time"], coords={"time": [0, 1, 2]})
with pytest.raises(ValueError, match="do not match"):
model.add_variables(lower=lower, coords=[time], name="x")


def test_add_variables_with_dataarray_bounds_broadcast() -> None:
model = Model()
time = pd.RangeIndex(3, name="time")
space = pd.Index(["a", "b"], name="space")
lower = DataArray([0, 0, 0], dims=["time"], coords={"time": range(3)})
var = model.add_variables(lower=lower, coords=[time, space], name="x")
assert set(var.data.dims) == {"time", "space"}
assert var.data.sizes["time"] == 3
assert var.data.sizes["space"] == 2


def test_as_dataarray_with_unsupported_type() -> None:
with pytest.raises(TypeError):
as_dataarray(lambda x: 1, dims=["dim1"], coords=[["a"]])
Expand Down