Skip to content

Commit 5de4c30

Browse files
committed
reverted
1 parent 0f55e85 commit 5de4c30

1 file changed

Lines changed: 42 additions & 8 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,9 @@ def prepare_latents(
152152
jax.debug.print("condition stats: mask_mean={mm}, latent_mean={lm}",
153153
mm=jnp.mean(condition[..., 0]),
154154
lm=jnp.mean(condition[..., 1:]))
155+
jax.debug.print("condition latent std={std}", std=jnp.std(condition[..., 1:]))
155156

156-
return latents, condition, None
157+
return latents, condition, first_frame_mask
157158

158159

159160
def __call__(
@@ -212,6 +213,12 @@ def __call__(
212213
last_image=last_image_tensor,
213214
num_videos_per_prompt=num_videos_per_prompt,
214215
)
216+
if first_frame_mask is not None:
217+
jax.debug.print("FIRST FRAME MASK stats: min={mn}, max={mx}, mean={mean}, shape={shape}",
218+
mn=jnp.min(first_frame_mask),
219+
mx=jnp.max(first_frame_mask),
220+
mean=jnp.mean(first_frame_mask),
221+
shape=first_frame_mask.shape)
215222

216223
scheduler_state = self.scheduler.set_timesteps(
217224
self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape
@@ -291,25 +298,43 @@ def run_inference_2_1_i2v(
291298
if do_classifier_free_guidance:
292299
prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
293300
image_embeds = jnp.concatenate([image_embeds, image_embeds], axis=0)
294-
condition = jnp.concatenate([condition] * 2)
295-
if first_frame_mask is not None:
301+
if expand_timesteps:
302+
condition = jnp.concatenate([condition] * 2)
296303
first_frame_mask = jnp.concatenate([first_frame_mask] * 2)
304+
else:
305+
condition = jnp.concatenate([condition] * 2)
297306

298307

299308
def loop_body(step, vals):
300309
latents, scheduler_state, rng = vals
301310
original_dtype = latents.dtype
302311
rng, timestep_rng = jax.random.split(rng)
303312
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
313+
jax.debug.print("Step {s}: timestep={t}", s=step, t=t)
304314

305315
latents_input = latents
306316
if do_classifier_free_guidance:
307317
latents_input = jnp.concatenate([latents, latents], axis=0)
318+
jax.debug.print("Step{s}: latents_input stats min={mn}, max={mx}, mean={mean}, std={std}",
319+
s=step,
320+
mn=jnp.min(latents_input),
321+
mx=jnp.max(latents_input),
322+
mean=jnp.mean(latents_input),
323+
std=jnp.std(latents_input))
308324

309325
latent_model_input = jnp.concatenate([latents_input, condition], axis=-1)
310326
timestep = jnp.broadcast_to(t, latents_input.shape[0])
311327
latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3))
312328

329+
jax.debug.print("Step {s}: latent_model_input shape: {shape}",
330+
s=step,
331+
shape=latent_model_input.shape)
332+
333+
channel_energy = jnp.sum(latent_model_input*latent_model_input,axis=(0,2,3,4))
334+
jax.debug.print("Step {s}: channel energy first 10={ce}",
335+
s=step,
336+
ce=channel_energy[:10])
337+
313338
prompt_embeds_input = prompt_embeds
314339
image_embeds_input = image_embeds
315340

@@ -322,19 +347,28 @@ def loop_body(step, vals):
322347
encoder_hidden_states_image=image_embeds_input,
323348
)
324349
noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1))
350+
jax.debug.print("Step {s}: noise_pred stats min={mn}, max={mx}, mean={mean}, std={std}",
351+
s=step,
352+
mn=jnp.min(noise_pred),
353+
mx=jnp.max(noise_pred),
354+
mean=jnp.mean(noise_pred),
355+
std=jnp.std(noise_pred))
325356
jax.debug.print("Step {s}: latents_prev std={std}, mean={mean}",
326357
s=step,
327358
std=jnp.std(latents),
328359
mean=jnp.mean(latents))
360+
jax.debug.print("first_frame_mask shape:", first_frame_mask.shape if first_frame_mask is not None else (-1,))
361+
jax.debug.print("first_frame_mask unique values:", jnp.unique(first_frame_mask))
362+
jax.debug.print("condition shape:", condition.shape)
363+
jax.debug.print("condition stats:", jnp.min(condition), jnp.max(condition), jnp.mean(condition))
364+
if first_frame_mask is not None:
365+
clean_latents = condition[..., 4:]
366+
latents = first_frame_mask * clean_latents + (1 - first_frame_mask) * latents
329367
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
330368
jax.debug.print("Step {s}: latents_next std={std}, mean={mean}",
331369
s=step,
332370
std=jnp.std(latents),
333371
mean=jnp.mean(latents))
334-
# Apply first frame preservation
335-
if first_frame_mask is not None:
336-
clean_latents = condition[..., 4:]
337-
latents = first_frame_mask * clean_latents + (1 - first_frame_mask) * latents
338372
latents = latents.astype(original_dtype)
339373
return latents, scheduler_state, rng
340374

@@ -346,4 +380,4 @@ def loop_body(step, vals):
346380
lmean=jnp.mean(latents),
347381
lstd=jnp.std(latents))
348382
max_logging.log("Finished fori_loop.")
349-
return latents
383+
return latents

0 commit comments

Comments
 (0)