-
-
Notifications
You must be signed in to change notification settings - Fork 35
Description
In order to make function evaluation properly lazy, all functions (e.g. matmul) should be implemented as LazyUDFs.
Secondly, the lazyexpr machinery of compute should loop over chunks of the result. Each function must then decide what slices of the operands are necessary to form the corresponding result chunk (as matmul currently does internally). Then, upon evaluation, although the expression evaluates term-by-term, it does not compute the full result for e..g matmul before proceeding to the next term, but only for the necessary chunk(s) of output. Thus there is a higher chance of cache hits (not the case currently for eager execution of linalg and reductions).
Something in this spirit has been implemented for cumsum/cumprod when axis=None (see this code ):
# Special case for cumulative operations with axis = None
if reduce_args["axis"] is None and reduce_op in {ReduceOp.CUMULATIVE_PROD, ReduceOp.CUMULATIVE_SUM}:
# res_out_ is just None, out set to all 0s (sum) or 1s (prod)
out, res_out_ = convert_none_out(dtype, reduce_op, reduced_shape)
# reduced_shape is just one-element tuple
chunklen = out.chunks[0] if hasattr(out, "chunks") else chunks[-1]
carry = 0
for cidx in range(0, reduced_shape[0] // chunklen):
slice_starts = np.unravel_index(cidx * chunklen, shape)
slice_stops = np.unravel_index((cidx + 1) * chunklen, shape)
cslice = tuple(
slice(start, stop) for start, stop in zip(slice_starts, slice_stops, strict=True)
)
_get_chunk_operands(operands, cslice, chunk_operands, shape)
result, _ = _get_result(expression, chunk_operands, ne_args, where)
result = np.require(result, requirements="C")
if reduce_op == ReduceOp.CUMULATIVE_SUM:
res = np.cumulative_sum(result, axis=None) + carry
else:
res = np.cumulative_prod(result, axis=None) * carry
carry = res[-1]
out[cidx * chunklen + include_initial : (cidx + 1) * chunklen + include_initial] = res
It should also be implemented for fancy-indexing with slice (see #441).
Should be possible to handle even something like "matmul(sum(a,axis=1), b)" by passing the desired slice from matmul->sum, which treats the asked-for-slice as a desired output slice and handles accordingly.
This would avoid large in-memory temporaries for example when calculating from eagerly executed linear algebra functions e.g. in "matmul(a, b) + b".
Problems:
1 - reductions could be a problem since for example "sum(a) + a" for the chunks of output would recalculate the scalar "sum(a)" for each chunk of output. This could be avoided by using some kind of result cache for reductions.
2 - naturally, numexpr would always be faster since it compiles the expression into bytecode. Thus we should make sure that, when possible, we still use numexpr preferentially (for most elementwise funcs).