Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 12 additions & 15 deletions src/spatialdata/_core/query/spatial_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from functools import singledispatch
from typing import TYPE_CHECKING, Any

import dask.array as da
import dask.dataframe as dd
import numpy as np
from dask.dataframe import DataFrame as DaskDataFrame
Expand Down Expand Up @@ -385,7 +384,7 @@ def _bounding_box_mask_points(
axes: tuple[str, ...],
min_coordinate: list[Number] | ArrayLike,
max_coordinate: list[Number] | ArrayLike,
) -> da.Array:
) -> list[ArrayLike]:
"""Compute a mask that is true for the points inside axis-aligned bounding boxes.

Parameters
Expand Down Expand Up @@ -427,12 +426,9 @@ def _bounding_box_mask_points(
continue
min_value = min_coordinate[box, axis_index]
max_value = max_coordinate[box, axis_index]
box_masks.append(
points[axis_name].gt(min_value).to_dask_array(lengths=True)
& points[axis_name].lt(max_value).to_dask_array(lengths=True)
)
bounding_box_mask = da.stack(box_masks, axis=-1)
in_bounding_box_masks.append(da.all(bounding_box_mask, axis=1))
box_masks.append(points[axis_name].gt(min_value).compute() & points[axis_name].lt(max_value).compute())
bounding_box_mask = np.stack(box_masks, axis=-1)
in_bounding_box_masks.append(np.all(bounding_box_mask, axis=1))
return in_bounding_box_masks


Expand Down Expand Up @@ -673,19 +669,20 @@ def _(
)

if not (len_df := len(in_intrinsic_bounding_box)) == (len_bb := len(min_coordinate)):
raise ValueError(f"Lenght of dataframe `{len_df}` is not equal to the number of bounding boxes `{len_bb}`.")
raise ValueError(
f"Length of list of dataframes `{len_df}` is not equal to the number of bounding boxes axes `{len_bb}`."
)
points_in_intrinsic_bounding_box: list[DaskDataFrame | None] = []
points_pd = points.compute()
attrs = points.attrs.copy()
for mask in in_intrinsic_bounding_box:
if mask.sum() == 0:
for mask_np in in_intrinsic_bounding_box:
if mask_np.sum() == 0:
points_in_intrinsic_bounding_box.append(None)
else:
# TODO there is a problem when mixing dask dataframe graph with dask array graph. Need to compute for now.
# we can't compute either mask or points as when we calculate either one of them
# test_query_points_multiple_partitions will fail as the mask will be used to index each partition.
# However, if we compute and then create the dask array again we get the mixed dask graph problem.
mask_np = mask.compute()
filtered_pd = points_pd[mask_np]
points_filtered = dd.from_pandas(filtered_pd, npartitions=points.npartitions)
points_filtered.attrs.update(attrs)
Expand Down Expand Up @@ -724,9 +721,9 @@ def _(
min_coordinate=min_c, # type: ignore[arg-type]
max_coordinate=max_c, # type: ignore[arg-type]
)
if len(bounding_box_mask) == 1:
bounding_box_mask = bounding_box_mask[0]
bounding_box_indices = np.where(bounding_box_mask.compute())[0]
if len(bounding_box_mask) != 1:
raise ValueError(f"Expected a single mask, got {len(bounding_box_mask)} masks. Please report this bug.")
bounding_box_indices = np.where(bounding_box_mask[0])[0]

if len(bounding_box_indices) == 0:
output.append(None)
Expand Down
Loading