-
Notifications
You must be signed in to change notification settings - Fork 447
[DECOUPLED-MODE] Adding Decoupling Logic #2865
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
(cherry picked from commit e8cc951)
(cherry picked from commit 0b58e96)
(cherry picked from commit 14f0508)
…ts library (cherry picked from commit 6f0b361)
(cherry picked from commit e43e370)
(cherry picked from commit 1c14d6c)
…ck, todo: remove this after updating jax. Configure ICI data parallelism for decoupled mode
tests/train_compile_test.py
Outdated
| from MaxText.globals import MAXTEXT_PKG_DIR | ||
| from maxtext.tests.test_utils import get_test_config_path | ||
|
|
||
| pytestmark = [pytest.mark.tpu_only] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These tests are suppose to run on CPUs. Why are we adding tpu_only marker?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These tests are suppose to run on CPUs. Why are we adding
tpu_onlymarker?
They are supposed to run on CPUs however it requires libtpu to generate the TPU topology. In the case where we do not have libtpu, it errors out.
|
|
||
| # Leave dataset-related keys to be overridden by individual tests. | ||
| dataset_type: "" | ||
| #dataset_type: "" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this #dataset_type intentional?
| from MaxText.globals import MAXTEXT_ASSETS_ROOT, MAXTEXT_PKG_DIR | ||
| from tests.test_utils import get_test_config_path, get_test_dataset_path, get_test_base_output_directory | ||
|
|
||
| pytestmark = [pytest.mark.tpu_only, pytest.mark.external_serving] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are few tests in this script that runs on gpu only. Can you remove pytest.mark.tpu_only?
| class DecodeTests(unittest.TestCase): | ||
| """Tests decode with various configs.""" | ||
|
|
||
| decoupled = is_decoupled() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this is unused.
| run_checkpoint_compatibility("tpu", "autoselected") | ||
|
|
||
|
|
||
| @pytest.mark.external_serving |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be external_training
| In decoupled mode (DECOUPLE_GCLOUD=TRUE) cloud diagnostics may be stubbed; if so, skip wrapping. | ||
| """ | ||
| if is_decoupled() or getattr(diagnostic, "__class__", None).__name__ == "_StubDiag": # runtime skip |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead, you can use contextlib.nullcontext to conditionally apply the diagnostic wrapper while keeping a single, clean call to your training loop.
| vertex_tensorboard_manager = VertexTensorboardManager() | ||
| if config.use_vertex_tensorboard or os.environ.get("UPLOAD_DATA_TO_TENSORBOARD"): | ||
| vertex_tensorboard_manager.configure_vertex_tensorboard(config) | ||
| if _vertex_tb_is_stub: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this if check be moved to configure_vertext_tensorboard() to keep train.py clean?
| elif ( | ||
| config.report_heartbeat_metric_for_gcp_monitoring or config.report_performance_metric_for_gcp_monitoring | ||
| ) and _monitor_is_stub: | ||
| max_logging.log("[DECOUPLED NO-OP] skipping GCP workload monitoring threads.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can move this as the first check inside get_performance_metric_queue()
| __all__.append("vertex_tensorboard_components") | ||
|
|
||
| # ---------------- TensorBoardX (moved stub) ----------------- | ||
| # ---------------- ML Diagnostics (google_cloud_mldiagnostics) ----------------- |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The stub classes use both _Stub and _Dummyprefix. Can we use a single convention, such as _Stub, throughout the module for consistency?
| with diagnostic.diagnose(diagnostic_config): | ||
| with maybe_record_goodput(recorder, GoodputEvent.JOB), maybe_monitor_goodput(config): | ||
| # In decoupled mode or when diagnostics are stubbed, skip the diagnose wrapper | ||
| if is_decoupled() or getattr(diagnostic, "__class__", None).__name__ == "_StubDiag": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as train.py
| from MaxText.layers import models | ||
| import pytest | ||
|
|
||
| pytestmark = [pytest.mark.external_serving] # uses pre-generated checkpoint |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
external_training
| from MaxText.experimental.rl import grpo_utils | ||
|
|
||
| # This test is for serving pathways via offline_engine and maxengine. | ||
| pytestmark = [pytest.mark.external_serving] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
external_training
Description
This PR is the second part of the decoupling support. It adds logic for decoupling support, along with some test modifications for decoupling to be enabled.
Details:
Tests
All unit tests pass in decoupled mode.
UT results:
== 306 passed, 170 skipped, 25 deselected, 6588 warnings in 975.16s (0:16:15) ==
Train test:
python -m MaxText.train MaxText/configs/base.yml run_name=test hardware=gpu steps=5 model_name=llama2-7b attention=cudnn_flash_te enable_checkpointing=False ici_expert_parallelism=1 ici_fsdp_parallelism=-1 ici_data_parallelism=1 remat_policy=minimal scan_layers=True dataset_type=synthetic logits_dot_in_fp32=False dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 max_target_length=2048 shardy=False
works.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.