Skip to content

Commit 0191cfd

Browse files
committed
cleanup
1 parent 655b8e6 commit 0191cfd

10 files changed

Lines changed: 55 additions & 56 deletions

File tree

src/diffusers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -923,8 +923,8 @@
923923
BriaFiboTransformer2DModel,
924924
BriaTransformer2DModel,
925925
CacheMixin,
926-
ChromaTransformer2DModel,
927926
ChromaRadianceTransformer2DModel,
927+
ChromaTransformer2DModel,
928928
ChronoEditTransformer3DModel,
929929
CogVideoXTransformer3DModel,
930930
CogView3PlusTransformer2DModel,
@@ -1133,8 +1133,8 @@
11331133
BriaFiboPipeline,
11341134
BriaPipeline,
11351135
ChromaImg2ImgPipeline,
1136-
ChromaRadiancePipeline,
11371136
ChromaPipeline,
1137+
ChromaRadiancePipeline,
11381138
ChronoEditPipeline,
11391139
CLIPImageProjection,
11401140
CogVideoXFunControlPipeline,

src/diffusers/loaders/single_file_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@
3131
convert_animatediff_checkpoint_to_diffusers,
3232
convert_auraflow_transformer_checkpoint_to_diffusers,
3333
convert_autoencoder_dc_checkpoint_to_diffusers,
34-
convert_chroma_transformer_checkpoint_to_diffusers,
3534
convert_chroma_radiance_transformer_checkpoint_to_diffusers,
35+
convert_chroma_transformer_checkpoint_to_diffusers,
3636
convert_controlnet_checkpoint,
3737
convert_cosmos_transformer_checkpoint_to_diffusers,
3838
convert_flux2_transformer_checkpoint_to_diffusers,

src/diffusers/loaders/single_file_utils.py

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -663,9 +663,7 @@ def infer_diffusers_model_type(checkpoint):
663663
model_type = "flux-2-dev"
664664

665665
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]):
666-
if any(
667-
c in checkpoint for c in ["distilled_guidance_layer.in_proj.bias"]
668-
):
666+
if any(c in checkpoint for c in ["distilled_guidance_layer.in_proj.bias"]):
669667
# Should be updated once a repo exists
670668
# if any(h in checkpoint for h in ["nerf_blocks.0.param_generator.bias"]):
671669
# model_type = "chroma-radiance"
@@ -3556,6 +3554,7 @@ def swap_scale_shift(weight):
35563554

35573555
return converted_state_dict
35583556

3557+
35593558
def convert_chroma_radiance_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
35603559
converted_state_dict = {}
35613560
keys = list(checkpoint.keys())
@@ -3715,30 +3714,20 @@ def swap_scale_shift(weight):
37153714
# output projections.
37163715
converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight")
37173716
converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias")
3718-
3717+
37193718
# nerf
3720-
3721-
converted_state_dict["nerf.nerf_embedder.embedder.0.bias"] = checkpoint.pop(
3722-
"nerf_image_embedder.embedder.0.bias"
3723-
)
3719+
3720+
converted_state_dict["nerf.nerf_embedder.embedder.0.bias"] = checkpoint.pop("nerf_image_embedder.embedder.0.bias")
37243721
converted_state_dict["nerf.nerf_embedder.embedder.0.weight"] = checkpoint.pop(
37253722
"nerf_image_embedder.embedder.0.weight"
37263723
)
3727-
converted_state_dict["nerf.final_layer.conv.bias"] = checkpoint.pop(
3728-
"nerf_final_layer_conv.conv.bias"
3729-
)
3730-
converted_state_dict["nerf.final_layer.conv.weight"] = checkpoint.pop(
3731-
"nerf_final_layer_conv.conv.weight"
3732-
)
3733-
converted_state_dict["nerf.final_layer.norm.weight"] = checkpoint.pop(
3734-
"nerf_final_layer_conv.norm.scale"
3735-
)
3724+
converted_state_dict["nerf.final_layer.conv.bias"] = checkpoint.pop("nerf_final_layer_conv.conv.bias")
3725+
converted_state_dict["nerf.final_layer.conv.weight"] = checkpoint.pop("nerf_final_layer_conv.conv.weight")
3726+
converted_state_dict["nerf.final_layer.norm.weight"] = checkpoint.pop("nerf_final_layer_conv.norm.scale")
37363727

