Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/maxdiffusion/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
AxisNames = tuple[str, ...]
# Physical axis names for device meshes.
DATA = "data"
FSDP = "fsdp"
FSDP = "fsdp_tpu"
TENSOR = "tensor"
# Logical axis names for model parameters and activations.
BATCH = "activation_batch"
Expand Down
12 changes: 6 additions & 6 deletions src/maxdiffusion/configs/base14.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ skip_jax_distributed_system: False
base_output_directory: ""

# Parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
mesh_axes: ['data', 'fsdp_tpu', 'tensor']

# batch : batch dimension of data and activations
# hidden :
Expand All @@ -122,16 +122,16 @@ mesh_axes: ['data', 'fsdp', 'tensor']
# conv_out : conv.shape[-1] weight
logical_axis_rules: [
['batch', 'data'],
['activation_batch', ['data','fsdp']],
['activation_batch', ['data','fsdp_tpu']],
['activation_heads', 'tensor'],
['activation_kv', 'tensor'],
['embed','fsdp'],
['embed','fsdp_tpu'],
['heads', 'tensor'],
['conv_batch', ['data','fsdp']],
['conv_batch', ['data','fsdp_tpu']],
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
['conv_out', 'fsdp_tpu'],
]
data_sharding: [['data', 'fsdp', 'tensor']]
data_sharding: [['data', 'fsdp_tpu', 'tensor']]

# One axis for each parallelism type may hold a placeholder (-1)
# value to auto-shard based on available slices and devices.
Expand Down
12 changes: 6 additions & 6 deletions src/maxdiffusion/configs/base21.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ skip_jax_distributed_system: False
base_output_directory: ""

# Parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
mesh_axes: ['data', 'fsdp_tpu', 'tensor']

# batch : batch dimension of data and activations
# hidden :
Expand All @@ -124,16 +124,16 @@ mesh_axes: ['data', 'fsdp', 'tensor']
# conv_out : conv.shape[-1] weight
logical_axis_rules: [
['batch', 'data'],
['activation_batch', ['data','fsdp']],
['activation_batch', ['data','fsdp_tpu']],
['activation_heads', 'tensor'],
['activation_kv', 'tensor'],
['embed','fsdp'],
['embed','fsdp_tpu'],
['heads', 'tensor'],
['conv_batch', ['data','fsdp']],
['conv_batch', ['data','fsdp_tpu']],
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
['conv_out', 'fsdp_tpu'],
]
data_sharding: [['data', 'fsdp', 'tensor']]
data_sharding: [['data', 'fsdp_tpu', 'tensor']]

# One axis for each parallelism type may hold a placeholder (-1)
# value to auto-shard based on available slices and devices.
Expand Down
12 changes: 6 additions & 6 deletions src/maxdiffusion/configs/base_2_base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ skip_jax_distributed_system: False
base_output_directory: ""

# Parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
mesh_axes: ['data', 'fsdp_tpu', 'tensor']

# batch : batch dimension of data and activations
# hidden :
Expand All @@ -137,16 +137,16 @@ mesh_axes: ['data', 'fsdp', 'tensor']
# conv_out : conv.shape[-1] weight
logical_axis_rules: [
['batch', 'data'],
['activation_batch', ['data','fsdp']],
['activation_batch', ['data','fsdp_tpu']],
['activation_heads', 'tensor'],
['activation_kv', 'tensor'],
['embed','fsdp'],
['embed','fsdp_tpu'],
['heads', 'tensor'],
['conv_batch', ['data','fsdp']],
['conv_batch', ['data','fsdp_tpu']],
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
['conv_out', 'fsdp_tpu'],
]
data_sharding: [['data', 'fsdp', 'tensor']]
data_sharding: [['data', 'fsdp_tpu', 'tensor']]

# One axis for each parallelism type may hold a placeholder (-1)
# value to auto-shard based on available slices and devices.
Expand Down
12 changes: 6 additions & 6 deletions src/maxdiffusion/configs/base_flux_dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
skip_jax_distributed_system: False

# Parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
mesh_axes: ['data', 'fsdp_tpu', 'tensor']

# batch : batch dimension of data and activations
# hidden :
Expand All @@ -148,17 +148,17 @@ mesh_axes: ['data', 'fsdp', 'tensor']
# conv_out : conv.shape[-1] weight
logical_axis_rules: [
['batch', 'data'],
['activation_batch', ['data','fsdp']],
['activation_batch', ['data','fsdp_tpu']],
['activation_heads', 'tensor'],
['activation_kv', 'tensor'],
['mlp','tensor'],
['embed','fsdp'],
['embed','fsdp_tpu'],
['heads', 'tensor'],
['conv_batch', ['data','fsdp']],
['conv_batch', ['data','fsdp_tpu']],
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
['conv_out', 'fsdp_tpu'],
]
data_sharding: [['data', 'fsdp', 'tensor']]
data_sharding: [['data', 'fsdp_tpu', 'tensor']]

