Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
310 changes: 280 additions & 30 deletions scripts/convert_ltx2_to_diffusers.py

Large diffs are not rendered by default.

88 changes: 68 additions & 20 deletions src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def forward(


# Like LTX 1.0 LTXVideoDownsampler3d, but uses new causal Conv3d
class LTXVideoDownsampler3d(nn.Module):
class LTX2VideoDownsampler3d(nn.Module):
def __init__(
self,
in_channels: int,
Expand Down Expand Up @@ -285,10 +285,11 @@ def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Ten


# Like LTX 1.0 LTXVideoUpsampler3d, but uses new causal Conv3d
class LTXVideoUpsampler3d(nn.Module):
class LTX2VideoUpsampler3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int | None = None,
stride: int | tuple[int, int, int] = 1,
residual: bool = False,
upscale_factor: int = 1,
Expand All @@ -300,7 +301,8 @@ def __init__(
self.residual = residual
self.upscale_factor = upscale_factor

out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor
out_channels = out_channels or in_channels
out_channels = (out_channels * stride[0] * stride[1] * stride[2]) // upscale_factor

self.conv = LTX2VideoCausalConv3d(
in_channels=in_channels,
Expand Down Expand Up @@ -408,7 +410,7 @@ def __init__(
)
elif downsample_type == "spatial":
self.downsamplers.append(
LTXVideoDownsampler3d(
LTX2VideoDownsampler3d(
in_channels=in_channels,
out_channels=out_channels,
stride=(1, 2, 2),
Expand All @@ -417,7 +419,7 @@ def __init__(
)
elif downsample_type == "temporal":
self.downsamplers.append(
LTXVideoDownsampler3d(
LTX2VideoDownsampler3d(
in_channels=in_channels,
out_channels=out_channels,
stride=(2, 1, 1),
Expand All @@ -426,7 +428,7 @@ def __init__(
)
elif downsample_type == "spatiotemporal":
self.downsamplers.append(
LTXVideoDownsampler3d(
LTX2VideoDownsampler3d(
in_channels=in_channels,
out_channels=out_channels,
stride=(2, 2, 2),
Expand Down Expand Up @@ -580,6 +582,7 @@ def __init__(
resnet_eps: float = 1e-6,
resnet_act_fn: str = "swish",
spatio_temporal_scale: bool = True,
upsample_type: str = "spatiotemporal",
inject_noise: bool = False,
timestep_conditioning: bool = False,
upsample_residual: bool = False,
Expand Down Expand Up @@ -609,17 +612,38 @@ def __init__(

self.upsamplers = None
if spatio_temporal_scale:
self.upsamplers = nn.ModuleList(
[
LTXVideoUpsampler3d(
out_channels * upscale_factor,
self.upsamplers = nn.ModuleList()

if upsample_type == "spatial":
self.upsamplers.append(
LTX2VideoUpsampler3d(
in_channels=out_channels * upscale_factor,
stride=(1, 2, 2),
residual=upsample_residual,
upscale_factor=upscale_factor,
spatial_padding_mode=spatial_padding_mode,
)
)
elif upsample_type == "temporal":
self.upsamplers.append(
LTX2VideoUpsampler3d(
in_channels=out_channels * upscale_factor,
stride=(2, 1, 1),
residual=upsample_residual,
upscale_factor=upscale_factor,
spatial_padding_mode=spatial_padding_mode,
)
)
elif upsample_type == "spatiotemporal":
self.upsamplers.append(
LTX2VideoUpsampler3d(
in_channels=out_channels * upscale_factor,
stride=(2, 2, 2),
residual=upsample_residual,
upscale_factor=upscale_factor,
spatial_padding_mode=spatial_padding_mode,
)
]
)
)

resnets = []
for _ in range(num_layers):
Expand Down Expand Up @@ -716,7 +740,7 @@ def __init__(
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
),
spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, True),
spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True, True),
layers_per_block: tuple[int, ...] = (4, 6, 6, 2, 2),
downsample_type: tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
patch_size: int = 4,
Expand All @@ -726,6 +750,9 @@ def __init__(
spatial_padding_mode: str = "zeros",
):
super().__init__()
num_encoder_blocks = len(layers_per_block)
if isinstance(spatio_temporal_scaling, bool):
spatio_temporal_scaling = (spatio_temporal_scaling,) * (num_encoder_blocks - 1)

self.patch_size = patch_size
self.patch_size_t = patch_size_t
Expand Down Expand Up @@ -860,19 +887,27 @@ def __init__(
in_channels: int = 128,
out_channels: int = 3,
block_out_channels: tuple[int, ...] = (256, 512, 1024),
spatio_temporal_scaling: tuple[bool, ...] = (True, True, True),
spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True),
layers_per_block: tuple[int, ...] = (5, 5, 5, 5),
upsample_type: tuple[str, ...] = ("spatiotemporal", "spatiotemporal", "spatiotemporal"),
patch_size: int = 4,
patch_size_t: int = 1,
resnet_norm_eps: float = 1e-6,
is_causal: bool = False,
inject_noise: tuple[bool, ...] = (False, False, False),
inject_noise: bool | tuple[bool, ...] = (False, False, False),
timestep_conditioning: bool = False,
upsample_residual: tuple[bool, ...] = (True, True, True),
upsample_residual: bool | tuple[bool, ...] = (True, True, True),
upsample_factor: tuple[bool, ...] = (2, 2, 2),
spatial_padding_mode: str = "reflect",
) -> None:
super().__init__()
num_decoder_blocks = len(layers_per_block)
if isinstance(spatio_temporal_scaling, bool):
spatio_temporal_scaling = (spatio_temporal_scaling,) * (num_decoder_blocks - 1)
if isinstance(inject_noise, bool):
inject_noise = (inject_noise,) * num_decoder_blocks
if isinstance(upsample_residual, bool):
upsample_residual = (upsample_residual,) * (num_decoder_blocks - 1)

self.patch_size = patch_size
self.patch_size_t = patch_size_t
Expand Down Expand Up @@ -917,6 +952,7 @@ def __init__(
num_layers=layers_per_block[i + 1],
resnet_eps=resnet_norm_eps,
spatio_temporal_scale=spatio_temporal_scaling[i],
upsample_type=upsample_type[i],
inject_noise=inject_noise[i + 1],
timestep_conditioning=timestep_conditioning,
upsample_residual=upsample_residual[i],
Expand Down Expand Up @@ -1058,11 +1094,12 @@ def __init__(
decoder_block_out_channels: tuple[int, ...] = (256, 512, 1024),
layers_per_block: tuple[int, ...] = (4, 6, 6, 2, 2),
decoder_layers_per_block: tuple[int, ...] = (5, 5, 5, 5),
spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, True),
decoder_spatio_temporal_scaling: tuple[bool, ...] = (True, True, True),
decoder_inject_noise: tuple[bool, ...] = (False, False, False, False),
spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True, True),
decoder_spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True),
decoder_inject_noise: bool | tuple[bool, ...] = (False, False, False, False),
downsample_type: tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
upsample_residual: tuple[bool, ...] = (True, True, True),
upsample_type: tuple[str, ...] = ("spatiotemporal", "spatiotemporal", "spatiotemporal"),
upsample_residual: bool | tuple[bool, ...] = (True, True, True),
upsample_factor: tuple[int, ...] = (2, 2, 2),
timestep_conditioning: bool = False,
patch_size: int = 4,
Expand All @@ -1077,6 +1114,16 @@ def __init__(
temporal_compression_ratio: int = None,
) -> None:
super().__init__()
num_encoder_blocks = len(layers_per_block)
num_decoder_blocks = len(decoder_layers_per_block)
if isinstance(spatio_temporal_scaling, bool):
spatio_temporal_scaling = (spatio_temporal_scaling,) * (num_encoder_blocks - 1)
if isinstance(decoder_spatio_temporal_scaling, bool):
decoder_spatio_temporal_scaling = (decoder_spatio_temporal_scaling,) * (num_decoder_blocks - 1)
if isinstance(decoder_inject_noise, bool):
decoder_inject_noise = (decoder_inject_noise,) * num_decoder_blocks
if isinstance(upsample_residual, bool):
upsample_residual = (upsample_residual,) * (num_decoder_blocks - 1)

self.encoder = LTX2VideoEncoder3d(
in_channels=in_channels,
Expand All @@ -1098,6 +1145,7 @@ def __init__(
block_out_channels=decoder_block_out_channels,
spatio_temporal_scaling=decoder_spatio_temporal_scaling,
layers_per_block=decoder_layers_per_block,
upsample_type=upsample_type,
patch_size=patch_size,
patch_size_t=patch_size_t,
resnet_norm_eps=resnet_norm_eps,
Expand Down
Loading
Loading