Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion linopy/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np
import pandas as pd
import polars as pl
from numpy import arange, signedinteger
from numpy import arange, nan, signedinteger
from xarray import DataArray, Dataset, apply_ufunc, broadcast
from xarray import align as xr_align
from xarray.core import dtypes, indexing
Expand Down Expand Up @@ -1393,3 +1393,51 @@ def is_constant(x: SideLike) -> bool:
"Expected a constant, variable, or expression on the constraint side, "
f"got {type(x)}."
)


def series_to_lookup_array(s: pd.Series) -> np.ndarray:
"""
Convert an integer-indexed Series to a dense numpy lookup array.

Non-negative indices are placed at their corresponding positions;
negative indices are ignored. Gaps are filled with NaN.

Parameters
----------
s : pd.Series
Series with an integer index.

Returns
-------
np.ndarray
Dense array of length ``max(index) + 1``.
"""
max_idx = max(int(s.index.max()), 0)
arr = np.full(max_idx + 1, nan)
mask = s.index >= 0
arr[s.index[mask]] = s.values[mask]
return arr


def lookup_vals(arr: np.ndarray, idx: np.ndarray) -> np.ndarray:
"""
Look up values from a dense array by integer labels.

Negative labels and labels beyond the array length map to NaN.

Parameters
----------
arr : np.ndarray
Dense lookup array (e.g. from :func:`series_to_lookup_array`).
idx : np.ndarray
Integer label indices.

Returns
-------
np.ndarray
Array of looked-up values with the same shape as *idx*.
"""
valid = (idx >= 0) & (idx < len(arr))
vals = np.full(idx.shape, nan)
vals[valid] = arr[idx[valid]]
return vals
28 changes: 14 additions & 14 deletions linopy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@
assign_multiindex_safe,
best_int,
broadcast_mask,
lookup_vals,
maybe_replace_signs,
replace_by_map,
series_to_lookup_array,
set_int_index,
to_path,
)
Expand Down Expand Up @@ -1591,26 +1593,24 @@ def solve(
sol = set_int_index(sol)
sol.loc[-1] = nan

for name, var in self.variables.items():
idx = np.ravel(var.labels)
try:
vals = sol[idx].values.reshape(var.labels.shape)
except KeyError:
vals = sol.reindex(idx).values.reshape(var.labels.shape)
var.solution = xr.DataArray(vals, var.coords)
sol_arr = series_to_lookup_array(sol)

for _, var in self.variables.items():
vals = lookup_vals(sol_arr, np.ravel(var.labels))
var.solution = xr.DataArray(vals.reshape(var.labels.shape), var.coords)

if not result.solution.dual.empty:
dual = result.solution.dual.copy()
dual = set_int_index(dual)
dual.loc[-1] = nan

for name, con in self.constraints.items():
idx = np.ravel(con.labels)
try:
vals = dual[idx].values.reshape(con.labels.shape)
except KeyError:
vals = dual.reindex(idx).values.reshape(con.labels.shape)
con.dual = xr.DataArray(vals, con.labels.coords)
dual_arr = series_to_lookup_array(dual)

for _, con in self.constraints.items():
vals = lookup_vals(dual_arr, np.ravel(con.labels))
con.dual = xr.DataArray(
vals.reshape(con.labels.shape), con.labels.coords
)

return result.status.status.value, result.status.termination_condition.value
finally:
Expand Down
73 changes: 73 additions & 0 deletions test/test_solution_lookup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import numpy as np
import pandas as pd
from numpy import nan

from linopy.common import lookup_vals, series_to_lookup_array


class TestSeriesToLookupArray:
def test_basic(self) -> None:
s = pd.Series([10.0, 20.0, 30.0], index=pd.Index([0, 1, 2]))
arr = series_to_lookup_array(s)
np.testing.assert_array_equal(arr, [10.0, 20.0, 30.0])

def test_with_negative_index(self) -> None:
s = pd.Series([nan, 10.0, 20.0], index=pd.Index([-1, 0, 2]))
arr = series_to_lookup_array(s)
assert arr[0] == 10.0
assert np.isnan(arr[1])
assert arr[2] == 20.0

def test_sparse_index(self) -> None:
s = pd.Series([5.0, 7.0], index=pd.Index([0, 100]))
arr = series_to_lookup_array(s)
assert len(arr) == 101
assert arr[0] == 5.0
assert arr[100] == 7.0
assert np.isnan(arr[50])

def test_only_negative_index(self) -> None:
s = pd.Series([nan], index=pd.Index([-1]))
arr = series_to_lookup_array(s)
assert len(arr) == 1
assert np.isnan(arr[0])


class TestLookupVals:
def test_basic(self) -> None:
arr = np.array([10.0, 20.0, 30.0])
idx = np.array([0, 1, 2])
result = lookup_vals(arr, idx)
np.testing.assert_array_equal(result, [10.0, 20.0, 30.0])

def test_negative_labels_become_nan(self) -> None:
arr = np.array([10.0, 20.0])
idx = np.array([0, -1, 1, -1])
result = lookup_vals(arr, idx)
assert result[0] == 10.0
assert np.isnan(result[1])
assert result[2] == 20.0
assert np.isnan(result[3])

def test_out_of_range_labels_become_nan(self) -> None:
arr = np.array([10.0, 20.0])
idx = np.array([0, 1, 999])
result = lookup_vals(arr, idx)
assert result[0] == 10.0
assert result[1] == 20.0
assert np.isnan(result[2])

def test_all_negative(self) -> None:
arr = np.array([10.0])
idx = np.array([-1, -1, -1])
result = lookup_vals(arr, idx)
assert all(np.isnan(result))

def test_no_mutation_of_source(self) -> None:
arr = np.array([10.0, 20.0, 30.0])
idx1 = np.array([-1, 1])
idx2 = np.array([0, 2])
lookup_vals(arr, idx1)
result2 = lookup_vals(arr, idx2)
np.testing.assert_array_equal(result2, [10.0, 30.0])
np.testing.assert_array_equal(arr, [10.0, 20.0, 30.0])
Loading