# One axis for each parallelism type may hold a placeholder (-1)
# value to auto-shard based on available slices and devices.
Expand Down
14 changes: 7 additions & 7 deletions src/maxdiffusion/configs/base_flux_dev_multi_res.yml
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
skip_jax_distributed_system: False

# Parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
mesh_axes: ['data', 'fsdp_tpu', 'tensor']

# batch : batch dimension of data and activations
# hidden :
Expand All @@ -148,17 +148,17 @@ mesh_axes: ['data', 'fsdp', 'tensor']
# conv_out : conv.shape[-1] weight
logical_axis_rules: [
['batch', 'data'],
['activation_batch', ['data','fsdp']],
['activation_batch', ['data','fsdp_tpu']],
['activation_heads', 'tensor'],
['activation_kv', 'tensor'],
# ['embed','fsdp'],
['mlp',['fsdp','tensor']],
# ['embed','fsdp_tpu'],
['mlp',['fsdp_tpu','tensor']],
['heads', 'tensor'],
['conv_batch', ['data','fsdp']],
['conv_batch', ['data','fsdp_tpu']],
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
['conv_out', 'fsdp_tpu'],
]
data_sharding: [['data', 'fsdp', 'tensor']]
data_sharding: [['data', 'fsdp_tpu', 'tensor']]

# One axis for each parallelism type may hold a placeholder (-1)
# value to auto-shard based on available slices and devices.
Expand Down
12 changes: 6 additions & 6 deletions src/maxdiffusion/configs/base_flux_schnell.yml
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
skip_jax_distributed_system: False

# Parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
mesh_axes: ['data', 'fsdp_tpu', 'tensor']

# batch : batch dimension of data and activations
# hidden :
Expand All @@ -156,17 +156,17 @@ mesh_axes: ['data', 'fsdp', 'tensor']
# conv_out : conv.shape[-1] weight
logical_axis_rules: [
['batch', 'data'],
['activation_batch', ['data','fsdp']],
['activation_batch', ['data','fsdp_tpu']],
['activation_heads', 'tensor'],
['activation_kv', 'tensor'],
['mlp','tensor'],
['embed','fsdp'],
['embed','fsdp_tpu'],
['heads', 'tensor'],
['conv_batch', ['data','fsdp']],
['conv_batch', ['data','fsdp_tpu']],
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
['conv_out', 'fsdp_tpu'],
]
data_sharding: [['data', 'fsdp', 'tensor']]
data_sharding: [['data', 'fsdp_tpu', 'tensor']]

# One axis for each parallelism type may hold a placeholder (-1)
# value to auto-shard based on available slices and devices.
Expand Down
26 changes: 14 additions & 12 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
skip_jax_distributed_system: False

# Parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
mesh_axes: ['data', 'tensor', 'fsdp_tpu', 'fsdp_gpu']

# batch : batch dimension of data and activations
# hidden :
Expand All @@ -163,32 +163,34 @@ mesh_axes: ['data', 'fsdp', 'tensor']
# conv_in : conv.shape[2] weight
# conv_out : conv.shape[-1] weight
logical_axis_rules: [
['batch', 'data'],
['activation_batch', 'data'],
['activation_self_attn_heads', ['fsdp', 'tensor']],
['activation_cross_attn_q_length', ['fsdp', 'tensor']],
['activation_length', 'fsdp'],
['batch', ['data', 'fsdp_gpu']],
['activation_batch', ['data', 'fsdp_gpu']],
['activation_length', 'fsdp_tpu'],
['activation_self_attn_heads', ['fsdp_tpu', 'tensor']],
['activation_cross_attn_q_length', ['fsdp_tpu', 'tensor']],
['activation_heads', 'tensor'],
['mlp','tensor'],
['embed','fsdp'],
['embed', ['fsdp_tpu', 'fsdp_gpu']],
['heads', 'tensor'],
['norm', 'tensor'],
['conv_batch', ['data','fsdp']],
['conv_batch', ['data', 'fsdp_tpu', 'fsdp_gpu']],
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
['conv_out', 'fsdp_tpu'],
]
data_sharding: [['data', 'fsdp', 'tensor']]
data_sharding: [['data', 'tensor', 'fsdp_tpu', 'fsdp_gpu']]

# One axis for each parallelism type may hold a placeholder (-1)
# value to auto-shard based on available slices and devices.
# By default, product of the DCN axes should equal number of slices
# and product of the ICI axes should equal number of devices per slice.
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: -1
dcn_tensor_parallelism: 1
dcn_fsdp_tpu_parallelism: -1
dcn_fsdp_gpu_parallelism: 1 # recommended DCN axis to be auto-sharded
ici_data_parallelism: 1
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
ici_tensor_parallelism: 1
ici_fsdp_tpu_parallelism: -1
ici_fsdp_gpu_parallelism: 1 # recommended ICI axis to be auto-sharded

