Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/maxdiffusion/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
WAN2_1 = "wan2.1"
WAN2_2 = "wan2.2"
LTX2_VIDEO = "ltx2_video"
LTX2_3 = "ltx2.3"

WAN_MODEL = WAN2_1

Expand Down
161 changes: 161 additions & 0 deletions src/maxdiffusion/configs/ltx2_3_video.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
#hardware
Comment thread
prishajain1 marked this conversation as resolved.
hardware: 'tpu'
skip_jax_distributed_system: False
attention: 'flash'
a2v_attention_kernel: 'flash'
v2a_attention_kernel: 'dot_product'
attention_sharding_uniform: True
precision: 'bf16'
scan_layers: True
names_which_can_be_saved: []
names_which_can_be_offloaded: []
remat_policy: "NONE"

jax_cache_dir: ''
weights_dtype: 'bfloat16'
activations_dtype: 'bfloat16'

run_name: 'ltx2_inference'
output_dir: ''
config_path: ''
save_config_to_gcs: False

#Checkpoints
max_sequence_length: 1024
sampler: "from_checkpoint"

# Generation parameters (aligned with Diffusers LTX-2.3 docs: use_cross_timestep, modality + audio CFG)
global_batch_size_to_train_on: 1
num_inference_steps: 40
guidance_scale: 3.0
guidance_rescale: 0.7
audio_guidance_scale: 7.0
audio_guidance_rescale: 0.7
stg_scale: 1.0
audio_stg_scale: 1.0
modality_scale: 3.0
audio_modality_scale: 3.0
use_cross_timestep: true
spatio_temporal_guidance_blocks: [28]
fps: 24
pipeline_type: multi-scale
prompt: "A man in a brightly lit room talks on a vintage telephone. In a low, heavy voice, he says, 'I understand. I won't call again. Goodbye.' He hangs up the receiver and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is brightly lit by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a dramatic movie."
negative_prompt: "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static."
height: 512
width: 768
decode_timestep: 0.05
decode_noise_scale: 0.025
noise_scale: 0.0
num_frames: 121
quantization: "int8"
scan_diffusion_loop: True
use_bwe: True
#parallelism
mesh_axes: ['data', 'fsdp', 'context', 'tensor']
logical_axis_rules: [
['batch', ['data', 'fsdp']],
['activation_batch', ['data', 'fsdp']],
['activation_self_attn_heads', ['context', 'tensor']],
['activation_cross_attn_q_length', ['context', 'tensor']],
['activation_length', 'context'],
['activation_heads', 'tensor'],
['mlp','tensor'],
['embed', ['context', 'fsdp']],
['heads', 'tensor'],
['norm', 'tensor'],
['conv_batch', ['data', 'context', 'fsdp']],
['out_channels', 'tensor'],
['conv_out', 'context'],
]
data_sharding: ['data', 'fsdp', 'context', 'tensor']

dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: -1

flash_block_sizes: {
block_q: 2048,
block_kv: 2048,
block_kv_compute: 1024,
block_q_dkv: 2048,
block_kv_dkv: 2048,
block_kv_dkv_compute: 2048,
use_fused_bwd_kernel: True,
}
flash_min_seq_length: 4096
dcn_context_parallelism: 1
dcn_tensor_parallelism: 1
ici_data_parallelism: 1
ici_fsdp_parallelism: 1
ici_context_parallelism: -1 # recommended ICI axis to be auto-sharded
ici_tensor_parallelism: 1
enable_profiler: False

# ML Diagnostics settings
enable_ml_diagnostics: True
profiler_gcs_path: ""
enable_ondemand_xprof: True
skip_first_n_steps_for_profiler: 0
profiler_steps: 5

replicate_vae: False

allow_split_physical_axes: False
learning_rate_schedule_steps: -1
max_train_steps: 500
pretrained_model_name_or_path: 'dg845/LTX-2.3-Diffusers'
model_name: "ltx2.3"
model_type: "T2V"
unet_checkpoint: ''
checkpoint_dir: ""
dataset_name: ''
train_split: 'train'
dataset_type: 'tfrecord'
cache_latents_text_encoder_outputs: True
per_device_batch_size: 1.0
compile_topology_num_slices: -1
quantization_local_shard_count: -1
use_qwix_quantization: False
weight_quantization_calibration_method: "absmax"
act_quantization_calibration_method: "absmax"
bwd_quantization_calibration_method: "absmax"
qwix_module_path: ".*"
jit_initializers: True
enable_single_replica_ckpt_restoring: False
seed: 10
audio_format: "s16"

# LoRA parameters
enable_lora: False

# Distilled LoRA
# lora_config: {
# lora_model_name_or_path: ["Lightricks/LTX-2"],
# weight_name: ["ltx-2-19b-distilled-lora-384.safetensors"],
# adapter_name: ["distilled-lora-384"],
# rank: [384]
# }

# Standard LoRA
lora_config: {
lora_model_name_or_path: ["Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In"],
weight_name: ["ltx-2-19b-lora-camera-control-dolly-in.safetensors"],
adapter_name: ["camera-control-dolly-in"],
rank: [32]
}


