diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index 9dee954af6d0..7ef322e3039a 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -17,7 +17,7 @@ LTX2Pipeline, LTX2VideoTransformer3DModel, ) -from diffusers.pipelines.ltx2 import LTX2LatentUpsamplerModel, LTX2TextConnectors, LTX2Vocoder +from diffusers.pipelines.ltx2 import LTX2LatentUpsamplerModel, LTX2TextConnectors, LTX2Vocoder, LTX2VocoderWithBWE from diffusers.utils.import_utils import is_accelerate_available @@ -44,6 +44,12 @@ "k_norm": "norm_k", } +LTX_2_3_TRANSFORMER_KEYS_RENAME_DICT = { + **LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT, + "audio_prompt_adaln_single": "audio_prompt_adaln", + "prompt_adaln_single": "prompt_adaln", +} + LTX_2_0_VIDEO_VAE_RENAME_DICT = { # Encoder "down_blocks.0": "down_blocks.0", @@ -72,6 +78,13 @@ "per_channel_statistics.std-of-means": "latents_std", } +LTX_2_3_VIDEO_VAE_RENAME_DICT = { + **LTX_2_0_VIDEO_VAE_RENAME_DICT, + # Decoder extra blocks + "up_blocks.7": "up_blocks.3.upsamplers.0", + "up_blocks.8": "up_blocks.3", +} + LTX_2_0_AUDIO_VAE_RENAME_DICT = { "per_channel_statistics.mean-of-means": "latents_mean", "per_channel_statistics.std-of-means": "latents_std", @@ -84,10 +97,34 @@ "conv_post": "conv_out", } -LTX_2_0_TEXT_ENCODER_RENAME_DICT = { +LTX_2_3_VOCODER_RENAME_DICT = { + # Handle upsamplers ("ups" --> "upsamplers") due to name clash + "resblocks": "resnets", + "conv_pre": "conv_in", + "conv_post": "conv_out", + "act_post": "act_out", + "downsample.lowpass": "downsample", +} + +LTX_2_0_CONNECTORS_KEYS_RENAME_DICT = { + "connectors.": "", + "video_embeddings_connector": "video_connector", + "audio_embeddings_connector": "audio_connector", + "transformer_1d_blocks": "transformer_blocks", + "text_embedding_projection.aggregate_embed": "text_proj_in", + # Attention QK Norms + "q_norm": "norm_q", + "k_norm": "norm_k", +} + +LTX_2_3_CONNECTORS_KEYS_RENAME_DICT = { + "connectors.": "", "video_embeddings_connector": "video_connector", "audio_embeddings_connector": "audio_connector", "transformer_1d_blocks": "transformer_blocks", + # LTX-2.3 uses per-modality embedding projections + "text_embedding_projection.audio_aggregate_embed": "audio_text_proj_in", + "text_embedding_projection.video_aggregate_embed": "video_text_proj_in", # Attention QK Norms "q_norm": "norm_q", "k_norm": "norm_k", @@ -129,23 +166,24 @@ def convert_ltx2_audio_vae_per_channel_statistics(key: str, state_dict: dict[str return +def convert_ltx2_3_vocoder_upsamplers(key: str, state_dict: dict[str, Any]) -> None: + # Skip if not a weight, bias + if ".weight" not in key and ".bias" not in key: + return + + if ".ups." in key: + new_key = key.replace(".ups.", ".upsamplers.") + param = state_dict.pop(key) + state_dict[new_key] = param + return + + LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = { "video_embeddings_connector": remove_keys_inplace, "audio_embeddings_connector": remove_keys_inplace, "adaln_single": convert_ltx2_transformer_adaln_single, } -LTX_2_0_CONNECTORS_KEYS_RENAME_DICT = { - "connectors.": "", - "video_embeddings_connector": "video_connector", - "audio_embeddings_connector": "audio_connector", - "transformer_1d_blocks": "transformer_blocks", - "text_embedding_projection.aggregate_embed": "text_proj_in", - # Attention QK Norms - "q_norm": "norm_q", - "k_norm": "norm_k", -} - LTX_2_0_VAE_SPECIAL_KEYS_REMAP = { "per_channel_statistics.channel": remove_keys_inplace, "per_channel_statistics.mean-of-stds": remove_keys_inplace, @@ -155,13 +193,19 @@ def convert_ltx2_audio_vae_per_channel_statistics(key: str, state_dict: dict[str LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {} +LTX_2_3_VOCODER_SPECIAL_KEYS_REMAP = { + ".ups.": convert_ltx2_3_vocoder_upsamplers, +} + +LTX_2_0_CONNECTORS_SPECIAL_KEYS_REMAP = {} + def split_transformer_and_connector_state_dict(state_dict: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: connector_prefixes = ( "video_embeddings_connector", "audio_embeddings_connector", "transformer_1d_blocks", - "text_embedding_projection.aggregate_embed", + "text_embedding_projection", "connectors.", "video_connector", "audio_connector", @@ -225,7 +269,7 @@ def get_ltx2_transformer_config(version: str) -> tuple[dict[str, Any], dict[str, special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP elif version == "2.0": config = { - "model_id": "diffusers-internal-dev/new-ltx-model", + "model_id": "Lightricks/LTX-2", "diffusers_config": { "in_channels": 128, "out_channels": 128, @@ -238,6 +282,8 @@ def get_ltx2_transformer_config(version: str) -> tuple[dict[str, Any], dict[str, "pos_embed_max_pos": 20, "base_height": 2048, "base_width": 2048, + "gated_attn": False, + "cross_attn_mod": False, "audio_in_channels": 128, "audio_out_channels": 128, "audio_patch_size": 1, @@ -249,6 +295,8 @@ def get_ltx2_transformer_config(version: str) -> tuple[dict[str, Any], dict[str, "audio_pos_embed_max_pos": 20, "audio_sampling_rate": 16000, "audio_hop_length": 160, + "audio_gated_attn": False, + "audio_cross_attn_mod": False, "num_layers": 48, "activation_fn": "gelu-approximate", "qk_norm": "rms_norm_across_heads", @@ -263,10 +311,62 @@ def get_ltx2_transformer_config(version: str) -> tuple[dict[str, Any], dict[str, "timestep_scale_multiplier": 1000, "cross_attn_timestep_scale_multiplier": 1000, "rope_type": "split", + "use_prompt_embeddings": True, + "perturbed_attn": False, }, } rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP + elif version == "2.3": + config = { + "model_id": "Lightricks/LTX-2.3", + "diffusers_config": { + "in_channels": 128, + "out_channels": 128, + "patch_size": 1, + "patch_size_t": 1, + "num_attention_heads": 32, + "attention_head_dim": 128, + "cross_attention_dim": 4096, + "vae_scale_factors": (8, 32, 32), + "pos_embed_max_pos": 20, + "base_height": 2048, + "base_width": 2048, + "gated_attn": True, + "cross_attn_mod": True, + "audio_in_channels": 128, + "audio_out_channels": 128, + "audio_patch_size": 1, + "audio_patch_size_t": 1, + "audio_num_attention_heads": 32, + "audio_attention_head_dim": 64, + "audio_cross_attention_dim": 2048, + "audio_scale_factor": 4, + "audio_pos_embed_max_pos": 20, + "audio_sampling_rate": 16000, + "audio_hop_length": 160, + "audio_gated_attn": True, + "audio_cross_attn_mod": True, + "num_layers": 48, + "activation_fn": "gelu-approximate", + "qk_norm": "rms_norm_across_heads", + "norm_elementwise_affine": False, + "norm_eps": 1e-6, + "caption_channels": 3840, + "attention_bias": True, + "attention_out_bias": True, + "rope_theta": 10000.0, + "rope_double_precision": True, + "causal_offset": 1, + "timestep_scale_multiplier": 1000, + "cross_attn_timestep_scale_multiplier": 1000, + "rope_type": "split", + "use_prompt_embeddings": False, + "perturbed_attn": True, + }, + } + rename_dict = LTX_2_3_TRANSFORMER_KEYS_RENAME_DICT + special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP return config, rename_dict, special_keys_remap @@ -293,7 +393,7 @@ def get_ltx2_connectors_config(version: str) -> tuple[dict[str, Any], dict[str, } elif version == "2.0": config = { - "model_id": "diffusers-internal-dev/new-ltx-model", + "model_id": "Lightricks/LTX-2", "diffusers_config": { "caption_channels": 3840, "text_proj_in_factor": 49, @@ -301,20 +401,52 @@ def get_ltx2_connectors_config(version: str) -> tuple[dict[str, Any], dict[str, "video_connector_attention_head_dim": 128, "video_connector_num_layers": 2, "video_connector_num_learnable_registers": 128, + "video_gated_attn": False, "audio_connector_num_attention_heads": 30, "audio_connector_attention_head_dim": 128, "audio_connector_num_layers": 2, "audio_connector_num_learnable_registers": 128, + "audio_gated_attn": False, "connector_rope_base_seq_len": 4096, "rope_theta": 10000.0, "rope_double_precision": True, "causal_temporal_positioning": False, "rope_type": "split", + "per_modality_projections": False, + "proj_bias": False, }, } - - rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT - special_keys_remap = {} + rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT + special_keys_remap = LTX_2_0_CONNECTORS_SPECIAL_KEYS_REMAP + elif version == "2.3": + config = { + "model_id": "Lightricks/LTX-2.3", + "diffusers_config": { + "caption_channels": 3840, + "text_proj_in_factor": 49, + "video_connector_num_attention_heads": 32, + "video_connector_attention_head_dim": 128, + "video_connector_num_layers": 8, + "video_connector_num_learnable_registers": 128, + "video_gated_attn": True, + "audio_connector_num_attention_heads": 32, + "audio_connector_attention_head_dim": 64, + "audio_connector_num_layers": 8, + "audio_connector_num_learnable_registers": 128, + "audio_gated_attn": True, + "connector_rope_base_seq_len": 4096, + "rope_theta": 10000.0, + "rope_double_precision": True, + "causal_temporal_positioning": False, + "rope_type": "split", + "per_modality_projections": True, + "video_hidden_dim": 4096, + "audio_hidden_dim": 2048, + "proj_bias": True, + }, + } + rename_dict = LTX_2_3_CONNECTORS_KEYS_RENAME_DICT + special_keys_remap = LTX_2_0_CONNECTORS_SPECIAL_KEYS_REMAP return config, rename_dict, special_keys_remap @@ -416,7 +548,7 @@ def get_ltx2_video_vae_config( special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP elif version == "2.0": config = { - "model_id": "diffusers-internal-dev/dummy-ltx2", + "model_id": "Lightricks/LTX-2", "diffusers_config": { "in_channels": 3, "out_channels": 3, @@ -435,6 +567,7 @@ def get_ltx2_video_vae_config( "decoder_spatio_temporal_scaling": (True, True, True), "decoder_inject_noise": (False, False, False, False), "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), + "upsample_type": ("spatiotemporal", "spatiotemporal", "spatiotemporal"), "upsample_residual": (True, True, True), "upsample_factor": (2, 2, 2), "timestep_conditioning": timestep_conditioning, @@ -451,6 +584,44 @@ def get_ltx2_video_vae_config( } rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP + elif version == "2.3": + config = { + "model_id": "Lightricks/LTX-2.3", + "diffusers_config": { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "block_out_channels": (256, 512, 1024, 1024), + "down_block_types": ( + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + ), + "decoder_block_out_channels": (256, 512, 512, 1024), + "layers_per_block": (4, 6, 4, 2, 2), + "decoder_layers_per_block": (4, 6, 4, 2, 2), + "spatio_temporal_scaling": (True, True, True, True), + "decoder_spatio_temporal_scaling": (True, True, True, True), + "decoder_inject_noise": (False, False, False, False, False), + "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), + "upsample_type": ("spatiotemporal", "spatiotemporal", "temporal", "spatial"), + "upsample_residual": (True, True, True, True), + "upsample_factor": (2, 2, 1, 2), + "timestep_conditioning": timestep_conditioning, + "patch_size": 4, + "patch_size_t": 1, + "resnet_norm_eps": 1e-6, + "encoder_causal": True, + "decoder_causal": False, + "encoder_spatial_padding_mode": "zeros", + "decoder_spatial_padding_mode": "zeros", + "spatial_compression_ratio": 32, + "temporal_compression_ratio": 8, + }, + } + rename_dict = LTX_2_3_VIDEO_VAE_RENAME_DICT + special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP return config, rename_dict, special_keys_remap @@ -485,7 +656,7 @@ def convert_ltx2_video_vae( def get_ltx2_audio_vae_config(version: str) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: if version == "2.0": config = { - "model_id": "diffusers-internal-dev/new-ltx-model", + "model_id": "Lightricks/LTX-2", "diffusers_config": { "base_channels": 128, "output_channels": 2, @@ -508,6 +679,31 @@ def get_ltx2_audio_vae_config(version: str) -> tuple[dict[str, Any], dict[str, A } rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT special_keys_remap = LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP + elif version == "2.3": + config = { + "model_id": "Lightricks/LTX-2.3", + "diffusers_config": { + "base_channels": 128, + "output_channels": 2, + "ch_mult": (1, 2, 4), + "num_res_blocks": 2, + "attn_resolutions": None, + "in_channels": 2, + "resolution": 256, + "latent_channels": 8, + "norm_type": "pixel", + "causality_axis": "height", + "dropout": 0.0, + "mid_block_add_attention": False, + "sample_rate": 16000, + "mel_hop_length": 160, + "is_causal": True, + "mel_bins": 64, + "double_z": True, + }, # Same config as LTX-2.0 + } + rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT + special_keys_remap = LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP return config, rename_dict, special_keys_remap @@ -540,7 +736,7 @@ def convert_ltx2_audio_vae(original_state_dict: dict[str, Any], version: str) -> def get_ltx2_vocoder_config(version: str) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: if version == "2.0": config = { - "model_id": "diffusers-internal-dev/new-ltx-model", + "model_id": "Lightricks/LTX-2", "diffusers_config": { "in_channels": 128, "hidden_channels": 1024, @@ -549,21 +745,71 @@ def get_ltx2_vocoder_config(version: str) -> tuple[dict[str, Any], dict[str, Any "upsample_factors": [6, 5, 2, 2, 2], "resnet_kernel_sizes": [3, 7, 11], "resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "act_fn": "leaky_relu", "leaky_relu_negative_slope": 0.1, + "antialias": False, + "final_act_fn": "tanh", + "final_bias": True, "output_sampling_rate": 24000, }, } rename_dict = LTX_2_0_VOCODER_RENAME_DICT special_keys_remap = LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP + elif version == "2.3": + config = { + "model_id": "Lightricks/LTX-2.3", + "diffusers_config": { + "in_channels": 128, + "hidden_channels": 1536, + "out_channels": 2, + "upsample_kernel_sizes": [11, 4, 4, 4, 4, 4], + "upsample_factors": [5, 2, 2, 2, 2, 2], + "resnet_kernel_sizes": [3, 7, 11], + "resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "act_fn": "snakebeta", + "leaky_relu_negative_slope": 0.1, + "antialias": True, + "antialias_ratio": 2, + "antialias_kernel_size": 12, + "final_act_fn": None, + "final_bias": False, + "bwe_in_channels": 128, + "bwe_hidden_channels": 512, + "bwe_out_channels": 2, + "bwe_upsample_kernel_sizes": [12, 11, 4, 4, 4], + "bwe_upsample_factors": [6, 5, 2, 2, 2], + "bwe_resnet_kernel_sizes": [3, 7, 11], + "bwe_resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "bwe_act_fn": "snakebeta", + "bwe_leaky_relu_negative_slope": 0.1, + "bwe_antialias": True, + "bwe_antialias_ratio": 2, + "bwe_antialias_kernel_size": 12, + "bwe_final_act_fn": None, + "bwe_final_bias": False, + "filter_length": 512, + "hop_length": 80, + "window_length": 512, + "num_mel_channels": 64, + "input_sampling_rate": 16000, + "output_sampling_rate": 48000, + }, + } + rename_dict = LTX_2_3_VOCODER_RENAME_DICT + special_keys_remap = LTX_2_3_VOCODER_SPECIAL_KEYS_REMAP return config, rename_dict, special_keys_remap def convert_ltx2_vocoder(original_state_dict: dict[str, Any], version: str) -> dict[str, Any]: config, rename_dict, special_keys_remap = get_ltx2_vocoder_config(version) diffusers_config = config["diffusers_config"] + if version == "2.3": + vocoder_cls = LTX2VocoderWithBWE + else: + vocoder_cls = LTX2Vocoder with init_empty_weights(): - vocoder = LTX2Vocoder.from_config(diffusers_config) + vocoder = vocoder_cls.from_config(diffusers_config) # Handle official code --> diffusers key remapping via the remap dict for key in list(original_state_dict.keys()): @@ -651,13 +897,17 @@ def get_model_state_dict_from_combined_ckpt(combined_ckpt: dict[str, Any], prefi model_state_dict = {} for param_name, param in combined_ckpt.items(): if param_name.startswith(prefix): - model_state_dict[param_name.replace(prefix, "")] = param + model_state_dict[param_name.removeprefix(prefix)] = param if prefix == "model.diffusion_model.": # Some checkpoints store the text connector projection outside the diffusion model prefix. - connector_key = "text_embedding_projection.aggregate_embed.weight" - if connector_key in combined_ckpt and connector_key not in model_state_dict: - model_state_dict[connector_key] = combined_ckpt[connector_key] + connector_prefixes = ["text_embedding_projection"] + for param_name, param in combined_ckpt.items(): + for prefix in connector_prefixes: + if param_name.startswith(prefix): + # Check to make sure we're not overwriting an existing key + if param_name not in model_state_dict: + model_state_dict[param_name] = combined_ckpt[param_name] return model_state_dict @@ -686,7 +936,7 @@ def none_or_str(value: str): "--version", type=str, default="2.0", - choices=["test", "2.0"], + choices=["test", "2.0", "2.3"], help="Version of the LTX 2.0 model", ) @@ -787,7 +1037,7 @@ def main(args): args.audio_vae, args.dit, args.vocoder, - args.text_encoder, + args.connectors, args.full_pipeline, args.upsample_pipeline, ] @@ -852,7 +1102,7 @@ def main(args): if not args.full_pipeline: tokenizer.save_pretrained(os.path.join(args.output_path, "tokenizer")) - if args.latent_upsampler or args.full_pipeline or args.upsample_pipeline: + if args.latent_upsampler or args.upsample_pipeline: original_latent_upsampler_ckpt = load_hub_or_local_checkpoint( repo_id=args.original_state_dict_repo_id, filename=args.latent_upsampler_filename ) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py index 7c04bd715c25..36cf73e4bb2a 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py @@ -237,7 +237,7 @@ def forward( # Like LTX 1.0 LTXVideoDownsampler3d, but uses new causal Conv3d -class LTXVideoDownsampler3d(nn.Module): +class LTX2VideoDownsampler3d(nn.Module): def __init__( self, in_channels: int, @@ -285,10 +285,11 @@ def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Ten # Like LTX 1.0 LTXVideoUpsampler3d, but uses new causal Conv3d -class LTXVideoUpsampler3d(nn.Module): +class LTX2VideoUpsampler3d(nn.Module): def __init__( self, in_channels: int, + out_channels: int | None = None, stride: int | tuple[int, int, int] = 1, residual: bool = False, upscale_factor: int = 1, @@ -300,7 +301,8 @@ def __init__( self.residual = residual self.upscale_factor = upscale_factor - out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor + out_channels = out_channels or in_channels + out_channels = (out_channels * stride[0] * stride[1] * stride[2]) // upscale_factor self.conv = LTX2VideoCausalConv3d( in_channels=in_channels, @@ -408,7 +410,7 @@ def __init__( ) elif downsample_type == "spatial": self.downsamplers.append( - LTXVideoDownsampler3d( + LTX2VideoDownsampler3d( in_channels=in_channels, out_channels=out_channels, stride=(1, 2, 2), @@ -417,7 +419,7 @@ def __init__( ) elif downsample_type == "temporal": self.downsamplers.append( - LTXVideoDownsampler3d( + LTX2VideoDownsampler3d( in_channels=in_channels, out_channels=out_channels, stride=(2, 1, 1), @@ -426,7 +428,7 @@ def __init__( ) elif downsample_type == "spatiotemporal": self.downsamplers.append( - LTXVideoDownsampler3d( + LTX2VideoDownsampler3d( in_channels=in_channels, out_channels=out_channels, stride=(2, 2, 2), @@ -580,6 +582,7 @@ def __init__( resnet_eps: float = 1e-6, resnet_act_fn: str = "swish", spatio_temporal_scale: bool = True, + upsample_type: str = "spatiotemporal", inject_noise: bool = False, timestep_conditioning: bool = False, upsample_residual: bool = False, @@ -609,17 +612,38 @@ def __init__( self.upsamplers = None if spatio_temporal_scale: - self.upsamplers = nn.ModuleList( - [ - LTXVideoUpsampler3d( - out_channels * upscale_factor, + self.upsamplers = nn.ModuleList() + + if upsample_type == "spatial": + self.upsamplers.append( + LTX2VideoUpsampler3d( + in_channels=out_channels * upscale_factor, + stride=(1, 2, 2), + residual=upsample_residual, + upscale_factor=upscale_factor, + spatial_padding_mode=spatial_padding_mode, + ) + ) + elif upsample_type == "temporal": + self.upsamplers.append( + LTX2VideoUpsampler3d( + in_channels=out_channels * upscale_factor, + stride=(2, 1, 1), + residual=upsample_residual, + upscale_factor=upscale_factor, + spatial_padding_mode=spatial_padding_mode, + ) + ) + elif upsample_type == "spatiotemporal": + self.upsamplers.append( + LTX2VideoUpsampler3d( + in_channels=out_channels * upscale_factor, stride=(2, 2, 2), residual=upsample_residual, upscale_factor=upscale_factor, spatial_padding_mode=spatial_padding_mode, ) - ] - ) + ) resnets = [] for _ in range(num_layers): @@ -716,7 +740,7 @@ def __init__( "LTX2VideoDownBlock3D", "LTX2VideoDownBlock3D", ), - spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, True), + spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True, True), layers_per_block: tuple[int, ...] = (4, 6, 6, 2, 2), downsample_type: tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), patch_size: int = 4, @@ -726,6 +750,9 @@ def __init__( spatial_padding_mode: str = "zeros", ): super().__init__() + num_encoder_blocks = len(layers_per_block) + if isinstance(spatio_temporal_scaling, bool): + spatio_temporal_scaling = (spatio_temporal_scaling,) * (num_encoder_blocks - 1) self.patch_size = patch_size self.patch_size_t = patch_size_t @@ -860,19 +887,27 @@ def __init__( in_channels: int = 128, out_channels: int = 3, block_out_channels: tuple[int, ...] = (256, 512, 1024), - spatio_temporal_scaling: tuple[bool, ...] = (True, True, True), + spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True), layers_per_block: tuple[int, ...] = (5, 5, 5, 5), + upsample_type: tuple[str, ...] = ("spatiotemporal", "spatiotemporal", "spatiotemporal"), patch_size: int = 4, patch_size_t: int = 1, resnet_norm_eps: float = 1e-6, is_causal: bool = False, - inject_noise: tuple[bool, ...] = (False, False, False), + inject_noise: bool | tuple[bool, ...] = (False, False, False), timestep_conditioning: bool = False, - upsample_residual: tuple[bool, ...] = (True, True, True), + upsample_residual: bool | tuple[bool, ...] = (True, True, True), upsample_factor: tuple[bool, ...] = (2, 2, 2), spatial_padding_mode: str = "reflect", ) -> None: super().__init__() + num_decoder_blocks = len(layers_per_block) + if isinstance(spatio_temporal_scaling, bool): + spatio_temporal_scaling = (spatio_temporal_scaling,) * (num_decoder_blocks - 1) + if isinstance(inject_noise, bool): + inject_noise = (inject_noise,) * num_decoder_blocks + if isinstance(upsample_residual, bool): + upsample_residual = (upsample_residual,) * (num_decoder_blocks - 1) self.patch_size = patch_size self.patch_size_t = patch_size_t @@ -917,6 +952,7 @@ def __init__( num_layers=layers_per_block[i + 1], resnet_eps=resnet_norm_eps, spatio_temporal_scale=spatio_temporal_scaling[i], + upsample_type=upsample_type[i], inject_noise=inject_noise[i + 1], timestep_conditioning=timestep_conditioning, upsample_residual=upsample_residual[i], @@ -1058,11 +1094,12 @@ def __init__( decoder_block_out_channels: tuple[int, ...] = (256, 512, 1024), layers_per_block: tuple[int, ...] = (4, 6, 6, 2, 2), decoder_layers_per_block: tuple[int, ...] = (5, 5, 5, 5), - spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, True), - decoder_spatio_temporal_scaling: tuple[bool, ...] = (True, True, True), - decoder_inject_noise: tuple[bool, ...] = (False, False, False, False), + spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True, True), + decoder_spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True), + decoder_inject_noise: bool | tuple[bool, ...] = (False, False, False, False), downsample_type: tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), - upsample_residual: tuple[bool, ...] = (True, True, True), + upsample_type: tuple[str, ...] = ("spatiotemporal", "spatiotemporal", "spatiotemporal"), + upsample_residual: bool | tuple[bool, ...] = (True, True, True), upsample_factor: tuple[int, ...] = (2, 2, 2), timestep_conditioning: bool = False, patch_size: int = 4, @@ -1077,6 +1114,16 @@ def __init__( temporal_compression_ratio: int = None, ) -> None: super().__init__() + num_encoder_blocks = len(layers_per_block) + num_decoder_blocks = len(decoder_layers_per_block) + if isinstance(spatio_temporal_scaling, bool): + spatio_temporal_scaling = (spatio_temporal_scaling,) * (num_encoder_blocks - 1) + if isinstance(decoder_spatio_temporal_scaling, bool): + decoder_spatio_temporal_scaling = (decoder_spatio_temporal_scaling,) * (num_decoder_blocks - 1) + if isinstance(decoder_inject_noise, bool): + decoder_inject_noise = (decoder_inject_noise,) * num_decoder_blocks + if isinstance(upsample_residual, bool): + upsample_residual = (upsample_residual,) * (num_decoder_blocks - 1) self.encoder = LTX2VideoEncoder3d( in_channels=in_channels, @@ -1098,6 +1145,7 @@ def __init__( block_out_channels=decoder_block_out_channels, spatio_temporal_scaling=decoder_spatio_temporal_scaling, layers_per_block=decoder_layers_per_block, + upsample_type=upsample_type, patch_size=patch_size, patch_size_t=patch_size_t, resnet_norm_eps=resnet_norm_eps, diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index db949ca34a1f..e734e4a46b0d 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -178,6 +178,10 @@ def __call__( if encoder_hidden_states is None: encoder_hidden_states = hidden_states + if attn.to_gate_logits is not None: + # Calculate gate logits on original hidden_states + gate_logits = attn.to_gate_logits(hidden_states) + query = attn.to_q(hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -212,6 +216,112 @@ def __call__( hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.to(query.dtype) + if attn.to_gate_logits is not None: + hidden_states = hidden_states.unflatten(2, (attn.heads, -1)) # [B, T, H, D] + # The factor of 2.0 is so that if the gates logits are zero-initialized the initial gates are all 1 + gates = 2.0 * torch.sigmoid(gate_logits) # [B, T, H] + hidden_states = hidden_states * gates.unsqueeze(-1) + hidden_states = hidden_states.flatten(2, 3) + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class LTX2PerturbedAttnProcessor: + r""" + Processor which implements attention with perturbation masking and per-head gating for LTX-2.X models. + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + if is_torch_version("<", "2.0"): + raise ValueError( + "LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation." + ) + + def __call__( + self, + attn: "LTX2Attention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + query_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + key_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + perturbation_mask: torch.Tensor | None = None, + all_perturbed: bool | None = None, + ) -> torch.Tensor: + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + if attn.to_gate_logits is not None: + # Calculate gate logits on original hidden_states + gate_logits = attn.to_gate_logits(hidden_states) + + value = attn.to_v(encoder_hidden_states) + if all_perturbed is None: + all_perturbed = torch.all(perturbation_mask == 0) if perturbation_mask is not None else False + + if all_perturbed: + # Skip attention, use the value projection value + hidden_states = value + else: + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if query_rotary_emb is not None: + if attn.rope_type == "interleaved": + query = apply_interleaved_rotary_emb(query, query_rotary_emb) + key = apply_interleaved_rotary_emb( + key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb + ) + elif attn.rope_type == "split": + query = apply_split_rotary_emb(query, query_rotary_emb) + key = apply_split_rotary_emb( + key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb + ) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if perturbation_mask is not None: + value = value.flatten(2, 3) + hidden_states = torch.lerp(value, hidden_states, perturbation_mask) + + if attn.to_gate_logits is not None: + hidden_states = hidden_states.unflatten(2, (attn.heads, -1)) # [B, T, H, D] + # The factor of 2.0 is so that if the gates logits are zero-initialized the initial gates are all 1 + gates = 2.0 * torch.sigmoid(gate_logits) # [B, T, H] + hidden_states = hidden_states * gates.unsqueeze(-1) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) return hidden_states @@ -224,7 +334,7 @@ class LTX2Attention(torch.nn.Module, AttentionModuleMixin): """ _default_processor_cls = LTX2AudioVideoAttnProcessor - _available_processors = [LTX2AudioVideoAttnProcessor] + _available_processors = [LTX2AudioVideoAttnProcessor, LTX2PerturbedAttnProcessor] def __init__( self, @@ -240,6 +350,7 @@ def __init__( norm_eps: float = 1e-6, norm_elementwise_affine: bool = True, rope_type: str = "interleaved", + apply_gated_attention: bool = False, processor=None, ): super().__init__() @@ -266,6 +377,12 @@ def __init__( self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) self.to_out.append(torch.nn.Dropout(dropout)) + if apply_gated_attention: + # Per head gate values + self.to_gate_logits = torch.nn.Linear(query_dim, heads, bias=True) + else: + self.to_gate_logits = None + if processor is None: processor = self._default_processor_cls() self.set_processor(processor) @@ -321,6 +438,10 @@ def __init__( audio_num_attention_heads: int, audio_attention_head_dim, audio_cross_attention_dim: int, + video_gated_attn: bool = False, + video_cross_attn_adaln: bool = False, + audio_gated_attn: bool = False, + audio_cross_attn_adaln: bool = False, qk_norm: str = "rms_norm_across_heads", activation_fn: str = "gelu-approximate", attention_bias: bool = True, @@ -328,9 +449,16 @@ def __init__( eps: float = 1e-6, elementwise_affine: bool = False, rope_type: str = "interleaved", + perturbed_attn: bool = False, ): super().__init__() + self.perturbed_attn = perturbed_attn + if perturbed_attn: + attn_processor_cls = LTX2PerturbedAttnProcessor + else: + attn_processor_cls = LTX2AudioVideoAttnProcessor + # 1. Self-Attention (video and audio) self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) self.attn1 = LTX2Attention( @@ -343,6 +471,8 @@ def __init__( out_bias=attention_out_bias, qk_norm=qk_norm, rope_type=rope_type, + apply_gated_attention=video_gated_attn, + processor=attn_processor_cls(), ) self.audio_norm1 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) @@ -356,6 +486,8 @@ def __init__( out_bias=attention_out_bias, qk_norm=qk_norm, rope_type=rope_type, + apply_gated_attention=audio_gated_attn, + processor=attn_processor_cls(), ) # 2. Prompt Cross-Attention @@ -370,6 +502,8 @@ def __init__( out_bias=attention_out_bias, qk_norm=qk_norm, rope_type=rope_type, + apply_gated_attention=video_gated_attn, + processor=attn_processor_cls(), ) self.audio_norm2 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) @@ -383,6 +517,8 @@ def __init__( out_bias=attention_out_bias, qk_norm=qk_norm, rope_type=rope_type, + apply_gated_attention=audio_gated_attn, + processor=attn_processor_cls(), ) # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention @@ -398,6 +534,8 @@ def __init__( out_bias=attention_out_bias, qk_norm=qk_norm, rope_type=rope_type, + apply_gated_attention=video_gated_attn, + processor=attn_processor_cls(), ) # Video-to-Audio (v2a) Attention --> Q: Audio; K,V: Video @@ -412,6 +550,8 @@ def __init__( out_bias=attention_out_bias, qk_norm=qk_norm, rope_type=rope_type, + apply_gated_attention=audio_gated_attn, + processor=attn_processor_cls(), ) # 4. Feedforward layers @@ -422,14 +562,36 @@ def __init__( self.audio_ff = FeedForward(audio_dim, activation_fn=activation_fn) # 5. Per-Layer Modulation Parameters - # Self-Attention / Feedforward AdaLayerNorm-Zero mod params - self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) - self.audio_scale_shift_table = nn.Parameter(torch.randn(6, audio_dim) / audio_dim**0.5) + # Self-Attention (attn1) / Feedforward AdaLayerNorm-Zero mod params + # 6 base mod params for text cross-attn K,V; if cross_attn_adaln, also has mod params for Q + self.video_cross_attn_adaln = video_cross_attn_adaln + self.audio_cross_attn_adaln = audio_cross_attn_adaln + video_mod_param_num = 9 if self.video_cross_attn_adaln else 6 + audio_mod_param_num = 9 if self.audio_cross_attn_adaln else 6 + self.scale_shift_table = nn.Parameter(torch.randn(video_mod_param_num, dim) / dim**0.5) + self.audio_scale_shift_table = nn.Parameter(torch.randn(audio_mod_param_num, audio_dim) / audio_dim**0.5) + + # Prompt cross-attn (attn2) additional modulation params + self.cross_attn_adaln = video_cross_attn_adaln or audio_cross_attn_adaln + if self.cross_attn_adaln: + self.prompt_scale_shift_table = nn.Parameter(torch.randn(2, dim)) + self.audio_prompt_scale_shift_table = nn.Parameter(torch.randn(2, audio_dim)) # Per-layer a2v, v2a Cross-Attention mod params self.video_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, dim)) self.audio_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, audio_dim)) + @staticmethod + def get_mod_params( + scale_shift_table: torch.Tensor, temb: torch.Tensor, batch_size: int + ) -> tuple[torch.Tensor, ...]: + num_ada_params = scale_shift_table.shape[0] + ada_values = scale_shift_table[None, None].to(temb.device) + temb.reshape( + batch_size, temb.shape[1], num_ada_params, -1 + ) + ada_params = ada_values.unbind(dim=2) + return ada_params + def forward( self, hidden_states: torch.Tensor, @@ -442,143 +604,181 @@ def forward( temb_ca_audio_scale_shift: torch.Tensor, temb_ca_gate: torch.Tensor, temb_ca_audio_gate: torch.Tensor, + temb_prompt: torch.Tensor | None = None, + temb_prompt_audio: torch.Tensor | None = None, video_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, ca_video_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, ca_audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, encoder_attention_mask: torch.Tensor | None = None, audio_encoder_attention_mask: torch.Tensor | None = None, + self_attention_mask: torch.Tensor | None = None, + audio_self_attention_mask: torch.Tensor | None = None, a2v_cross_attention_mask: torch.Tensor | None = None, v2a_cross_attention_mask: torch.Tensor | None = None, + use_a2v_cross_attention: bool = True, + use_v2a_cross_attention: bool = True, + perturbation_mask: torch.Tensor | None = None, + all_perturbed: bool | None = None, ) -> torch.Tensor: batch_size = hidden_states.size(0) # 1. Video and Audio Self-Attention - norm_hidden_states = self.norm1(hidden_states) + # 1.1. Video Self-Attention + video_ada_params = self.get_mod_params(self.scale_shift_table, temb, batch_size) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = video_ada_params[:6] + if self.video_cross_attn_adaln: + shift_text_q, scale_text_q, gate_text_q = video_ada_params[6:9] - num_ada_params = self.scale_shift_table.shape[0] - ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape( - batch_size, temb.size(1), num_ada_params, -1 - ) - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2) + norm_hidden_states = self.norm1(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa - attn_hidden_states = self.attn1( - hidden_states=norm_hidden_states, - encoder_hidden_states=None, - query_rotary_emb=video_rotary_emb, - ) + video_self_attn_args = { + "hidden_states": norm_hidden_states, + "encoder_hidden_states": None, + "query_rotary_emb": video_rotary_emb, + "attention_mask": self_attention_mask, + } + if self.perturbed_attn: + video_self_attn_args["perturbation_mask"] = perturbation_mask + video_self_attn_args["all_perturbed"] = all_perturbed + + attn_hidden_states = self.attn1(**video_self_attn_args) hidden_states = hidden_states + attn_hidden_states * gate_msa - norm_audio_hidden_states = self.audio_norm1(audio_hidden_states) - - num_audio_ada_params = self.audio_scale_shift_table.shape[0] - audio_ada_values = self.audio_scale_shift_table[None, None].to(temb_audio.device) + temb_audio.reshape( - batch_size, temb_audio.size(1), num_audio_ada_params, -1 - ) + # 1.2. Audio Self-Attention + audio_ada_params = self.get_mod_params(self.audio_scale_shift_table, temb_audio, batch_size) audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_mlp, audio_scale_mlp, audio_gate_mlp = ( - audio_ada_values.unbind(dim=2) + audio_ada_params[:6] ) + if self.audio_cross_attn_adaln: + audio_shift_text_q, audio_scale_text_q, audio_gate_text_q = audio_ada_params[6:9] + + norm_audio_hidden_states = self.audio_norm1(audio_hidden_states) norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa) + audio_shift_msa - attn_audio_hidden_states = self.audio_attn1( - hidden_states=norm_audio_hidden_states, - encoder_hidden_states=None, - query_rotary_emb=audio_rotary_emb, - ) + audio_self_attn_args = { + "hidden_states": norm_audio_hidden_states, + "encoder_hidden_states": None, + "query_rotary_emb": audio_rotary_emb, + "attention_mask": audio_self_attention_mask, + } + if self.perturbed_attn: + audio_self_attn_args["perturbation_mask"] = perturbation_mask + audio_self_attn_args["all_perturbed"] = all_perturbed + + attn_audio_hidden_states = self.audio_attn1(**audio_self_attn_args) audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa - # 2. Video and Audio Cross-Attention with the text embeddings + # 2. Video and Audio Cross-Attention with the text embeddings (Q: Video or Audio; K,V: Text) + if self.cross_attn_adaln: + video_prompt_ada_params = self.get_mod_params(self.prompt_scale_shift_table, temb_prompt, batch_size) + shift_text_kv, scale_text_kv = video_prompt_ada_params + + audio_prompt_ada_params = self.get_mod_params( + self.audio_prompt_scale_shift_table, temb_prompt_audio, batch_size + ) + audio_shift_text_kv, audio_scale_text_kv = audio_prompt_ada_params + + # 2.1. Video-Text Cross-Attention (Q: Video; K,V: Test) norm_hidden_states = self.norm2(hidden_states) + if self.video_cross_attn_adaln: + norm_hidden_states = norm_hidden_states * (1 + scale_text_q) + shift_text_q + if self.cross_attn_adaln: + encoder_hidden_states = encoder_hidden_states * (1 + scale_text_kv) + shift_text_kv + attn_hidden_states = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, query_rotary_emb=None, attention_mask=encoder_attention_mask, ) + if self.video_cross_attn_adaln: + attn_hidden_states = attn_hidden_states * gate_text_q hidden_states = hidden_states + attn_hidden_states + # 2.2. Audio-Text Cross-Attention norm_audio_hidden_states = self.audio_norm2(audio_hidden_states) + if self.audio_cross_attn_adaln: + norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_text_q) + audio_shift_text_q + if self.cross_attn_adaln: + audio_encoder_hidden_states = audio_encoder_hidden_states * (1 + audio_scale_text_kv) + audio_shift_text_kv + attn_audio_hidden_states = self.audio_attn2( norm_audio_hidden_states, encoder_hidden_states=audio_encoder_hidden_states, query_rotary_emb=None, attention_mask=audio_encoder_attention_mask, ) + if self.audio_cross_attn_adaln: + attn_audio_hidden_states = attn_audio_hidden_states * audio_gate_text_q audio_hidden_states = audio_hidden_states + attn_audio_hidden_states # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention - norm_hidden_states = self.audio_to_video_norm(hidden_states) - norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states) - - # Combine global and per-layer cross attention modulation parameters - # Video - video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[:4, :] - video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :] - - video_ca_scale_shift_table = ( - video_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_scale_shift.dtype) - + temb_ca_scale_shift.reshape(batch_size, temb_ca_scale_shift.shape[1], 4, -1) - ).unbind(dim=2) - video_ca_gate = ( - video_per_layer_ca_gate[:, :, ...].to(temb_ca_gate.dtype) - + temb_ca_gate.reshape(batch_size, temb_ca_gate.shape[1], 1, -1) - ).unbind(dim=2) - - video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_scale_shift_table - a2v_gate = video_ca_gate[0].squeeze(2) - - # Audio - audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[:4, :] - audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :] - - audio_ca_scale_shift_table = ( - audio_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_audio_scale_shift.dtype) - + temb_ca_audio_scale_shift.reshape(batch_size, temb_ca_audio_scale_shift.shape[1], 4, -1) - ).unbind(dim=2) - audio_ca_gate = ( - audio_per_layer_ca_gate[:, :, ...].to(temb_ca_audio_gate.dtype) - + temb_ca_audio_gate.reshape(batch_size, temb_ca_audio_gate.shape[1], 1, -1) - ).unbind(dim=2) - - audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_scale_shift_table - v2a_gate = audio_ca_gate[0].squeeze(2) - - # Audio-to-Video Cross Attention: Q: Video; K,V: Audio - mod_norm_hidden_states = norm_hidden_states * (1 + video_a2v_ca_scale.squeeze(2)) + video_a2v_ca_shift.squeeze( - 2 - ) - mod_norm_audio_hidden_states = norm_audio_hidden_states * ( - 1 + audio_a2v_ca_scale.squeeze(2) - ) + audio_a2v_ca_shift.squeeze(2) - - a2v_attn_hidden_states = self.audio_to_video_attn( - mod_norm_hidden_states, - encoder_hidden_states=mod_norm_audio_hidden_states, - query_rotary_emb=ca_video_rotary_emb, - key_rotary_emb=ca_audio_rotary_emb, - attention_mask=a2v_cross_attention_mask, - ) + if use_a2v_cross_attention or use_v2a_cross_attention: + norm_hidden_states = self.audio_to_video_norm(hidden_states) + norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states) - hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states + # 3.1. Combine global and per-layer cross attention modulation parameters + # Video + video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[:4, :] + video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :] - # Video-to-Audio Cross Attention: Q: Audio; K,V: Video - mod_norm_hidden_states = norm_hidden_states * (1 + video_v2a_ca_scale.squeeze(2)) + video_v2a_ca_shift.squeeze( - 2 - ) - mod_norm_audio_hidden_states = norm_audio_hidden_states * ( - 1 + audio_v2a_ca_scale.squeeze(2) - ) + audio_v2a_ca_shift.squeeze(2) - - v2a_attn_hidden_states = self.video_to_audio_attn( - mod_norm_audio_hidden_states, - encoder_hidden_states=mod_norm_hidden_states, - query_rotary_emb=ca_audio_rotary_emb, - key_rotary_emb=ca_video_rotary_emb, - attention_mask=v2a_cross_attention_mask, - ) + video_ca_ada_params = self.get_mod_params(video_per_layer_ca_scale_shift, temb_ca_scale_shift, batch_size) + video_ca_gate_param = self.get_mod_params(video_per_layer_ca_gate, temb_ca_gate, batch_size) + + video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_ada_params + a2v_gate = video_ca_gate_param[0].squeeze(2) - audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states + # Audio + audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[:4, :] + audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :] + + audio_ca_ada_params = self.get_mod_params( + audio_per_layer_ca_scale_shift, temb_ca_audio_scale_shift, batch_size + ) + audio_ca_gate_param = self.get_mod_params(audio_per_layer_ca_gate, temb_ca_audio_gate, batch_size) + + audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_ada_params + v2a_gate = audio_ca_gate_param[0].squeeze(2) + + # 3.2. Audio-to-Video Cross Attention: Q: Video; K,V: Audio + if use_a2v_cross_attention: + mod_norm_hidden_states = norm_hidden_states * ( + 1 + video_a2v_ca_scale.squeeze(2) + ) + video_a2v_ca_shift.squeeze(2) + mod_norm_audio_hidden_states = norm_audio_hidden_states * ( + 1 + audio_a2v_ca_scale.squeeze(2) + ) + audio_a2v_ca_shift.squeeze(2) + + a2v_attn_hidden_states = self.audio_to_video_attn( + mod_norm_hidden_states, + encoder_hidden_states=mod_norm_audio_hidden_states, + query_rotary_emb=ca_video_rotary_emb, + key_rotary_emb=ca_audio_rotary_emb, + attention_mask=a2v_cross_attention_mask, + ) + + hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states + + # 3.3. Video-to-Audio Cross Attention: Q: Audio; K,V: Video + if use_v2a_cross_attention: + mod_norm_hidden_states = norm_hidden_states * ( + 1 + video_v2a_ca_scale.squeeze(2) + ) + video_v2a_ca_shift.squeeze(2) + mod_norm_audio_hidden_states = norm_audio_hidden_states * ( + 1 + audio_v2a_ca_scale.squeeze(2) + ) + audio_v2a_ca_shift.squeeze(2) + + v2a_attn_hidden_states = self.video_to_audio_attn( + mod_norm_audio_hidden_states, + encoder_hidden_states=mod_norm_hidden_states, + query_rotary_emb=ca_audio_rotary_emb, + key_rotary_emb=ca_video_rotary_emb, + attention_mask=v2a_cross_attention_mask, + ) + + audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states # 4. Feedforward norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp) + shift_mlp @@ -918,6 +1118,8 @@ def __init__( pos_embed_max_pos: int = 20, base_height: int = 2048, base_width: int = 2048, + gated_attn: bool = False, + cross_attn_mod: bool = False, audio_in_channels: int = 128, # Audio Arguments audio_out_channels: int | None = 128, audio_patch_size: int = 1, @@ -929,6 +1131,8 @@ def __init__( audio_pos_embed_max_pos: int = 20, audio_sampling_rate: int = 16000, audio_hop_length: int = 160, + audio_gated_attn: bool = False, + audio_cross_attn_mod: bool = False, num_layers: int = 48, # Shared arguments activation_fn: str = "gelu-approximate", qk_norm: str = "rms_norm_across_heads", @@ -943,6 +1147,8 @@ def __init__( timestep_scale_multiplier: int = 1000, cross_attn_timestep_scale_multiplier: int = 1000, rope_type: str = "interleaved", + use_prompt_embeddings=True, + perturbed_attn: bool = False, ) -> None: super().__init__() @@ -956,17 +1162,25 @@ def __init__( self.audio_proj_in = nn.Linear(audio_in_channels, audio_inner_dim) # 2. Prompt embeddings - self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) - self.audio_caption_projection = PixArtAlphaTextProjection( - in_features=caption_channels, hidden_size=audio_inner_dim - ) + if use_prompt_embeddings: + # LTX-2.0; LTX-2.3 uses per-modality feature projections in the connector instead + self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) + self.audio_caption_projection = PixArtAlphaTextProjection( + in_features=caption_channels, hidden_size=audio_inner_dim + ) # 3. Timestep Modulation Params and Embedding + self.prompt_modulation = cross_attn_mod or audio_cross_attn_mod # used by LTX-2.3 + # 3.1. Global Timestep Modulation Parameters (except for cross-attention) and timestep + size embedding # time_embed and audio_time_embed calculate both the timestep embedding and (global) modulation parameters - self.time_embed = LTX2AdaLayerNormSingle(inner_dim, num_mod_params=6, use_additional_conditions=False) + video_time_emb_mod_params = 9 if cross_attn_mod else 6 + audio_time_emb_mod_params = 9 if audio_cross_attn_mod else 6 + self.time_embed = LTX2AdaLayerNormSingle( + inner_dim, num_mod_params=video_time_emb_mod_params, use_additional_conditions=False + ) self.audio_time_embed = LTX2AdaLayerNormSingle( - audio_inner_dim, num_mod_params=6, use_additional_conditions=False + audio_inner_dim, num_mod_params=audio_time_emb_mod_params, use_additional_conditions=False ) # 3.2. Global Cross Attention Modulation Parameters @@ -995,6 +1209,13 @@ def __init__( self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) self.audio_scale_shift_table = nn.Parameter(torch.randn(2, audio_inner_dim) / audio_inner_dim**0.5) + # 3.4. Prompt Scale/Shift Modulation parameters (LTX-2.3) + if self.prompt_modulation: + self.prompt_adaln = LTX2AdaLayerNormSingle(inner_dim, num_mod_params=2, use_additional_conditions=False) + self.audio_prompt_adaln = LTX2AdaLayerNormSingle( + audio_inner_dim, num_mod_params=2, use_additional_conditions=False + ) + # 4. Rotary Positional Embeddings (RoPE) # Self-Attention self.rope = LTX2AudioVideoRotaryPosEmbed( @@ -1071,6 +1292,10 @@ def __init__( audio_num_attention_heads=audio_num_attention_heads, audio_attention_head_dim=audio_attention_head_dim, audio_cross_attention_dim=audio_cross_attention_dim, + video_gated_attn=gated_attn, + video_cross_attn_adaln=cross_attn_mod, + audio_gated_attn=audio_gated_attn, + audio_cross_attn_adaln=audio_cross_attn_mod, qk_norm=qk_norm, activation_fn=activation_fn, attention_bias=attention_bias, @@ -1078,6 +1303,7 @@ def __init__( eps=norm_eps, elementwise_affine=norm_elementwise_affine, rope_type=rope_type, + perturbed_attn=perturbed_attn, ) for _ in range(num_layers) ] @@ -1101,8 +1327,12 @@ def forward( audio_encoder_hidden_states: torch.Tensor, timestep: torch.LongTensor, audio_timestep: torch.LongTensor | None = None, + sigma: torch.Tensor | None = None, + audio_sigma: torch.Tensor | None = None, encoder_attention_mask: torch.Tensor | None = None, audio_encoder_attention_mask: torch.Tensor | None = None, + self_attention_mask: torch.Tensor | None = None, + audio_self_attention_mask: torch.Tensor | None = None, num_frames: int | None = None, height: int | None = None, width: int | None = None, @@ -1110,6 +1340,9 @@ def forward( audio_num_frames: int | None = None, video_coords: torch.Tensor | None = None, audio_coords: torch.Tensor | None = None, + isolate_modalities: bool = False, + spatio_temporal_guidance_blocks: list[int] | None = None, + perturbation_mask: torch.Tensor | None = None, attention_kwargs: dict[str, Any] | None = None, return_dict: bool = True, ) -> torch.Tensor: @@ -1131,10 +1364,19 @@ def forward( audio_timestep (`torch.Tensor`, *optional*): Input timestep of shape `(batch_size,)` or `(batch_size, num_audio_tokens)` for audio modulation params. This is only used by certain pipelines such as the I2V pipeline. + sigma (`torch.Tensor`, *optional*): + Input scaled timestep of shape (batch_size,). Used for video prompt cross attention modulation in + models such as LTX-2.3. + audio_sigma (`torch.Tensor`, *optional*): + Input scaled timestep of shape (batch_size,). Used for audio prompt cross attention modulation in + models such as LTX-2.3. If `sigma` is supplied but `audio_sigma` is not, `audio_sigma` will be set to + the provided `sigma` value. encoder_attention_mask (`torch.Tensor`, *optional*): Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)`. audio_encoder_attention_mask (`torch.Tensor`, *optional*): Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)` for audio modeling. + self_attention_mask (`torch.Tensor`, *optional*): + Optional multiplicative self-attention mask of shape `(batch_size, seq_len, seq_len)`. num_frames (`int`, *optional*): The number of latent video frames. Used if calculating the video coordinates for RoPE. height (`int`, *optional*): @@ -1152,6 +1394,17 @@ def forward( audio_coords (`torch.Tensor`, *optional*): The audio coordinates to be used when calculating the rotary positional embeddings (RoPE) of shape `(batch_size, 1, num_audio_tokens, 2)`. If not supplied, this will be calculated inside `forward`. + isolate_modalities (`bool`, *optional*, defaults to `False`): + Whether to isolate each modality by turning off cross-modality (audio-to-video and video-to-audio) + cross attention (for all blocks). Use for modality guidance in LTX-2.3. + spatio_temporal_guidance_blocks (`list[int]`, *optional*, defaults to `None`): + The transformer block indices at which to apply spatio-temporal guidance (STG), which shortcuts the + self-attention operations by simply using the values rather than the full scaled dot-product attention + (SDPA) operation. If `None` or empty, STG will not be applied to any block. + perturbation_mask (`torch.Tensor`, *optional*): + Perturbation mask for STG of shape `(batch_size,)` or `(batch_size, 1, 1)`. Should be 0 at batch + elements where STG should be applied and 1 elsewhere. If STG is being used but `peturbation_mask` is + not supplied, will default to applying STG (perturbing) all batch elements. attention_kwargs (`dict[str, Any]`, *optional*): Optional dict of keyword args to be passed to the attention processor. return_dict (`bool`, *optional*, defaults to `True`): @@ -1165,6 +1418,7 @@ def forward( """ # Determine timestep for audio. audio_timestep = audio_timestep if audio_timestep is not None else timestep + audio_sigma = audio_sigma if audio_sigma is not None else sigma # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: @@ -1175,6 +1429,32 @@ def forward( audio_encoder_attention_mask = (1 - audio_encoder_attention_mask.to(audio_hidden_states.dtype)) * -10000.0 audio_encoder_attention_mask = audio_encoder_attention_mask.unsqueeze(1) + if self_attention_mask is not None and self_attention_mask.ndim == 3: + # Convert to additive attention mask in log-space where 0 (masked) values get mapped to a large negative + # number and positive values are mapped to their logarithm. + dtype_finfo = torch.finfo(hidden_states.dtype) + additive_self_attn_mask = torch.full_like(self_attention_mask, dtype_finfo.min, dtype=hidden_states.dtype) + unmasked_entries = self_attention_mask > 0 + if torch.any(unmasked_entries): + additive_self_attn_mask[unmasked_entries] = torch.log( + self_attention_mask[unmasked_entries].clamp(min=dtype_finfo.tiny) + ).to(hidden_states.dtype) + self_attention_mask = additive_self_attn_mask.unsqueeze(1) # [batch_size, 1, seq_len, seq_len] + + if audio_self_attention_mask is not None and audio_self_attention_mask.ndim == 3: + # Convert to additive attention mask in log-space where 0 (masked) values get mapped to a large negative + # number and positive values are mapped to their logarithm. + dtype_finfo = torch.finfo(hidden_states.dtype) + additive_self_attn_mask = torch.full_like( + audio_self_attention_mask, dtype_finfo.min, dtype=hidden_states.dtype + ) + unmasked_entries = audio_self_attention_mask > 0 + if torch.any(unmasked_entries): + additive_self_attn_mask[unmasked_entries] = torch.log( + audio_self_attention_mask[unmasked_entries].clamp(min=dtype_finfo.tiny) + ).to(hidden_states.dtype) + audio_self_attention_mask = additive_self_attn_mask.unsqueeze(1) # [batch_size, 1, seq_len, seq_len] + batch_size = hidden_states.size(0) # 1. Prepare RoPE positional embeddings @@ -1223,6 +1503,19 @@ def forward( temb_audio = temb_audio.view(batch_size, -1, temb_audio.size(-1)) audio_embedded_timestep = audio_embedded_timestep.view(batch_size, -1, audio_embedded_timestep.size(-1)) + if self.prompt_modulation: + # LTX-2.3 + temb_prompt, _ = self.prompt_adaln( + sigma.flatten(), batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + temb_prompt_audio, _ = self.audio_prompt_adaln( + audio_sigma.flatten(), batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype + ) + temb_prompt = temb_prompt.view(batch_size, -1, temb_prompt.size(-1)) + temb_prompt_audio = temb_prompt_audio.view(batch_size, -1, temb_prompt_audio.size(-1)) + else: + temb_prompt = temb_prompt_audio = None + # 3.2. Prepare global modality cross attention modulation parameters video_cross_attn_scale_shift, _ = self.av_cross_attn_video_scale_shift( timestep.flatten(), @@ -1254,15 +1547,30 @@ def forward( ) audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.view(batch_size, -1, audio_cross_attn_v2a_gate.shape[-1]) - # 4. Prepare prompt embeddings - encoder_hidden_states = self.caption_projection(encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1)) + # 4. Prepare prompt embeddings (LTX-2.0) + if self.config.use_prompt_embeddings: + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1)) - audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states) - audio_encoder_hidden_states = audio_encoder_hidden_states.view(batch_size, -1, audio_hidden_states.size(-1)) + audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states) + audio_encoder_hidden_states = audio_encoder_hidden_states.view( + batch_size, -1, audio_hidden_states.size(-1) + ) # 5. Run transformer blocks - for block in self.transformer_blocks: + spatio_temporal_guidance_blocks = spatio_temporal_guidance_blocks or [] + if len(spatio_temporal_guidance_blocks) > 0 and perturbation_mask is None: + # If STG is being used and perturbation_mask is not supplied, default to perturbing all batch elements. + perturbation_mask = torch.zeros((batch_size,)) + if perturbation_mask is not None and perturbation_mask.ndim == 1: + perturbation_mask = perturbation_mask[:, None, None] # unsqueeze to 3D to broadcast with hidden_states + all_perturbed = torch.all(perturbation_mask == 0) if perturbation_mask is not None else False + stg_blocks = set(spatio_temporal_guidance_blocks) + + for block_idx, block in enumerate(self.transformer_blocks): + block_perturbation_mask = perturbation_mask if block_idx in stg_blocks else None + block_all_perturbed = all_perturbed if block_idx in stg_blocks else False + if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states, audio_hidden_states = self._gradient_checkpointing_func( block, @@ -1276,12 +1584,22 @@ def forward( audio_cross_attn_scale_shift, video_cross_attn_a2v_gate, audio_cross_attn_v2a_gate, + temb_prompt, + temb_prompt_audio, video_rotary_emb, audio_rotary_emb, video_cross_attn_rotary_emb, audio_cross_attn_rotary_emb, encoder_attention_mask, audio_encoder_attention_mask, + self_attention_mask, + audio_self_attention_mask, + None, # a2v_cross_attention_mask + None, # v2a_cross_attention_mask + not isolate_modalities, # use_a2v_cross_attention + not isolate_modalities, # use_v2a_cross_attention + block_perturbation_mask, + block_all_perturbed, ) else: hidden_states, audio_hidden_states = block( @@ -1295,12 +1613,22 @@ def forward( temb_ca_audio_scale_shift=audio_cross_attn_scale_shift, temb_ca_gate=video_cross_attn_a2v_gate, temb_ca_audio_gate=audio_cross_attn_v2a_gate, + temb_prompt=temb_prompt, + temb_prompt_audio=temb_prompt_audio, video_rotary_emb=video_rotary_emb, audio_rotary_emb=audio_rotary_emb, ca_video_rotary_emb=video_cross_attn_rotary_emb, ca_audio_rotary_emb=audio_cross_attn_rotary_emb, encoder_attention_mask=encoder_attention_mask, audio_encoder_attention_mask=audio_encoder_attention_mask, + self_attention_mask=self_attention_mask, + audio_self_attention_mask=audio_self_attention_mask, + a2v_cross_attention_mask=None, + v2a_cross_attention_mask=None, + use_a2v_cross_attention=not isolate_modalities, + use_v2a_cross_attention=not isolate_modalities, + perturbation_mask=block_perturbation_mask, + all_perturbed=block_all_perturbed, ) # 6. Output layers (including unpatchification) diff --git a/src/diffusers/pipelines/ltx2/__init__.py b/src/diffusers/pipelines/ltx2/__init__.py index d6a408d5c546..7177faaf3486 100644 --- a/src/diffusers/pipelines/ltx2/__init__.py +++ b/src/diffusers/pipelines/ltx2/__init__.py @@ -28,7 +28,7 @@ _import_structure["pipeline_ltx2_condition"] = ["LTX2ConditionPipeline"] _import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"] _import_structure["pipeline_ltx2_latent_upsample"] = ["LTX2LatentUpsamplePipeline"] - _import_structure["vocoder"] = ["LTX2Vocoder"] + _import_structure["vocoder"] = ["LTX2Vocoder", "LTX2VocoderWithBWE"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -44,7 +44,7 @@ from .pipeline_ltx2_condition import LTX2ConditionPipeline from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline from .pipeline_ltx2_latent_upsample import LTX2LatentUpsamplePipeline - from .vocoder import LTX2Vocoder + from .vocoder import LTX2Vocoder, LTX2VocoderWithBWE else: import sys diff --git a/src/diffusers/pipelines/ltx2/connectors.py b/src/diffusers/pipelines/ltx2/connectors.py index 4b2a81a9dc2c..3f721a2cfbf9 100644 --- a/src/diffusers/pipelines/ltx2/connectors.py +++ b/src/diffusers/pipelines/ltx2/connectors.py @@ -1,3 +1,5 @@ +import math + import torch import torch.nn as nn import torch.nn.functional as F @@ -9,6 +11,79 @@ from ...models.transformers.transformer_ltx2 import LTX2Attention, LTX2AudioVideoAttnProcessor +def per_layer_masked_mean_norm( + text_hidden_states: torch.Tensor, + sequence_lengths: torch.Tensor, + device: str | torch.device, + padding_side: str = "left", + scale_factor: int = 8, + eps: float = 1e-6, +): + """ + Performs per-batch per-layer normalization using a masked mean and range on per-layer text encoder hidden_states. + Respects the padding of the hidden states. + + Args: + text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`): + Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`). + sequence_lengths (`torch.Tensor of shape `(batch_size,)`): + The number of valid (non-padded) tokens for each batch instance. + device: (`str` or `torch.device`, *optional*): + torch device to place the resulting embeddings on + padding_side: (`str`, *optional*, defaults to `"left"`): + Whether the text tokenizer performs padding on the `"left"` or `"right"`. + scale_factor (`int`, *optional*, defaults to `8`): + Scaling factor to multiply the normalized hidden states by. + eps (`float`, *optional*, defaults to `1e-6`): + A small positive value for numerical stability when performing normalization. + + Returns: + `torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`: + Normed and flattened text encoder hidden states. + """ + batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape + original_dtype = text_hidden_states.dtype + + # Create padding mask + token_indices = torch.arange(seq_len, device=device).unsqueeze(0) + if padding_side == "right": + # For right padding, valid tokens are from 0 to sequence_length-1 + mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len] + elif padding_side == "left": + # For left padding, valid tokens are from (T - sequence_length) to T-1 + start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1] + mask = token_indices >= start_indices # [B, T] + else: + raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") + mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1] + + # Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len) + masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) + num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) + masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) + + # Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len) + x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) + x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) + + # Normalization + normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) + normalized_hidden_states = normalized_hidden_states * scale_factor + + # Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.flatten(2) + mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) + normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) + return normalized_hidden_states + + +def per_token_rms_norm(text_encoder_hidden_states: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: + variance = torch.mean(text_encoder_hidden_states**2, dim=2, keepdim=True) + norm_text_encoder_hidden_states = text_encoder_hidden_states + torch.rsqrt(variance + eps) + return norm_text_encoder_hidden_states + + class LTX2RotaryPosEmbed1d(nn.Module): """ 1D rotary positional embeddings (RoPE) for the LTX 2.0 text encoder connectors. @@ -106,6 +181,7 @@ def __init__( activation_fn: str = "gelu-approximate", eps: float = 1e-6, rope_type: str = "interleaved", + apply_gated_attention: bool = False, ): super().__init__() @@ -115,8 +191,9 @@ def __init__( heads=num_attention_heads, kv_heads=num_attention_heads, dim_head=attention_head_dim, - processor=LTX2AudioVideoAttnProcessor(), rope_type=rope_type, + apply_gated_attention=apply_gated_attention, + processor=LTX2AudioVideoAttnProcessor(), ) self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False) @@ -160,6 +237,7 @@ def __init__( eps: float = 1e-6, causal_temporal_positioning: bool = False, rope_type: str = "interleaved", + gated_attention: bool = False, ): super().__init__() self.num_attention_heads = num_attention_heads @@ -188,6 +266,7 @@ def __init__( num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, rope_type=rope_type, + apply_gated_attention=gated_attention, ) for _ in range(num_layers) ] @@ -260,24 +339,36 @@ class LTX2TextConnectors(ModelMixin, PeftAdapterMixin, ConfigMixin): @register_to_config def __init__( self, - caption_channels: int, - text_proj_in_factor: int, - video_connector_num_attention_heads: int, - video_connector_attention_head_dim: int, - video_connector_num_layers: int, - video_connector_num_learnable_registers: int | None, - audio_connector_num_attention_heads: int, - audio_connector_attention_head_dim: int, - audio_connector_num_layers: int, - audio_connector_num_learnable_registers: int | None, - connector_rope_base_seq_len: int, - rope_theta: float, - rope_double_precision: bool, - causal_temporal_positioning: bool, + caption_channels: int = 3840, # default Gemma-3-12B text encoder hidden_size + text_proj_in_factor: int = 49, # num_layers + 1 for embedding layer = 48 + 1 for Gemma-3-12B + video_connector_num_attention_heads: int = 30, + video_connector_attention_head_dim: int = 128, + video_connector_num_layers: int = 2, + video_connector_num_learnable_registers: int | None = 128, + video_gated_attn: bool = False, + audio_connector_num_attention_heads: int = 30, + audio_connector_attention_head_dim: int = 128, + audio_connector_num_layers: int = 2, + audio_connector_num_learnable_registers: int | None = 128, + audio_gated_attn: bool = False, + connector_rope_base_seq_len: int = 4096, + rope_theta: float = 10000.0, + rope_double_precision: bool = True, + causal_temporal_positioning: bool = False, rope_type: str = "interleaved", + per_modality_projections: bool = False, + video_hidden_dim: int = 4096, + audio_hidden_dim: int = 2048, + proj_bias: bool = False, ): super().__init__() - self.text_proj_in = nn.Linear(caption_channels * text_proj_in_factor, caption_channels, bias=False) + text_encoder_dim = caption_channels * text_proj_in_factor + if per_modality_projections: + self.video_text_proj_in = nn.Linear(text_encoder_dim, video_hidden_dim, bias=proj_bias) + self.audio_text_proj_in = nn.Linear(text_encoder_dim, audio_hidden_dim, bias=proj_bias) + else: + self.text_proj_in = nn.Linear(text_encoder_dim, caption_channels, bias=proj_bias) + self.video_connector = LTX2ConnectorTransformer1d( num_attention_heads=video_connector_num_attention_heads, attention_head_dim=video_connector_attention_head_dim, @@ -288,6 +379,7 @@ def __init__( rope_double_precision=rope_double_precision, causal_temporal_positioning=causal_temporal_positioning, rope_type=rope_type, + gated_attention=video_gated_attn, ) self.audio_connector = LTX2ConnectorTransformer1d( num_attention_heads=audio_connector_num_attention_heads, @@ -299,26 +391,86 @@ def __init__( rope_double_precision=rope_double_precision, causal_temporal_positioning=causal_temporal_positioning, rope_type=rope_type, + gated_attention=audio_gated_attn, ) def forward( - self, text_encoder_hidden_states: torch.Tensor, attention_mask: torch.Tensor, additive_mask: bool = False - ): - # Convert to additive attention mask, if necessary - if not additive_mask: - text_dtype = text_encoder_hidden_states.dtype - attention_mask = (attention_mask - 1).reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) - attention_mask = attention_mask.to(text_dtype) * torch.finfo(text_dtype).max - - text_encoder_hidden_states = self.text_proj_in(text_encoder_hidden_states) - - video_text_embedding, new_attn_mask = self.video_connector(text_encoder_hidden_states, attention_mask) - - attn_mask = (new_attn_mask < 1e-6).to(torch.int64) - attn_mask = attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1) - video_text_embedding = video_text_embedding * attn_mask - new_attn_mask = attn_mask.squeeze(-1) - - audio_text_embedding, _ = self.audio_connector(text_encoder_hidden_states, attention_mask) - - return video_text_embedding, audio_text_embedding, new_attn_mask + self, + text_encoder_hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + padding_side: str = "left", + scale_factor: int = 8, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Given per-layer text encoder hidden_states, extracts features and runs per-modality connectors to get text + embeddings for the LTX-2.X DiT models. + + Args: + text_encoder_hidden_states (`torch.Tensor`)): + Per-layer text encoder hidden_states. Can either be 4D with shape `(batch_size, seq_len, + caption_channels, text_proj_in_factor) or 3D with the last two dimensions flattened. + attention_mask (`torch.Tensor` of shape `(batch_size, seq_len)`): + Multiplicative binary attention mask where 1s indicate unmasked positions and 0s indicate masked + positions. + padding_side (`str`, *optional*, defaults to `"left"`): + The padding side used by the text encoder's text encoder (either `"left"` or `"right"`). Defaults to + `"left"` as this is what the default Gemma3-12B text encoder uses. Only used if + `per_modality_projections` is `False` (LTX-2.0 models). + scale_factor (`int`, *optional*, defaults to `8`): + Scale factor for masked mean/range normalization. Only used if `per_modality_projections` is `False` + (LTX-2.0 models). + """ + if text_encoder_hidden_states.ndim == 3: + # Ensure shape is [batch_size, seq_len, caption_channels, text_proj_in_factor] + text_encoder_hidden_states = text_encoder_hidden_states.unflatten(2, (self.config.caption_channels, -1)) + + if self.config.per_modality_projections: + # LTX-2.3 + norm_text_encoder_hidden_states = per_token_rms_norm(text_encoder_hidden_states) + + norm_text_encoder_hidden_states = norm_text_encoder_hidden_states.flatten(2, 3) + bool_mask = attention_mask.bool().unsqueeze(-1) + norm_text_encoder_hidden_states = torch.where( + bool_mask, norm_text_encoder_hidden_states, torch.zeros_like(norm_text_encoder_hidden_states) + ) + + # Rescale norms with respect to video and audio dims for feature extractors + video_scale_factor = math.sqrt(self.config.video_hidden_dim / self.config.caption_channels) + video_norm_text_emb = norm_text_encoder_hidden_states * video_scale_factor + audio_scale_factor = math.sqrt(self.config.audio_hidden_dim / self.config.caption_channels) + audio_norm_text_emb = norm_text_encoder_hidden_states * audio_scale_factor + + # Per-Modality Feature extractors + video_text_emb_proj = self.video_text_proj_in(video_norm_text_emb) + audio_text_emb_proj = self.audio_text_proj_in(audio_norm_text_emb) + else: + # LTX-2.0 + sequence_lengths = attention_mask.sum(dim=-1) + norm_text_encoder_hidden_states = per_layer_masked_mean_norm( + text_hidden_states=text_encoder_hidden_states, + sequence_lengths=sequence_lengths, + device=text_encoder_hidden_states.device, + padding_side=padding_side, + scale_factor=scale_factor, + ) + + text_emb_proj = self.text_proj_in(norm_text_encoder_hidden_states) + video_text_emb_proj = text_emb_proj + audio_text_emb_proj = text_emb_proj + + # Convert to additive attention mask for connectors + text_dtype = video_text_emb_proj.dtype + attention_mask = (attention_mask.to(torch.int64) - 1).to(text_dtype) + attention_mask = attention_mask.reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) + add_attn_mask = attention_mask * torch.finfo(text_dtype).max + + video_text_embedding, video_attn_mask = self.video_connector(video_text_emb_proj, add_attn_mask) + + # Convert video attn mask to binary (multiplicative) mask and mask video text embedding + binary_attn_mask = (video_attn_mask < 1e-6).to(torch.int64) + binary_attn_mask = binary_attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1) + video_text_embedding = video_text_embedding * binary_attn_mask + + audio_text_embedding, _ = self.audio_connector(audio_text_emb_proj, add_attn_mask) + + return video_text_embedding, audio_text_embedding, binary_attn_mask.squeeze(-1) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 037840360137..c7c02f5ae622 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -31,7 +31,7 @@ from ..pipeline_utils import DiffusionPipeline from .connectors import LTX2TextConnectors from .pipeline_output import LTX2PipelineOutput -from .vocoder import LTX2Vocoder +from .vocoder import LTX2Vocoder, LTX2VocoderWithBWE if is_torch_xla_available(): @@ -221,7 +221,7 @@ def __init__( tokenizer: GemmaTokenizer | GemmaTokenizerFast, connectors: LTX2TextConnectors, transformer: LTX2VideoTransformer3DModel, - vocoder: LTX2Vocoder, + vocoder: LTX2Vocoder | LTX2VocoderWithBWE, ): super().__init__() @@ -268,73 +268,6 @@ def __init__( self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024 ) - @staticmethod - def _pack_text_embeds( - text_hidden_states: torch.Tensor, - sequence_lengths: torch.Tensor, - device: str | torch.device, - padding_side: str = "left", - scale_factor: int = 8, - eps: float = 1e-6, - ) -> torch.Tensor: - """ - Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and - per-layer in a masked fashion (only over non-padded positions). - - Args: - text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`): - Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`). - sequence_lengths (`torch.Tensor of shape `(batch_size,)`): - The number of valid (non-padded) tokens for each batch instance. - device: (`str` or `torch.device`, *optional*): - torch device to place the resulting embeddings on - padding_side: (`str`, *optional*, defaults to `"left"`): - Whether the text tokenizer performs padding on the `"left"` or `"right"`. - scale_factor (`int`, *optional*, defaults to `8`): - Scaling factor to multiply the normalized hidden states by. - eps (`float`, *optional*, defaults to `1e-6`): - A small positive value for numerical stability when performing normalization. - - Returns: - `torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`: - Normed and flattened text encoder hidden states. - """ - batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape - original_dtype = text_hidden_states.dtype - - # Create padding mask - token_indices = torch.arange(seq_len, device=device).unsqueeze(0) - if padding_side == "right": - # For right padding, valid tokens are from 0 to sequence_length-1 - mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len] - elif padding_side == "left": - # For left padding, valid tokens are from (T - sequence_length) to T-1 - start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1] - mask = token_indices >= start_indices # [B, T] - else: - raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") - mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1] - - # Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len) - masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) - num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) - masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) - - # Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len) - x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) - x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) - - # Normalization - normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) - normalized_hidden_states = normalized_hidden_states * scale_factor - - # Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers) - normalized_hidden_states = normalized_hidden_states.flatten(2) - mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) - normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) - normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) - return normalized_hidden_states - def _get_gemma_prompt_embeds( self, prompt: str | list[str], @@ -387,16 +320,7 @@ def _get_gemma_prompt_embeds( ) text_encoder_hidden_states = text_encoder_outputs.hidden_states text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) - sequence_lengths = prompt_attention_mask.sum(dim=-1) - - prompt_embeds = self._pack_text_embeds( - text_encoder_hidden_states, - sequence_lengths, - device=device, - padding_side=self.tokenizer.padding_side, - scale_factor=scale_factor, - ) - prompt_embeds = prompt_embeds.to(dtype=dtype) + prompt_embeds = text_encoder_hidden_states.flatten(2, 3).to(dtype=dtype) # Pack to 3D # duplicate text embeddings for each generation per prompt, using mps friendly method _, seq_len, _ = prompt_embeds.shape @@ -504,6 +428,9 @@ def check_inputs( negative_prompt_embeds=None, prompt_attention_mask=None, negative_prompt_attention_mask=None, + spatio_temporal_guidance_blocks=None, + stg_scale=None, + audio_stg_scale=None, ): if height % 32 != 0 or width % 32 != 0: raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") @@ -547,6 +474,12 @@ def check_inputs( f" {negative_prompt_attention_mask.shape}." ) + if ((stg_scale > 0.0) or (audio_stg_scale > 0.0)) and not spatio_temporal_guidance_blocks: + raise ValueError( + "Spatio-Temporal Guidance (STG) is specified but no STG blocks are supplied. Please supply a list of" + "block indices at which to apply STG in `spatio_temporal_guidance_blocks`" + ) + @staticmethod def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. @@ -757,9 +690,41 @@ def guidance_scale(self): def guidance_rescale(self): return self._guidance_rescale + @property + def stg_scale(self): + return self._stg_scale + + @property + def modality_scale(self): + return self._modality_scale + + @property + def audio_guidance_scale(self): + return self._audio_guidance_scale + + @property + def audio_guidance_rescale(self): + return self._audio_guidance_rescale + + @property + def audio_stg_scale(self): + return self._audio_stg_scale + + @property + def audio_modality_scale(self): + return self._audio_modality_scale + @property def do_classifier_free_guidance(self): - return self._guidance_scale > 1.0 + return (self._guidance_scale > 1.0) or (self._audio_guidance_scale > 1.0) + + @property + def do_spatio_temporal_guidance(self): + return (self._stg_scale > 0.0) or (self._audio_stg_scale > 0.0) + + @property + def do_modality_isolation_guidance(self): + return (self._modality_scale > 1.0) or (self._audio_modality_scale > 1.0) @property def num_timesteps(self): @@ -791,7 +756,14 @@ def __call__( sigmas: list[float] | None = None, timesteps: list[int] = None, guidance_scale: float = 4.0, + stg_scale: float = 0.0, + modality_scale: float = 1.0, guidance_rescale: float = 0.0, + audio_guidance_scale: float | None = None, + audio_stg_scale: float | None = None, + audio_modality_scale: float | None = None, + audio_guidance_rescale: float | None = None, + spatio_temporal_guidance_blocks: list[int] | None = None, noise_scale: float = 0.0, num_videos_per_prompt: int = 1, generator: torch.Generator | list[torch.Generator] | None = None, @@ -841,13 +813,47 @@ def __call__( Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to - the text `prompt`, usually at the expense of lower image quality. + the text `prompt`, usually at the expense of lower image quality. Used for the video modality (there is + a separate value `audio_guidance_scale` for the audio modality). + stg_scale (`float`, *optional*, defaults to `0.0`): + Video guidance scale for Spatio-Temporal Guidance (STG), proposed in [Spatiotemporal Skip Guidance for + Enhanced Video Diffusion Sampling](https://arxiv.org/abs/2411.18664). STG uses a CFG-like estimate + where we move the sample away from a weak sample from a perturbed version of the denoising model. + Enabling STG will result in an additional denoising model forward pass; the default value of `0.0` + means that STG is disabled. + modality_scale (`float`, *optional*, defaults to `1.0`): + Video guidance scale for LTX-2.X modality isolation guidance, where we move the sample away from a + weaker sample generated by the denoising model withy cross-modality (audio-to-video and video-to-audio) + cross attention disabled using a CFG-like estimate. Enabling modality guidance will result in an + additional denoising model forward pass; the default value of `1.0` means that modality guidance is + disabled. guidance_rescale (`float`, *optional*, defaults to 0.0): Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when - using zero terminal SNR. + using zero terminal SNR. Used for the video modality. + audio_guidance_scale (`float`, *optional* defaults to `None`): + Audio guidance scale for CFG with respect to the negative prompt. The CFG update rule is the same for + video and audio, but they can use different values for the guidance scale. The LTX-2.X authors suggest + that the `audio_guidance_scale` should be higher relative to the video `guidance_scale` (e.g. for + LTX-2.3 they suggest 3.0 for video and 7.0 for audio). If `None`, defaults to the video value + `guidance_scale`. + audio_stg_scale (`float`, *optional*, defaults to `None`): + Audio guidance scale for STG. As with CFG, the STG update rule is otherwise the same for video and + audio. For LTX-2.3, a value of 1.0 is suggested for both video and audio. If `None`, defaults to the + video value `stg_scale`. + audio_modality_scale (`float`, *optional*, defaults to `None`): + Audio guidance scale for LTX-2.X modality isolation guidance. As with CFG, the modality guidance rule + is otherwise the same for video and audio. For LTX-2.3, a value of 3.0 is suggested for both video and + audio. If `None`, defaults to the video value `modality_scale`. + audio_guidance_rescale (`float`, *optional*, defaults to `None`): + A separate guidance rescale factor for the audio modality. If `None`, defaults to the video value + `guidance_rescale`. + spatio_temporal_guidance_blocks (`list[int]`, *optional*, defaults to `None`): + The zero-indexed transformer block indices at which to apply STG. Must be supplied if STG is used + (`stg_scale` or `audio_stg_scale` is greater than `0`). A value of `[29]` is recommended for LTX-2.0 + and `[28]` is recommended for LTX-2.3. noise_scale (`float`, *optional*, defaults to `0.0`): The interpolation factor between random noise and denoised latents at each timestep. Applying noise to the `latents` and `audio_latents` before continue denoising. @@ -910,6 +916,11 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + audio_guidance_scale = audio_guidance_scale or guidance_scale + audio_stg_scale = audio_stg_scale or stg_scale + audio_modality_scale = audio_modality_scale or modality_scale + audio_guidance_rescale = audio_guidance_rescale or guidance_rescale + # 1. Check inputs. Raise error if not correct self.check_inputs( prompt=prompt, @@ -920,10 +931,21 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, + spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks, + stg_scale=stg_scale, + audio_stg_scale=audio_stg_scale, ) + # Per-modality guidance scales (video, audio) self._guidance_scale = guidance_scale + self._stg_scale = stg_scale + self._modality_scale = modality_scale self._guidance_rescale = guidance_rescale + self._audio_guidance_scale = audio_guidance_scale + self._audio_stg_scale = audio_stg_scale + self._audio_modality_scale = audio_modality_scale + self._audio_guidance_rescale = audio_guidance_rescale + self._attention_kwargs = attention_kwargs self._interrupt = False self._current_timestep = None @@ -960,9 +982,11 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) - additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0 + tokenizer_padding_side = "left" # Padding side for default Gemma3-12B text encoder + if getattr(self, "tokenizer", None) is not None: + tokenizer_padding_side = getattr(self.tokenizer, "padding_side", "left") connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors( - prompt_embeds, additive_attention_mask, additive_mask=True + prompt_embeds, prompt_attention_mask, padding_side=tokenizer_padding_side ) # 4. Prepare latent variables @@ -1069,11 +1093,6 @@ def __call__( self._num_timesteps = len(timesteps) # 6. Prepare micro-conditions - rope_interpolation_scale = ( - self.vae_temporal_compression_ratio / frame_rate, - self.vae_spatial_compression_ratio, - self.vae_spatial_compression_ratio, - ) # Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop video_coords = self.transformer.rope.prepare_video_coords( latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate @@ -1111,8 +1130,11 @@ def __call__( encoder_hidden_states=connector_prompt_embeds, audio_encoder_hidden_states=connector_audio_prompt_embeds, timestep=timestep, + sigma=timestep, # Used by LTX-2.3 encoder_attention_mask=connector_attention_mask, audio_encoder_attention_mask=connector_attention_mask, + self_attention_mask=None, + audio_self_attention_mask=None, num_frames=latent_num_frames, height=latent_height, width=latent_width, @@ -1120,7 +1142,9 @@ def __call__( audio_num_frames=audio_num_frames, video_coords=video_coords, audio_coords=audio_coords, - # rope_interpolation_scale=rope_interpolation_scale, + isolate_modalities=False, + spatio_temporal_guidance_blocks=None, + perturbation_mask=None, attention_kwargs=attention_kwargs, return_dict=False, ) @@ -1128,24 +1152,134 @@ def __call__( noise_pred_audio = noise_pred_audio.float() if self.do_classifier_free_guidance: - noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2) - noise_pred_video = noise_pred_video_uncond + self.guidance_scale * ( - noise_pred_video_text - noise_pred_video_uncond - ) + noise_pred_video_uncond_text, noise_pred_video = noise_pred_video.chunk(2) + # Use delta formulation as it works more nicely with multiple guidance terms + video_cfg_delta = (self.guidance_scale - 1) * (noise_pred_video - noise_pred_video_uncond_text) - noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2) - noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * ( - noise_pred_audio_text - noise_pred_audio_uncond + noise_pred_audio_uncond_text, noise_pred_audio = noise_pred_audio.chunk(2) + audio_cfg_delta = (self.audio_guidance_scale - 1) * ( + noise_pred_audio - noise_pred_audio_uncond_text ) - if self.guidance_rescale > 0: - # Based on 3.4. in https://huggingface.co/papers/2305.08891 - noise_pred_video = rescale_noise_cfg( - noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale + # Get positive values from merged CFG inputs in case we need to do other DiT forward passes + if self.do_spatio_temporal_guidance or self.do_modality_isolation_guidance: + if i == 0: + # Only split values that remain constant throughout the loop once + video_prompt_embeds = connector_prompt_embeds.chunk(2, dim=0)[1] + audio_prompt_embeds = connector_audio_prompt_embeds.chunk(2, dim=0)[1] + prompt_attn_mask = connector_attention_mask.chunk(2, dim=0)[1] + + video_pos_ids = video_coords.chunk(2, dim=0)[0] + audio_pos_ids = audio_coords.chunk(2, dim=0)[0] + + # Split values that vary each denoising loop iteration + timestep = timestep.chunk(2, dim=0)[0] + else: + video_cfg_delta = audio_cfg_delta = 0 + + video_prompt_embeds = connector_prompt_embeds + audio_prompt_embeds = connector_audio_prompt_embeds + prompt_attn_mask = connector_attention_mask + + video_pos_ids = video_coords + audio_pos_ids = audio_coords + + if self.do_spatio_temporal_guidance: + with self.transformer.cache_context("uncond_stg"): + noise_pred_video_uncond_stg, noise_pred_audio_uncond_stg = self.transformer( + hidden_states=latents.to(dtype=prompt_embeds.dtype), + audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype), + encoder_hidden_states=video_prompt_embeds, + audio_encoder_hidden_states=audio_prompt_embeds, + timestep=timestep, + sigma=timestep, # Used by LTX-2.3 + encoder_attention_mask=prompt_attn_mask, + audio_encoder_attention_mask=prompt_attn_mask, + self_attention_mask=None, + audio_self_attention_mask=None, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_pos_ids, + audio_coords=audio_pos_ids, + isolate_modalities=False, + # Use STG at given blocks to perturb model + spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks, + perturbation_mask=None, + attention_kwargs=attention_kwargs, + return_dict=False, ) - noise_pred_audio = rescale_noise_cfg( - noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale + noise_pred_video_uncond_stg = noise_pred_video_uncond_stg.float() + noise_pred_audio_uncond_stg = noise_pred_audio_uncond_stg.float() + + video_stg_delta = self.stg_scale * (noise_pred_video - noise_pred_video_uncond_stg) + audio_stg_delta = self.audio_stg_scale * (noise_pred_audio - noise_pred_audio_uncond_stg) + else: + video_stg_delta = audio_stg_delta = 0 + + if self.do_modality_isolation_guidance: + with self.transformer.cache_context("uncond_modality"): + noise_pred_video_uncond_modality, noise_pred_audio_uncond_modality = self.transformer( + hidden_states=latents.to(dtype=prompt_embeds.dtype), + audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype), + encoder_hidden_states=video_prompt_embeds, + audio_encoder_hidden_states=audio_prompt_embeds, + timestep=timestep, + sigma=timestep, # Used by LTX-2.3 + encoder_attention_mask=prompt_attn_mask, + audio_encoder_attention_mask=prompt_attn_mask, + self_attention_mask=None, + audio_self_attention_mask=None, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_pos_ids, + audio_coords=audio_pos_ids, + # Turn off A2V and V2A cross attn to isolate video and audio modalities + isolate_modalities=True, + spatio_temporal_guidance_blocks=None, + perturbation_mask=None, + attention_kwargs=attention_kwargs, + return_dict=False, ) + noise_pred_video_uncond_modality = noise_pred_video_uncond_modality.float() + noise_pred_audio_uncond_modality = noise_pred_audio_uncond_modality.float() + + video_modality_delta = (self.modality_scale - 1) * ( + noise_pred_video - noise_pred_video_uncond_modality + ) + audio_modality_delta = (self.audio_modality_scale - 1) * ( + noise_pred_audio - noise_pred_audio_uncond_modality + ) + else: + video_modality_delta = audio_modality_delta = 0 + + # Now apply all guidance terms + noise_pred_video_g = noise_pred_video + video_cfg_delta + video_stg_delta + video_modality_delta + noise_pred_audio_g = noise_pred_audio + audio_cfg_delta + audio_stg_delta + audio_modality_delta + + # Apply LTX-2.X guidance rescaling + if self.guidance_rescale > 0: + video_rescale = self.guidance_rescale + cond_std = noise_pred_video.std(dim=list(range(1, noise_pred_video.ndim)), keepdim=True) + guided_std = noise_pred_video_g.std(dim=list(range(1, noise_pred_video_g.ndim)), keepdim=True) + rescale_factor = video_rescale * (cond_std / guided_std) + (1 - video_rescale) + noise_pred_video = noise_pred_video_g * rescale_factor + else: + noise_pred_video = noise_pred_video_g + + if self.audio_guidance_rescale > 0: + audio_rescale = self.audio_guidance_rescale + cond_std = noise_pred_audio.std(dim=list(range(1, noise_pred_audio.ndim)), keepdim=True) + guided_std = noise_pred_audio_g.std(dim=list(range(1, noise_pred_audio_g.ndim)), keepdim=True) + rescale_factor = audio_rescale * (cond_std / guided_std) + (1 - audio_rescale) + noise_pred_audio = noise_pred_audio_g * rescale_factor + else: + noise_pred_audio = noise_pred_audio_g # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred_video, t, latents, return_dict=False)[0] diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py index 4c451330f439..e7875aa5426e 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py @@ -33,7 +33,7 @@ from ..pipeline_utils import DiffusionPipeline from .connectors import LTX2TextConnectors from .pipeline_output import LTX2PipelineOutput -from .vocoder import LTX2Vocoder +from .vocoder import LTX2Vocoder, LTX2VocoderWithBWE if is_torch_xla_available(): @@ -254,7 +254,7 @@ def __init__( tokenizer: GemmaTokenizer | GemmaTokenizerFast, connectors: LTX2TextConnectors, transformer: LTX2VideoTransformer3DModel, - vocoder: LTX2Vocoder, + vocoder: LTX2Vocoder | LTX2VocoderWithBWE, ): super().__init__() @@ -300,74 +300,6 @@ def __init__( self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024 ) - @staticmethod - # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_text_embeds - def _pack_text_embeds( - text_hidden_states: torch.Tensor, - sequence_lengths: torch.Tensor, - device: str | torch.device, - padding_side: str = "left", - scale_factor: int = 8, - eps: float = 1e-6, - ) -> torch.Tensor: - """ - Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and - per-layer in a masked fashion (only over non-padded positions). - - Args: - text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`): - Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`). - sequence_lengths (`torch.Tensor of shape `(batch_size,)`): - The number of valid (non-padded) tokens for each batch instance. - device: (`str` or `torch.device`, *optional*): - torch device to place the resulting embeddings on - padding_side: (`str`, *optional*, defaults to `"left"`): - Whether the text tokenizer performs padding on the `"left"` or `"right"`. - scale_factor (`int`, *optional*, defaults to `8`): - Scaling factor to multiply the normalized hidden states by. - eps (`float`, *optional*, defaults to `1e-6`): - A small positive value for numerical stability when performing normalization. - - Returns: - `torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`: - Normed and flattened text encoder hidden states. - """ - batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape - original_dtype = text_hidden_states.dtype - - # Create padding mask - token_indices = torch.arange(seq_len, device=device).unsqueeze(0) - if padding_side == "right": - # For right padding, valid tokens are from 0 to sequence_length-1 - mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len] - elif padding_side == "left": - # For left padding, valid tokens are from (T - sequence_length) to T-1 - start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1] - mask = token_indices >= start_indices # [B, T] - else: - raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") - mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1] - - # Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len) - masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) - num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) - masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) - - # Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len) - x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) - x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) - - # Normalization - normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) - normalized_hidden_states = normalized_hidden_states * scale_factor - - # Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers) - normalized_hidden_states = normalized_hidden_states.flatten(2) - mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) - normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) - normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) - return normalized_hidden_states - # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._get_gemma_prompt_embeds def _get_gemma_prompt_embeds( self, @@ -421,16 +353,7 @@ def _get_gemma_prompt_embeds( ) text_encoder_hidden_states = text_encoder_outputs.hidden_states text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) - sequence_lengths = prompt_attention_mask.sum(dim=-1) - - prompt_embeds = self._pack_text_embeds( - text_encoder_hidden_states, - sequence_lengths, - device=device, - padding_side=self.tokenizer.padding_side, - scale_factor=scale_factor, - ) - prompt_embeds = prompt_embeds.to(dtype=dtype) + prompt_embeds = text_encoder_hidden_states.flatten(2, 3).to(dtype=dtype) # Pack to 3D # duplicate text embeddings for each generation per prompt, using mps friendly method _, seq_len, _ = prompt_embeds.shape @@ -541,6 +464,9 @@ def check_inputs( negative_prompt_attention_mask=None, latents=None, audio_latents=None, + spatio_temporal_guidance_blocks=None, + stg_scale=None, + audio_stg_scale=None, ): if height % 32 != 0 or width % 32 != 0: raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") @@ -597,6 +523,12 @@ def check_inputs( f" using the `_unpack_audio_latents` method)." ) + if ((stg_scale > 0.0) or (audio_stg_scale > 0.0)) and not spatio_temporal_guidance_blocks: + raise ValueError( + "Spatio-Temporal Guidance (STG) is specified but no STG blocks are supplied. Please supply a list of" + "block indices at which to apply STG in `spatio_temporal_guidance_blocks`" + ) + @staticmethod # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_latents def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: @@ -992,9 +924,41 @@ def guidance_scale(self): def guidance_rescale(self): return self._guidance_rescale + @property + def stg_scale(self): + return self._stg_scale + + @property + def modality_scale(self): + return self._modality_scale + + @property + def audio_guidance_scale(self): + return self._audio_guidance_scale + + @property + def audio_guidance_rescale(self): + return self._audio_guidance_rescale + + @property + def audio_stg_scale(self): + return self._audio_stg_scale + + @property + def audio_modality_scale(self): + return self._audio_modality_scale + @property def do_classifier_free_guidance(self): - return self._guidance_scale > 1.0 + return (self._guidance_scale > 1.0) or (self._audio_guidance_scale > 1.0) + + @property + def do_spatio_temporal_guidance(self): + return (self._stg_scale > 0.0) or (self._audio_stg_scale > 0.0) + + @property + def do_modality_isolation_guidance(self): + return (self._modality_scale > 1.0) or (self._audio_modality_scale > 1.0) @property def num_timesteps(self): @@ -1027,7 +991,14 @@ def __call__( sigmas: list[float] | None = None, timesteps: list[float] | None = None, guidance_scale: float = 4.0, + stg_scale: float = 0.0, + modality_scale: float = 1.0, guidance_rescale: float = 0.0, + audio_guidance_scale: float | None = None, + audio_stg_scale: float | None = None, + audio_modality_scale: float | None = None, + audio_guidance_rescale: float | None = None, + spatio_temporal_guidance_blocks: list[int] | None = None, noise_scale: float | None = None, num_videos_per_prompt: int | None = 1, generator: torch.Generator | list[torch.Generator] | None = None, @@ -1079,13 +1050,47 @@ def __call__( Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to - the text `prompt`, usually at the expense of lower image quality. + the text `prompt`, usually at the expense of lower image quality. Used for the video modality (there is + a separate value `audio_guidance_scale` for the audio modality). + stg_scale (`float`, *optional*, defaults to `0.0`): + Video guidance scale for Spatio-Temporal Guidance (STG), proposed in [Spatiotemporal Skip Guidance for + Enhanced Video Diffusion Sampling](https://arxiv.org/abs/2411.18664). STG uses a CFG-like estimate + where we move the sample away from a weak sample from a perturbed version of the denoising model. + Enabling STG will result in an additional denoising model forward pass; the default value of `0.0` + means that STG is disabled. + modality_scale (`float`, *optional*, defaults to `1.0`): + Video guidance scale for LTX-2.X modality isolation guidance, where we move the sample away from a + weaker sample generated by the denoising model withy cross-modality (audio-to-video and video-to-audio) + cross attention disabled using a CFG-like estimate. Enabling modality guidance will result in an + additional denoising model forward pass; the default value of `1.0` means that modality guidance is + disabled. guidance_rescale (`float`, *optional*, defaults to 0.0): Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when - using zero terminal SNR. + using zero terminal SNR. Used for the video modality. + audio_guidance_scale (`float`, *optional* defaults to `None`): + Audio guidance scale for CFG with respect to the negative prompt. The CFG update rule is the same for + video and audio, but they can use different values for the guidance scale. The LTX-2.X authors suggest + that the `audio_guidance_scale` should be higher relative to the video `guidance_scale` (e.g. for + LTX-2.3 they suggest 3.0 for video and 7.0 for audio). If `None`, defaults to the video value + `guidance_scale`. + audio_stg_scale (`float`, *optional*, defaults to `None`): + Audio guidance scale for STG. As with CFG, the STG update rule is otherwise the same for video and + audio. For LTX-2.3, a value of 1.0 is suggested for both video and audio. If `None`, defaults to the + video value `stg_scale`. + audio_modality_scale (`float`, *optional*, defaults to `None`): + Audio guidance scale for LTX-2.X modality isolation guidance. As with CFG, the modality guidance rule + is otherwise the same for video and audio. For LTX-2.3, a value of 3.0 is suggested for both video and + audio. If `None`, defaults to the video value `modality_scale`. + audio_guidance_rescale (`float`, *optional*, defaults to `None`): + A separate guidance rescale factor for the audio modality. If `None`, defaults to the video value + `guidance_rescale`. + spatio_temporal_guidance_blocks (`list[int]`, *optional*, defaults to `None`): + The zero-indexed transformer block indices at which to apply STG. Must be supplied if STG is used + (`stg_scale` or `audio_stg_scale` is greater than `0`). A value of `[29]` is recommended for LTX-2.0 + and `[28]` is recommended for LTX-2.3. noise_scale (`float`, *optional*, defaults to `None`): The interpolation factor between random noise and denoised latents at each timestep. Applying noise to the `latents` and `audio_latents` before continue denoising. If not set, will be inferred from the @@ -1149,6 +1154,11 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + audio_guidance_scale = audio_guidance_scale or guidance_scale + audio_stg_scale = audio_stg_scale or stg_scale + audio_modality_scale = audio_modality_scale or modality_scale + audio_guidance_rescale = audio_guidance_rescale or guidance_rescale + # 1. Check inputs. Raise error if not correct self.check_inputs( prompt=prompt, @@ -1161,10 +1171,21 @@ def __call__( negative_prompt_attention_mask=negative_prompt_attention_mask, latents=latents, audio_latents=audio_latents, + spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks, + stg_scale=stg_scale, + audio_stg_scale=audio_stg_scale, ) + # Per-modality guidance scales (video, audio) self._guidance_scale = guidance_scale + self._stg_scale = stg_scale + self._modality_scale = modality_scale self._guidance_rescale = guidance_rescale + self._audio_guidance_scale = audio_guidance_scale + self._audio_stg_scale = audio_stg_scale + self._audio_modality_scale = audio_modality_scale + self._audio_guidance_rescale = audio_guidance_rescale + self._attention_kwargs = attention_kwargs self._interrupt = False self._current_timestep = None @@ -1208,9 +1229,11 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) - additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0 + tokenizer_padding_side = "left" # Padding side for default Gemma3-12B text encoder + if getattr(self, "tokenizer", None) is not None: + tokenizer_padding_side = getattr(self.tokenizer, "padding_side", "left") connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors( - prompt_embeds, additive_attention_mask, additive_mask=True + prompt_embeds, prompt_attention_mask, padding_side=tokenizer_padding_side ) # 4. Prepare latent variables @@ -1301,11 +1324,6 @@ def __call__( self._num_timesteps = len(timesteps) # 6. Prepare micro-conditions - rope_interpolation_scale = ( - self.vae_temporal_compression_ratio / frame_rate, - self.vae_spatial_compression_ratio, - self.vae_spatial_compression_ratio, - ) # Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop video_coords = self.transformer.rope.prepare_video_coords( latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate @@ -1344,8 +1362,11 @@ def __call__( audio_encoder_hidden_states=connector_audio_prompt_embeds, timestep=video_timestep, audio_timestep=timestep, + sigma=timestep, # Used by LTX-2.3 encoder_attention_mask=connector_attention_mask, audio_encoder_attention_mask=connector_attention_mask, + self_attention_mask=None, + audio_self_attention_mask=None, num_frames=latent_num_frames, height=latent_height, width=latent_width, @@ -1353,7 +1374,9 @@ def __call__( audio_num_frames=audio_num_frames, video_coords=video_coords, audio_coords=audio_coords, - # rope_interpolation_scale=rope_interpolation_scale, + isolate_modalities=False, + spatio_temporal_guidance_blocks=None, + perturbation_mask=None, attention_kwargs=attention_kwargs, return_dict=False, ) @@ -1361,24 +1384,137 @@ def __call__( noise_pred_audio = noise_pred_audio.float() if self.do_classifier_free_guidance: - noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2) - noise_pred_video = noise_pred_video_uncond + self.guidance_scale * ( - noise_pred_video_text - noise_pred_video_uncond - ) + noise_pred_video_uncond_text, noise_pred_video = noise_pred_video.chunk(2) + # Use delta formulation as it works more nicely with multiple guidance terms + video_cfg_delta = (self.guidance_scale - 1) * (noise_pred_video - noise_pred_video_uncond_text) - noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2) - noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * ( - noise_pred_audio_text - noise_pred_audio_uncond + noise_pred_audio_uncond_text, noise_pred_audio = noise_pred_audio.chunk(2) + audio_cfg_delta = (self.audio_guidance_scale - 1) * ( + noise_pred_audio - noise_pred_audio_uncond_text ) - if self.guidance_rescale > 0: - # Based on 3.4. in https://huggingface.co/papers/2305.08891 - noise_pred_video = rescale_noise_cfg( - noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale + # Get positive values from merged CFG inputs in case we need to do other DiT forward passes + if self.do_spatio_temporal_guidance or self.do_modality_isolation_guidance: + if i == 0: + # Only split values that remain constant throughout the loop once + video_prompt_embeds = connector_prompt_embeds.chunk(2, dim=0)[1] + audio_prompt_embeds = connector_audio_prompt_embeds.chunk(2, dim=0)[1] + prompt_attn_mask = connector_attention_mask.chunk(2, dim=0)[1] + + video_pos_ids = video_coords.chunk(2, dim=0)[0] + audio_pos_ids = audio_coords.chunk(2, dim=0)[0] + + # Split values that vary each denoising loop iteration + timestep = timestep.chunk(2, dim=0)[0] + video_timestep = video_timestep.chunk(2, dim=0)[0] + else: + video_cfg_delta = audio_cfg_delta = 0 + + video_prompt_embeds = connector_prompt_embeds + audio_prompt_embeds = connector_audio_prompt_embeds + prompt_attn_mask = connector_attention_mask + + video_pos_ids = video_coords + audio_pos_ids = audio_coords + + if self.do_spatio_temporal_guidance: + with self.transformer.cache_context("uncond_stg"): + noise_pred_video_uncond_stg, noise_pred_audio_uncond_stg = self.transformer( + hidden_states=latents.to(dtype=prompt_embeds.dtype), + audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype), + encoder_hidden_states=video_prompt_embeds, + audio_encoder_hidden_states=audio_prompt_embeds, + timestep=video_timestep, + audio_timestep=timestep, + sigma=timestep, # Used by LTX-2.3 + encoder_attention_mask=prompt_attn_mask, + audio_encoder_attention_mask=prompt_attn_mask, + self_attention_mask=None, + audio_self_attention_mask=None, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_pos_ids, + audio_coords=audio_pos_ids, + isolate_modalities=False, + # Use STG at given blocks to perturb model + spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks, + perturbation_mask=None, + attention_kwargs=attention_kwargs, + return_dict=False, ) - noise_pred_audio = rescale_noise_cfg( - noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale + noise_pred_video_uncond_stg = noise_pred_video_uncond_stg.float() + noise_pred_audio_uncond_stg = noise_pred_audio_uncond_stg.float() + + video_stg_delta = self.stg_scale * (noise_pred_video - noise_pred_video_uncond_stg) + audio_stg_delta = self.audio_stg_scale * (noise_pred_audio - noise_pred_audio_uncond_stg) + else: + video_stg_delta = audio_stg_delta = 0 + + if self.do_modality_isolation_guidance: + with self.transformer.cache_context("uncond_modality"): + noise_pred_video_uncond_modality, noise_pred_audio_uncond_modality = self.transformer( + hidden_states=latents.to(dtype=prompt_embeds.dtype), + audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype), + encoder_hidden_states=video_prompt_embeds, + audio_encoder_hidden_states=audio_prompt_embeds, + timestep=video_timestep, + audio_timestep=timestep, + sigma=timestep, # Used by LTX-2.3 + encoder_attention_mask=prompt_attn_mask, + audio_encoder_attention_mask=prompt_attn_mask, + self_attention_mask=None, + audio_self_attention_mask=None, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_pos_ids, + audio_coords=audio_pos_ids, + # Turn off A2V and V2A cross attn to isolate video and audio modalities + isolate_modalities=True, + spatio_temporal_guidance_blocks=None, + perturbation_mask=None, + attention_kwargs=attention_kwargs, + return_dict=False, ) + noise_pred_video_uncond_modality = noise_pred_video_uncond_modality.float() + noise_pred_audio_uncond_modality = noise_pred_audio_uncond_modality.float() + + video_modality_delta = (self.modality_scale - 1) * ( + noise_pred_video - noise_pred_video_uncond_modality + ) + audio_modality_delta = (self.audio_modality_scale - 1) * ( + noise_pred_audio - noise_pred_audio_uncond_modality + ) + else: + video_modality_delta = audio_modality_delta = 0 + + # Now apply all guidance terms + noise_pred_video_g = noise_pred_video + video_cfg_delta + video_stg_delta + video_modality_delta + noise_pred_audio_g = noise_pred_audio + audio_cfg_delta + audio_stg_delta + audio_modality_delta + + # Apply LTX-2.X guidance rescaling + if self.guidance_rescale > 0: + video_rescale = self.guidance_rescale + cond_std = noise_pred_video.std(dim=list(range(1, noise_pred_video.ndim)), keepdim=True) + guided_std = noise_pred_video_g.std(dim=list(range(1, noise_pred_video_g.ndim)), keepdim=True) + rescale_factor = video_rescale * (cond_std / guided_std) + (1 - video_rescale) + noise_pred_video = noise_pred_video_g * rescale_factor + else: + noise_pred_video = noise_pred_video_g + + if self.audio_guidance_rescale > 0: + audio_rescale = self.audio_guidance_rescale + cond_std = noise_pred_audio.std(dim=list(range(1, noise_pred_audio.ndim)), keepdim=True) + guided_std = noise_pred_audio_g.std(dim=list(range(1, noise_pred_audio_g.ndim)), keepdim=True) + rescale_factor = audio_rescale * (cond_std / guided_std) + (1 - audio_rescale) + noise_pred_audio = noise_pred_audio_g * rescale_factor + else: + noise_pred_audio = noise_pred_audio_g # NOTE: use only the first chunk of conditioning mask in case it is duplicated for CFG bsz = noise_pred_video.size(0) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index 83ba2cd7c685..5885f98e43ab 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -32,7 +32,7 @@ from ..pipeline_utils import DiffusionPipeline from .connectors import LTX2TextConnectors from .pipeline_output import LTX2PipelineOutput -from .vocoder import LTX2Vocoder +from .vocoder import LTX2Vocoder, LTX2VocoderWithBWE if is_torch_xla_available(): @@ -224,7 +224,7 @@ def __init__( tokenizer: GemmaTokenizer | GemmaTokenizerFast, connectors: LTX2TextConnectors, transformer: LTX2VideoTransformer3DModel, - vocoder: LTX2Vocoder, + vocoder: LTX2Vocoder | LTX2VocoderWithBWE, ): super().__init__() @@ -271,74 +271,6 @@ def __init__( self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024 ) - @staticmethod - # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_text_embeds - def _pack_text_embeds( - text_hidden_states: torch.Tensor, - sequence_lengths: torch.Tensor, - device: str | torch.device, - padding_side: str = "left", - scale_factor: int = 8, - eps: float = 1e-6, - ) -> torch.Tensor: - """ - Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and - per-layer in a masked fashion (only over non-padded positions). - - Args: - text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`): - Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`). - sequence_lengths (`torch.Tensor of shape `(batch_size,)`): - The number of valid (non-padded) tokens for each batch instance. - device: (`str` or `torch.device`, *optional*): - torch device to place the resulting embeddings on - padding_side: (`str`, *optional*, defaults to `"left"`): - Whether the text tokenizer performs padding on the `"left"` or `"right"`. - scale_factor (`int`, *optional*, defaults to `8`): - Scaling factor to multiply the normalized hidden states by. - eps (`float`, *optional*, defaults to `1e-6`): - A small positive value for numerical stability when performing normalization. - - Returns: - `torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`: - Normed and flattened text encoder hidden states. - """ - batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape - original_dtype = text_hidden_states.dtype - - # Create padding mask - token_indices = torch.arange(seq_len, device=device).unsqueeze(0) - if padding_side == "right": - # For right padding, valid tokens are from 0 to sequence_length-1 - mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len] - elif padding_side == "left": - # For left padding, valid tokens are from (T - sequence_length) to T-1 - start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1] - mask = token_indices >= start_indices # [B, T] - else: - raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") - mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1] - - # Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len) - masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) - num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) - masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) - - # Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len) - x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) - x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) - - # Normalization - normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) - normalized_hidden_states = normalized_hidden_states * scale_factor - - # Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers) - normalized_hidden_states = normalized_hidden_states.flatten(2) - mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) - normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) - normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) - return normalized_hidden_states - # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._get_gemma_prompt_embeds def _get_gemma_prompt_embeds( self, @@ -392,16 +324,7 @@ def _get_gemma_prompt_embeds( ) text_encoder_hidden_states = text_encoder_outputs.hidden_states text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) - sequence_lengths = prompt_attention_mask.sum(dim=-1) - - prompt_embeds = self._pack_text_embeds( - text_encoder_hidden_states, - sequence_lengths, - device=device, - padding_side=self.tokenizer.padding_side, - scale_factor=scale_factor, - ) - prompt_embeds = prompt_embeds.to(dtype=dtype) + prompt_embeds = text_encoder_hidden_states.flatten(2, 3).to(dtype=dtype) # Pack to 3D # duplicate text embeddings for each generation per prompt, using mps friendly method _, seq_len, _ = prompt_embeds.shape @@ -511,6 +434,9 @@ def check_inputs( negative_prompt_embeds=None, prompt_attention_mask=None, negative_prompt_attention_mask=None, + spatio_temporal_guidance_blocks=None, + stg_scale=None, + audio_stg_scale=None, ): if height % 32 != 0 or width % 32 != 0: raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") @@ -554,6 +480,12 @@ def check_inputs( f" {negative_prompt_attention_mask.shape}." ) + if ((stg_scale > 0.0) or (audio_stg_scale > 0.0)) and not spatio_temporal_guidance_blocks: + raise ValueError( + "Spatio-Temporal Guidance (STG) is specified but no STG blocks are supplied. Please supply a list of" + "block indices at which to apply STG in `spatio_temporal_guidance_blocks`" + ) + @staticmethod # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_latents def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: @@ -811,9 +743,41 @@ def guidance_scale(self): def guidance_rescale(self): return self._guidance_rescale + @property + def stg_scale(self): + return self._stg_scale + + @property + def modality_scale(self): + return self._modality_scale + + @property + def audio_guidance_scale(self): + return self._audio_guidance_scale + + @property + def audio_guidance_rescale(self): + return self._audio_guidance_rescale + + @property + def audio_stg_scale(self): + return self._audio_stg_scale + + @property + def audio_modality_scale(self): + return self._audio_modality_scale + @property def do_classifier_free_guidance(self): - return self._guidance_scale > 1.0 + return (self._guidance_scale > 1.0) or (self._audio_guidance_scale > 1.0) + + @property + def do_spatio_temporal_guidance(self): + return (self._stg_scale > 0.0) or (self._audio_stg_scale > 0.0) + + @property + def do_modality_isolation_guidance(self): + return (self._modality_scale > 1.0) or (self._audio_modality_scale > 1.0) @property def num_timesteps(self): @@ -846,7 +810,14 @@ def __call__( sigmas: list[float] | None = None, timesteps: list[int] | None = None, guidance_scale: float = 4.0, + stg_scale: float = 0.0, + modality_scale: float = 1.0, guidance_rescale: float = 0.0, + audio_guidance_scale: float | None = None, + audio_stg_scale: float | None = None, + audio_modality_scale: float | None = None, + audio_guidance_rescale: float | None = None, + spatio_temporal_guidance_blocks: list[int] | None = None, noise_scale: float = 0.0, num_videos_per_prompt: int = 1, generator: torch.Generator | list[torch.Generator] | None = None, @@ -898,13 +869,47 @@ def __call__( Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to - the text `prompt`, usually at the expense of lower image quality. + the text `prompt`, usually at the expense of lower image quality. Used for the video modality (there is + a separate value `audio_guidance_scale` for the audio modality). + stg_scale (`float`, *optional*, defaults to `0.0`): + Video guidance scale for Spatio-Temporal Guidance (STG), proposed in [Spatiotemporal Skip Guidance for + Enhanced Video Diffusion Sampling](https://arxiv.org/abs/2411.18664). STG uses a CFG-like estimate + where we move the sample away from a weak sample from a perturbed version of the denoising model. + Enabling STG will result in an additional denoising model forward pass; the default value of `0.0` + means that STG is disabled. + modality_scale (`float`, *optional*, defaults to `1.0`): + Video guidance scale for LTX-2.X modality isolation guidance, where we move the sample away from a + weaker sample generated by the denoising model withy cross-modality (audio-to-video and video-to-audio) + cross attention disabled using a CFG-like estimate. Enabling modality guidance will result in an + additional denoising model forward pass; the default value of `1.0` means that modality guidance is + disabled. guidance_rescale (`float`, *optional*, defaults to 0.0): Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when - using zero terminal SNR. + using zero terminal SNR. Used for the video modality. + audio_guidance_scale (`float`, *optional* defaults to `None`): + Audio guidance scale for CFG with respect to the negative prompt. The CFG update rule is the same for + video and audio, but they can use different values for the guidance scale. The LTX-2.X authors suggest + that the `audio_guidance_scale` should be higher relative to the video `guidance_scale` (e.g. for + LTX-2.3 they suggest 3.0 for video and 7.0 for audio). If `None`, defaults to the video value + `guidance_scale`. + audio_stg_scale (`float`, *optional*, defaults to `None`): + Audio guidance scale for STG. As with CFG, the STG update rule is otherwise the same for video and + audio. For LTX-2.3, a value of 1.0 is suggested for both video and audio. If `None`, defaults to the + video value `stg_scale`. + audio_modality_scale (`float`, *optional*, defaults to `None`): + Audio guidance scale for LTX-2.X modality isolation guidance. As with CFG, the modality guidance rule + is otherwise the same for video and audio. For LTX-2.3, a value of 3.0 is suggested for both video and + audio. If `None`, defaults to the video value `modality_scale`. + audio_guidance_rescale (`float`, *optional*, defaults to `None`): + A separate guidance rescale factor for the audio modality. If `None`, defaults to the video value + `guidance_rescale`. + spatio_temporal_guidance_blocks (`list[int]`, *optional*, defaults to `None`): + The zero-indexed transformer block indices at which to apply STG. Must be supplied if STG is used + (`stg_scale` or `audio_stg_scale` is greater than `0`). A value of `[29]` is recommended for LTX-2.0 + and `[28]` is recommended for LTX-2.3. noise_scale (`float`, *optional*, defaults to `0.0`): The interpolation factor between random noise and denoised latents at each timestep. Applying noise to the `latents` and `audio_latents` before continue denoising. @@ -967,6 +972,11 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + audio_guidance_scale = audio_guidance_scale or guidance_scale + audio_stg_scale = audio_stg_scale or stg_scale + audio_modality_scale = audio_modality_scale or modality_scale + audio_guidance_rescale = audio_guidance_rescale or guidance_rescale + # 1. Check inputs. Raise error if not correct self.check_inputs( prompt=prompt, @@ -977,10 +987,21 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, + spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks, + stg_scale=stg_scale, + audio_stg_scale=audio_stg_scale, ) + # Per-modality guidance scales (video, audio) self._guidance_scale = guidance_scale + self._stg_scale = stg_scale + self._modality_scale = modality_scale self._guidance_rescale = guidance_rescale + self._audio_guidance_scale = audio_guidance_scale + self._audio_stg_scale = audio_stg_scale + self._audio_modality_scale = audio_modality_scale + self._audio_guidance_rescale = audio_guidance_rescale + self._attention_kwargs = attention_kwargs self._interrupt = False self._current_timestep = None @@ -1017,9 +1038,11 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) - additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0 + tokenizer_padding_side = "left" # Padding side for default Gemma3-12B text encoder + if getattr(self, "tokenizer", None) is not None: + tokenizer_padding_side = getattr(self.tokenizer, "padding_side", "left") connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors( - prompt_embeds, additive_attention_mask, additive_mask=True + prompt_embeds, prompt_attention_mask, padding_side=tokenizer_padding_side ) # 4. Prepare latent variables @@ -1134,11 +1157,6 @@ def __call__( self._num_timesteps = len(timesteps) # 6. Prepare micro-conditions - rope_interpolation_scale = ( - self.vae_temporal_compression_ratio / frame_rate, - self.vae_spatial_compression_ratio, - self.vae_spatial_compression_ratio, - ) # Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop video_coords = self.transformer.rope.prepare_video_coords( latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate @@ -1177,8 +1195,11 @@ def __call__( audio_encoder_hidden_states=connector_audio_prompt_embeds, timestep=video_timestep, audio_timestep=timestep, + sigma=timestep, # Used by LTX-2.3 encoder_attention_mask=connector_attention_mask, audio_encoder_attention_mask=connector_attention_mask, + self_attention_mask=None, + audio_self_attention_mask=None, num_frames=latent_num_frames, height=latent_height, width=latent_width, @@ -1186,7 +1207,9 @@ def __call__( audio_num_frames=audio_num_frames, video_coords=video_coords, audio_coords=audio_coords, - # rope_interpolation_scale=rope_interpolation_scale, + isolate_modalities=False, + spatio_temporal_guidance_blocks=None, + perturbation_mask=None, attention_kwargs=attention_kwargs, return_dict=False, ) @@ -1194,24 +1217,137 @@ def __call__( noise_pred_audio = noise_pred_audio.float() if self.do_classifier_free_guidance: - noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2) - noise_pred_video = noise_pred_video_uncond + self.guidance_scale * ( - noise_pred_video_text - noise_pred_video_uncond - ) + noise_pred_video_uncond_text, noise_pred_video = noise_pred_video.chunk(2) + # Use delta formulation as it works more nicely with multiple guidance terms + video_cfg_delta = (self.guidance_scale - 1) * (noise_pred_video - noise_pred_video_uncond_text) - noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2) - noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * ( - noise_pred_audio_text - noise_pred_audio_uncond + noise_pred_audio_uncond_text, noise_pred_audio = noise_pred_audio.chunk(2) + audio_cfg_delta = (self.audio_guidance_scale - 1) * ( + noise_pred_audio - noise_pred_audio_uncond_text ) - if self.guidance_rescale > 0: - # Based on 3.4. in https://huggingface.co/papers/2305.08891 - noise_pred_video = rescale_noise_cfg( - noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale + # Get positive values from merged CFG inputs in case we need to do other DiT forward passes + if self.do_spatio_temporal_guidance or self.do_modality_isolation_guidance: + if i == 0: + # Only split values that remain constant throughout the loop once + video_prompt_embeds = connector_prompt_embeds.chunk(2, dim=0)[1] + audio_prompt_embeds = connector_audio_prompt_embeds.chunk(2, dim=0)[1] + prompt_attn_mask = connector_attention_mask.chunk(2, dim=0)[1] + + video_pos_ids = video_coords.chunk(2, dim=0)[0] + audio_pos_ids = audio_coords.chunk(2, dim=0)[0] + + # Split values that vary each denoising loop iteration + timestep = timestep.chunk(2, dim=0)[0] + video_timestep = video_timestep.chunk(2, dim=0)[0] + else: + video_cfg_delta = audio_cfg_delta = 0 + + video_prompt_embeds = connector_prompt_embeds + audio_prompt_embeds = connector_audio_prompt_embeds + prompt_attn_mask = connector_attention_mask + + video_pos_ids = video_coords + audio_pos_ids = audio_coords + + if self.do_spatio_temporal_guidance: + with self.transformer.cache_context("uncond_stg"): + noise_pred_video_uncond_stg, noise_pred_audio_uncond_stg = self.transformer( + hidden_states=latents.to(dtype=prompt_embeds.dtype), + audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype), + encoder_hidden_states=video_prompt_embeds, + audio_encoder_hidden_states=audio_prompt_embeds, + timestep=video_timestep, + audio_timestep=timestep, + sigma=timestep, # Used by LTX-2.3 + encoder_attention_mask=prompt_attn_mask, + audio_encoder_attention_mask=prompt_attn_mask, + self_attention_mask=None, + audio_self_attention_mask=None, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_pos_ids, + audio_coords=audio_pos_ids, + isolate_modalities=False, + # Use STG at given blocks to perturb model + spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks, + perturbation_mask=None, + attention_kwargs=attention_kwargs, + return_dict=False, ) - noise_pred_audio = rescale_noise_cfg( - noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale + noise_pred_video_uncond_stg = noise_pred_video_uncond_stg.float() + noise_pred_audio_uncond_stg = noise_pred_audio_uncond_stg.float() + + video_stg_delta = self.stg_scale * (noise_pred_video - noise_pred_video_uncond_stg) + audio_stg_delta = self.audio_stg_scale * (noise_pred_audio - noise_pred_audio_uncond_stg) + else: + video_stg_delta = audio_stg_delta = 0 + + if self.do_modality_isolation_guidance: + with self.transformer.cache_context("uncond_modality"): + noise_pred_video_uncond_modality, noise_pred_audio_uncond_modality = self.transformer( + hidden_states=latents.to(dtype=prompt_embeds.dtype), + audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype), + encoder_hidden_states=video_prompt_embeds, + audio_encoder_hidden_states=audio_prompt_embeds, + timestep=video_timestep, + audio_timestep=timestep, + sigma=timestep, # Used by LTX-2.3 + encoder_attention_mask=prompt_attn_mask, + audio_encoder_attention_mask=prompt_attn_mask, + self_attention_mask=None, + audio_self_attention_mask=None, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_pos_ids, + audio_coords=audio_pos_ids, + # Turn off A2V and V2A cross attn to isolate video and audio modalities + isolate_modalities=True, + spatio_temporal_guidance_blocks=None, + perturbation_mask=None, + attention_kwargs=attention_kwargs, + return_dict=False, ) + noise_pred_video_uncond_modality = noise_pred_video_uncond_modality.float() + noise_pred_audio_uncond_modality = noise_pred_audio_uncond_modality.float() + + video_modality_delta = (self.modality_scale - 1) * ( + noise_pred_video - noise_pred_video_uncond_modality + ) + audio_modality_delta = (self.audio_modality_scale - 1) * ( + noise_pred_audio - noise_pred_audio_uncond_modality + ) + else: + video_modality_delta = audio_modality_delta = 0 + + # Now apply all guidance terms + noise_pred_video_g = noise_pred_video + video_cfg_delta + video_stg_delta + video_modality_delta + noise_pred_audio_g = noise_pred_audio + audio_cfg_delta + audio_stg_delta + audio_modality_delta + + # Apply LTX-2.X guidance rescaling + if self.guidance_rescale > 0: + video_rescale = self.guidance_rescale + cond_std = noise_pred_video.std(dim=list(range(1, noise_pred_video.ndim)), keepdim=True) + guided_std = noise_pred_video_g.std(dim=list(range(1, noise_pred_video_g.ndim)), keepdim=True) + rescale_factor = video_rescale * (cond_std / guided_std) + (1 - video_rescale) + noise_pred_video = noise_pred_video_g * rescale_factor + else: + noise_pred_video = noise_pred_video_g + + if self.audio_guidance_rescale > 0: + audio_rescale = self.audio_guidance_rescale + cond_std = noise_pred_audio.std(dim=list(range(1, noise_pred_audio.ndim)), keepdim=True) + guided_std = noise_pred_audio_g.std(dim=list(range(1, noise_pred_audio_g.ndim)), keepdim=True) + rescale_factor = audio_rescale * (cond_std / guided_std) + (1 - audio_rescale) + noise_pred_audio = noise_pred_audio_g * rescale_factor + else: + noise_pred_audio = noise_pred_audio_g # compute the previous noisy sample x_t -> x_t-1 noise_pred_video = self._unpack_latents( diff --git a/src/diffusers/pipelines/ltx2/vocoder.py b/src/diffusers/pipelines/ltx2/vocoder.py index 551c3ac5980f..f0004f2ec02d 100644 --- a/src/diffusers/pipelines/ltx2/vocoder.py +++ b/src/diffusers/pipelines/ltx2/vocoder.py @@ -8,6 +8,209 @@ from ...models.modeling_utils import ModelMixin +def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> torch.Tensor: + """ + Creates a Kaiser sinc kernel for low-pass filtering. + + Args: + cutoff (`float`): + Normalized frequency cutoff (relative to the sampling rate). Must be between 0 and 0.5 (the Nyquist + frequency). + half_width (`float`): + Used to determine the Kaiser window's beta parameter. + kernel_size: + Size of the Kaiser window (and ultimately the Kaiser sinc kernel). + + Returns: + `torch.Tensor` of shape `(kernel_size,)`: + The Kaiser sinc kernel. + """ + delta_f = 4 * half_width + half_size = kernel_size // 2 + amplitude = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if amplitude > 50.0: + beta = 0.1102 * (amplitude - 8.7) + elif amplitude >= 21.0: + beta = 0.5842 * (amplitude - 21) ** 0.4 + 0.07886 * (amplitude - 21.0) + else: + beta = 0.0 + + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + + even = kernel_size % 2 == 0 + time = torch.arange(-half_size, half_size) + 0.5 if even else torch.arange(kernel_size) - half_size + + if cutoff == 0.0: + filter = torch.zeros_like(time) + else: + time = 2 * cutoff * time + sinc = torch.where( + time == 0, + torch.ones_like(time), + torch.sin(math.pi * time) / math.pi / time, + ) + filter = 2 * cutoff * window * sinc + filter = filter / filter.sum() + return filter + + +class DownSample1d(nn.Module): + """1D low-pass filter for antialias downsampling.""" + + def __init__( + self, + ratio: int = 2, + kernel_size: int | None = None, + use_padding: bool = True, + padding_mode: str = "replicate", + persistent: bool = True, + ): + super().__init__() + self.ratio = ratio + self.kernel_size = kernel_size or int(6 * ratio // 2) * 2 + self.pad_left = self.kernel_size // 2 + (self.kernel_size % 2) - 1 + self.pad_right = self.kernel_size // 2 + self.use_padding = use_padding + self.padding_mode = padding_mode + + cutoff = 0.5 / ratio + half_width = 0.6 / ratio + low_pass_filter = kaiser_sinc_filter1d(cutoff, half_width, self.kernel_size) + self.register_buffer("filter", low_pass_filter.view(1, 1, self.kernel_size), persistent=persistent) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x expected shape: [batch_size, num_channels, hidden_dim] + num_channels = x.shape[1] + if self.use_padding: + x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) + x_filtered = F.conv1d(x, self.filter.expand(num_channels, -1, -1), stride=self.ratio, groups=num_channels) + return x_filtered + + +class UpSample1d(nn.Module): + def __init__( + self, + ratio: int = 2, + kernel_size: int | None = None, + window_type: str = "kaiser", + padding_mode: str = "replicate", + persistent: bool = True, + ): + super().__init__() + self.ratio = ratio + self.padding_mode = padding_mode + + if window_type == "hann": + rolloff = 0.99 + lowpass_filter_width = 6 + width = math.ceil(lowpass_filter_width / rolloff) + self.kernel_size = 2 * width * ratio + 1 + self.pad = width + self.pad_left = 2 * width * ratio + self.pad_right = self.kernel_size - ratio + + time_axis = (torch.arange(self.kernel_size) / ratio - width) * rolloff + time_clamped = time_axis.clamp(-lowpass_filter_width, lowpass_filter_width) + window = torch.cos(time_clamped * math.pi / lowpass_filter_width / 2) ** 2 + sinc_filter = (torch.sinc(time_axis) * window * rolloff / ratio).view(1, 1, -1) + else: + # Kaiser sinc filter is BigVGAN default + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.ratio + (self.kernel_size - self.ratio) // 2 + self.pad_right = self.pad * self.ratio + (self.kernel_size - self.ratio + 1) // 2 + + sinc_filter = kaiser_sinc_filter1d( + cutoff=0.5 / ratio, + half_width=0.6 / ratio, + kernel_size=self.kernel_size, + ) + + self.register_buffer("filter", sinc_filter.view(1, 1, self.kernel_size), persistent=persistent) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x expected shape: [batch_size, num_channels, hidden_dim] + num_channels = x.shape[1] + x = F.pad(x, (self.pad, self.pad), mode=self.padding_mode) + low_pass_filter = self.filter.to(dtype=x.dtype, device=x.device).expand(num_channels, -1, -1) + x = self.ratio * F.conv_transpose1d(x, low_pass_filter, stride=self.ratio, groups=num_channels) + return x[..., self.pad_left : -self.pad_right] + + +class AntiAliasAct1d(nn.Module): + """ + Antialiasing activation for a 1D signal: upsamples, applies an activation (usually snakebeta), and then downsamples + to avoid aliasing. + """ + + def __init__( + self, + act_fn: str | nn.Module, + ratio: int = 2, + kernel_size: int = 12, + **kwargs, + ): + super().__init__() + self.upsample = UpSample1d(ratio=ratio, kernel_size=kernel_size) + if isinstance(act_fn, str): + if act_fn == "snakebeta": + act_fn = SnakeBeta(**kwargs) + elif act_fn == "snake": + act_fn = SnakeBeta(**kwargs) + else: + act_fn = nn.LeakyReLU(**kwargs) + self.act = act_fn + self.downsample = DownSample1d(ratio=ratio, kernel_size=kernel_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + return x + + +class SnakeBeta(nn.Module): + """ + Implements the Snake and SnakeBeta activations, which help with learning periodic patterns. + """ + + def __init__( + self, + channels: int, + alpha: float = 1.0, + eps: float = 1e-9, + trainable_params: bool = True, + logscale: bool = True, + use_beta: bool = True, + ): + super().__init__() + self.eps = eps + self.logscale = logscale + self.use_beta = use_beta + + self.alpha = nn.Parameter(torch.zeros(channels) if self.logscale else torch.ones(channels) * alpha) + self.alpha.requires_grad = trainable_params + if use_beta: + self.beta = nn.Parameter(torch.zeros(channels) if self.logscale else torch.ones(channels) * alpha) + self.beta.requires_grad = trainable_params + + def forward(self, hidden_states: torch.Tensor, channel_dim: int = 1) -> torch.Tensor: + broadcast_shape = [1] * hidden_states.ndim + broadcast_shape[channel_dim] = -1 + alpha = self.alpha.view(broadcast_shape) + if self.use_beta: + beta = self.beta.view(broadcast_shape) + + if self.logscale: + alpha = torch.exp(alpha) + if self.use_beta: + beta = torch.exp(beta) + + amplitude = beta if self.use_beta else alpha + hidden_states = hidden_states + (1.0 / (amplitude + self.eps)) * torch.sin(hidden_states * alpha).pow(2) + return hidden_states + + class ResBlock(nn.Module): def __init__( self, @@ -15,12 +218,15 @@ def __init__( kernel_size: int = 3, stride: int = 1, dilations: tuple[int, ...] = (1, 3, 5), + act_fn: str = "leaky_relu", leaky_relu_negative_slope: float = 0.1, + antialias: bool = False, + antialias_ratio: int = 2, + antialias_kernel_size: int = 12, padding_mode: str = "same", ): super().__init__() self.dilations = dilations - self.negative_slope = leaky_relu_negative_slope self.convs1 = nn.ModuleList( [ @@ -28,6 +234,18 @@ def __init__( for dilation in dilations ] ) + self.acts1 = nn.ModuleList() + for _ in range(len(self.convs1)): + if act_fn == "snakebeta": + act = SnakeBeta(channels, use_beta=True) + elif act_fn == "snake": + act = SnakeBeta(channels, use_beta=False) + else: + act = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope) + + if antialias: + act = AntiAliasAct1d(act, ratio=antialias_ratio, kernel_size=antialias_kernel_size) + self.acts1.append(act) self.convs2 = nn.ModuleList( [ @@ -35,12 +253,24 @@ def __init__( for _ in range(len(dilations)) ] ) + self.acts2 = nn.ModuleList() + for _ in range(len(self.convs2)): + if act_fn == "snakebeta": + act = SnakeBeta(channels, use_beta=True) + elif act_fn == "snake": + act = SnakeBeta(channels, use_beta=False) + else: + act_fn = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope) + + if antialias: + act = AntiAliasAct1d(act, ratio=antialias_ratio, kernel_size=antialias_kernel_size) + self.acts2.append(act) def forward(self, x: torch.Tensor) -> torch.Tensor: - for conv1, conv2 in zip(self.convs1, self.convs2): - xt = F.leaky_relu(x, negative_slope=self.negative_slope) + for act1, conv1, act2, conv2 in zip(self.acts1, self.convs1, self.acts2, self.convs2): + xt = act1(x) xt = conv1(xt) - xt = F.leaky_relu(xt, negative_slope=self.negative_slope) + xt = act2(xt) xt = conv2(xt) x = x + xt return x @@ -61,7 +291,13 @@ def __init__( upsample_factors: list[int] = [6, 5, 2, 2, 2], resnet_kernel_sizes: list[int] = [3, 7, 11], resnet_dilations: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + act_fn: str = "leaky_relu", leaky_relu_negative_slope: float = 0.1, + antialias: bool = False, + antialias_ratio: int = 2, + antialias_kernel_size: int = 12, + final_act_fn: str | None = "tanh", # tanh, clamp, None + final_bias: bool = True, output_sampling_rate: int = 24000, ): super().__init__() @@ -69,7 +305,9 @@ def __init__( self.resnets_per_upsample = len(resnet_kernel_sizes) self.out_channels = out_channels self.total_upsample_factor = math.prod(upsample_factors) + self.act_fn = act_fn self.negative_slope = leaky_relu_negative_slope + self.final_act_fn = final_act_fn if self.num_upsample_layers != len(upsample_factors): raise ValueError( @@ -83,6 +321,13 @@ def __init__( f" {len(self.resnets_per_upsample)} and {len(resnet_dilations)}, respectively." ) + supported_act_fns = ["snakebeta", "snake", "leaky_relu"] + if self.act_fn not in supported_act_fns: + raise ValueError( + f"Unsupported activation function: {self.act_fn}. Currently supported values of `act_fn` are " + f"{supported_act_fns}." + ) + self.conv_in = nn.Conv1d(in_channels, hidden_channels, kernel_size=7, stride=1, padding=3) self.upsamplers = nn.ModuleList() @@ -103,15 +348,27 @@ def __init__( for kernel_size, dilations in zip(resnet_kernel_sizes, resnet_dilations): self.resnets.append( ResBlock( - output_channels, - kernel_size, + channels=output_channels, + kernel_size=kernel_size, dilations=dilations, + act_fn=act_fn, leaky_relu_negative_slope=leaky_relu_negative_slope, + antialias=antialias, + antialias_ratio=antialias_ratio, + antialias_kernel_size=antialias_kernel_size, ) ) input_channels = output_channels - self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3) + if act_fn == "snakebeta" or act_fn == "snake": + # Always use antialiasing + act_out = SnakeBeta(channels=output_channels, use_beta=True) + self.act_out = AntiAliasAct1d(act_out, ratio=antialias_ratio, kernel_size=antialias_kernel_size) + elif act_fn == "leaky_relu": + # NOTE: does NOT use self.negative_slope, following the original code + self.act_out = nn.LeakyReLU() + + self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3, bias=final_bias) def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch.Tensor: r""" @@ -139,7 +396,9 @@ def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch hidden_states = self.conv_in(hidden_states) for i in range(self.num_upsample_layers): - hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope) + if self.act_fn == "leaky_relu": + # Other activations are inside each upsampling block + hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope) hidden_states = self.upsamplers[i](hidden_states) # Run all resnets in parallel on hidden_states @@ -149,10 +408,190 @@ def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch hidden_states = torch.mean(resnet_outputs, dim=0) - # NOTE: unlike the first leaky ReLU, this leaky ReLU is set to use the default F.leaky_relu negative slope of - # 0.01 (whereas the others usually use a slope of 0.1). Not sure if this is intended - hidden_states = F.leaky_relu(hidden_states, negative_slope=0.01) + hidden_states = self.act_out(hidden_states) hidden_states = self.conv_out(hidden_states) - hidden_states = torch.tanh(hidden_states) + if self.final_act_fn == "tanh": + hidden_states = torch.tanh(hidden_states) + elif self.final_act_fn == "clamp": + hidden_states = torch.clamp(hidden_states, -1, 1) return hidden_states + + +class CausalSTFT(nn.Module): + """ + Performs a causal short-time Fourier transform (STFT) using causal Hann windows on a waveform. The DFT bases + multiplied by the Hann windows are pre-calculated and stored as buffers. For exact parity with training, the exact + buffers should be loaded from the checkpoint in bfloat16. + """ + + def __init__(self, filter_length: int = 512, hop_length: int = 80, window_length: int = 512): + super().__init__() + self.hop_length = hop_length + self.window_length = window_length + n_freqs = filter_length // 2 + 1 + + self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length), persistent=True) + self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length), persistent=True) + + def forward(self, waveform: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + if waveform.ndim == 2: + waveform = waveform.unsqueeze(1) # [B, num_channels, num_samples] + + left_pad = max(0, self.window_length - self.hop_length) # causal: left-only + waveform = F.pad(waveform, (left_pad, 0)) + + spec = F.conv1d(waveform, self.forward_basis, stride=self.hop_length, padding=0) + n_freqs = spec.shape[1] // 2 + real, imag = spec[:, :n_freqs], spec[:, n_freqs:] + magnitude = torch.sqrt(real**2 + imag**2) + phase = torch.atan2(imag.float(), real.float()).to(dtype=real.dtype) + return magnitude, phase + + +class MelSTFT(nn.Module): + """ + Calculates a causal log-mel spectrogram from a waveform. Uses a pre-calculated mel filterbank, which should be + loaded from the checkpoint in bfloat16. + """ + + def __init__( + self, + filter_length: int = 512, + hop_length: int = 80, + window_length: int = 512, + num_mel_channels: int = 64, + ): + super().__init__() + self.stft_fn = CausalSTFT(filter_length, hop_length, window_length) + + num_freqs = filter_length // 2 + 1 + self.register_buffer("mel_basis", torch.zeros(num_mel_channels, num_freqs), persistent=True) + + def forward(self, waveform: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + magnitude, phase = self.stft_fn(waveform) + energy = torch.norm(magnitude, dim=1) + mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude) + log_mel = torch.log(torch.clamp(mel, min=1e-5)) + return log_mel, magnitude, phase, energy + + +class LTX2VocoderWithBWE(ModelMixin, ConfigMixin): + """ + LTX-2.X vocoder with bandwidth extension (BWE) upsampling. The vocoder and the BWE module run in sequence, with the + BWE module upsampling the vocoder output waveform to a higher sampling rate. The BWE module itself has the same + architecture as the original vocoder. + """ + + @register_to_config + def __init__( + self, + in_channels: int = 128, + hidden_channels: int = 1536, + out_channels: int = 2, + upsample_kernel_sizes: list[int] = [11, 4, 4, 4, 4, 4], + upsample_factors: list[int] = [5, 2, 2, 2, 2, 2], + resnet_kernel_sizes: list[int] = [3, 7, 11], + resnet_dilations: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + act_fn: str = "snakebeta", + leaky_relu_negative_slope: float = 0.1, + antialias: bool = True, + antialias_ratio: int = 2, + antialias_kernel_size: int = 12, + final_act_fn: str | None = None, + final_bias: bool = False, + bwe_in_channels: int = 128, + bwe_hidden_channels: int = 512, + bwe_out_channels: int = 2, + bwe_upsample_kernel_sizes: list[int] = [12, 11, 4, 4, 4], + bwe_upsample_factors: list[int] = [6, 5, 2, 2, 2], + bwe_resnet_kernel_sizes: list[int] = [3, 7, 11], + bwe_resnet_dilations: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + bwe_act_fn: str = "snakebeta", + bwe_leaky_relu_negative_slope: float = 0.1, + bwe_antialias: bool = True, + bwe_antialias_ratio: int = 2, + bwe_antialias_kernel_size: int = 12, + bwe_final_act_fn: str | None = None, + bwe_final_bias: bool = False, + filter_length: int = 512, + hop_length: int = 80, + window_length: int = 512, + num_mel_channels: int = 64, + input_sampling_rate: int = 16000, + output_sampling_rate: int = 48000, + ): + super().__init__() + + self.vocoder = LTX2Vocoder( + in_channels=in_channels, + hidden_channels=hidden_channels, + out_channels=out_channels, + upsample_kernel_sizes=upsample_kernel_sizes, + upsample_factors=upsample_factors, + resnet_kernel_sizes=resnet_kernel_sizes, + resnet_dilations=resnet_dilations, + act_fn=act_fn, + leaky_relu_negative_slope=leaky_relu_negative_slope, + antialias=antialias, + antialias_ratio=antialias_ratio, + antialias_kernel_size=antialias_kernel_size, + final_act_fn=final_act_fn, + final_bias=final_bias, + output_sampling_rate=input_sampling_rate, + ) + self.bwe_generator = LTX2Vocoder( + in_channels=bwe_in_channels, + hidden_channels=bwe_hidden_channels, + out_channels=bwe_out_channels, + upsample_kernel_sizes=bwe_upsample_kernel_sizes, + upsample_factors=bwe_upsample_factors, + resnet_kernel_sizes=bwe_resnet_kernel_sizes, + resnet_dilations=bwe_resnet_dilations, + act_fn=bwe_act_fn, + leaky_relu_negative_slope=bwe_leaky_relu_negative_slope, + antialias=bwe_antialias, + antialias_ratio=bwe_antialias_ratio, + antialias_kernel_size=bwe_antialias_kernel_size, + final_act_fn=bwe_final_act_fn, + final_bias=bwe_final_bias, + output_sampling_rate=output_sampling_rate, + ) + + self.mel_stft = MelSTFT( + filter_length=filter_length, + hop_length=hop_length, + window_length=window_length, + num_mel_channels=num_mel_channels, + ) + + self.resampler = UpSample1d( + ratio=output_sampling_rate // input_sampling_rate, + window_type="hann", + persistent=False, + ) + + def forward(self, mel_spec: torch.Tensor) -> torch.Tensor: + # 1. Run stage 1 vocoder to get low sampling rate waveform + x = self.vocoder(mel_spec) + batch_size, num_channels, num_samples = x.shape + + # Pad to exact multiple of hop_length for exact mel frame count + remainder = num_samples % self.config.hop_length + if remainder != 0: + x = F.pad(x, (0, self.hop_length - remainder)) + + # 2. Compute mel spectrogram on vocoder output + mel, _, _, _ = self.mel_stft(x.flatten(0, 1)) + mel = mel.unflatten(0, (-1, num_channels)) + + # 3. Run bandwidth extender (BWE) on new mel spectrogram + mel_for_bwe = mel.transpose(2, 3) # [B, C, num_mel_bins, num_frames] --> [B, C, num_frames, num_mel_bins] + residual = self.bwe_generator(mel_for_bwe) + + # 4. Residual connection with resampler + skip = self.resampler(x) + waveform = torch.clamp(residual + skip, -1, 1) + output_samples = num_samples * self.config.output_sampling_rate // self.config.input_sampling_rate + waveform = waveform[..., :output_samples] + return waveform