Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 16 additions & 19 deletions src/MaxText/layers/attention_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1303,27 +1303,24 @@ def wrap_flash_attention(
decoder_segment_ids_tuple = splash_attention_kernel.SegmentIds(decoder_segment_ids_q, decoder_segment_ids_kv)
else:
decoder_segment_ids_tuple = None

if self.config.use_tokamax_splash:
if self.config.use_sparse_indexer and index_mask is not None:
# Construct the splash kernel call with dynamic mask
def dynamic_mask_splash_kernel(q, k, v, segment, sinks, index_mask):
splash_kernel = tokamax_splash_kernel.make_dynamic_splash_mha(
mask=index_mask,
config=sa_config,
)
kernel = partial(splash_kernel, max_logit_value=max_logit_value)
return kernel(q, k, v, segment, sinks=sinks)

# Iterate over batch dimension for (query, key, value, segment, sinks, mask)
attn_fn = jax.vmap(dynamic_mask_splash_kernel, (0, 0, 0, 0, None, 0))
index_mask = jnp.isclose(index_mask, 0.0)
attention_output = attn_fn(query, key, value, decoder_segment_ids_tuple, sinks, index_mask)
else:
kernel = partial(splash_kernel, max_logit_value=max_logit_value)
attention_output = jax.vmap(lambda q, k, v, d, s: kernel(q, k, v, d, sinks=s), in_axes=(0, 0, 0, 0, None))(
query, key, value, decoder_segment_ids_tuple, sinks
)
q_len = query.shape[2]
kv_len = key.shape[2]

row_ids = jax.lax.broadcasted_iota(jnp.int32, (q_len, kv_len), 0)
col_ids = jax.lax.broadcasted_iota(jnp.int32, (q_len, kv_len), 1)
causal_bool_mask = row_ids >= col_ids

splash_kernel = tokamax_splash_kernel.make_dynamic_splash_mha(
mask=causal_bool_mask,
config=sa_config,
)
kernel = partial(splash_kernel, max_logit_value=max_logit_value)

attention_output = jax.vmap(lambda q, k, v, d, s: kernel(q, k, v, d, sinks=s), in_axes=(0, 0, 0, 0, None))(
query, key, value, decoder_segment_ids_tuple, sinks
)
elif self.config.use_jax_splash:
materialized_mask = jnp.asarray(mask[:, :])
attention_output = jax_flash_attention.flash_attention_block_masked(
Expand Down
Loading