Skip to content

Commit ec2f6bb

Browse files
committed
feat: add KV caching support for Wan and VACE models
1 parent 79cd005 commit ec2f6bb

16 files changed

Lines changed: 1027 additions & 158 deletions

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ use_cfg_cache: False
355355
# Batch positive and negative prompts in text encoder to save compute.
356356
use_batched_text_encoder: False
357357

358+
use_kv_cache: False
358359
use_magcache: False
359360
magcache_thresh: 0.12
360361
magcache_K: 2

src/maxdiffusion/configs/base_wan_1_3b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ flow_shift: 3.0
301301

302302
# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only)
303303
use_cfg_cache: False
304+
use_kv_cache: False
304305

305306
# Batch positive and negative prompts in text encoder to save compute.
306307
use_batched_text_encoder: False

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ use_cfg_cache: False
331331
# Batch positive and negative prompts in text encoder to save compute.
332332
use_batched_text_encoder: False
333333

334-
334+
use_kv_cache: False
335335
# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208) — skip forward pass
336336
# when predicted output change (based on accumulated latent/timestep drift) is small
337337
use_sen_cache: False

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ use_cfg_cache: False
318318
# Batch positive and negative prompts in text encoder to save compute.
319319
use_batched_text_encoder: False
320320

321-
321+
use_kv_cache: False
322322
# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208)
323323
use_sen_cache: False
324324
use_magcache: False

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ use_cfg_cache: False
330330
# Batch positive and negative prompts in text encoder to save compute.
331331
use_batched_text_encoder: False
332332

333-
333+
use_kv_cache: False
334334
# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208)
335335
use_sen_cache: False
336336

src/maxdiffusion/generate_wan.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
104104
magcache_thresh=config.magcache_thresh,
105105
magcache_K=config.magcache_K,
106106
retention_ratio=config.retention_ratio,
107+
use_kv_cache=config.use_kv_cache,
107108
)
108109
elif model_key == WAN2_2:
109110
return pipeline(
@@ -118,6 +119,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
118119
guidance_scale_high=config.guidance_scale_high,
119120
use_cfg_cache=config.use_cfg_cache,
120121
use_sen_cache=config.use_sen_cache,
122+
use_kv_cache=config.use_kv_cache,
121123
)
122124
else:
123125
raise ValueError(f"Unsupported model_name for I2V in config: {model_key}")
@@ -136,6 +138,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
136138
magcache_thresh=config.magcache_thresh,
137139
magcache_K=config.magcache_K,
138140
retention_ratio=config.retention_ratio,
141+
use_kv_cache=config.use_kv_cache,
139142
)
140143
elif model_key == WAN2_2:
141144
return pipeline(
@@ -149,6 +152,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
149152
guidance_scale_high=config.guidance_scale_high,
150153
use_cfg_cache=config.use_cfg_cache,
151154
use_sen_cache=config.use_sen_cache,
155+
use_kv_cache=getattr(config, "use_kv_cache", False),
152156
)
153157
else:
154158
raise ValueError(f"Unsupported model_name for T2Vin config: {model_key}")

0 commit comments

Comments
 (0)