37373728
for i in range(num_nerf_layers):
37383729
block_prefix = f"nerf.blocks.{i}."
3739-
converted_state_dict[f"{block_prefix}norm.weight"] = checkpoint.pop(
3740-
f"nerf_blocks.{i}.norm.scale"
3741-
)
3730+
converted_state_dict[f"{block_prefix}norm.weight"] = checkpoint.pop(f"nerf_blocks.{i}.norm.scale")
37423731
converted_state_dict[f"{block_prefix}param_generator.bias"] = checkpoint.pop(
37433732
f"nerf_blocks.{i}.param_generator.bias"
37443733
)
@@ -3747,16 +3736,13 @@ def swap_scale_shift(weight):
37473736
)
37483737

37493738
# patch
3750-
3751-
converted_state_dict["x_embedder_patch.bias"] = checkpoint.pop(
3752-
"img_in_patch.bias"
3753-
)
3754-
converted_state_dict["x_embedder_patch.weight"] = checkpoint.pop(
3755-
"img_in_patch.weight"
3756-
)
3739+
3740+
converted_state_dict["x_embedder_patch.bias"] = checkpoint.pop("img_in_patch.bias")
3741+
converted_state_dict["x_embedder_patch.weight"] = checkpoint.pop("img_in_patch.weight")
37573742

37583743
return converted_state_dict
37593744

3745+
37603746
def convert_cosmos_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
37613747
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
37623748

src/diffusers/models/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,10 @@
8686
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
8787
_import_structure["transformers.transformer_bria"] = ["BriaTransformer2DModel"]
8888
_import_structure["transformers.transformer_bria_fibo"] = ["BriaFiboTransformer2DModel"]
89-
_import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel", "ChromaRadianceTransformer2DModel"]
89+
_import_structure["transformers.transformer_chroma"] = [
90+
"ChromaTransformer2DModel",
91+
"ChromaRadianceTransformer2DModel",
92+
]
9093
_import_structure["transformers.transformer_chronoedit"] = ["ChronoEditTransformer3DModel"]
9194
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
9295
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
@@ -184,8 +187,8 @@
184187
AuraFlowTransformer2DModel,
185188
BriaFiboTransformer2DModel,
186189
BriaTransformer2DModel,
187-
ChromaTransformer2DModel,
188190
ChromaRadianceTransformer2DModel,
191+
ChromaTransformer2DModel,
189192
ChronoEditTransformer3DModel,
190193
CogVideoXTransformer3DModel,
191194
CogView3PlusTransformer2DModel,

src/diffusers/models/transformers/transformer_chroma.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3737

