@@ -152,8 +152,9 @@ def prepare_latents(
152152 jax .debug .print ("condition stats: mask_mean={mm}, latent_mean={lm}" ,
153153 mm = jnp .mean (condition [..., 0 ]),
154154 lm = jnp .mean (condition [..., 1 :]))
155+ jax .debug .print ("condition latent std={std}" , std = jnp .std (condition [..., 1 :]))
155156
156- return latents , condition , None
157+ return latents , condition , first_frame_mask
157158
158159
159160 def __call__ (
@@ -212,6 +213,12 @@ def __call__(
212213 last_image = last_image_tensor ,
213214 num_videos_per_prompt = num_videos_per_prompt ,
214215 )
216+ if first_frame_mask is not None :
217+ jax .debug .print ("FIRST FRAME MASK stats: min={mn}, max={mx}, mean={mean}, shape={shape}" ,
218+ mn = jnp .min (first_frame_mask ),
219+ mx = jnp .max (first_frame_mask ),
220+ mean = jnp .mean (first_frame_mask ),
221+ shape = first_frame_mask .shape )
215222
216223 scheduler_state = self .scheduler .set_timesteps (
217224 self .scheduler_state , num_inference_steps = num_inference_steps , shape = latents .shape
@@ -291,25 +298,43 @@ def run_inference_2_1_i2v(
291298 if do_classifier_free_guidance :
292299 prompt_embeds = jnp .concatenate ([prompt_embeds , negative_prompt_embeds ], axis = 0 )
293300 image_embeds = jnp .concatenate ([image_embeds , image_embeds ], axis = 0 )
294- condition = jnp . concatenate ([ condition ] * 2 )
295- if first_frame_mask is not None :
301+ if expand_timesteps :
302+ condition = jnp . concatenate ([ condition ] * 2 )
296303 first_frame_mask = jnp .concatenate ([first_frame_mask ] * 2 )
304+ else :
305+ condition = jnp .concatenate ([condition ] * 2 )
297306
298307
299308 def loop_body (step , vals ):
300309 latents , scheduler_state , rng = vals
301310 original_dtype = latents .dtype
302311 rng , timestep_rng = jax .random .split (rng )
303312 t = jnp .array (scheduler_state .timesteps , dtype = jnp .int32 )[step ]
313+ jax .debug .print ("Step {s}: timestep={t}" , s = step , t = t )
304314
305315 latents_input = latents
306316 if do_classifier_free_guidance :
307317 latents_input = jnp .concatenate ([latents , latents ], axis = 0 )
318+ jax .debug .print ("Step{s}: latents_input stats min={mn}, max={mx}, mean={mean}, std={std}" ,
319+ s = step ,
320+ mn = jnp .min (latents_input ),
321+ mx = jnp .max (latents_input ),
322+ mean = jnp .mean (latents_input ),
323+ std = jnp .std (latents_input ))
308324
309325 latent_model_input = jnp .concatenate ([latents_input , condition ], axis = - 1 )
310326 timestep = jnp .broadcast_to (t , latents_input .shape [0 ])
311327 latent_model_input = jnp .transpose (latent_model_input , (0 , 4 , 1 , 2 , 3 ))
312328
329+ jax .debug .print ("Step {s}: latent_model_input shape: {shape}" ,
330+ s = step ,
331+ shape = latent_model_input .shape )
332+
333+ channel_energy = jnp .sum (latent_model_input * latent_model_input ,axis = (0 ,2 ,3 ,4 ))
334+ jax .debug .print ("Step {s}: channel energy first 10={ce}" ,
335+ s = step ,
336+ ce = channel_energy [:10 ])
337+
313338 prompt_embeds_input = prompt_embeds
314339 image_embeds_input = image_embeds
315340
@@ -322,19 +347,28 @@ def loop_body(step, vals):
322347 encoder_hidden_states_image = image_embeds_input ,
323348 )
324349 noise_pred = jnp .transpose (noise_pred , (0 , 2 , 3 , 4 , 1 ))
350+ jax .debug .print ("Step {s}: noise_pred stats min={mn}, max={mx}, mean={mean}, std={std}" ,
351+ s = step ,
352+ mn = jnp .min (noise_pred ),
353+ mx = jnp .max (noise_pred ),
354+ mean = jnp .mean (noise_pred ),
355+ std = jnp .std (noise_pred ))
325356 jax .debug .print ("Step {s}: latents_prev std={std}, mean={mean}" ,
326357 s = step ,
327358 std = jnp .std (latents ),
328359 mean = jnp .mean (latents ))
360+ jax .debug .print ("first_frame_mask shape:" , first_frame_mask .shape if first_frame_mask is not None else (- 1 ,))
361+ jax .debug .print ("first_frame_mask unique values:" , jnp .unique (first_frame_mask ))
362+ jax .debug .print ("condition shape:" , condition .shape )
363+ jax .debug .print ("condition stats:" , jnp .min (condition ), jnp .max (condition ), jnp .mean (condition ))
364+ if first_frame_mask is not None :
365+ clean_latents = condition [..., 4 :]
366+ latents = first_frame_mask * clean_latents + (1 - first_frame_mask ) * latents
329367 latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
330368 jax .debug .print ("Step {s}: latents_next std={std}, mean={mean}" ,
331369 s = step ,
332370 std = jnp .std (latents ),
333371 mean = jnp .mean (latents ))
334- # Apply first frame preservation
335- if first_frame_mask is not None :
336- clean_latents = condition [..., 4 :]
337- latents = first_frame_mask * clean_latents + (1 - first_frame_mask ) * latents
338372 latents = latents .astype (original_dtype )
339373 return latents , scheduler_state , rng
340374
@@ -346,4 +380,4 @@ def loop_body(step, vals):
346380 lmean = jnp .mean (latents ),
347381 lstd = jnp .std (latents ))
348382 max_logging .log ("Finished fori_loop." )
349- return latents
383+ return latents
0 commit comments