allow_split_physical_axes: False

Expand Down
12 changes: 6 additions & 6 deletions src/maxdiffusion/configs/base_wan_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
skip_jax_distributed_system: False

# Parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
mesh_axes: ['data', 'fsdp_tpu', 'tensor']

# batch : batch dimension of data and activations
# hidden :
Expand All @@ -154,18 +154,18 @@ mesh_axes: ['data', 'fsdp', 'tensor']
logical_axis_rules: [
['batch', 'data'],
['activation_batch', 'data'],
['activation_length', 'fsdp'],
['activation_length', 'fsdp_tpu'],

['activation_heads', 'tensor'],
['mlp','tensor'],
['embed','fsdp'],
['embed','fsdp_tpu'],
['heads', 'tensor'],
['norm', 'tensor'],
['conv_batch', ['data','fsdp']],
['conv_batch', ['data','fsdp_tpu']],
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
['conv_out', 'fsdp_tpu'],
]
data_sharding: [['data', 'fsdp', 'tensor']]
data_sharding: [['data', 'fsdp_tpu', 'tensor']]

# One axis for each parallelism type may hold a placeholder (-1)
# value to auto-shard based on available slices and devices.
Expand Down
12 changes: 6 additions & 6 deletions src/maxdiffusion/configs/base_xl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ base_output_directory: ""
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
skip_jax_distributed_system: False
# Parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
mesh_axes: ['data', 'fsdp_tpu', 'tensor']

# batch : batch dimension of data and activations
# hidden :
Expand All @@ -122,16 +122,16 @@ mesh_axes: ['data', 'fsdp', 'tensor']
# conv_out : conv.shape[-1] weight
logical_axis_rules: [
['batch', 'data'],
['activation_batch', ['data','fsdp']],
['activation_batch', ['data','fsdp_tpu']],
['activation_heads', 'tensor'],
['activation_kv', 'tensor'],
['embed','fsdp'],
['embed','fsdp_tpu'],
['heads', 'tensor'],
['conv_batch', ['data','fsdp']],
['conv_batch', ['data','fsdp_tpu']],
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
['conv_out', 'fsdp_tpu'],
]
data_sharding: [['data', 'fsdp', 'tensor']]
data_sharding: [['data', 'fsdp_tpu', 'tensor']]

# One axis for each parallelism type may hold a placeholder (-1)
# value to auto-shard based on available slices and devices.
Expand Down
12 changes: 6 additions & 6 deletions src/maxdiffusion/configs/base_xl_lightning.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ skip_jax_distributed_system: False
base_output_directory: ""

# Parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
mesh_axes: ['data', 'fsdp_tpu', 'tensor']

# batch : batch dimension of data and activations
# hidden :
Expand All @@ -102,16 +102,16 @@ mesh_axes: ['data', 'fsdp', 'tensor']
# conv_out : conv.shape[-1] weight
logical_axis_rules: [
['batch', 'data'],
['activation_batch', ['data','fsdp']],
['activation_batch', ['data','fsdp_tpu']],
['activation_heads', 'tensor'],
['activation_kv', 'tensor'],
['embed','fsdp'],
['embed','fsdp_tpu'],
['heads', 'tensor'],
['conv_batch', ['data','fsdp']],
['conv_batch', ['data','fsdp_tpu']],
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
['conv_out', 'fsdp_tpu'],
]
data_sharding: [['data', 'fsdp', 'tensor']]
data_sharding: [['data', 'fsdp_tpu', 'tensor']]

# One axis for each parallelism type may hold a placeholder (-1)
# value to auto-shard based on available slices and devices.
Expand Down
16 changes: 8 additions & 8 deletions src/maxdiffusion/configs/ltx_video.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,22 +62,22 @@ second_pass:
cfg_star_rescale: True

#parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
mesh_axes: ['data', 'fsdp_tpu', 'tensor']
logical_axis_rules: [
['batch', 'data'],
['activation_heads', 'fsdp'],
['activation_heads', 'fsdp_tpu'],
['activation_batch', 'data'],
['activation_kv', 'tensor'],
['mlp','tensor'],
['embed','fsdp'],
['embed','fsdp_tpu'],
['heads', 'tensor'],
['norm', 'fsdp'],
['conv_batch', ['data','fsdp']],
['norm', 'fsdp_tpu'],
['conv_batch', ['data','fsdp_tpu']],
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
['conv_in', 'fsdp']
['conv_out', 'fsdp_tpu'],
['conv_in', 'fsdp_tpu']
]
data_sharding: [['data', 'fsdp', 'tensor']]
data_sharding: [['data', 'fsdp_tpu', 'tensor']]
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: -1
dcn_tensor_parallelism: 1
Expand Down
Loading