Skip to content

Conversation

@gulsumgudukbay
Copy link
Collaborator

Description

This PR is the second part of the decoupling support. It adds logic for decoupling support, along with some test modifications for decoupling to be enabled.

Details:

  1. Update decoupled_base_test.yml
  2. Add decoupling locig to src/MaxText/decode.py, src/MaxText/elastic_train.py, src/MaxText/experimental/rl/grpo_trainer.py, src/MaxText/gcp_workload_monitor.py, src/MaxText/max_utils.py, src/MaxText/maxengine.py, src/MaxText/maxengine_config.py, src/MaxText/maxengine_server.py, src/MaxText/metric_logger.py, src/MaxText/prefill_packing.py, src/MaxText/profiler.py, src/MaxText/sft/hooks.py, src/MaxText/sft/sft_trainer.py, src/MaxText/train.py, src/MaxText/utils/gcs_utils.py, src/MaxText/utils/goodput_utils.py, src/MaxText/vertex_tensorboard.py
  3. Update src/MaxText/gcloud_stub.py to add IS_STUB variables, and add google_cloud_mldiagnostics stub
  4. Update tests to support decoupled mode (add markers, update file paths, make them use decoupled_base_test.yml config file).

Tests

All unit tests pass in decoupled mode.
UT results:
== 306 passed, 170 skipped, 25 deselected, 6588 warnings in 975.16s (0:16:15) ==

Train test:
python -m MaxText.train MaxText/configs/base.yml run_name=test hardware=gpu steps=5 model_name=llama2-7b attention=cudnn_flash_te enable_checkpointing=False ici_expert_parallelism=1 ici_fsdp_parallelism=-1 ici_data_parallelism=1 remat_policy=minimal scan_layers=True dataset_type=synthetic logits_dot_in_fp32=False dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 max_target_length=2048 shardy=False

works.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

gulsumgudukbay and others added 25 commits December 21, 2025 06:16
(cherry picked from commit e8cc951)
(cherry picked from commit 0b58e96)
(cherry picked from commit 14f0508)
(cherry picked from commit e43e370)
(cherry picked from commit 1c14d6c)
…ck, todo: remove this after updating jax. Configure ICI data parallelism for decoupled mode
from MaxText.globals import MAXTEXT_PKG_DIR
from maxtext.tests.test_utils import get_test_config_path

pytestmark = [pytest.mark.tpu_only]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests are suppose to run on CPUs. Why are we adding tpu_only marker?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests are suppose to run on CPUs. Why are we adding tpu_only marker?

They are supposed to run on CPUs however it requires libtpu to generate the TPU topology. In the case where we do not have libtpu, it errors out.


# Leave dataset-related keys to be overridden by individual tests.
dataset_type: ""
#dataset_type: ""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this #dataset_type intentional?

from MaxText.globals import MAXTEXT_ASSETS_ROOT, MAXTEXT_PKG_DIR
from tests.test_utils import get_test_config_path, get_test_dataset_path, get_test_base_output_directory

pytestmark = [pytest.mark.tpu_only, pytest.mark.external_serving]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are few tests in this script that runs on gpu only. Can you remove pytest.mark.tpu_only?

class DecodeTests(unittest.TestCase):
"""Tests decode with various configs."""

decoupled = is_decoupled()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is unused.

run_checkpoint_compatibility("tpu", "autoselected")


@pytest.mark.external_serving
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be external_training

In decoupled mode (DECOUPLE_GCLOUD=TRUE) cloud diagnostics may be stubbed; if so, skip wrapping.
"""
if is_decoupled() or getattr(diagnostic, "__class__", None).__name__ == "_StubDiag": # runtime skip
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead, you can use contextlib.nullcontext to conditionally apply the diagnostic wrapper while keeping a single, clean call to your training loop.

vertex_tensorboard_manager = VertexTensorboardManager()
if config.use_vertex_tensorboard or os.environ.get("UPLOAD_DATA_TO_TENSORBOARD"):
vertex_tensorboard_manager.configure_vertex_tensorboard(config)
if _vertex_tb_is_stub:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this if check be moved to configure_vertext_tensorboard() to keep train.py clean?

elif (
config.report_heartbeat_metric_for_gcp_monitoring or config.report_performance_metric_for_gcp_monitoring
) and _monitor_is_stub:
max_logging.log("[DECOUPLED NO-OP] skipping GCP workload monitoring threads.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can move this as the first check inside get_performance_metric_queue()

__all__.append("vertex_tensorboard_components")

# ---------------- TensorBoardX (moved stub) -----------------
# ---------------- ML Diagnostics (google_cloud_mldiagnostics) -----------------
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The stub classes use both _Stub and _Dummyprefix. Can we use a single convention, such as _Stub, throughout the module for consistency?

with diagnostic.diagnose(diagnostic_config):
with maybe_record_goodput(recorder, GoodputEvent.JOB), maybe_monitor_goodput(config):
# In decoupled mode or when diagnostics are stubbed, skip the diagnose wrapper
if is_decoupled() or getattr(diagnostic, "__class__", None).__name__ == "_StubDiag":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as train.py

from MaxText.layers import models
import pytest

pytestmark = [pytest.mark.external_serving] # uses pre-generated checkpoint
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

external_training

from MaxText.experimental.rl import grpo_utils

# This test is for serving pathways via offline_engine and maxengine.
pytestmark = [pytest.mark.external_serving]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

external_training

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants