Skip to content

feat(zenflow): run the overlapped CPU optimizer in a native process#8058

Open
Antlera wants to merge 13 commits into
masterfrom
zenflow-cpu-op
Open

feat(zenflow): run the overlapped CPU optimizer in a native process#8058
Antlera wants to merge 13 commits into
masterfrom
zenflow-cpu-op

Conversation

@Antlera

@Antlera Antlera commented Jun 10, 2026

Copy link
Copy Markdown
Collaborator

What changes this PR introduce

ZenFlow's overlapped CPU optimizer step previously ran in a Python multiprocessing
subprocess coordinated by a pickling Pipe. This PR moves that optimizer into a
native CPU optimizer process packaged inside the cpu_adam op, coordinated through a
shared-memory POSIX-semaphore control block instead of pickling. The Adam state is
allocated in that process, NUMA-local to the optimizer's pinned thread pool.

Highlights:

  • Fused multi-tensor CPU Adam (adam_update_multi): drives a whole flattened partition
    in C++ and writes the stale snapshot natively, removing the per-parameter Python↔C++ loop
    and the Python-side clone().
  • ZenFlowAdam native class: a pinned std::thread pool (pinned to ZenFlow's dedicated
    cores) running the serial Adam kernel per slice, driven from the main process via the
    shared-memory control block (run_worker / submit / wait).
  • Covers ZeRO stages 1, 2, and 3; removes the old pickling subprocess entirely.
  • Chunked copyback: streaming the updated fp32 master partition back to the GPU bit16
    partition in chunks drops a transient GPU spike from ~2944 MiB to ~256 MiB for a
    0.75B-param partition (the old fp32.to(device) materialized the whole fp32 partition on
    the GPU first).
  • ZenFlowCPUAdam is now a recognized ZeRO optimizer, so zero_allow_untested_optimizer
    is no longer required in ZenFlow configs.

Correctness & performance

  • Bit-identical to the previous subprocess path: cross-process and fused-op unit tests, plus
    seeded end-to-end loss across ZeRO stages 1/2/3.
  • Real Qwen2.5-1.5B end-to-end (ZeRO-2, CPU offload, 2 GPUs): per-step throughput unchanged
    (no regression). Small / IPC-bound configurations are faster (the per-step pickling/IPC
    overhead is removed).

Dependency / merge order

This branch is based on top of #7771 ("Fix ZenFlow NaN under PyTorch-style backward"), so its
backward_prologue commit rides along here. Please merge #7771 first, then this PR — after
#7771 lands, that commit is already in master and only the native-optimizer changes remain.

Testing

  • tests/unit/ops/adam/test_cpu_adam.py: test_zenflow_adam_cross_process (production path,
    bit-identical to the fused reference) and TestCPUAdamFusedMultiTensor.
  • End-to-end ZenFlow finetuning on ZeRO stages 1/2/3 (single- and multi-GPU).

Note: the native optimizer process uses POSIX semaphores and is Linux-only.

Antlera added 11 commits June 1, 2026 00:36
The PyTorch-style backward API drives backward through loss.backward() and
the engine's autograd hooks, which call optimizer.backward_prologue() at the
start of each backward pass instead of ZenFlow's own backward(). ZenFlow's
per-microbatch setup therefore never ran, leaving micro_step unadvanced and
the selective optimizer unsynced at a selection boundary, so the top-k update
operated on stale state and the loss went NaN.

- Override backward_prologue() with ZenFlow's per-microbatch setup: advance
  micro_step, refresh the auto-update bookkeeping, and on a selection boundary
  resync the fp32 master partition and clear the selective optimizer's moments.
- Remove the standalone backward() override, which the PyTorch-style engine no
  longer calls.

Validated on Qwen2.5-0.5B + Alpaca (ZeRO-2 offload, overlap step): loss now
matches the old-version ZenFlow step-for-step instead of diverging to NaN.

Signed-off-by: Tingfeng Lan <erc8gx@virginia.edu>
ZenFlow's overlapped CPU optimizer stepped each parameter through a separate
`adam_update` call from Python and kept a stale snapshot for the GPU sync via
`p.stale_param.data.copy_(p.data.clone())`. For a group with many parameters
this pays one Python<->C++ crossing (and one OpenMP region spawn) per parameter,
and the `clone()` adds a full allocation plus an extra memory pass every step.

Add a fused multi-tensor entry that drives the whole group in C++ and writes the
stale snapshot natively, so the overlapped step issues a single native call.

