Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion diffsynth/core/device/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type, get_device_name
from .npu_compatible_device import IS_NPU_AVAILABLE
from .npu_compatible_device import IS_NPU_AVAILABLE, IS_CUDA_AVAILABLE
2 changes: 2 additions & 0 deletions diffsynth/models/wan_video_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Tuple, Optional
from einops import rearrange
from .wan_video_camera_controller import SimpleAdapter

try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
Expand Down Expand Up @@ -92,6 +93,7 @@ def rope_apply(x, freqs, num_heads):
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
x.shape[0], x.shape[1], x.shape[2], -1, 2))
freqs = freqs.to(torch.complex64) if freqs.device == "npu" else freqs
x_out = torch.view_as_real(x_out * freqs).flatten(2)
Comment on lines 94 to 97
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation converts freqs to complex64 on NPU, but x_out remains complex128 (as it's created from x.to(torch.float64)). During the multiplication x_out * freqs, PyTorch will promote freqs back to complex128, which means the operation will still use complex128 and not achieve the intended performance improvement on NPU. To fix this, x_out should also be complex64 when running on NPU. The suggested change ensures both tensors are of the correct data type for the operation.

Suggested change
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
x.shape[0], x.shape[1], x.shape[2], -1, 2))
freqs = freqs.to(torch.complex64) if IS_NPU_AVAILABLE else freqs
x_out = torch.view_as_real(x_out * freqs).flatten(2)
x_out = torch.view_as_complex(x.to(torch.float32 if IS_NPU_AVAILABLE else torch.float64).reshape(
x.shape[0], x.shape[1], x.shape[2], -1, 2))
freqs = freqs.to(x_out.dtype)
x_out = torch.view_as_real(x_out * freqs).flatten(2)

return x_out.to(x.dtype)

Expand Down
2 changes: 1 addition & 1 deletion diffsynth/utils/xfuser/xdit_context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def rope_apply(x, freqs, num_heads):
sp_rank = get_sequence_parallel_rank()
freqs = pad_freqs(freqs, s_per_rank * sp_size)
freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]

freqs_rank = freqs_rank.to(torch.complex64) if freqs_rank.device == "npu" else freqs_rank
x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
return x_out.to(x.dtype)

Expand Down
2 changes: 1 addition & 1 deletion docs/en/Pipeline_Usage/GPU_support.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ save_video(video, "video.mp4", fps=15, quality=5)
```

### Training
NPU startup script samples have been added for each type of model,the scripts are stored in the `examples/xxx/special/npu_scripts`, for example `examples/wanvideo/model_training/special/npu_scripts/Wan2.2-T2V-A14B-NPU.sh`.
NPU startup script samples have been added for each type of model,the scripts are stored in the `examples/xxx/special/npu_training`, for example `examples/wanvideo/model_training/special/npu_training/Wan2.2-T2V-A14B-NPU.sh`.

In the NPU training scripts, NPU specific environment variables that can optimize performance have been added, and relevant parameters have been enabled for specific models.

Expand Down
2 changes: 1 addition & 1 deletion docs/zh/Pipeline_Usage/GPU_support.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ save_video(video, "video.mp4", fps=15, quality=5)
```

### 训练
当前已为每类模型添加NPU的启动脚本样例,脚本存放在`examples/xxx/special/npu_scripts`目录下,例如 `examples/wanvideo/model_training/special/npu_scripts/Wan2.2-T2V-A14B-NPU.sh`。
当前已为每类模型添加NPU的启动脚本样例,脚本存放在`examples/xxx/special/npu_training`目录下,例如 `examples/wanvideo/model_training/special/npu_training/Wan2.2-T2V-A14B-NPU.sh`。

在NPU训练脚本中,添加了可以优化性能的NPU特有环境变量,并针对特定模型开启了相关参数。

Expand Down