diff --git a/CMakeLists.txt b/CMakeLists.txt index 769eb14b..ea5922a4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -58,7 +58,8 @@ set(MINIEXPR_BUILD_BENCH OFF CACHE BOOL "Build miniexpr benchmarks" FORCE) FetchContent_Declare(miniexpr GIT_REPOSITORY https://github.com/Blosc/miniexpr.git - GIT_TAG 77d633cb2c134552da045b8d2cc0ad23908e6b9e + GIT_TAG b4cfa9c2dc26772ad2126e6a611f93daf050915f + #SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../miniexpr ) FetchContent_MakeAvailable(miniexpr) diff --git a/bench/ndarray/dsl-kernel-bench.py b/bench/ndarray/dsl-kernel-bench.py new file mode 100644 index 00000000..35e6b36c --- /dev/null +++ b/bench/ndarray/dsl-kernel-bench.py @@ -0,0 +1,240 @@ +####################################################################### +# Copyright (c) 2019-present, Blosc Development Team +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause +####################################################################### + +import contextlib +import time + +import numpy as np + +import blosc2 +import importlib + +lazyexpr_mod = importlib.import_module("blosc2.lazyexpr") + + +@blosc2.dsl_kernel +def kernel_loop1(x, y): + acc = 0.0 + for i in range(1): + if i % 2 == 0: + tmp = np.where(x < y, y + i, x - i) + else: + tmp = np.where(x > y, x + i, y - i) + acc = acc + tmp * (i + 1) + return acc + + +@blosc2.dsl_kernel +def kernel_loop2(x, y): + acc = 0.0 + for i in range(2): + if i % 2 == 0: + tmp = np.where(x < y, y + i, x - i) + else: + tmp = np.where(x > y, x + i, y - i) + acc = acc + tmp * (i + 1) + return acc + + +@blosc2.dsl_kernel +def kernel_loop4(x, y): + acc = 0.0 + for i in range(4): + if i % 2 == 0: + tmp = np.where(x < y, y + i, x - i) + else: + tmp = np.where(x > y, x + i, y - i) + acc = acc + tmp * (i + 1) + return acc + + +@blosc2.dsl_kernel +def kernel_loop4_heavy(x, y): + acc = 0.0 + for i in range(4): + if i % 2 == 0: + tmp = np.where(x < y, y + i, x - i) + else: + tmp = np.where(x > y, x + i, y - i) + acc = acc + tmp * (i + 1) + (tmp * tmp) * 0.05 + return acc + + +@blosc2.dsl_kernel +def kernel_nested2(x, y): + acc = 0.0 + for i in range(2): + for j in range(2): + if (i + j) % 2 == 0: + tmp = np.where(x < y, y + i + j, x - i - j) + else: + tmp = np.where(x > y, x + i + j, y - i - j) + acc = acc + tmp * (i + j + 1) + return acc + + +def expr_for_steps(steps: int) -> str: + terms = [] + for i in range(steps): + if i % 2 == 0: + terms.append(f"where(x < y, y + {i}, x - {i}) * {i + 1}") + else: + terms.append(f"where(x > y, x + {i}, y - {i}) * {i + 1}") + return " + ".join(terms) + + +def expr_for_steps_heavy(steps: int) -> str: + terms = [] + for i in range(steps): + if i % 2 == 0: + term = f"where(x < y, y + {i}, x - {i})" + else: + term = f"where(x > y, x + {i}, y - {i})" + terms.append(f"{term} * {i + 1} + ({term} * {term}) * 0.05") + return " + ".join(terms) + + +def expr_nested2() -> str: + terms = [] + for i in range(2): + for j in range(2): + if (i + j) % 2 == 0: + term = f"where(x < y, y + {i + j}, x - {i + j})" + else: + term = f"where(x > y, x + {i + j}, y - {i + j})" + terms.append(f"{term} * {i + j + 1}") + return " + ".join(terms) + + +@contextlib.contextmanager +def miniexpr_enabled(enabled: bool): + old = lazyexpr_mod.try_miniexpr + lazyexpr_mod.try_miniexpr = enabled + try: + yield + finally: + lazyexpr_mod.try_miniexpr = old + + +def time_it(fn, niter=3): + best = None + for _ in range(niter): + t0 = time.perf_counter() + out = fn() + dt = time.perf_counter() - t0 + best = dt if best is None else min(best, dt) + return best, out + + +def bench_case(name, kernel, expr, a, b, dtype, gb): + if kernel.dsl_source is None: + raise RuntimeError(f"DSL extraction failed for {name}") + + with miniexpr_enabled(False): + lazy_expr_base = blosc2.lazyexpr(expr, {"x": a, "y": b}) + res_base = lazy_expr_base.compute() + base_time, _ = time_it(lambda: lazy_expr_base.compute()) + + with miniexpr_enabled(True): + lazy_expr_fast = blosc2.lazyexpr(expr, {"x": a, "y": b}) + _ = lazy_expr_fast.compute() + expr_time, _ = time_it(lambda: lazy_expr_fast.compute()) + + lazy_dsl = blosc2.lazyudf(kernel, (a, b), dtype=dtype) + res_dsl = lazy_dsl.compute() + dsl_time, _ = time_it(lambda: lazy_dsl.compute()) + + np.testing.assert_allclose(res_dsl[...], res_base[...], rtol=1e-5, atol=1e-6) + + return { + "case": name, + "baseline": base_time, + "lazyexpr": expr_time, + "dsl": dsl_time, + "baseline_gbps": gb / base_time, + "lazyexpr_gbps": gb / expr_time, + "dsl_gbps": gb / dsl_time, + } + + +def table_formatter(): + headers = [ + "Case", + "Base ms", + "Base GB/s", + "Expr ms", + "Expr GB/s", + "DSL ms", + "DSL GB/s", + "Expr/Base", + "DSL/Base", + ] + widths = [ + 12, + len(headers[1]), + len(headers[2]), + len(headers[3]), + len(headers[4]), + len(headers[5]), + len(headers[6]), + len(headers[7]), + len(headers[8]), + ] + align_right = {1, 2, 3, 4, 5, 6, 7, 8} + fmt_parts = [] + for i, w in enumerate(widths): + align = ">" if i in align_right else "<" + fmt_parts.append(f"{{:{align}{w}}}") + fmt = "|".join(fmt_parts) + sep = "+".join("-" * w for w in widths) + return headers, fmt, sep + + +def format_row(row): + base = row["baseline"] * 1000 + expr = row["lazyexpr"] * 1000 + dsl = row["dsl"] * 1000 + return [ + row["case"], + f"{base:.2f}", + f"{row['baseline_gbps']:.2f}", + f"{expr:.2f}", + f"{row['lazyexpr_gbps']:.2f}", + f"{dsl:.2f}", + f"{row['dsl_gbps']:.2f}", + f"{row['baseline'] / row['lazyexpr']:.2f}x", + f"{row['baseline'] / row['dsl']:.2f}x", + ] + + +def main(): + n = 10_000 + dtype = np.float32 + cparams = blosc2.CParams(codec=blosc2.Codec.BLOSCLZ, clevel=1) + + a = blosc2.linspace(0, 1, n * n, shape=(n, n), dtype=dtype, cparams=cparams) + b = blosc2.linspace(1, 0, n * n, shape=(n, n), dtype=dtype, cparams=cparams) + gb = a.nbytes * 3 / 1e9 + + cases = [ + ("loop1", kernel_loop1, expr_for_steps(1)), + ("loop2", kernel_loop2, expr_for_steps(2)), + ("loop4", kernel_loop4, expr_for_steps(4)), + ("loop4_heavy", kernel_loop4_heavy, expr_for_steps_heavy(4)), + ("nested2", kernel_nested2, expr_nested2()), + ] + + headers, fmt, sep = table_formatter() + print(fmt.format(*headers), flush=True) + print(sep, flush=True) + for name, kernel, expr in cases: + row = bench_case(name, kernel, expr, a, b, dtype, gb) + print(fmt.format(*format_row(row)), flush=True) + + +if __name__ == "__main__": + main() diff --git a/examples/ndarray/mandelbrot-dsl.ipynb b/examples/ndarray/mandelbrot-dsl.ipynb new file mode 100644 index 00000000..99cbfdb5 --- /dev/null +++ b/examples/ndarray/mandelbrot-dsl.ipynb @@ -0,0 +1,319 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "intro", + "metadata": {}, + "source": "# Mandelbrot With Blosc2 DSL vs Blosc2+Numba\n\nThis notebook compares two Blosc2-backed execution paths for Mandelbrot side-by-side:\n- `@blosc2.dsl_kernel` through `blosc2.lazyudf` (`blosc2+DSL`)\n- a Numba-compiled `lazyudf` kernel (`blosc2+numba`), following the pattern in `compute_udf_numba.py`\n\nThe previous native Numba implementation is moved earlier as a baseline and is still plotted for visual comparison.\n" + }, + { + "cell_type": "code", + "id": "imports", + "metadata": {}, + "source": [ + "import time\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from numba import njit, prange\n", + "\n", + "import blosc2" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "grid-setup", + "metadata": {}, + "source": [ + "# Problem size and Mandelbrot domain\n", + "WIDTH = 1200\n", + "HEIGHT = 800\n", + "MAX_ITER = 200\n", + "X_MIN, X_MAX = -2.0, 0.6\n", + "Y_MIN, Y_MAX = -1.1, 1.1\n", + "DTYPE = np.float32\n", + "\n", + "x = np.linspace(X_MIN, X_MAX, WIDTH, dtype=DTYPE)\n", + "y = np.linspace(Y_MIN, Y_MAX, HEIGHT, dtype=DTYPE)\n", + "cr_np, ci_np = np.meshgrid(x, y)\n", + "\n", + "# Keep compression overhead low for the timing comparison\n", + "cparams_fast = blosc2.CParams(codec=blosc2.Codec.LZ4, clevel=1)\n", + "cr_b2 = blosc2.asarray(cr_np, cparams=cparams_fast)\n", + "ci_b2 = blosc2.asarray(ci_np, cparams=cparams_fast)\n", + "\n", + "print(f\"grid: {cr_np.shape}, dtype: {cr_np.dtype}\")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "dsl-kernel", + "metadata": {}, + "source": [ + "@blosc2.dsl_kernel\n", + "def mandelbrot_dsl(cr, ci, max_iter):\n", + " # max_iter is passed per-call through lazyudf inputs\n", + " zr = 0.0\n", + " zi = 0.0\n", + " escape_iter = float(max_iter)\n", + " for it in range(max_iter):\n", + " zr2 = zr * zr\n", + " zi2 = zi * zi\n", + " mag2 = zr2 + zi2\n", + "\n", + " active = escape_iter == float(max_iter)\n", + " just_escaped = (mag2 > 4.0) & active\n", + " escape_iter = np.where(just_escaped, it, escape_iter)\n", + "\n", + " active = escape_iter == float(max_iter)\n", + " if np.all(active == 0):\n", + " break\n", + "\n", + " zr_new = zr2 - zi2 + cr\n", + " zi_new = 2.0 * zr * zi + ci\n", + " zr = np.where(active, zr_new, zr)\n", + " zi = np.where(active, zi_new, zi)\n", + "\n", + " return escape_iter\n", + "\n", + "\n", + "if mandelbrot_dsl.dsl_source is None:\n", + " raise RuntimeError(\"DSL extraction failed. Re-run this cell in a file-backed notebook session.\")\n", + "\n", + "print(mandelbrot_dsl.dsl_source)" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "numba-kernel", + "metadata": {}, + "source": [ + "@njit(parallel=True, fastmath=False)\n", + "def mandelbrot_numba_native(cr, ci, max_iter):\n", + " h, w = cr.shape\n", + " out = np.empty((h, w), dtype=np.float32)\n", + " for iy in prange(h):\n", + " for ix in range(w):\n", + " zr = np.float32(0.0)\n", + " zi = np.float32(0.0)\n", + " escape_iter = np.float32(max_iter)\n", + " c_re = cr[iy, ix]\n", + " c_im = ci[iy, ix]\n", + " for it in range(max_iter):\n", + " zr2 = zr * zr\n", + " zi2 = zi * zi\n", + " if zr2 + zi2 > np.float32(4.0):\n", + " escape_iter = np.float32(it)\n", + " break\n", + " zr_new = zr2 - zi2 + c_re\n", + " zi_new = np.float32(2.0) * zr * zi + c_im\n", + " zr = zr_new\n", + " zi = zi_new\n", + " out[iy, ix] = escape_iter\n", + " return out\n", + "\n", + "\n", + "@njit(parallel=True, fastmath=False)\n", + "def mandelbrot_numba_lazyudf(inputs_tuple, output, offset):\n", + " cr = inputs_tuple[0]\n", + " ci = inputs_tuple[1]\n", + " max_iter = np.int32(MAX_ITER)\n", + " h, w = output.shape\n", + " for iy in prange(h):\n", + " for ix in range(w):\n", + " zr = np.float32(0.0)\n", + " zi = np.float32(0.0)\n", + " escape_iter = np.float32(max_iter)\n", + " c_re = cr[iy, ix]\n", + " c_im = ci[iy, ix]\n", + " for it in range(max_iter):\n", + " zr2 = zr * zr\n", + " zi2 = zi * zi\n", + " if zr2 + zi2 > np.float32(4.0):\n", + " escape_iter = np.float32(it)\n", + " break\n", + " zr_new = zr2 - zi2 + c_re\n", + " zi_new = np.float32(2.0) * zr * zi + c_im\n", + " zr = zr_new\n", + " zi = zi_new\n", + " output[iy, ix] = escape_iter" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "benchmark", + "metadata": {}, + "source": [ + "def best_time(func, repeats=3, warmup=1):\n", + " for _ in range(warmup):\n", + " func()\n", + " best = float(\"inf\")\n", + " best_out = None\n", + " for _ in range(repeats):\n", + " t0 = time.perf_counter()\n", + " out = func()\n", + " dt = time.perf_counter() - t0\n", + " if dt < best:\n", + " best = dt\n", + " best_out = out\n", + " return best, best_out\n", + "\n", + "\n", + "def run_numba_native():\n", + " return mandelbrot_numba_native(cr_np, ci_np, MAX_ITER)\n", + "\n", + "\n", + "def run_blosc2_numba():\n", + " lazy = blosc2.lazyudf(mandelbrot_numba_lazyudf, (cr_b2, ci_b2), dtype=np.float32, cparams=cparams_fast)\n", + " return lazy.compute()[...]\n", + "\n", + "\n", + "def run_dsl():\n", + " lazy = blosc2.lazyudf(mandelbrot_dsl, (cr_b2, ci_b2, MAX_ITER), dtype=np.float32, cparams=cparams_fast)\n", + " return lazy.compute()[...]\n", + "\n", + "\n", + "# Measure first iteration (includes one-time overhead, especially Numba JIT compile)\n", + "t0 = time.perf_counter()\n", + "_ = run_numba_native()\n", + "t_numba_native_first = time.perf_counter() - t0\n", + "\n", + "t0 = time.perf_counter()\n", + "_ = run_blosc2_numba()\n", + "t_b2_numba_first = time.perf_counter() - t0\n", + "\n", + "t0 = time.perf_counter()\n", + "_ = run_dsl()\n", + "t_dsl_first = time.perf_counter() - t0\n", + "\n", + "\n", + "t_numba_native, img_numba_native = best_time(run_numba_native, repeats=5, warmup=1)\n", + "t_b2_numba, img_b2_numba = best_time(run_blosc2_numba, repeats=3, warmup=1)\n", + "t_dsl, img_dsl = best_time(run_dsl, repeats=3, warmup=1)\n", + "\n", + "max_abs_diff_dsl_vs_b2_numba = float(np.max(np.abs(img_dsl - img_b2_numba)))\n", + "max_abs_diff_native_vs_b2_numba = float(np.max(np.abs(img_numba_native - img_b2_numba)))\n", + "max_abs_diff_native_vs_dsl = float(np.max(np.abs(img_numba_native - img_dsl)))\n", + "\n", + "print(\"First iteration timings (one-time overhead included):\")\n", + "print(f\"Native numba first run (baseline): {t_numba_native_first:.6f} s\")\n", + "print(f\"Blosc2+numba first run: {t_b2_numba_first:.6f} s\")\n", + "print(f\"Blosc2+DSL first run: {t_dsl_first:.6f} s\")\n", + "\n", + "print(\"\\nBest-time stats:\")\n", + "print(f\"Native numba time (best): {t_numba_native:.6f} s\")\n", + "print(f\"Blosc2+numba time (best): {t_b2_numba:.6f} s\")\n", + "print(f\"Blosc2+DSL time (best): {t_dsl:.6f} s\")\n", + "print(f\"Blosc2+numba / native: {t_b2_numba / t_numba_native:.2f}x\")\n", + "print(f\"Blosc2+DSL / native: {t_dsl / t_numba_native:.2f}x\")\n", + "print(f\"Blosc2+DSL / Blosc2+numba: {t_dsl / t_b2_numba:.2f}x\")\n", + "print(f\"max |dsl-b2_numba|: {max_abs_diff_dsl_vs_b2_numba:.6f}\")\n", + "print(f\"max |native-b2_numba|: {max_abs_diff_native_vs_b2_numba:.6f}\")\n", + "print(f\"max |native-dsl|: {max_abs_diff_native_vs_dsl:.6f}\")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "plot", + "metadata": {}, + "source": [ + "fig, ax = plt.subplots(1, 2, figsize=(13, 5), constrained_layout=True)\n", + "\n", + "im0 = ax[0].imshow(\n", + " img_b2_numba,\n", + " cmap=\"magma\",\n", + " extent=(X_MIN, X_MAX, Y_MIN, Y_MAX),\n", + " origin=\"lower\",\n", + ")\n", + "ax[0].set_title(\"Mandelbrot (Blosc2+Numba)\")\n", + "ax[0].set_xlabel(\"Re(c)\")\n", + "ax[0].set_ylabel(\"Im(c)\")\n", + "fig.colorbar(im0, ax=ax[0], shrink=0.82, label=\"Escape iteration\")\n", + "\n", + "im1 = ax[1].imshow(\n", + " img_dsl,\n", + " cmap=\"magma\",\n", + " extent=(X_MIN, X_MAX, Y_MIN, Y_MAX),\n", + " origin=\"lower\",\n", + ")\n", + "ax[1].set_title(\"Mandelbrot (Blosc2+DSL)\")\n", + "ax[1].set_xlabel(\"Re(c)\")\n", + "ax[1].set_ylabel(\"Im(c)\")\n", + "fig.colorbar(im1, ax=ax[1], shrink=0.82, label=\"Escape iteration\")\n", + "\n", + "plt.show()\n", + "\n", + "fig, ax = plt.subplots(1, 1, figsize=(6.4, 5), constrained_layout=True)\n", + "im_base = ax.imshow(\n", + " img_numba_native,\n", + " cmap=\"magma\",\n", + " extent=(X_MIN, X_MAX, Y_MIN, Y_MAX),\n", + " origin=\"lower\",\n", + ")\n", + "ax.set_title(\"Mandelbrot (Native Numba baseline)\")\n", + "ax.set_xlabel(\"Re(c)\")\n", + "ax.set_ylabel(\"Im(c)\")\n", + "fig.colorbar(im_base, ax=ax, shrink=0.82, label=\"Escape iteration\")\n", + "\n", + "plt.show()" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "timing-bars", + "metadata": {}, + "source": [ + "labels = [\"Native Numba (baseline)\", \"Blosc2+Numba\", \"Blosc2+DSL\"]\n", + "first_times = [t_numba_native_first, t_b2_numba_first, t_dsl_first]\n", + "best_times = [t_numba_native, t_b2_numba, t_dsl]\n", + "\n", + "x = np.arange(len(labels))\n", + "width = 0.36\n", + "\n", + "fig, ax = plt.subplots(figsize=(9, 5), constrained_layout=True)\n", + "ax.bar(x - width / 2, first_times, width, label=\"First run\", color=\"#4C78A8\")\n", + "ax.bar(x + width / 2, best_times, width, label=\"Best run\", color=\"#F58518\")\n", + "\n", + "ax.set_xticks(x)\n", + "ax.set_xticklabels(labels)\n", + "ax.set_ylabel(\"Time (seconds)\")\n", + "ax.set_title(\"Mandelbrot Timings: Native Baseline vs Blosc2 Pipelines\")\n", + "ax.legend()\n", + "\n", + "for i, t in enumerate(first_times):\n", + " ax.text(i - width / 2, t, f\"{t:.3f}s\", ha=\"center\", va=\"bottom\")\n", + "for i, t in enumerate(best_times):\n", + " ax.text(i + width / 2, t, f\"{t:.3f}s\", ha=\"center\", va=\"bottom\")\n", + "\n", + "plt.show()" + ], + "outputs": [], + "execution_count": null + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/blosc2/__init__.py b/src/blosc2/__init__.py index c9c3ab36..896cd553 100644 --- a/src/blosc2/__init__.py +++ b/src/blosc2/__init__.py @@ -465,6 +465,7 @@ def _raise(exc): from .c2array import c2context, C2Array, URLPath +from .dsl_kernel import DSLKernel, dsl_kernel from .lazyexpr import ( LazyExpr, lazyudf, @@ -645,6 +646,7 @@ def _raise(exc): "EmbedStore", "Filter", "LazyArray", + "DSLKernel", "LazyExpr", "LazyUDF", "NDArray", @@ -760,6 +762,7 @@ def _raise(exc): "isnan", "jit", "lazyexpr", + "dsl_kernel", "lazyudf", "lazywhere", "less", diff --git a/src/blosc2/dsl_kernel.py b/src/blosc2/dsl_kernel.py new file mode 100644 index 00000000..98221c4b --- /dev/null +++ b/src/blosc2/dsl_kernel.py @@ -0,0 +1,742 @@ +####################################################################### +# Copyright (c) 2019-present, Blosc Development Team +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause +####################################################################### + +from __future__ import annotations + +import ast +import contextlib +import inspect +import os +import textwrap +from typing import ClassVar + +_PRINT_DSL_KERNEL = os.environ.get("PRINT_DSL_KERNEL", "").strip().lower() +_PRINT_DSL_KERNEL = _PRINT_DSL_KERNEL not in ("", "0", "false", "no", "off") + + +def _normalize_dsl_scalar(value): + # NumPy scalar-like values expose .item(); plain Python scalars do not. + if hasattr(value, "item") and callable(value.item): + with contextlib.suppress(Exception): + value = value.item() + if isinstance(value, bool): + return int(value) + if isinstance(value, int | float): + return value + raise TypeError("Unsupported scalar type for DSL miniexpr specialization") + + +class _DSLScalarSpecializer(ast.NodeTransformer): + def __init__(self, replacements: dict[str, int | float]): + self.replacements = replacements + + def visit_Name(self, node): + if isinstance(node.ctx, ast.Load) and node.id in self.replacements: + return ast.copy_location(ast.Constant(value=self.replacements[node.id]), node) + return node + + def visit_Call(self, node): + node = self.generic_visit(node) + if ( + isinstance(node.func, ast.Name) + and node.func.id in {"float", "int"} + and len(node.args) == 1 + and not node.keywords + and isinstance(node.args[0], ast.Constant) + and isinstance(node.args[0].value, int | float | bool) + ): + folded = float(node.args[0].value) if node.func.id == "float" else int(node.args[0].value) + return ast.copy_location(ast.Constant(value=folded), node) + return node + + +def specialize_dsl_miniexpr_inputs(expr_string: str, operands: dict): + """Inline scalar DSL operands as constants for miniexpr compilation.""" + scalar_replacements = {} + array_operands = {} + for name, value in operands.items(): + if hasattr(value, "shape") and value.shape == (): + scalar_replacements[name] = _normalize_dsl_scalar(value[()]) + continue + if isinstance(value, int | float | bool) or (hasattr(value, "item") and callable(value.item)): + try: + scalar_replacements[name] = _normalize_dsl_scalar(value) + continue + except TypeError: + pass + array_operands[name] = value + + if not scalar_replacements: + return expr_string, operands + + tree = ast.parse(expr_string) + tree = _DSLScalarSpecializer(scalar_replacements).visit(tree) + for node in tree.body: + if isinstance(node, ast.FunctionDef): + node.args.posonlyargs = [a for a in node.args.posonlyargs if a.arg not in scalar_replacements] + node.args.args = [a for a in node.args.args if a.arg not in scalar_replacements] + ast.fix_missing_locations(tree) + return ast.unparse(tree), array_operands + + +class DSLKernel: + """Wrap a Python function and optionally extract a miniexpr DSL kernel from it.""" + + def __init__(self, func): + self.func = func + self.__name__ = getattr(func, "__name__", self.__class__.__name__) + self.__qualname__ = getattr(func, "__qualname__", self.__name__) + self.__doc__ = getattr(func, "__doc__", None) + try: + sig = inspect.signature(func) + except (TypeError, ValueError): + sig = None + self._sig = sig + self._sig_has_varargs = False + self._sig_npositional = None + self._legacy_udf_signature = False + if sig is not None: + params = list(sig.parameters.values()) + positional_params = [p for p in params if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)] + self._sig_has_varargs = any(p.kind == p.VAR_POSITIONAL for p in params) + self._sig_npositional = len(positional_params) + # Preserve support for classic lazyudf signature: (inputs_tuple, output, offset) + if not self._sig_has_varargs and len(positional_params) == 3: + p2 = positional_params[1].name.lower() + p3 = positional_params[2].name.lower() + self._legacy_udf_signature = p2 in {"output", "out"} and p3 == "offset" + self.dsl_source = None + self.input_names = None + try: + dsl_source, input_names = self._extract_dsl(func) + except Exception: + dsl_source = None + input_names = None + self.dsl_source = dsl_source + self.input_names = input_names + + def _extract_dsl(self, func): + source = inspect.getsource(func) + source = textwrap.dedent(source) + tree = ast.parse(source) + func_node = None + for node in tree.body: + if isinstance(node, ast.FunctionDef) and node.name == func.__name__: + func_node = node + break + if func_node is None: + for node in tree.body: + if isinstance(node, ast.FunctionDef): + func_node = node + break + if func_node is None: + raise ValueError("No function definition found for DSL extraction") + + dsl_source_full = None + if _PRINT_DSL_KERNEL: + try: + dsl_source_full = _DSLBuilder().build(func_node) + func_name = getattr(func, "__name__", "") + print(f"[DSLKernel:{func_name}] dsl_source (full):") + print(dsl_source_full[0]) + except Exception as exc: + func_name = getattr(func, "__name__", "") + print(f"[DSLKernel:{func_name}] dsl_source (full) failed: {exc}") + + reducer = _DSLReducer() + reduced = reducer.reduce(func_node) + if reduced is not None: + if _PRINT_DSL_KERNEL: + func_name = getattr(func, "__name__", "") + print(f"[DSLKernel:{func_name}] reduced_expr:") + print(reduced[0]) + return reduced + + if dsl_source_full is not None: + return dsl_source_full + + builder = _DSLBuilder() + return builder.build(func_node) + + def __call__(self, inputs_tuple, output, offset=None): + if self._legacy_udf_signature: + return self.func(inputs_tuple, output, offset) + + n_inputs = len(inputs_tuple) + if self._sig is not None and ( + self._sig_npositional in (n_inputs, n_inputs + 1) or self._sig_has_varargs + ): + if self._sig_npositional == n_inputs + 1: + result = self.func(*inputs_tuple, offset) + else: + result = self.func(*inputs_tuple) + output[...] = result + return None + + try: + return self.func(inputs_tuple, output, offset) + except TypeError: + result = self.func(*inputs_tuple) + output[...] = result + return None + + +def dsl_kernel(func): + """Decorator to wrap a function in a DSLKernel.""" + + return DSLKernel(func) + + +class _DSLBuilder: + _binop_map: ClassVar[dict[type[ast.operator], str]] = { + ast.Add: "+", + ast.Sub: "-", + ast.Mult: "*", + ast.Div: "/", + ast.FloorDiv: "//", + ast.Mod: "%", + ast.Pow: "**", + ast.BitAnd: "&", + ast.BitOr: "|", + ast.BitXor: "^", + ast.LShift: "<<", + ast.RShift: ">>", + } + + _cmp_map: ClassVar[dict[type[ast.cmpop], str]] = { + ast.Eq: "==", + ast.NotEq: "!=", + ast.Lt: "<", + ast.LtE: "<=", + ast.Gt: ">", + ast.GtE: ">=", + } + + def __init__(self): + self._lines = [] + + def build(self, func_node: ast.FunctionDef): + input_names = self._args(func_node.args) + self._emit(f"def {func_node.name}({', '.join(input_names)}):", 0) + if not func_node.body: + raise ValueError("DSL kernel must have a body") + for stmt in func_node.body: + self._stmt(stmt, 4) + return "\n".join(self._lines), input_names + + def _emit(self, line: str, indent: int): + self._lines.append(" " * indent + line) + + def _args(self, args: ast.arguments): + if args.vararg or args.kwarg or args.kwonlyargs: + raise ValueError("DSL kernel does not support *args/**kwargs/kwonly args") + if args.defaults or args.kw_defaults: + raise ValueError("DSL kernel does not support default arguments") + names = [a.arg for a in (args.posonlyargs + args.args)] + if not names: + raise ValueError("DSL kernel must accept at least one argument") + return names + + def _stmt(self, node: ast.stmt, indent: int): + if isinstance(node, ast.Assign): + if len(node.targets) != 1 or not isinstance(node.targets[0], ast.Name): + raise ValueError("Only simple assignments are supported in DSL kernels") + target = node.targets[0].id + value = self._expr(node.value) + self._emit(f"{target} = {value}", indent) + return + if isinstance(node, ast.AugAssign): + if not isinstance(node.target, ast.Name): + raise ValueError("Only simple augmented assignments are supported") + target = node.target.id + op = self._binop(node.op) + value = self._expr(node.value) + self._emit(f"{target} = {target} {op} {value}", indent) + return + if isinstance(node, ast.Return): + if node.value is None: + raise ValueError("DSL kernel return must have a value") + value = self._expr(node.value) + self._emit(f"return {value}", indent) + return + if isinstance(node, ast.Expr): + value = self._expr(node.value) + self._emit(value, indent) + return + if isinstance(node, ast.If): + self._if_stmt(node, indent) + return + if isinstance(node, ast.For): + self._for_stmt(node, indent) + return + if isinstance(node, ast.Break): + self._emit("break", indent) + return + if isinstance(node, ast.Continue): + self._emit("continue", indent) + return + raise ValueError(f"Unsupported DSL statement: {type(node).__name__}") + + def _stmt_block(self, body, indent: int): + if not body: + raise ValueError("Empty blocks are not supported in DSL kernels") + i = 0 + while i < len(body): + stmt = body[i] + if ( + isinstance(stmt, ast.If) + and not stmt.orelse + and self._block_terminates(stmt.body) + and i + 1 < len(body) + and isinstance(body[i + 1], ast.If) + ): + merged = ast.If(test=stmt.test, body=stmt.body, orelse=[body[i + 1]]) + self._if_stmt(merged, indent) + i += 2 + continue + self._stmt(stmt, indent) + i += 1 + + def _block_terminates(self, body) -> bool: + if not body: + return False + return self._stmt_terminates(body[-1]) + + def _stmt_terminates(self, node: ast.stmt) -> bool: + if isinstance(node, (ast.Return, ast.Break, ast.Continue)): + return True + if isinstance(node, ast.If) and node.orelse: + return self._block_terminates(node.body) and self._block_terminates(node.orelse) + return False + + def _if_stmt(self, node: ast.If, indent: int): + current = node + first = True + while True: + prefix = "if" if first else "elif" + cond = self._expr(current.test) + self._emit(f"{prefix} {cond}:", indent) + self._stmt_block(current.body, indent + 4) + first = False + if current.orelse and len(current.orelse) == 1 and isinstance(current.orelse[0], ast.If): + current = current.orelse[0] + continue + break + if current.orelse: + self._emit("else:", indent) + self._stmt_block(current.orelse, indent + 4) + + def _for_stmt(self, node: ast.For, indent: int): + if node.orelse: + raise ValueError("for/else is not supported in DSL kernels") + if not isinstance(node.target, ast.Name): + raise ValueError("DSL for-loop target must be a simple name") + if not isinstance(node.iter, ast.Call): + raise ValueError("DSL for-loop must iterate over range()") + func_name = self._call_name(node.iter.func) + if func_name != "range": + raise ValueError("DSL for-loop must iterate over range()") + if node.iter.keywords or len(node.iter.args) != 1: + raise ValueError("DSL range() must take a single argument") + limit = self._expr(node.iter.args[0]) + self._emit(f"for {node.target.id} in range({limit}):", indent) + self._stmt_block(node.body, indent + 4) + + def _expr(self, node: ast.AST) -> str: # noqa: C901 + if isinstance(node, ast.Name): + return node.id + if isinstance(node, ast.Constant): + val = node.value + if isinstance(val, bool): + return "1" if val else "0" + if isinstance(val, int | float): + return repr(val) + raise ValueError("Unsupported constant in DSL expression") + if isinstance(node, ast.UnaryOp): + if isinstance(node.op, ast.UAdd): + return f"+{self._expr(node.operand)}" + if isinstance(node.op, ast.USub): + return f"-{self._expr(node.operand)}" + if isinstance(node.op, ast.Not): + return f"!{self._expr(node.operand)}" + raise ValueError("Unsupported unary operator in DSL expression") + if isinstance(node, ast.BinOp): + left = self._expr(node.left) + right = self._expr(node.right) + op = self._binop(node.op) + return f"({left} {op} {right})" + if isinstance(node, ast.BoolOp): + op = "&" if isinstance(node.op, ast.And) else "|" + values = [self._expr(v) for v in node.values] + expr = values[0] + for val in values[1:]: + expr = f"({expr} {op} {val})" + return expr + if isinstance(node, ast.Compare): + if len(node.ops) != 1 or len(node.comparators) != 1: + raise ValueError("Chained comparisons are not supported in DSL") + left = self._expr(node.left) + right = self._expr(node.comparators[0]) + op = self._cmpop(node.ops[0]) + return f"({left} {op} {right})" + if isinstance(node, ast.Call): + func_name = self._call_name(node.func) + if node.keywords: + raise ValueError("Keyword arguments are not supported in DSL calls") + args = ", ".join(self._expr(a) for a in node.args) + return f"{func_name}({args})" + if isinstance(node, ast.IfExp): + cond = self._expr(node.test) + body = self._expr(node.body) + orelse = self._expr(node.orelse) + return f"where({cond}, {body}, {orelse})" + raise ValueError(f"Unsupported DSL expression: {type(node).__name__}") + + def _call_name(self, node: ast.AST) -> str: + if isinstance(node, ast.Name): + return node.id + if ( + isinstance(node, ast.Attribute) + and isinstance(node.value, ast.Name) + and node.value.id in {"np", "numpy", "math"} + ): + return node.attr + raise ValueError("Unsupported call target in DSL") + + def _binop(self, op: ast.operator) -> str: + for k, v in self._binop_map.items(): + if isinstance(op, k): + return v + raise ValueError("Unsupported binary operator in DSL") + + def _cmpop(self, op: ast.cmpop) -> str: + for k, v in self._cmp_map.items(): + if isinstance(op, k): + return v + raise ValueError("Unsupported comparison in DSL") + + +class _DSLReducer: + _binop_map: ClassVar[dict[type[ast.operator], str]] = _DSLBuilder._binop_map + _cmp_map: ClassVar[dict[type[ast.cmpop], str]] = _DSLBuilder._cmp_map + + def __init__(self, max_unroll: int = 64): + self._env: dict[str, str] = {} + self._const_env: dict[str, object] = {} + self._return_expr: str | None = None + self._max_unroll = max_unroll + + def reduce(self, func_node: ast.FunctionDef): + input_names = self._args(func_node.args) + if not func_node.body: + return None + for stmt in func_node.body: + if not self._stmt(stmt): + return None + if self._return_expr is not None: + break + if self._return_expr is None: + return None + return self._return_expr, input_names + + def _args(self, args: ast.arguments): + if args.vararg or args.kwarg or args.kwonlyargs: + raise ValueError("DSL kernel does not support *args/**kwargs/kwonly args") + if args.defaults or args.kw_defaults: + raise ValueError("DSL kernel does not support default arguments") + names = [a.arg for a in (args.posonlyargs + args.args)] + if not names: + raise ValueError("DSL kernel must accept at least one argument") + return names + + def _stmt(self, node: ast.stmt) -> bool: # noqa: C901 + if isinstance(node, ast.Assign): + if len(node.targets) != 1 or not isinstance(node.targets[0], ast.Name): + return False + target = node.targets[0].id + value = self._expr(node.value) + self._env[target] = value + const_val = self._const_eval(node.value) + if const_val is None: + self._const_env.pop(target, None) + else: + self._const_env[target] = const_val + return True + if isinstance(node, ast.AugAssign): + if not isinstance(node.target, ast.Name): + return False + target = node.target.id + op = self._binop(node.op) + value = self._expr(node.value) + left = self._env.get(target, target) + left_const = self._const_env.get(target) + right_const = self._const_eval(node.value) + simplified = self._simplify_binop_expr(op, left, value, left_const, right_const) + self._env[target] = simplified + if left_const is None or right_const is None: + self._const_env.pop(target, None) + else: + self._const_env[target] = self._apply_binop(left_const, right_const, node.op) + return True + if isinstance(node, ast.Return): + if node.value is None: + return False + self._return_expr = self._expr(node.value) + return True + if isinstance(node, ast.If): + test_val = self._const_eval(node.test) + if test_val is None: + return False + branch = node.body if bool(test_val) else node.orelse + if not branch: + return True + for stmt in branch: + if not self._stmt(stmt): + return False + if self._return_expr is not None: + return True + return True + if isinstance(node, ast.For): + if node.orelse: + return False + if not isinstance(node.target, ast.Name): + return False + if not isinstance(node.iter, ast.Call): + return False + func_name = self._call_name(node.iter.func) + if func_name != "range": + return False + if node.iter.keywords or len(node.iter.args) != 1: + return False + limit_val = self._const_eval(node.iter.args[0]) + if limit_val is None or not isinstance(limit_val, int): + return False + if limit_val < 0 or limit_val > self._max_unroll: + return False + loop_var = node.target.id + old_env = self._env.get(loop_var) + old_const = self._const_env.get(loop_var) + for i in range(limit_val): + self._env[loop_var] = str(i) + self._const_env[loop_var] = i + for stmt in node.body: + if not self._stmt(stmt): + if old_env is None: + self._env.pop(loop_var, None) + else: + self._env[loop_var] = old_env + if old_const is None: + self._const_env.pop(loop_var, None) + else: + self._const_env[loop_var] = old_const + return False + if self._return_expr is not None: + break + if self._return_expr is not None: + break + if old_env is None: + self._env.pop(loop_var, None) + else: + self._env[loop_var] = old_env + if old_const is None: + self._const_env.pop(loop_var, None) + else: + self._const_env[loop_var] = old_const + return True + return False + + def _expr(self, node: ast.AST) -> str: # noqa: C901 + const_val = self._const_eval(node) + if const_val is not None: + if isinstance(const_val, bool): + return "1" if const_val else "0" + return repr(const_val) + if isinstance(node, ast.Name): + if node.id in self._env: + val = self._env[node.id] + # Avoid double-wrapping if already parenthesized or is a function call + if (val.startswith("(") and val.endswith(")")) or "(" in val: + return val + return f"({val})" + return node.id + if isinstance(node, ast.Constant): + val = node.value + if isinstance(val, bool): + return "1" if val else "0" + if isinstance(val, int | float): + return repr(val) + raise ValueError("Unsupported constant in DSL expression") + if isinstance(node, ast.UnaryOp): + if isinstance(node.op, ast.UAdd): + return f"+{self._expr(node.operand)}" + if isinstance(node.op, ast.USub): + return f"-{self._expr(node.operand)}" + if isinstance(node.op, ast.Not): + return f"!{self._expr(node.operand)}" + raise ValueError("Unsupported unary operator in DSL expression") + if isinstance(node, ast.BinOp): + left = self._expr(node.left) + right = self._expr(node.right) + op = self._binop(node.op) + left_const = self._const_eval(node.left) + right_const = self._const_eval(node.right) + return self._simplify_binop_expr(op, left, right, left_const, right_const) + if isinstance(node, ast.BoolOp): + op = "&" if isinstance(node.op, ast.And) else "|" + values = [self._expr(v) for v in node.values] + expr = values[0] + for val in values[1:]: + expr = f"({expr} {op} {val})" + return expr + if isinstance(node, ast.Compare): + if len(node.ops) != 1 or len(node.comparators) != 1: + raise ValueError("Chained comparisons are not supported in DSL") + left = self._expr(node.left) + right = self._expr(node.comparators[0]) + op = self._cmpop(node.ops[0]) + return f"({left} {op} {right})" + if isinstance(node, ast.Call): + func_name = self._call_name(node.func) + if node.keywords: + raise ValueError("Keyword arguments are not supported in DSL calls") + args = ", ".join(self._expr(a) for a in node.args) + return f"{func_name}({args})" + if isinstance(node, ast.IfExp): + cond = self._expr(node.test) + body = self._expr(node.body) + orelse = self._expr(node.orelse) + return f"where({cond}, {body}, {orelse})" + raise ValueError(f"Unsupported DSL expression: {type(node).__name__}") + + def _call_name(self, node: ast.AST) -> str: + if isinstance(node, ast.Name): + return node.id + if ( + isinstance(node, ast.Attribute) + and isinstance(node.value, ast.Name) + and node.value.id in {"np", "numpy", "math"} + ): + return node.attr + raise ValueError("Unsupported call target in DSL") + + def _binop(self, op: ast.operator) -> str: + for k, v in self._binop_map.items(): + if isinstance(op, k): + return v + raise ValueError("Unsupported binary operator in DSL") + + def _cmpop(self, op: ast.cmpop) -> str: + for k, v in self._cmp_map.items(): + if isinstance(op, k): + return v + raise ValueError("Unsupported comparison in DSL") + + def _const_eval(self, node: ast.AST): # noqa: C901 + if isinstance(node, ast.Constant): + if isinstance(node.value, int | float | bool): + return node.value + return None + if isinstance(node, ast.Name): + return self._const_env.get(node.id) + if isinstance(node, ast.UnaryOp): + val = self._const_eval(node.operand) + if val is None: + return None + if isinstance(node.op, ast.UAdd): + return +val + if isinstance(node.op, ast.USub): + return -val + if isinstance(node.op, ast.Not): + return not val + return None + if isinstance(node, ast.BinOp): + left = self._const_eval(node.left) + right = self._const_eval(node.right) + if left is None or right is None: + return None + return self._apply_binop(left, right, node.op) + if isinstance(node, ast.BoolOp): + vals = [self._const_eval(v) for v in node.values] + if any(v is None for v in vals): + return None + if isinstance(node.op, ast.And): + return all(vals) + if isinstance(node.op, ast.Or): + return any(vals) + return None + if isinstance(node, ast.Compare): + if len(node.ops) != 1 or len(node.comparators) != 1: + return None + left = self._const_eval(node.left) + right = self._const_eval(node.comparators[0]) + if left is None or right is None: + return None + return self._apply_cmp(left, right, node.ops[0]) + return None + + def _apply_binop(self, left, right, op): + if isinstance(op, ast.Add): + return left + right + if isinstance(op, ast.Sub): + return left - right + if isinstance(op, ast.Mult): + return left * right + if isinstance(op, ast.Div): + return left / right + if isinstance(op, ast.FloorDiv): + return left // right + if isinstance(op, ast.Mod): + return left % right + if isinstance(op, ast.Pow): + return left**right + if isinstance(op, ast.BitAnd): + return left & right + if isinstance(op, ast.BitOr): + return left | right + if isinstance(op, ast.BitXor): + return left ^ right + if isinstance(op, ast.LShift): + return left << right + if isinstance(op, ast.RShift): + return left >> right + return None + + def _apply_cmp(self, left, right, op): + if isinstance(op, ast.Eq): + return left == right + if isinstance(op, ast.NotEq): + return left != right + if isinstance(op, ast.Lt): + return left < right + if isinstance(op, ast.LtE): + return left <= right + if isinstance(op, ast.Gt): + return left > right + if isinstance(op, ast.GtE): + return left >= right + return None + + def _simplify_binop_expr(self, op, left_expr, right_expr, left_const, right_const): + if op == "+": + if self._is_zero(left_const): + return right_expr + if self._is_zero(right_const): + return left_expr + if op == "-" and self._is_zero(right_const): + return left_expr + if op == "*": + if self._is_one(left_const): + return right_expr + if self._is_one(right_const): + return left_expr + return f"({left_expr} {op} {right_expr})" + + def _is_zero(self, value): + return isinstance(value, int | float | bool) and value == 0 + + def _is_one(self, value): + return isinstance(value, int | float | bool) and value == 1 diff --git a/src/blosc2/lazyexpr.py b/src/blosc2/lazyexpr.py index 6f53a249..e7a59cee 100644 --- a/src/blosc2/lazyexpr.py +++ b/src/blosc2/lazyexpr.py @@ -41,6 +41,8 @@ import blosc2 +from .dsl_kernel import DSLKernel, specialize_dsl_miniexpr_inputs + if blosc2._HAS_NUMBA: import numba from blosc2 import compute_chunks_blocks @@ -1262,8 +1264,11 @@ def fast_eval( # noqa: C901 # Use a local copy so we don't modify the global use_miniexpr = try_miniexpr - # Disable miniexpr for UDFs (callable expressions) - if callable(expression): + is_dsl = isinstance(expression, DSLKernel) and expression.dsl_source + expr_string = expression.dsl_source if is_dsl else expression + + # Disable miniexpr for UDFs (callable expressions), except DSL kernels + if callable(expression) and not is_dsl: use_miniexpr = False out = kwargs.pop("_output", None) @@ -1315,25 +1320,38 @@ def fast_eval( # noqa: C901 # WebAssembly does not support threading, so we cannot use the iter_disk option iter_disk = False + expr_string_miniexpr = expr_string + operands_miniexpr = operands + if use_miniexpr and is_dsl: + try: + expr_string_miniexpr, operands_miniexpr = specialize_dsl_miniexpr_inputs(expr_string, operands) + except Exception: + # If specialization fails, keep original expression/operands and let normal checks decide. + expr_string_miniexpr = expr_string + operands_miniexpr = operands + # Check whether we can use miniexpr if use_miniexpr: + all_ndarray_miniexpr = all( + isinstance(value, blosc2.NDArray) and value.shape != () for value in operands_miniexpr.values() + ) # Require aligned NDArray operands with identical chunk/block grid. - same_shape = all(hasattr(op, "shape") and op.shape == shape for op in operands.values()) - same_chunks = all(hasattr(op, "chunks") and op.chunks == chunks for op in operands.values()) - same_blocks = all(hasattr(op, "blocks") and op.blocks == blocks for op in operands.values()) + same_shape = all(hasattr(op, "shape") and op.shape == shape for op in operands_miniexpr.values()) + same_chunks = all(hasattr(op, "chunks") and op.chunks == chunks for op in operands_miniexpr.values()) + same_blocks = all(hasattr(op, "blocks") and op.blocks == blocks for op in operands_miniexpr.values()) if not (same_shape and same_chunks and same_blocks): use_miniexpr = False - if not (all_ndarray and out is None): + if not (all_ndarray_miniexpr and out is None): use_miniexpr = False has_complex = any( isinstance(op, blosc2.NDArray) and blosc2.isdtype(op.dtype, "complex floating") - for op in operands.values() + for op in operands_miniexpr.values() ) - if isinstance(expression, str) and has_complex: + if isinstance(expr_string_miniexpr, str) and has_complex: if sys.platform == "win32": # On Windows, miniexpr has issues with complex numbers use_miniexpr = False - if any(tok in expression for tok in ("!=", "==", "<=", ">=", "<", ">")): + if any(tok in expr_string_miniexpr for tok in ("!=", "==", "<=", ">=", "<", ">")): use_miniexpr = False if sys.platform == "win32" and use_miniexpr and not _MINIEXPR_WINDOWS_OVERRIDE: # Work around Windows miniexpr issues for integer outputs and dtype conversions. @@ -1341,7 +1359,7 @@ def fast_eval( # noqa: C901 use_miniexpr = False else: dtype_mismatch = any( - isinstance(op, blosc2.NDArray) and op.dtype != dtype for op in operands.values() + isinstance(op, blosc2.NDArray) and op.dtype != dtype for op in operands_miniexpr.values() ) if dtype_mismatch: use_miniexpr = False @@ -1351,7 +1369,7 @@ def fast_eval( # noqa: C901 # All values will be overwritten, so we can use an uninitialized array res_eval = blosc2.uninit(shape, dtype, chunks=chunks, blocks=blocks, cparams=cparams, **kwargs) try: - res_eval._set_pref_expr(expression, operands, fp_accuracy=fp_accuracy) + res_eval._set_pref_expr(expr_string_miniexpr, operands_miniexpr, fp_accuracy=fp_accuracy) # print("expr->miniexpr:", expression, fp_accuracy) # Data to compress is fetched from operands, so it can be uninitialized here data = np.empty(res_eval.schunk.chunksize, dtype=np.uint8) @@ -3526,7 +3544,13 @@ def __init__(self, func, inputs, dtype, shape=None, chunked_eval=True, **kwargs) # if 0 not in self._shape: # self.res_getitem._set_postf_udf(self.func, id(self.inputs)) - self.inputs_dict = {f"o{i}": obj for i, obj in enumerate(self.inputs)} + if isinstance(self.func, DSLKernel) and self.func.input_names: + if len(self.func.input_names) == len(self.inputs): + self.inputs_dict = dict(zip(self.func.input_names, self.inputs, strict=True)) + else: + self.inputs_dict = {f"o{i}": obj for i, obj in enumerate(self.inputs)} + else: + self.inputs_dict = {f"o{i}": obj for i, obj in enumerate(self.inputs)} @property def dtype(self): @@ -3734,10 +3758,16 @@ def save(self, urlpath=None, **kwargs): if value.schunk.urlpath is None: raise ValueError("To save a LazyArray, all operands must be stored on disk/network") operands[key] = value.schunk.urlpath + udf_func = self.func.func if isinstance(self.func, DSLKernel) else self.func + udf_name = getattr(udf_func, "__name__", self.func.__name__) + try: + udf_source = textwrap.dedent(inspect.getsource(udf_func)).lstrip() + except Exception: + udf_source = None array.schunk.vlmeta["_LazyArray"] = { - "UDF": textwrap.dedent(inspect.getsource(self.func)).lstrip(), + "UDF": udf_source, "operands": operands, - "name": self.func.__name__, + "name": udf_name, } diff --git a/tests/ndarray/test_dsl_kernels.py b/tests/ndarray/test_dsl_kernels.py new file mode 100644 index 00000000..717bbac2 --- /dev/null +++ b/tests/ndarray/test_dsl_kernels.py @@ -0,0 +1,248 @@ +####################################################################### +# Copyright (c) 2019-present, Blosc Development Team +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause +####################################################################### + +import numpy as np +import pytest + +import blosc2 + + +def _make_arrays(shape=(8, 8), chunks=(4, 4), blocks=(2, 2)): + a = np.linspace(0, 1, num=np.prod(shape), dtype=np.float32).reshape(shape) + b = np.linspace(1, 2, num=np.prod(shape), dtype=np.float32).reshape(shape) + a2 = blosc2.asarray(a, chunks=chunks, blocks=blocks) + b2 = blosc2.asarray(b, chunks=chunks, blocks=blocks) + return a, b, a2, b2 + + +def _make_int_arrays(shape=(8, 8), chunks=(4, 4), blocks=(2, 2)): + a = np.arange(np.prod(shape), dtype=np.int32).reshape(shape) + b = np.arange(np.prod(shape), dtype=np.int32).reshape(shape) + 3 + a2 = blosc2.asarray(a, chunks=chunks, blocks=blocks) + b2 = blosc2.asarray(b, chunks=chunks, blocks=blocks) + return a, b, a2, b2 + + +@blosc2.dsl_kernel +def kernel_loop(x, y): + acc = 0.0 + for i in range(2): + if i % 2 == 0: + tmp = np.where(x < y, y + i, x - i) + else: + tmp = np.where(x > y, x + i, y - i) + acc = acc + tmp * (i + 1) + return acc + + +@blosc2.dsl_kernel +def kernel_fallback_range_2args(x, y): + acc = 0.0 + for i in range(1, 3): + acc = acc + x + y + i + return acc + + +@blosc2.dsl_kernel +def kernel_integer_ops(x, y): + acc = ((x + y) - (x * 2)) // 3 + acc = acc % 5 + acc = acc ^ (x & y) + acc = acc | (x << 1) + return acc + (y >> 1) + + +@blosc2.dsl_kernel +def kernel_control_flow_full(x, y): + acc = x + for i in range(4): + if i == 0: + acc = acc + y + continue + if i == 1: + acc = acc - y + else: + acc = np.where(acc < y, acc + i, acc - i) + if i == 3: + break + return acc + + +@blosc2.dsl_kernel +def kernel_loop_param(x, y, niter): + acc = x + for _i in range(niter): + acc = np.where(acc < y, acc + 1, acc - 1) + return acc + + +@blosc2.dsl_kernel +def kernel_fallback_kw_call(x, y): + return np.clip(x + y, a_min=0.5, a_max=2.5) + + +@blosc2.dsl_kernel +def kernel_fallback_for_else(x, y): + acc = x + for i in range(2): + acc = acc + i + else: + acc = acc + y + return acc + + +@blosc2.dsl_kernel +def kernel_fallback_tuple_assign(x, y): + lhs, rhs = x, y + return lhs + rhs + + +def test_dsl_kernel_reduced_expr(): + assert kernel_loop.dsl_source is not None + assert "def " not in kernel_loop.dsl_source + assert kernel_loop.input_names == ["x", "y"] + + a, b, a2, b2 = _make_arrays() + expr = blosc2.lazyudf(kernel_loop, (a2, b2), dtype=a2.dtype, chunks=a2.chunks, blocks=a2.blocks) + res = expr.compute() + expected = kernel_loop.func(a, b) + + np.testing.assert_allclose(res[...], expected, rtol=1e-5, atol=1e-6) + + +def test_dsl_kernel_integer_ops_reduced_expr(): + assert kernel_integer_ops.dsl_source is not None + assert "def " not in kernel_integer_ops.dsl_source + assert kernel_integer_ops.input_names == ["x", "y"] + + a, b, a2, b2 = _make_int_arrays() + expr = blosc2.lazyudf( + kernel_integer_ops, + (a2, b2), + dtype=a2.dtype, + chunks=a2.chunks, + blocks=a2.blocks, + ) + res = expr.compute() + expected = kernel_integer_ops.func(a, b) + + np.testing.assert_equal(res[...], expected) + + +def test_dsl_kernel_full_control_flow_kept_as_dsl_function(): + assert kernel_control_flow_full.dsl_source is not None + assert "def kernel_control_flow_full(x, y):" in kernel_control_flow_full.dsl_source + assert "for i in range(4):" in kernel_control_flow_full.dsl_source + assert "elif (i == 1):" in kernel_control_flow_full.dsl_source + assert "continue" in kernel_control_flow_full.dsl_source + assert "break" in kernel_control_flow_full.dsl_source + assert "where(" in kernel_control_flow_full.dsl_source + + a, b, a2, b2 = _make_arrays() + expr = blosc2.lazyudf( + kernel_control_flow_full, + (a2, b2), + dtype=a2.dtype, + chunks=a2.chunks, + blocks=a2.blocks, + ) + res = expr.compute() + expected = kernel_control_flow_full.func(a, b) + + np.testing.assert_allclose(res[...], expected, rtol=1e-5, atol=1e-6) + + +def test_dsl_kernel_accepts_scalar_param_per_call(): + assert kernel_loop_param.dsl_source is not None + assert "def kernel_loop_param(x, y, niter):" in kernel_loop_param.dsl_source + assert "for _i in range(niter):" in kernel_loop_param.dsl_source + assert kernel_loop_param.input_names == ["x", "y", "niter"] + + a, b, a2, b2 = _make_arrays() + niter = 3 + expr = blosc2.lazyudf( + kernel_loop_param, + (a2, b2, niter), + dtype=a2.dtype, + chunks=a2.chunks, + blocks=a2.blocks, + ) + res = expr.compute() + expected = kernel_loop_param.func(a, b, niter) + + np.testing.assert_allclose(res[...], expected, rtol=1e-5, atol=1e-6) + + +def test_dsl_kernel_scalar_param_keeps_miniexpr_fast_path(monkeypatch): + if blosc2.IS_WASM: + pytest.skip("miniexpr fast path is not available on WASM") + + import importlib + + lazyexpr_mod = importlib.import_module("blosc2.lazyexpr") + old_try_miniexpr = lazyexpr_mod.try_miniexpr + lazyexpr_mod.try_miniexpr = True + + original_set_pref_expr = blosc2.NDArray._set_pref_expr + captured = {"calls": 0, "expr": None, "keys": None} + + def wrapped_set_pref_expr(self, expression, inputs, fp_accuracy, aux_reduc=None): + captured["calls"] += 1 + captured["expr"] = expression.decode("utf-8") if isinstance(expression, bytes) else expression + captured["keys"] = tuple(inputs.keys()) + return original_set_pref_expr(self, expression, inputs, fp_accuracy, aux_reduc) + + monkeypatch.setattr(blosc2.NDArray, "_set_pref_expr", wrapped_set_pref_expr) + + try: + a, b, a2, b2 = _make_arrays(shape=(32, 32), chunks=(16, 16), blocks=(8, 8)) + niter = 3 + expr = blosc2.lazyudf( + kernel_loop_param, + (a2, b2, niter), + dtype=a2.dtype, + ) + res = expr.compute() + expected = kernel_loop_param.func(a, b, niter) + + np.testing.assert_allclose(res[...], expected, rtol=1e-5, atol=1e-6) + assert captured["calls"] >= 1 + assert captured["keys"] == ("x", "y") + assert "def kernel_loop_param(x, y):" in captured["expr"] + assert "for it in range(3):" not in captured["expr"] + assert "for _i in range(3):" in captured["expr"] + assert "range(niter)" not in captured["expr"] + assert "float(niter)" not in captured["expr"] + finally: + lazyexpr_mod.try_miniexpr = old_try_miniexpr + + +@pytest.mark.parametrize( + "kernel", + [ + kernel_fallback_range_2args, + kernel_fallback_kw_call, + kernel_fallback_for_else, + kernel_fallback_tuple_assign, + ], +) +def test_dsl_kernel_flawed_syntax_detected_fallback_callable(kernel): + assert kernel.dsl_source is None + assert kernel.input_names is None + + a, b, a2, b2 = _make_arrays() + expr = blosc2.lazyudf( + kernel, + (a2, b2), + dtype=a2.dtype, + chunks=a2.chunks, + blocks=a2.blocks, + ) + res = expr.compute() + expected = kernel.func(a, b) + + np.testing.assert_allclose(res[...], expected, rtol=1e-5, atol=1e-6)