38+
3839
class Nerf(nn.Module):
3940
def __init__(
4041
self,
@@ -44,7 +45,7 @@ def __init__(
4445
transformer_hidden_size: int,
4546
max_freqs: int,
4647
mlp_ratio: int,
47-
eps = 1e-6,
48+
eps=1e-6,
4849
):
4950
super().__init__()
5051
self.nerf_embedder = NerfEmbedder(
@@ -69,6 +70,7 @@ def __init__(
6970
eps=eps,
7071
)
7172
self.transformer_hidden_size = transformer_hidden_size
73+
7274
def __call__(
7375
self,
7476
pixels,
@@ -77,20 +79,20 @@ def __call__(
7779
num_patches,
7880
):
7981
batch_size, channels, height, width = pixels.shape
80-
82+
8183
pixels = nn.functional.unfold(pixels, kernel_size=patch_size, stride=patch_size)
8284
pixels = pixels.transpose(1, 2)
83-
85+
8486
hidden = latents.reshape(batch_size * num_patches, self.transformer_hidden_size)
8587
pixels = pixels.reshape(batch_size * num_patches, channels, patch_size**2).transpose(1, 2)
86-
88+
8789
# Get pixel embeddings
8890
latents_dct = self.nerf_embedder(pixels)
89-
91+
9092
# Pass through blocks
9193
for block in self.blocks:
9294
latents_dct = block(latents_dct, hidden)
93-
95+
9496
latents_dct = latents_dct.transpose(1, 2).reshape(batch_size, num_patches, -1).transpose(1, 2)
9597
latents_dct = nn.functional.fold(
9698
latents_dct,
@@ -100,6 +102,7 @@ def __call__(
100102
)
101103
return self.final_layer(latents_dct)
102104

105+
103106
class NerfEmbedder(nn.Module):
104107
def __init__(
105108
self,
@@ -111,6 +114,7 @@ def __init__(
111114
self.max_freqs = max_freqs
112115
self.hidden_size = hidden_size
113116
self.embedder = nn.Sequential(nn.Linear(in_channels + max_freqs**2, hidden_size))
117+
114118
def fetch_pos(self, patch_size) -> torch.Tensor:
115119
pos_x = torch.linspace(0, 1, patch_size)
116120
pos_y = torch.linspace(0, 1, patch_size)
@@ -123,8 +127,9 @@ def fetch_pos(self, patch_size) -> torch.Tensor:
123127
coeffs = (1 + freqs_x * freqs_y) ** -1
124128
dct_x = torch.cos(pos_x * freqs_x * torch.pi)
125129
dct_y = torch.cos(pos_y * freqs_y * torch.pi)
126-
dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs ** 2)
130+
dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs**2)
127131
return dct
132+
128133
def __call__(self, inputs: torch.Tensor) -> torch.Tensor:
129134
batch, pixels, channels = inputs.shape
130135
patch_size = int(pixels**0.5)
@@ -134,13 +139,15 @@ def __call__(self, inputs: torch.Tensor) -> torch.Tensor:
134139
inputs = torch.cat((inputs, dct), dim=-1)
135140
return self.embedder(inputs)
136141

142+
137143
class NerfGLUBlock(nn.Module):
138144
def __init__(self, transformer_hidden_size: int, nerf_hidden_size: int, mlp_ratio, eps):
139145
super().__init__()
140146
total_params = 3 * nerf_hidden_size**2 * mlp_ratio
141147
self.param_generator = nn.Linear(transformer_hidden_size, total_params)
142148
self.norm = RMSNorm(nerf_hidden_size, eps=eps)
143149
self.mlp_ratio = mlp_ratio
150+
144151
def forward(self, x: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
145152
batch_size, num_x, hidden_size_x = x.shape
146153
mlp_params = self.param_generator(s)
@@ -156,6 +163,7 @@ def forward(self, x: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
156163
x = torch.bmm(torch.nn.functional.silu(torch.bmm(x, fc1_gate)) * torch.bmm(x, fc1_value), fc2)
157164
return x + res_x
158165

166+
159167
class NerfFinalLayer(nn.Module):
160168
def __init__(self, hidden_size: int, out_channels: int, eps):
161169
super().__init__()
@@ -166,9 +174,11 @@ def __init__(self, hidden_size: int, out_channels: int, eps):
166174
kernel_size=3,
167175
padding=1,
168176
)
177+
169178
def forward(self, x: torch.Tensor) -> torch.Tensor:
170179
return self.conv(self.norm(x.movedim(1, -1)).movedim(-1, 1))
171180

181+
172182
class ChromaAdaLayerNormZeroPruned(nn.Module):
173183
r"""
174184
Norm layer adaptive layer norm zero (adaLN-Zero).
@@ -658,7 +668,7 @@ def forward(
658668
logger.warning(
659669
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
660670
)
661-
671+
662672
hidden_states = self.x_embedder(hidden_states)
663673

664674
timestep = timestep.to(hidden_states.dtype) * 1000
@@ -773,6 +783,7 @@ def forward(
773783

774784
return Transformer2DModelOutput(sample=output)
775785

786+
776787
class ChromaRadianceTransformer2DModel(
777788
ModelMixin,
778789
ConfigMixin,
@@ -850,7 +861,7 @@ def __init__(
850861
hidden_dim=approximator_hidden_dim,
851862
n_layers=approximator_layers,
852863
)
853-
864+
854865
self.nerf = Nerf(
855866
in_channels,
856867
nerf_layers,
@@ -859,7 +870,7 @@ def __init__(
859870
nerf_max_freqs,
860871
nerf_mlp_ratio,
861872
)
862-
873+
863874
self.x_embedder_patch = nn.Conv2d(
864875
in_channels,
865876
self.inner_dim,
@@ -932,7 +943,6 @@ def forward(
932943
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
933944
`tuple` where the first element is the sample tensor.
934945
"""
935-
print(self.device)
936946
pixels = hidden_states.to(self.device)
937947
if joint_attention_kwargs is not None:
938948
joint_attention_kwargs = joint_attention_kwargs.copy()

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,6 @@ def __call__(
115115
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
116116
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
117117

118-
print("query", query.shape, "key", key.shape, "value", value.shape)
119-
120118
hidden_states = dispatch_attention_fn(
121119
query,
122120
key,

src/diffusers/pipelines/chroma/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@
3333
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
3434
else:
3535
from .pipeline_chroma import ChromaPipeline
36-
from .pipeline_chroma_radiance import ChromaRadiancePipeline
3736
from .pipeline_chroma_img2img import ChromaImg2ImgPipeline
37+
from .pipeline_chroma_radiance import ChromaRadiancePipeline
3838
else:
3939
import sys
4040

src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,10 @@
2121

2222
from ...image_processor import PipelineImageInput, VaeImageProcessor
2323
from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
24-
from ...models import AutoencoderKL, ChromaTransformer2DModel, ChromaRadianceTransformer2DModel
24+
from ...models import ChromaRadianceTransformer2DModel
2525
from ...schedulers import FlowMatchEulerDiscreteScheduler
2626
from ...utils import (
2727
USE_PEFT_BACKEND,
28-
deprecate,
2928
is_torch_xla_available,
3029
logging,
3130
replace_example_docstring,
@@ -147,6 +146,7 @@ def retrieve_timesteps(
147146
timesteps = scheduler.timesteps
148147
return timesteps, num_inference_steps
149148

149+
150150
class ChromaRadiancePipeline(
151151
DiffusionPipeline,
152152
FluxLoraLoaderMixin,
@@ -420,7 +420,6 @@ def check_inputs(
420420
callback_on_step_end_tensor_inputs=None,
421421
max_sequence_length=None,
422422
):
423-
424423
if callback_on_step_end_tensor_inputs is not None and not all(
425424
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
426425
):
@@ -503,9 +502,7 @@ def prepare_latents(
503502
latents=None,
504503
patch_size=2,
505504
):
506-
507505
shape = (batch_size, num_channels_latents, height, width)
508-
print(shape)
509506

510507
if latents is not None:
511508
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
@@ -518,9 +515,11 @@ def prepare_latents(
518515
)
519516

520517
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
521-
#latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
518+
# latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
522519

523-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // patch_size, width // patch_size, device, dtype)
520+
latent_image_ids = self._prepare_latent_image_ids(
521+
batch_size, height // patch_size, width // patch_size, device, dtype
522+
)
524523

525524
return latents, latent_image_ids
526525

@@ -822,7 +821,6 @@ def __call__(
822821
batch_size * num_images_per_prompt,
823822
)
824823

825-
826824
# 6. Denoising loop
827825
with self.progress_bar(total=num_inference_steps) as progress_bar:
828826
for i, t in enumerate(timesteps):
@@ -887,12 +885,12 @@ def __call__(
887885

888886
self._current_timestep = None
889887

890-
# 7.
888+
# 7.
891889

892890
if output_type == "latent":
893891
image = latents
894892
else:
895-
#image = self._unpack_latents(image, height, width)
893+
# image = self._unpack_latents(image, height, width)
896894
image = self.image_processor.postprocess(latents, output_type=output_type)
897895

898896
# Offload all models

src/diffusers/utils/dummy_pt_objects.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,7 @@ def from_config(cls, *args, **kwargs):
662662
def from_pretrained(cls, *args, **kwargs):
663663
requires_backends(cls, ["torch"])
664664

665+
665666
class ChromaRadianceTransformer2DModel(metaclass=DummyObject):
666667
_backends = ["torch"]
667668

@@ -676,6 +677,7 @@ def from_config(cls, *args, **kwargs):
676677
def from_pretrained(cls, *args, **kwargs):
677678
requires_backends(cls, ["torch"])
678679

680+
679681
class ChronoEditTransformer3DModel(metaclass=DummyObject):
680682
_backends = ["torch"]
681683

0 commit comments

Comments
 (0)