- Add `ds_adam_step_multi` (bound as `adam_update_multi`): one call updates a
  list of params/grads/exp_avg/exp_avg_sq, advancing the bias-correction state
  once for the shared step; when a stale list is provided, each post-update
  parameter is snapshotted into it via a native copy.
- Rewrite `ZenFlowCPUAdam._parallel_step` to collect the group's tensors and
  issue a single `adam_update_multi`, dropping the per-parameter calls and the
  Python-side `clone()`.
- Leave the existing per-parameter `ds_adam_step` path unchanged.
- Add a numerical-equivalence test: fused vs per-parameter is bit-for-bit equal
  across fp16/bf16/fp32 (params, moments, and the stale snapshot), plus the
  empty-stale path.

Behavior is identical to the per-parameter path, verified bit-for-bit at the op
level and as an unchanged end-to-end loss trajectory across ZeRO stages 1/2/3.

Signed-off-by: Tingfeng Lan <erc8gx@virginia.edu>
Prepare the kernel for ZenFlow's in-process optimizer thread (L2). When the
optimizer runs on a background thread pinned to a dedicated set of cores, it must
not spawn OpenMP teams from the global libgomp pool — that pool is shared with
the training thread's torch ops and would defeat the core partitioning.

Thread a `parallel` flag through the step path (`Step_1/4/8`, `Step_AVX`,
`step_invoker`, the dtype dispatch map, and `invoke`) and turn the two
`#pragma omp parallel for` into `if (parallel)`. With `parallel=true` (the
default everywhere) the region is identical to before; with `parallel=false` the
loop runs serially in the calling thread, so a pinned pool can drive each element
slice itself.

- Expose the flag as an optional `parallel` argument on `adam_update_multi`
  (defaults to true, so existing callers are unchanged).
- Add a test that the serial path matches the OpenMP path bit-for-bit across
  fp16/bf16/fp32.

No behavior change for existing paths; Adam math is untouched.

Signed-off-by: Tingfeng Lan <erc8gx@virginia.edu>
Add the native side of ZenFlow's overlapped optimizer so the CPU Adam step can
run concurrently with the Python training thread without a separate process. The
existing design dodges the GIL by running the step in a multiprocessing
subprocess, which costs process spawn, shared-memory tensors, a pipe, and
per-step rebinding. With the step in native code that releases the GIL, a
background thread in the same process achieves the same overlap and touches the
same tensors directly.

ZenFlowAdam owns a dispatcher thread and a pool of worker threads pinned to
ZenFlow's dedicated cores. submit_step() hands a step to the dispatcher and
returns immediately; wait_step() blocks (with the GIL released) until it
finishes. The dispatcher advances the shared optimizer's bias-correction state
per group, then fans each group's elements out to the pinned pool, where every
thread runs its slice through the serial (parallel=false) kernel -- so the pool,
not OpenMP, provides the parallelism and stays on the ZenFlow cores.

- Pin pool threads with pthread_setaffinity_np (Linux); slice boundaries are
  rounded to the SIMD block so each slice's AVX/scalar split matches the
  whole-tensor kernel and the result is bit-identical.
- Expose a small C handle API (zenflow_adam_create/register_group/submit/wait/
  destroy); submit/wait/destroy release the GIL.