# LTX-2 Latent Upsampler
run_latent_upsampler: False
upsampler_model_path: "Lightricks/LTX-2.3"
# Following upsampler files are supported:
# ltx-2.3-spatial-upscaler-x2-1.0.safetensors
# ltx-2.3-spatial-upscaler-x2-1.1.safetensors
# ltx-2.3-spatial-upscaler-x1.5-1.0.safetensors
# ltx-2.3-temporal-upscaler-x2-1.0.safetensors
upsampler_filename: "ltx-2.3-spatial-upscaler-x2-1.0.safetensors"
upsampler_spatial_patch_size: 1
upsampler_temporal_patch_size: 1
upsampler_adain_factor: 0.0
upsampler_tone_map_compression_ratio: 0.0
upsampler_rational_spatial_scale: 2.0
upsampler_output_type: "pil"
14 changes: 13 additions & 1 deletion src/maxdiffusion/configs/ltx2_video.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ sampler: "from_checkpoint"
global_batch_size_to_train_on: 1
num_inference_steps: 40
guidance_scale: 3.0
guidance_rescale: 0.0
audio_guidance_scale: 0.0
audio_guidance_rescale: 0.0
stg_scale: 0.0
audio_stg_scale: 0.0
modality_scale: 1.0
audio_modality_scale: 1.0
use_cross_timestep: false
spatio_temporal_guidance_blocks: []
noise_scale: 1.0
fps: 24
pipeline_type: multi-scale
prompt: "A man in a brightly lit room talks on a vintage telephone. In a low, heavy voice, he says, 'I understand. I won't call again. Goodbye.' He hangs up the receiver and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is brightly lit by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a dramatic movie."
Expand Down Expand Up @@ -87,13 +97,15 @@ enable_profiler: False

# ML Diagnostics settings
enable_ml_diagnostics: True
profiler_gcs_path: "gs://mehdy/profiler/ml_diagnostics"
profiler_gcs_path: ""
enable_ondemand_xprof: True
skip_first_n_steps_for_profiler: 0
profiler_steps: 5

replicate_vae: False

use_bwe: False

allow_split_physical_axes: False
learning_rate_schedule_steps: -1
max_train_steps: 500
Expand Down
13 changes: 12 additions & 1 deletion src/maxdiffusion/generate_ltx2.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,20 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
num_frames=config.num_frames,
num_inference_steps=config.num_inference_steps,
guidance_scale=guidance_scale,
guidance_rescale=getattr(config, "guidance_rescale", 0.0),
generator=generator,
frame_rate=getattr(config, "fps", 24.0),
decode_timestep=getattr(config, "decode_timestep", 0.0),
decode_noise_scale=getattr(config, "decode_noise_scale", None),
max_sequence_length=getattr(config, "max_sequence_length", 1024),
audio_guidance_scale=getattr(config, "audio_guidance_scale", None),
audio_guidance_rescale=getattr(config, "audio_guidance_rescale", None),
stg_scale=getattr(config, "stg_scale", 0.0),
audio_stg_scale=getattr(config, "audio_stg_scale", None),
modality_scale=getattr(config, "modality_scale", 1.0),
audio_modality_scale=getattr(config, "audio_modality_scale", None),
use_cross_timestep=getattr(config, "use_cross_timestep", None),
noise_scale=getattr(config, "noise_scale", 1.0),
dtype=jnp.bfloat16 if getattr(config, "activations_dtype", "bfloat16") == "bfloat16" else jnp.float32,
output_type=getattr(config, "upsampler_output_type", "pil"),
)
Expand Down Expand Up @@ -220,7 +229,9 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):

saved_video_path = []
audio_sample_rate = (
getattr(pipeline.vocoder.config, "output_sampling_rate", 24000) if hasattr(pipeline, "vocoder") else 24000
getattr(pipeline.vocoder.config, "output_sampling_rate", 24000)
if getattr(pipeline, "vocoder", None) is not None
else 24000
)
fps = getattr(config, "fps", 24)

Expand Down
26 changes: 26 additions & 0 deletions src/maxdiffusion/models/ltx2/attention_ltx2.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ def __init__(
qkv_sharding_spec: Optional[tuple] = None,
out_sharding_spec: Optional[tuple] = None,
out_bias_sharding_spec: Optional[tuple] = None,
gated_attn: bool = False,
):
self.heads = heads
self.rope_type = rope_type
Expand Down Expand Up @@ -444,6 +445,17 @@ def __init__(
else:
self.dropout_layer = None

if gated_attn:
self.to_gate_logits = nnx.Linear(
query_dim,
heads,
use_bias=True,
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "heads")),
bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), ("heads",)),
rngs=rngs,
dtype=dtype,
)

self.attention_op = NNXAttentionOp(
mesh=mesh,
attention_kernel=attention_kernel,
Expand All @@ -464,6 +476,7 @@ def __call__(
attention_mask: Optional[Array] = None,
rotary_emb: Optional[Tuple[Array, Array]] = None,
k_rotary_emb: Optional[Tuple[Array, Array]] = None,
perturbation_mask: Optional[Array] = None,
) -> Array:
# Determine context (Self or Cross)
context = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
Expand Down Expand Up @@ -507,6 +520,19 @@ def __call__(
# NNXAttentionOp expects flattened input [B, S, InnerDim] for flash kernel
attn_output = self.attention_op.apply_attention(query=query, key=key, value=value, attention_mask=attention_mask)

if perturbation_mask is not None:
# value is [B, S, InnerDim]
# attn_output is [B, S, InnerDim]
attn_output = value + perturbation_mask * (attn_output - value)

if getattr(self, "to_gate_logits", None) is not None:
gate_logits = self.to_gate_logits(hidden_states)
b, s, _ = attn_output.shape
attn_output = attn_output.reshape(b, s, self.heads, self.dim_head)
gates = 2.0 * jax.nn.sigmoid(gate_logits)
attn_output = attn_output * jnp.expand_dims(gates, axis=-1)
attn_output = attn_output.reshape(b, s, -1)

# 7. Output Projection
hidden_states = self.to_out(attn_output)

Expand Down
Loading
Loading