Skip to content

Refactor AttentionProcessorSkipHook to Support Custom STG Logic#13220

Open
dg845 wants to merge 2 commits intomainfrom
skip-layer-guider-refactor
Open

Refactor AttentionProcessorSkipHook to Support Custom STG Logic#13220
dg845 wants to merge 2 commits intomainfrom
skip-layer-guider-refactor

Conversation

@dg845
Copy link
Collaborator

@dg845 dg845 commented Mar 7, 2026

What does this PR do?

This PR refactors AttentionProcessorSkipHook to support attention processors with a _skip_attn_scores flag, which is meant to allow attention processors to specify custom spatio-temporal guidance (STG)-like logic. The idea is that an AttentionProcessor could then be implemented as follows:

class NewAttnProcessor:
    ...
    _skip_attn_scores = False
    _skip_attn_scores_fn = None

    def __call__(self, attn, hidden_states, ...):
        ...
        query = attn.to_q(hidden_states)
        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)
        ...
        if self._skip_attn_scores:
            # Defaults to returning `value` as in standard STG
            hidden_states = self._skip_attn_scores_fn(attn, query, key, value)
        else:
            ...
            hidden_states = dispatch_attention_fn(query, key, value, ...)
        ...
        hidden_states = attn.to_out[0](hidden_states)
        hidden_states = attn.to_out[1](hidden_states)
        return hidden_states

This is motivated by supporting the following use cases:

  1. Supporting all attention backends: the current implementation intercepts calls to F.scaled_dot_product_attention. But other attention backends like flash-attn won't call this function. The PR allows SkipLayerGuidance to work with any attention backend. (If the AttnProcessor doesn't have a _skip_attn_scores attribute, we will fall back to the current intercept approach.)
  2. Custom Skip Logic: we can define custom logic in the _skip_attn_scores_fn. For example, LTX-2.3 applies learned per-head gates to the values before passing them to the attention output projection to_out. This can be implemented _skip_attn_scores_fn.

Inspired by #13217, particularly #13217 (comment).

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@yiyixuxu
@sayakpaul

…n attn processors to allow custom STG-style logic
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this. It looks quite clean for it is supposed to support!

def __init__(
self,
skip_processor_output_fn: Callable,
skip_attn_scores_fn: Callable | None = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it go at the end to preserve BC compatibility (in case someone initialized this class with positional arguments only)?

No strong opinions here.

super().__init__()
self.skip_processor_output_fn = skip_processor_output_fn
# STG default: return the values as attention output
self.skip_attn_scores_fn = skip_attn_scores_fn or (lambda attn, q, k, v: v)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it safe to assume that the skip_attn_scores_fn will only take those four (attn, q, k, v) as inputs and always return v as the output?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lambda attn, q, k, v: v is intended to be the default skip_attn_scores_fn (unless I made a mistake), which performs standard STG by returning the values v. I don't think this constrains the signature or behavior of skip_attn_scores_fn if we pass it as an argument (and we can always manually implement the skip logic in the if self._skip_attn_scores: branch in the attention processor itself if necessary).

Comment on lines +137 to +142
if processor_supports_skip_fn:
output = self.fn_ref.original_forward(*args, **kwargs)
else:
# Fallback to torch native SDPA intercept approach
with AttentionScoreSkipFunctionMode():
output = self.fn_ref.original_forward(*args, **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have to condition like this?

Copy link
Collaborator Author

@dg845 dg845 Mar 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We condition on processor_supports_skip_fn here in case the attention processor doesn't define a _skip_attn_scores attribute. If it doesn't we will fallback to the current behavior, which is to intercept a torch.nn.functional.scaled_dot_product_attention call and return the value from there. (The AttentionScoreSkipFunctionMode context manager performs the interception.)

# Fallback to torch native SDPA intercept approach
with AttentionScoreSkipFunctionMode():
output = self.fn_ref.original_forward(*args, **kwargs)
finally:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should also not raise an exception for user clarity?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused by this question, is your suggestion to catch any exceptions in a except block? Maybe something like

            try:
                ...
            except Exception as e:
                logger.error(f"Tried to skip attn scores but got error {e}", exc_info=True)
                raise
            finally:
                # Clean up if necessary
                if processor_supports_skip_fn:
                    module.processor._skip_attn_scores = False
                    module.processor._skip_attn_scores_fn = None

?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants