@@ -78,7 +78,7 @@ def __call__(
7878 ):
7979 batch_size , channels , height , width = pixels .shape
8080
81- pixels = nn .functional .unfold (pixels , kernel_size = self . patch_size , stride = self . patch_size )
81+ pixels = nn .functional .unfold (pixels , kernel_size = patch_size , stride = patch_size )
8282 pixels = pixels .transpose (1 , 2 )
8383
8484 hidden = latents .reshape (batch_size * num_patches , self .transformer_hidden_size )
@@ -91,7 +91,7 @@ def __call__(
9191 for block in self .blocks :
9292 latents_dct = block (latents_dct , hidden )
9393
94- latents_dct = latents_dct .transpose (1 , 2 ).reshape (batch_size , - 1 ).transpose (1 , 2 )
94+ latents_dct = latents_dct .transpose (1 , 2 ).reshape (batch_size , num_patches , - 1 ).transpose (1 , 2 )
9595 latents_dct = nn .functional .fold (
9696 latents_dct ,
9797 output_size = (height , width ),
@@ -129,11 +129,10 @@ def __call__(self, inputs: torch.Tensor) -> torch.Tensor:
129129 batch , pixels , channels = inputs .shape
130130 patch_size = int (pixels ** 0.5 )
131131 input_dtype = inputs .dtype
132- inputs = inputs .to (dtype = self .embedder [0 ].weight .dtype )
133- dct = self .fetch_pos (patch_size , inputs .device , self .dtype )
134- dct = dct .repeat (batch , 1 , 1 )
132+ dct = self .fetch_pos (patch_size )
133+ dct = dct .repeat (batch , 1 , 1 ).to (dtype = input_dtype , device = inputs .device )
135134 inputs = torch .cat ((inputs , dct ), dim = - 1 )
136- return self .embedder (inputs ). to ( dtype = input_dtype )
135+ return self .embedder (inputs )
137136
138137class NerfGLUBlock (nn .Module ):
139138 def __init__ (self , transformer_hidden_size : int , nerf_hidden_size : int , mlp_ratio , eps ):
@@ -933,7 +932,8 @@ def forward(
933932 If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
934933 `tuple` where the first element is the sample tensor.
935934 """
936- print ("states" , hidden_states .shape , encoder_hidden_states .shape )
935+ print (self .device )
936+ pixels = hidden_states .to (self .device )
937937 if joint_attention_kwargs is not None :
938938 joint_attention_kwargs = joint_attention_kwargs .copy ()
939939 lora_scale = joint_attention_kwargs .pop ("scale" , 1.0 )
@@ -948,23 +948,16 @@ def forward(
948948 logger .warning (
949949 "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
950950 )
951-
952- pixels = nn .functional .unfold (hidden_states , kernel_size = self .config .patch_size , stride = self .config .patch_size )
953- pixels = pixels .transpose (1 , 2 )
954- print ("pixels" , pixels .shape )
955951 hidden_states = self .x_embedder_patch (hidden_states )
956- print ("img_patch:" , hidden_states .shape )
957952 num_patches = hidden_states .shape [2 ] * hidden_states .shape [3 ]
958953 hidden_states = hidden_states .flatten (2 ).transpose (1 , 2 )
959- print (hidden_states .shape )
960954
961955 timestep = timestep .to (hidden_states .dtype ) * 1000
962956
963957 input_vec = self .time_text_embed (timestep )
964958 pooled_temb = self .distilled_guidance_layer (input_vec )
965959
966960 encoder_hidden_states = self .context_embedder (encoder_hidden_states )
967- print (encoder_hidden_states .shape )
968961
969962 if txt_ids .ndim == 3 :
970963 logger .warning (
0 commit comments