- Tests: ZenFlowAdam matches the fused reference bit-for-bit with alternating
  double buffers and multiple groups, and the pipelined submit/wait (including
  the engine's skipped post-warmup wait) does not desync.

Packaged inside the cpu_adam op to reuse Adam_Optimizer and the dtype dispatch;
not yet wired into the ZenFlow engine.

Signed-off-by: Tingfeng Lan <erc8gx@virginia.edu>
Replace the multiprocessing optimizer subprocess with the in-process ZenFlowAdam
handle for ZeRO stage 1/2. The subprocess existed only to dodge the GIL; now that
the step runs in native code that releases the GIL, a background dispatcher plus a
pinned thread pool in the same process give the same overlap and operate on the
same tensors directly -- removing the pipe, shared-memory sharing, the manager
dict, and the per-step rebinding.

- `start_optimizer_process` branches: stage 1/2 builds an in-process ZenFlowCPUAdam,
  eagerly allocates the double-buffered moments, registers each group with the
  native handle, and confines the training thread to the PyTorch core set
  (affinity + OMP_NUM_THREADS + torch.set_num_threads) so it does not contend with
  the optimizer's pinned pool. Stage 3 keeps the subprocess for now.
- `ZenFlowCPUAdam` gains init_native_overlap/submit_overlap_step/wait_overlap_step
  and destroys the handle on teardown.
- stage 1/2 `zenflow_cpu_optimizer_step`/`wait_last_update_and_copy` call the handle's
  submit/wait instead of pipe send/recv.
- Factor the zf/pt core split into `_compute_zf_pt_affinity`, shared by both paths.
- Add an overlap_step=True unit test for stage 1/2 (the in-process path runs under
  the test harness; the stage 3 subprocess cannot spawn from the daemonic test
  process, which is itself a reason to migrate it).

Verified: native and subprocess paths produce bit-identical loss trajectories for
stage 1/2 over a seeded run.

Signed-off-by: Tingfeng Lan <erc8gx@virginia.edu>
Profiling the in-process design showed it regressed ~18% on large, memory-
bandwidth-bound updates: the Adam moments (two thirds of the step's memory
traffic) were allocated by the training thread and ended up NUMA-remote from the
optimizer's pinned pool, and the pool contended with the training thread inside
one process. A separate process avoids both -- it allocates its state locally on
its own NUMA node and is isolated -- which is why the old subprocess was faster
there. The old subprocess was only slow on small models because of its per-step
Python/pickle/Manager overhead.

So keep the separate process but make the coordination native: the optimizer
runs the ZenFlowAdam pinned pool in its own process and talks to the training
process through two process-shared semaphores in a shared-memory control block,
instead of a pickling pipe. No Python in the optimizer loop, no per-step
rebinding. Measured (ms/step, best of 3): 0.5M 7.6 vs 9.9, 134M 114 vs 119 --
faster than the old subprocess at both ends.

- C++: ZenControl shared-memory block (sem_t cmd_ready/done, command, per-group
  hyperparameters); ZenFlowAdam::run_worker drives the pool from it;
  zenflow_adam_ctrl_{size,init,submit,wait,exit} for the training side. Reuses the
  pinned pool and run_step; in-process submit/wait kept only as a fast unit-test
  driver for the pool. Linux-only (POSIX semaphores).
- Python: the optimizer process builds the pool, allocates state locally, and runs
  the worker loop; stage 1/2 submit/wait call the control functions. Drops the
  in-process ZenFlowCPUAdam overlap helpers.
- Test: a cross-process op test (plain, not DistributedTest, so the non-daemonic
  pytest process can spawn the optimizer) checks bit-for-bit equality with the
  fused reference across alternating double buffers. The engine-level overlap test
  is removed again: like the subprocess, the optimizer process cannot be spawned
  from the daemonic test worker.

Stage 3 still uses the pickling subprocess; migrating it is a follow-up.
Verified: stage 1/2 training loss is bit-identical to the subprocess over a
seeded run.

Signed-off-by: Tingfeng Lan <erc8gx@virginia.edu>
Migrate ZeRO stage 3 overlap to the same separate native-process optimizer used
for stage 1/2: the optimizer process runs the ZenFlowAdam pinned pool driven by
the shared-memory semaphore control block, instead of the pickling subprocess.

- Generalize the optimizer-process startup to gather groups from
  fp32_partitioned_groups_flat for stage 3 (one flat partition per sub-group)
  and from the param groups for stage 1/2; both carry overlap_grad double
  buffers and a stale snapshot. start_optimizer_process now always takes the
  native path.
- engine_stage3 submit/wait call zenflow_adam_ctrl_submit/ctrl_wait instead of
  the pipe; the warm-up transition guard is unchanged.
- Remove the now-unreachable pickling optimizer loop (zenflow_optimizer_process)
  and its subprocess setup.

Verified: stage 3 training loss is bit-identical to the old subprocess over a
seeded run.

Note: ZenFlowCPUAdam._parallel_step (and the adam_update_multi Python caller) are
now only reachable from tests; pruning those superseded layers is left to a
dedicated cleanup.

Signed-off-by: Tingfeng Lan <erc8gx@virginia.edu>
The training process waited unbounded on the optimizer process's ready signal. If
that process crashed during initialization (for example a SIGBUS when /dev/shm is
exhausted, or a bad spawn), the training process blocked forever on the first
step's wait with no indication of what went wrong.

Bound the wait and raise a clear error if the optimizer process never signals
ready, so the failure surfaces instead of hanging.

Verified at scale: ZeRO stage 1/2/3 overlap trains 0.5B and 1.5B parameter models
on 1 and 2 GPUs (the optimizer process registers the flattened partitions, signals
ready, and steps to finite loss).

Signed-off-by: Tingfeng Lan <erc8gx@virginia.edu>
When the overlapped CPU optimizer finishes, the updated fp32 master partition
is copied back to its GPU bit16 partition via bit16.copy_(fp32.to(device)).
The .to(device) first materializes the entire fp32 partition on the GPU -- a
transient spike of ~2x the bit16 partition (measured ~2944 MiB for a 0.75B-param
partition) stacked on top of the model, which is exactly the memory CPU offload
is meant to save.

Stream the copy in fixed-size chunks so only one chunk's fp32 staging tensor is
resident at a time; the transient peak drops to the chunk size (measured ~256
MiB) and the bit16 result is unchanged. End-to-end throughput is unaffected.

Signed-off-by: Tingfeng Lan <erc8gx@virginia.edu>
ZenFlow's overlapped optimizer now always runs in a dedicated process driven by a
shared-memory semaphore control block (ZenFlowAdam::run_worker). The earlier
in-process variant -- a background dispatcher thread with submit_step/wait_step,
exposed as zenflow_adam_submit/wait(handle) and ZenFlowCPUAdam._parallel_step --
was kept only as a unit-test driver and is no longer reachable in production.

Remove it: drop the dispatcher thread and its sync state from ZenFlowAdam, delete
the handle-based submit/wait bindings and _parallel_step, and delete the
TestZenFlowAdamNative test. With the in-process submit/wait gone, the cross-process
control-block ops reclaim the plain names zenflow_adam_submit/wait. The fused
adam_update_multi op (still used by the worker kernel and its own tests) is kept.

No functional change to the production cross-process path; cross-process and fused
unit tests and a stage 1/2 end-to-end run remain bit-identical.

Signed-off-by: Tingfeng Lan <erc8gx@virginia.edu>
is_zero_supported_optimizer matches the optimizer type exactly, so ZenFlowCPUAdam
(a DeepSpeedCPUAdam subclass used by ZenFlow's CPU offload) was treated as untested
and required zero_allow_untested_optimizer: true in every ZenFlow config. Add it to
ZERO_SUPPORTED_OPTIMIZERS so ZenFlow runs without that flag.

Signed-off-by: Tingfeng Lan <erc8gx@virginia.edu>
@Antlera Antlera requested review from delock and tohtana and removed request for loadams, tjruwase and tohtana June 10, 2026 21:34
@Antlera

Antlera commented Jun 10, 2026

Copy link
Copy Markdown
Collaborator Author

Hi @tohtana @delock, could you help review this PR. Thanks! I am trying to build better cpu_op for zenflow. So that we have better control over the cpu optimizer and memory usage part.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 13ce892f2c

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread csrc/adam/cpu_adam_impl.cpp Outdated
void zenflow_adam_wait(uintptr_t control_ptr)
{
auto* ctrl = reinterpret_cast<ZenControl*>(control_ptr);
while (sem_wait(&ctrl->done) != 0) {} // retry on EINTR

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Detect optimizer process death while waiting

If the native optimizer process exits after ready.set() but before posting done (for example a TORCH_CHECK/OOM during run_step), this wait has no timeout or process-liveness check and the training rank blocks forever on the semaphore. The old Pipe-based path would surface a closed pipe/error; please make the Python wait path poll zf_optimizer.process or use a timed wait so step-time crashes fail loudly instead of hanging distributed training.

Useful? React with 👍 / 👎.

Antlera added 2 commits June 10, 2026 21:42
CI's clang-format (18.1.3) expands the single-line constructor and
zenflow_adam_run_worker bodies to multi-line; match it.

Signed-off-by: Tingfeng Lan <erc8gx@virginia.edu>
If the optimizer process exited after signalling ready but before posting a
completion (e.g. an OOM or TORCH_CHECK in run_step), the training side blocked
forever on the done semaphore, hanging the whole distributed job -- unlike the
old Pipe path, which surfaced a closed-pipe error.

Make zenflow_adam_wait a bounded wait (sem_timedwait) returning whether a
completion was consumed. The training side (ZeRO stage 1/2 and 3) now loops on
it and, on each timeout, checks the optimizer process is still alive, raising a
clear error instead of hanging if it died. Normal steps are unaffected (the wait
returns as soon as the worker posts done).

Signed-off-by: Tingfeng Lan <erc8gx@virginia.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant