Refactor AttentionProcessorSkipHook to Support Custom STG Logic#13220
Refactor AttentionProcessorSkipHook to Support Custom STG Logic#13220
AttentionProcessorSkipHook to Support Custom STG Logic#13220Conversation
…n attn processors to allow custom STG-style logic
sayakpaul
left a comment
There was a problem hiding this comment.
Thanks for working on this. It looks quite clean for it is supposed to support!
src/diffusers/hooks/layer_skip.py
Outdated
| def __init__( | ||
| self, | ||
| skip_processor_output_fn: Callable, | ||
| skip_attn_scores_fn: Callable | None = None, |
There was a problem hiding this comment.
Should it go at the end to preserve BC compatibility (in case someone initialized this class with positional arguments only)?
No strong opinions here.
| super().__init__() | ||
| self.skip_processor_output_fn = skip_processor_output_fn | ||
| # STG default: return the values as attention output | ||
| self.skip_attn_scores_fn = skip_attn_scores_fn or (lambda attn, q, k, v: v) |
There was a problem hiding this comment.
Is it safe to assume that the skip_attn_scores_fn will only take those four (attn, q, k, v) as inputs and always return v as the output?
There was a problem hiding this comment.
lambda attn, q, k, v: v is intended to be the default skip_attn_scores_fn (unless I made a mistake), which performs standard STG by returning the values v. I don't think this constrains the signature or behavior of skip_attn_scores_fn if we pass it as an argument (and we can always manually implement the skip logic in the if self._skip_attn_scores: branch in the attention processor itself if necessary).
| if processor_supports_skip_fn: | ||
| output = self.fn_ref.original_forward(*args, **kwargs) | ||
| else: | ||
| # Fallback to torch native SDPA intercept approach | ||
| with AttentionScoreSkipFunctionMode(): | ||
| output = self.fn_ref.original_forward(*args, **kwargs) |
There was a problem hiding this comment.
Why do we have to condition like this?
There was a problem hiding this comment.
We condition on processor_supports_skip_fn here in case the attention processor doesn't define a _skip_attn_scores attribute. If it doesn't we will fallback to the current behavior, which is to intercept a torch.nn.functional.scaled_dot_product_attention call and return the value from there. (The AttentionScoreSkipFunctionMode context manager performs the interception.)
| # Fallback to torch native SDPA intercept approach | ||
| with AttentionScoreSkipFunctionMode(): | ||
| output = self.fn_ref.original_forward(*args, **kwargs) | ||
| finally: |
There was a problem hiding this comment.
Should also not raise an exception for user clarity?
There was a problem hiding this comment.
I'm confused by this question, is your suggestion to catch any exceptions in a except block? Maybe something like
try:
...
except Exception as e:
logger.error(f"Tried to skip attn scores but got error {e}", exc_info=True)
raise
finally:
# Clean up if necessary
if processor_supports_skip_fn:
module.processor._skip_attn_scores = False
module.processor._skip_attn_scores_fn = None?
What does this PR do?
This PR refactors
AttentionProcessorSkipHookto support attention processors with a_skip_attn_scoresflag, which is meant to allow attention processors to specify custom spatio-temporal guidance (STG)-like logic. The idea is that anAttentionProcessorcould then be implemented as follows:This is motivated by supporting the following use cases:
F.scaled_dot_product_attention. But other attention backends likeflash-attnwon't call this function. The PR allowsSkipLayerGuidanceto work with any attention backend. (If theAttnProcessordoesn't have a_skip_attn_scoresattribute, we will fall back to the current intercept approach.)_skip_attn_scores_fn. For example, LTX-2.3 applies learned per-head gates to thevalues before passing them to the attention output projectionto_out. This can be implemented_skip_attn_scores_fn.Inspired by #13217, particularly #13217 (comment).
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@yiyixuxu
@sayakpaul