Skip to content

Conversation

@martinarroyo
Copy link
Collaborator

@martinarroyo martinarroyo commented Jan 7, 2026

This brings in the VACE model taken from diffusers, trying to comply as much as possible with the conventions upstream.

@github-actions
Copy link

github-actions bot commented Jan 7, 2026

img_height, img_width = image.shape[-2:]
scale = min(image_size[0] / img_height, image_size[1] / img_width)
new_height, new_width = int(img_height * scale), int(img_width * scale)
# TODO: should we use jax/TF-based resizing here?
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know what you think about this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it is necessary right now. Wouldn't it require casting to numpy for running the torch function below?

Copy link
Collaborator Author

@martinarroyo martinarroyo Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It worked fine so far, video_processor.preprocess returns a Torch tensor already, but I will keep an eye just in case

@martinarroyo martinarroyo force-pushed the martinarroyo-wan2.1-vace branch from e9086b5 to db2a559 Compare January 8, 2026 08:19
@martinarroyo martinarroyo changed the title [WIP, WAN] Adds VACE conditioning to WAN 2.1 [WAN] Adds VACE conditioning to WAN 2.1 Jan 8, 2026
@martinarroyo martinarroyo marked this pull request as ready for review January 8, 2026 14:04
blocks.append(block)
self.blocks = blocks

if scan_layers:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't looked too deeply at the vace architecture, but why is it that scan cannot be used?

Copy link
Collaborator Author

@martinarroyo martinarroyo Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure how to do it, because the nnx.vmap decorator does not differentiate between each separate layer. In fact, it simply creates a tensor with an extra axis, so passing parameters like apply_input_projection=vace_block_id == 0 is to my knowledge not feasible. I think the nnx.scan function later can probably be used in this context if we keep some new variable that acts as counter to identify the current iteration, but I was not able to work around the limitation in the initialization (and this parameter cannot be passed later because it conditions how the layer is initialized). I would like to support this though, in case you have any ideas I would appreciate it!

I can also try to have the Wan layers vmap-initialized and skip it for the Vace ones.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok sounds good we can add it later.

img_height, img_width = image.shape[-2:]
scale = min(image_size[0] / img_height, image_size[1] / img_width)
new_height, new_width = int(img_height * scale), int(img_width * scale)
# TODO: should we use jax/TF-based resizing here?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it is necessary right now. Wouldn't it require casting to numpy for running the torch function below?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants