diff --git a/requirements.txt b/requirements.txt index 478359fe..83f69d2d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -36,4 +36,4 @@ aqtp imageio==2.37.0 imageio-ffmpeg==0.6.0 hf_transfer>=0.1.9 -qwix@git+https://github.com/google/qwix.git \ No newline at end of file +qwix==0.1.5 \ No newline at end of file diff --git a/requirements_with_jax_ai_image.txt b/requirements_with_jax_ai_image.txt index c279edb8..c5a65135 100644 --- a/requirements_with_jax_ai_image.txt +++ b/requirements_with_jax_ai_image.txt @@ -37,4 +37,4 @@ aqtp imageio==2.37.0 imageio-ffmpeg==0.6.0 hf_transfer>=0.1.9 -qwix@git+https://github.com/google/qwix.git \ No newline at end of file +qwix==0.1.5 \ No newline at end of file diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 153c225d..557a9dfe 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -302,8 +302,8 @@ def get_fp8_config(cls, config: HyperParameters): act_qtype=jnp.float8_e4m3fn, bwd_qtype=jnp.float8_e5m2, disable_channelwise_axes=True, # per_tensor calibration - weight_calibration_method=config.quantization_calibration_method, - act_calibration_method=config.quantization_calibration_method, + weight_calibration_method="fixed,-224,224", + act_calibration_method="fixed,-224,224", bwd_calibration_method=config.quantization_calibration_method, op_names=("dot_general", "einsum"), ), @@ -313,8 +313,8 @@ def get_fp8_config(cls, config: HyperParameters): act_qtype=jnp.float8_e4m3fn, bwd_qtype=jnp.float8_e4m3fn, disable_channelwise_axes=True, # per_tensor calibration - weight_calibration_method=config.quantization_calibration_method, - act_calibration_method=config.quantization_calibration_method, + weight_calibration_method="fixed,-224,224", + act_calibration_method="fixed,-224,224", bwd_calibration_method=config.quantization_calibration_method, op_names=("conv_general_dilated"), ), diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index 34f0ef64..47c23edd 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -343,8 +343,8 @@ def create_real_rule_instance(*args, **kwargs): act_qtype=jnp.float8_e4m3fn, bwd_qtype=jnp.float8_e5m2, disable_channelwise_axes=True, # per_tensor calibration - weight_calibration_method=config_fp8_full.quantization_calibration_method, - act_calibration_method=config_fp8_full.quantization_calibration_method, + weight_calibration_method="fixed,-224,224", + act_calibration_method="fixed,-224,224", bwd_calibration_method=config_fp8_full.quantization_calibration_method, op_names=("dot_general", "einsum"), ), @@ -354,8 +354,8 @@ def create_real_rule_instance(*args, **kwargs): act_qtype=jnp.float8_e4m3fn, bwd_qtype=jnp.float8_e4m3fn, disable_channelwise_axes=True, # per_tensor calibration - weight_calibration_method=config_fp8_full.quantization_calibration_method, - act_calibration_method=config_fp8_full.quantization_calibration_method, + weight_calibration_method="fixed,-224,224", + act_calibration_method="fixed,-224,224", bwd_calibration_method=config_fp8_full.quantization_calibration_method, op_names=("conv_general_dilated"), ),