Skip to content

Commit 655b8e6

Browse files
committed
make inference run
1 parent 54f3bbf commit 655b8e6

2 files changed

Lines changed: 8 additions & 20 deletions

File tree

src/diffusers/models/transformers/transformer_chroma.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def __call__(
7878
):
7979
batch_size, channels, height, width = pixels.shape
8080

81-
pixels = nn.functional.unfold(pixels, kernel_size=self.patch_size, stride=self.patch_size)
81+
pixels = nn.functional.unfold(pixels, kernel_size=patch_size, stride=patch_size)
8282
pixels = pixels.transpose(1, 2)
8383

8484
hidden = latents.reshape(batch_size * num_patches, self.transformer_hidden_size)
@@ -91,7 +91,7 @@ def __call__(
9191
for block in self.blocks:
9292
latents_dct = block(latents_dct, hidden)
9393

94-
latents_dct = latents_dct.transpose(1, 2).reshape(batch_size, -1).transpose(1, 2)
94+
latents_dct = latents_dct.transpose(1, 2).reshape(batch_size, num_patches, -1).transpose(1, 2)
9595
latents_dct = nn.functional.fold(
9696
latents_dct,
9797
output_size=(height, width),
@@ -129,11 +129,10 @@ def __call__(self, inputs: torch.Tensor) -> torch.Tensor:
129129
batch, pixels, channels = inputs.shape
130130
patch_size = int(pixels**0.5)
131131
input_dtype = inputs.dtype
132-
inputs = inputs.to(dtype=self.embedder[0].weight.dtype)
133-
dct = self.fetch_pos(patch_size, inputs.device, self.dtype)
134-
dct = dct.repeat(batch, 1, 1)
132+
dct = self.fetch_pos(patch_size)
133+
dct = dct.repeat(batch, 1, 1).to(dtype=input_dtype, device=inputs.device)
135134
inputs = torch.cat((inputs, dct), dim=-1)
136-
return self.embedder(inputs).to(dtype=input_dtype)
135+
return self.embedder(inputs)
137136

138137
class NerfGLUBlock(nn.Module):
139138
def __init__(self, transformer_hidden_size: int, nerf_hidden_size: int, mlp_ratio, eps):
@@ -933,7 +932,8 @@ def forward(
933932
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
934933
`tuple` where the first element is the sample tensor.
935934
"""
936-
print("states", hidden_states.shape, encoder_hidden_states.shape)
935+
print(self.device)
936+
pixels = hidden_states.to(self.device)
937937
if joint_attention_kwargs is not None:
938938
joint_attention_kwargs = joint_attention_kwargs.copy()
939939
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
@@ -948,23 +948,16 @@ def forward(
948948
logger.warning(
949949
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
950950
)
951-
952-
pixels = nn.functional.unfold(hidden_states, kernel_size=self.config.patch_size, stride=self.config.patch_size)
953-
pixels = pixels.transpose(1, 2)
954-
print("pixels", pixels.shape)
955951
hidden_states = self.x_embedder_patch(hidden_states)
956-
print("img_patch:", hidden_states.shape)
957952
num_patches = hidden_states.shape[2] * hidden_states.shape[3]
958953
hidden_states = hidden_states.flatten(2).transpose(1, 2)
959-
print(hidden_states.shape)
960954

961955
timestep = timestep.to(hidden_states.dtype) * 1000
962956

963957
input_vec = self.time_text_embed(timestep)
964958
pooled_temb = self.distilled_guidance_layer(input_vec)
965959

966960
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
967-
print(encoder_hidden_states.shape)
968961

969962
if txt_ids.ndim == 3:
970963
logger.warning(

src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ def __init__(
195195
image_encoder=image_encoder,
196196
feature_extractor=feature_extractor,
197197
)
198+
self.image_processor = VaeImageProcessor()
198199
self.default_sample_size = 1024
199200

200201
def _get_t5_prompt_embeds(
@@ -891,12 +892,6 @@ def __call__(
891892
if output_type == "latent":
892893
image = latents
893894
else:
894-
#image = self.transformer.nerf(
895-
# pixels,
896-
# latents,
897-
# self.transformer.config.patch_size,
898-
# num_patches,
899-
#)
900895
#image = self._unpack_latents(image, height, width)
901896
image = self.image_processor.postprocess(latents, output_type=output_type)
902897

0 commit comments

Comments
 (0)