diff --git a/README.md b/README.md index b5720a6a..60eb2875 100644 --- a/README.md +++ b/README.md @@ -254,6 +254,9 @@ After installation completes, run the training script. - In Wan2.1, the ici_fsdp_parallelism axis is used for sequence parallelism, the ici_tensor_parallelism axis is used for head parallelism. - You can enable both, keeping in mind that Wan2.1 has 40 heads and 40 must be evenly divisible by ici_tensor_parallelism. - For Sequence parallelism, the code pads the sequence length to evenly divide the sequence. Try out different ici_fsdp_parallelism numbers, but we find 2 and 4 to be the best right now. + - For use on GPU it is recommended to enable the cudnn_te_flash attention kernel for optimal performance. + - Best performance is achieved with the use of batch parallelism, which can be enabled by using the ici_fsdp_batch_parallelism axis. Note that this parallelism strategy does not support fractional batch sizes. + - ici_fsdp_batch_parallelism and ici_fsdp_parallelism can be combined to allow for fractional batch sizes. However, padding is not currently supported for the cudnn_te_flash attention kernel and it is therefore required that the sequence length is divisible by the number of devices in the ici_fsdp_parallelism axis. You should eventually see a training run as: diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 1b647424..7be1ec38 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -148,7 +148,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' skip_jax_distributed_system: False # Parallelism -mesh_axes: ['data', 'fsdp', 'tensor'] +mesh_axes: ['data', 'fsdp_batch', 'fsdp', 'tensor'] # batch : batch dimension of data and activations # hidden : @@ -163,30 +163,32 @@ mesh_axes: ['data', 'fsdp', 'tensor'] # conv_in : conv.shape[2] weight # conv_out : conv.shape[-1] weight logical_axis_rules: [ - ['batch', 'data'], - ['activation_batch', 'data'], + ['batch', ['data', 'fsdp_batch']], + ['activation_batch', ['data', 'fsdp_batch']], + ['activation_length', 'fsdp'], ['activation_self_attn_heads', ['fsdp', 'tensor']], ['activation_cross_attn_q_length', ['fsdp', 'tensor']], - ['activation_length', 'fsdp'], ['activation_heads', 'tensor'], ['mlp','tensor'], - ['embed','fsdp'], + ['embed', ['fsdp', 'fsdp_batch']], ['heads', 'tensor'], ['norm', 'tensor'], - ['conv_batch', ['data','fsdp']], + ['conv_batch', ['data', 'fsdp', 'fsdp_batch']], ['out_channels', 'tensor'], ['conv_out', 'fsdp'], ] -data_sharding: [['data', 'fsdp', 'tensor']] +data_sharding: [['data', 'fsdp_batch', 'fsdp', 'tensor']] # One axis for each parallelism type may hold a placeholder (-1) # value to auto-shard based on available slices and devices. # By default, product of the DCN axes should equal number of slices # and product of the ICI axes should equal number of devices per slice. dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded +dcn_fsdp_batch_parallelism: 1 dcn_fsdp_parallelism: -1 dcn_tensor_parallelism: 1 ici_data_parallelism: 1 +ici_fsdp_batch_parallelism: 1 ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index e3365e96..65a058ec 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -209,7 +209,10 @@ def run(config, pipeline=None, filename_prefix=""): def main(argv: Sequence[str]) -> None: pyconfig.initialize(argv) - flax.config.update("flax_always_shard_variable", False) + try: + flax.config.update("flax_always_shard_variable", False) + except: + pass run(pyconfig.config) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 48c6ca44..8e706c87 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -268,17 +268,30 @@ def create_device_mesh(config, devices=None, logging=True): max_logging.log(f"Devices: {devices} (num_devices: {num_devices})") multi_slice_env = num_slices > 1 - - dcn_parallelism = [ - config.dcn_data_parallelism, - config.dcn_fsdp_parallelism, - config.dcn_tensor_parallelism, - ] - ici_parallelism = [ - config.ici_data_parallelism, - config.ici_fsdp_parallelism, - config.ici_tensor_parallelism, - ] + if "dcn_fsdp_batch_parallelism" in config.get_keys(): + dcn_parallelism = [ + config.dcn_data_parallelism, + config.dcn_fsdp_batch_parallelism, + config.dcn_fsdp_parallelism, + config.dcn_tensor_parallelism, + ] + ici_parallelism = [ + config.ici_data_parallelism, + config.ici_fsdp_batch_parallelism, + config.ici_fsdp_parallelism, + config.ici_tensor_parallelism, + ] + else: + dcn_parallelism = [ + config.dcn_data_parallelism, + config.dcn_fsdp_parallelism, + config.dcn_tensor_parallelism, + ] + ici_parallelism = [ + config.ici_data_parallelism, + config.ici_fsdp_parallelism, + config.ici_tensor_parallelism, + ] # Find possible unspecified parallelisms ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, "ICI") diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 218b3b79..7ecddb98 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -78,8 +78,11 @@ def _reshape_data_from_cudnn_flash(tensor): def _reshape_data_for_cudnn_flash(tensor, heads): # reshapes from [b, s, h * d] to [b, s, h, d] (input format to flash format) - batch, seq, heads_and_dim_head = tensor.shape - tensor = tensor.reshape(batch, seq, heads, heads_and_dim_head // heads) + if len(tensor.shape) == 3: + batch, seq, dim_head = tensor.shape + tensor = tensor.reshape(batch, seq, heads, dim_head // heads) + else: + tensor = jnp.transpose(tensor, (0, 2, 1, 3)) return tensor @@ -89,7 +92,8 @@ def _reshape_batch_dim_to_heads(tensor, heads): tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) tensor = jnp.transpose(tensor, (0, 2, 1, 3)) reshaped_tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size) - return jax.lax.with_sharding_constraint(reshaped_tensor, PartitionSpec("data", "fsdp", "tensor")) + axis_names = nn.logical_to_mesh_axes((BATCH, LENGTH, HEAD)) + return jax.lax.with_sharding_constraint(reshaped_tensor, axis_names) def _reshape_heads_to_batch_dim(tensor, heads): @@ -102,8 +106,8 @@ def _reshape_heads_to_batch_dim(tensor, heads): else: batch_size, head_size, seq_len, head_dim = tensor.shape reshaped_tensor = tensor.reshape(batch_size * head_size, seq_len, head_dim) - - return jax.lax.with_sharding_constraint(reshaped_tensor, PartitionSpec("data", "fsdp", "tensor")) + axis_names = nn.logical_to_mesh_axes((BATCH, LENGTH, HEAD)) + return jax.lax.with_sharding_constraint(reshaped_tensor, axis_names) def _reshape_heads_to_head_dim(tensor): @@ -112,7 +116,8 @@ def _reshape_heads_to_head_dim(tensor): b, h, s, d = tensor.shape tensor = jnp.transpose(tensor, axes=[0, 2, 1, 3]) reshaped_tensor = jnp.reshape(tensor, (b, -1, h * d)) - return jax.lax.with_sharding_constraint(reshaped_tensor, PartitionSpec("data", "fsdp", "tensor")) + axis_names = nn.logical_to_mesh_axes((BATCH, LENGTH, HEAD)) + return jax.lax.with_sharding_constraint(reshaped_tensor, axis_names) def _unflatten_heads(tensor, heads): @@ -492,24 +497,12 @@ def _cudnn_flash_attention(query: Array, key: Array, value: Array, heads: int, m key = _reshape_data_for_cudnn_flash(key, heads) value = _reshape_data_for_cudnn_flash(value, heads) - cudnn_flash_axis_names = (BATCH, LENGTH, HEAD, D_KV) - axis_names = nn.logical_to_mesh_axes(cudnn_flash_axis_names) - - query = nn.with_logical_constraint(query, axis_names) - key = nn.with_logical_constraint(key, axis_names) - value = nn.with_logical_constraint(value, axis_names) - - @functools.partial( - shard_map.shard_map, - mesh=mesh, - in_specs=(axis_names, axis_names, axis_names), - out_specs=axis_names, - check_rep=False, - ) - def wrap_flash_attention(query, key, value): - return jax.vmap(dpa_layer)(query, key, value, mask=None) - - out = wrap_flash_attention(query, key, value) + axis_names = nn.logical_to_mesh_axes((BATCH, LENGTH, HEAD, D_KV)) + query = jax.lax.with_sharding_constraint(query, axis_names) + key = jax.lax.with_sharding_constraint(key, axis_names) + value = jax.lax.with_sharding_constraint(value, axis_names) + + out = dpa_layer(query, key, value, mask=None) return _reshape_data_from_cudnn_flash(out) @@ -706,7 +699,24 @@ def __init__( ): self.dpa_layer = None if attention_kernel == "cudnn_flash_te": - raise NotImplementedError(f"{self} has not been tested with {attention_kernel}") + from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error + jax.config.update("jax_use_shardy_partitioner", False) + + dpa_layer = DotProductAttention( + head_dim=dim_head, + num_attention_heads=heads, + num_gqa_groups=heads, + attn_mask_type="no_mask", # 'no_mask', 'padding', 'causal', or 'padding_causal' + attn_bias_type="NO_BIAS", # 'no_bias', 'pre_scale_bias' or 'post_scale_bias' + # attention_dropout=self.dropout_rate, + dropout_rng_name="aqt", + dtype=dtype, + qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD' + scale_factor=scale, + transpose_batch_sequence=False, + ) + variables = {} + self.dpa_layer = functools.partial(dpa_layer.apply, variables) self.mesh = mesh self.scale = scale @@ -769,8 +779,9 @@ def setup(self): self.dpa_layer = None if self.attention_kernel == "cudnn_flash_te": from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error + jax.config.update("jax_use_shardy_partitioner", False) - self.dpa_layer = DotProductAttention( + dpa_layer = DotProductAttention( head_dim=self.dim_head, num_attention_heads=self.heads, num_gqa_groups=self.heads, @@ -784,6 +795,9 @@ def setup(self): scale_factor=self.scale, transpose_batch_sequence=False, ) + variables = {} + self.dpa_layer = functools.partial(dpa_layer.apply, variables) + def apply_attention(self, query: Array, key: Array, value: Array): return _apply_attention( @@ -839,9 +853,6 @@ def __init__( residual_checkpoint_name: str | None = None, enable_jax_named_scopes: bool = False, ): - if attention_kernel == "cudnn_flash_te": - raise NotImplementedError(f"Wan 2.1 has not been tested with {attention_kernel}") - if attention_kernel in {"flash", "cudnn_flash_te"} and mesh is None: raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}") self.dim_head = dim_head @@ -998,8 +1009,9 @@ def __call__( deterministic: bool = True, rngs: nnx.Rngs = None, ) -> jax.Array: - hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor")) - encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", "tensor")) + axis_names = nn.logical_to_mesh_axes((BATCH, LENGTH, HEAD)) + hidden_states = jax.lax.with_sharding_constraint(hidden_states, axis_names) + encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, axis_names) dtype = hidden_states.dtype if encoder_hidden_states is None: encoder_hidden_states = hidden_states diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index 77f35073..179d7402 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -28,7 +28,10 @@ BlockSizes = common_types.BlockSizes CACHE_T = 2 -flax.config.update('flax_always_shard_variable', False) +try: + flax.config.update('flax_always_shard_variable', False) +except: + pass # Helper to ensure kernel_size, stride, padding are tuples of 3 integers def _canonicalize_tuple(x: Union[int, Sequence[int]], rank: int, name: str) -> Tuple[int, ...]: @@ -73,7 +76,7 @@ def __init__( self._depth_padding_before = self._causal_padding[1][0] # 2 * padding_tuple[0] # Set sharding dynamically based on out_channels. - num_fsdp_axis_devices = mesh.device_ids.shape[1] + num_fsdp_axis_devices = mesh.shape["fsdp"] kernel_sharding = (None, None, None, None, None) if out_channels % num_fsdp_axis_devices == 0: kernel_sharding = (None, None, None, None, "conv_out") diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index cb952afa..a432c4d9 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -362,9 +362,11 @@ def __call__( shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split( (self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1 ) - hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor")) + axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_heads")) + hidden_states = jax.lax.with_sharding_constraint(hidden_states, axis_names) hidden_states = checkpoint_name(hidden_states, "hidden_states") - encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", None)) + axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_kv")) + encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, axis_names) # 1. Self-attention with self.conditional_named_scope("self_attn"): @@ -515,7 +517,7 @@ def init_block(rngs): if scan_layers: self.blocks = init_block(rngs) else: - blocks = nnx.List([]) + blocks = [] for _ in range(num_layers): block = WanTransformerBlock( rngs=rngs, @@ -535,7 +537,7 @@ def init_block(rngs): enable_jax_named_scopes=enable_jax_named_scopes, ) blocks.append(block) - self.blocks = blocks + self.blocks = nnx.data(blocks) self.norm_out = FP32LayerNorm(rngs=rngs, dim=inner_dim, eps=eps, elementwise_affine=False) self.proj_out = nnx.Linear( diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py index 5617e3b7..e5f878af 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py @@ -17,6 +17,7 @@ from typing import List, Union, Optional from ...pyconfig import HyperParameters from functools import partial +from contextlib import nullcontext from flax import nnx from flax.linen import partitioning as nn_partitioning import jax @@ -113,8 +114,14 @@ def __call__( scheduler=self.scheduler, scheduler_state=scheduler_state, ) + # Set the TE shard_guard context_manager if using TE cudnn_flash attention + if self.config.attention == "cudnn_flash_te": + from transformer_engine.jax.sharding import global_shard_guard, MeshResource # pytype: disable=import-error + shard_guard = global_shard_guard(MeshResource(cp_resource="fsdp")) + else: + shard_guard = nullcontext() - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules), shard_guard: latents = p_run_inference( graphdef=graphdef, sharded_state=state, diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py index 9efccf90..c82c7cc4 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py @@ -17,6 +17,7 @@ from typing import List, Union, Optional from ...pyconfig import HyperParameters from functools import partial +from contextlib import nullcontext from flax import nnx from flax.linen import partitioning as nn_partitioning import jax @@ -127,8 +128,14 @@ def __call__( scheduler=self.scheduler, scheduler_state=scheduler_state, ) - - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + # Set the TE shard_guard context_manager if using TE cudnn_flash attention + if self.config.attention == "cudnn_flash_te": + from transformer_engine.jax.sharding import global_shard_guard, MeshResource # pytype: disable=import-error + shard_guard = global_shard_guard(MeshResource(cp_resource="fsdp")) + else: + shard_guard = nullcontext() + + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules), shard_guard: latents = p_run_inference( low_noise_graphdef=low_noise_graphdef, low_noise_state=low_noise_state, diff --git a/src/maxdiffusion/train_wan.py b/src/maxdiffusion/train_wan.py index fea15720..cc246797 100644 --- a/src/maxdiffusion/train_wan.py +++ b/src/maxdiffusion/train_wan.py @@ -35,7 +35,10 @@ def main(argv: Sequence[str]) -> None: config = pyconfig.config validate_train_config(config) max_logging.log(f"Found {jax.device_count()} devices.") - flax.config.update("flax_always_shard_variable", False) + try: + flax.config.update("flax_always_shard_variable", False) + except: + pass train(config) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index f23836a5..7a3d1bee 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -20,6 +20,7 @@ import pprint import numpy as np import threading +from contextlib import nullcontext from concurrent.futures import ThreadPoolExecutor import tensorflow as tf import jax.numpy as jnp @@ -210,8 +211,8 @@ def prepare_sample_eval(features): return data_iterator def start_training(self): - - pipeline, opt_state, step = self.checkpointer.load_checkpoint() + with nn_partitioning.axis_rules(self.config.logical_axis_rules): + pipeline, opt_state, step = self.checkpointer.load_checkpoint() restore_args = {} if opt_state and step: restore_args = {"opt_state": opt_state, "step": step} @@ -309,7 +310,8 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data pretty_string = pprint.pformat(state_spec.opt_state, indent=4, width=60) max_logging.log(pretty_string) max_logging.log("------------------------------------------------") - max_utils.delete_pytree(params) + if self.config.hardware != 'gpu': + max_utils.delete_pytree(params) data_shardings = self.get_data_shardings(mesh) eval_data_shardings = self.get_eval_data_shardings(mesh) @@ -364,10 +366,17 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data if self.config.enable_profiler and step == first_profiling_step: max_utils.activate_profiler(self.config) start_step_time = datetime.datetime.now() + + # Designate the context parallel axis for sharding + if self.config.attention == "cudnn_flash_te": + from transformer_engine.jax.sharding import global_shard_guard, MeshResource # pytype: disable=import-error + shard_guard = global_shard_guard(MeshResource(cp_resource="fsdp")) + else: + shard_guard = nullcontext() + next_batch_future = executor.submit(load_next_batch, train_data_iterator, example_batch, self.config) - with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh, nn_partitioning.axis_rules( - self.config.logical_axis_rules - ): + with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh, \ + shard_guard, nn_partitioning.axis_rules(self.config.logical_axis_rules): state, scheduler_state, train_metric, rng = p_train_step(state, example_batch, rng, scheduler_state) train_metric["scalar"]["learning/loss"].block_until_ready() last_step_completion = datetime.datetime.now()