Add LTX-2.3 text-to-video generation support#402
Conversation
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This Pull Request successfully introduces support for LTX-2.3 text-to-video generation. It includes significant updates to the transformer architecture (gated attention, cross-modal modulation) and the denoising pipeline (4-way batched denoising for STG/CFG/MIG). The implementation is high-quality and integrates well with the existing LTX-2 infrastructure.
🔍 General Feedback
- Redundant Patch File: The
scratch_diff.patchfile was likely added by mistake and should be removed before merging. - Robustness: A few areas in the pipeline (like
audio_channelsfallback and upsampler parameter inference) could be made more robust to handle different model versions and naming conventions. - Optimization: The use of
nnx.jitfor the vocoder and the optimized sequence length in smoke tests are excellent additions for performance and stability.
|
🤖 Hi @prishajain1, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 Hi @prishajain1, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @prishajain1, but I was unable to process your request. Please see the logs for more details. |
|
🤖 Hi @prishajain1, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @prishajain1, but I was unable to process your request. Please see the logs for more details. |
|
🤖 Hi @prishajain1, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This Pull Request introduces comprehensive support for LTX-2.3 text-to-video generation, including the end-to-end pipeline, model updates, and a new vocoder with bandwidth extension (BWE). The implementation correctly handles complex features like Spatio-Temporal Guidance (STG) and Modality Isolation Guidance (MIG) using a 4-way batched denoising approach in JAX.
🔍 General Feedback
- STG/MIG Logic: The implementation of the 4-way split denoising logic and the corresponding delta formulations for guidance is impressive and aligns well with the LTX-2.3 technical requirements.
- Efficiency: Utilizing
nnx.scanfor the denoising loop ensures optimal performance on TPU/GPU hardware. - Redundancy: I identified some redundant initializations and assignments in the transformer and autoencoder models that should be cleaned up.
- Parameter Initialization: Double-check the usage of
nnx.Paramwithkernel_init, asnnx.Paramtypically only accepts the data tensor and might ignore additional keyword arguments.
| num_mod_params=num_mod_params, | ||
| use_additional_conditions=False, | ||
| dtype=self.dtype, | ||
| weights_dtype=self.weights_dtype, |
There was a problem hiding this comment.
🟡 This block is redundant as it exactly duplicates the initialization of prompt_adaln and audio_prompt_adaln already performed in lines 743-756.
| weights_dtype=self.weights_dtype, |
| ) | ||
|
|
||
| num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 | ||
| num_mel_bins = self.audio_vae.config.mel_bins |
There was a problem hiding this comment.
🟠 Similar to the __init__ check, this will crash if audio_vae is None. A fallback value (e.g., 64 or 128) or a conditional check is needed.
| num_mel_bins = self.audio_vae.config.mel_bins | |
| num_mel_bins = self.audio_vae.config.mel_bins if self.audio_vae is not None else 128 |
| def convert_to_vel(lat, x0, sigma_t): | ||
| return (lat - x0) / sigma_t | ||
|
|
||
| def scan_body(carry, inputs): |
There was a problem hiding this comment.
🟡 The current logic ties the 4-way guidance pass (STG + MIG) strictly to do_cfg and do_stg. If a user enables stg_scale > 0 but sets guidance_scale = 1.0, the pipeline will fall back to a 1-pass (or 2-pass if CFG is somehow active elsewhere) execution, and STG/MIG masks will not be applied. Consider decoupling these or adding a check if either guidance is requested.
| def scan_body(carry, inputs): | |
| do_cfg = guidance_scale > 1.0 | |
| do_stg = stg_scale > 0.0 |
| encoder_attention_mask=encoder_attention_mask, | ||
| audio_encoder_attention_mask=audio_encoder_attention_mask, | ||
| perturbation_mask=mask, | ||
| ) |
There was a problem hiding this comment.
🔴 The modality_mask is missing in the non-scan (else) path of the transformer forward pass. This will prevent Modality Isolation Guidance (MIG) from working correctly when scan_layers=False is set in the configuration.
| ) | |
| audio_encoder_attention_mask=audio_encoder_attention_mask, | |
| perturbation_mask=mask, | |
| modality_mask=modality_mask, | |
| ) |
| def convert_to_vel(lat, x0, sigma_t): | ||
| return (lat - x0) / sigma_t | ||
|
|
||
| def scan_body(carry, inputs): |
There was a problem hiding this comment.
🟡 The current logic ties the 4-way guidance pass (STG + MIG) strictly to do_cfg and do_stg. If a user enables stg_scale > 0 but sets guidance_scale = 1.0, the pipeline will fall back to a 1-pass (or 2-pass if CFG is somehow active elsewhere) execution, and STG/MIG masks will not be applied. Consider decoupling these or adding a check if either guidance is requested.
| def scan_body(carry, inputs): | |
| do_cfg = guidance_scale > 1.0 | |
| do_stg = stg_scale > 0.0 |
| v2a_attention_kernel: str = "dot_product", | ||
| flash_block_sizes: BlockSizes = None, | ||
| flash_min_seq_length: int = 4096, | ||
| perturbed_attn: bool = False, |
There was a problem hiding this comment.
| config_path = config.upsampler_model_path | ||
| if config_path == "Lightricks/LTX-2.3": | ||
| config_path = "Lightricks/LTX-2" | ||
|
|
There was a problem hiding this comment.
| filename = getattr(config, "upsampler_filename", None) |
| k1, k2, k3, k4 = jax.random.split(key, 4) | ||
| k1, k2, k3, k4, k5, k6 = jax.random.split(key, 6) | ||
|
|
||
| self.cross_attn_mod = cross_attn_mod |
There was a problem hiding this comment.
| self.cross_attn_mod = cross_attn_mod | |
| self.scale_shift_table = nnx.Param( | |
| jax.random.normal(k1, (table_size, self.dim), dtype=weights_dtype) / jnp.sqrt(self.dim) | |
| ) |
| inject_noise = tuple(reversed(inject_noise)) | ||
| upsample_residual = tuple(reversed(upsample_residual)) | ||
| upsample_factor = tuple(reversed(upsample_factor)) | ||
| upsample_type = upsample_type |
There was a problem hiding this comment.
| upsample_type = upsample_type | |
| upsample_type = upsample_type |
| ) | ||
|
|
||
| # Two independent connectors | ||
| self.per_modality_projections = per_modality_projections |
There was a problem hiding this comment.
| self.per_modality_projections = per_modality_projections | |
| self.caption_channels = caption_channels | |
| self.per_modality_projections = per_modality_projections |
This PR introduces end-to-end pipeline and model changes to support the LTX-2.3 multi-modal (audio-video) transformer model. It enables integrated text-to-audio-video generation using Gemma-based text conditioning, latent upsamplers, and vocoders.
Key architectural changes
to_gate_logits) applied to all attention operations in the block (Self-Video, Self-Audio, Prompt-Cross, and Modal-Cross).self.prompt_adaln). For this specific cross-attention modulation, it derives scale and shift parameters directly from the continuous noise level (sigma)per_modality_projections=True). Instead of a shared feature extractor, it applies per-token RMS normalization to the raw hidden states and passes them through two separate linear projection layers (video_text_proj_inandaudio_text_proj_in) before sending them to the respective video and audio connectors.LTX2VocoderWithBWE).Files added/modified
ltx2_3_video.yml file: New config file for LTX2.3vocoder_ltx2.py: Added support for BWE vocoderltx2_pipeline.py: Enabled 4-way sliced batched inference (Uncond, Cond, Perturb, Isolated) and integrated velocity/x0 conversion delta equations with guidance rescaling.transformer_ltx2.py: Propagated modality/perturbation masks to transformer blocks and integrated prompt adaptive layer norms.generate_ltx2.py,pyconfig.py,common_types.py: Added support for LTX2.3ltx2_utils.py: Added support to load new LTX2.3 specific weightsattention_ltx2.py: Added support for gated attention and perturbed attentionautoencoder_kl_ltx2.py: Added support for differentupsample_typeembeddings_connector_ltx2.py: Added gated attention configurations (gated_attn) support to intermediate transformer block connectors.feature_extractor_ltx2.py: support forper_modality_projectionsparameter addedtext_encoders.py: Implemented dual-modality parallel text connectors routing, token-wise RMS scaling, and independent video-audio linear projections.Sample outputs
In addition, we also tested with
scan_diffusion_loop = Trueandscan_diffusion_loop = False