Add GRPO micro-batch size validation to prevent incorrect grouping and advantage computation #820
+19
−0
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
This PR adds defensive validation to ensure GRPO's required grouping
invariants are not silently violated during rollout or log-probability
computation.
GRPO requires that for each prompt, the policy generates
Gcompletions(
num_generations). TheseGcompletions must stay grouped togetherthroughout rollout, reward computation, and log-prob computation.
If the configured micro-batch sizes are smaller than
num_generations,the framework silently reshapes rewards/logprobs incorrectly, mixing
generations from different prompts. This leads to invalid advantage
computations and results in the model failing to learn (e.g., 0%
accuracy for all metrics).
This PR introduces explicit validation checks to stop training early if
incorrect batching configurations are used.
What This Fix Does
Adds two safety checks:
These checks prevent GRPO from running with invalid batching
configuration that silently corrupts reward grouping.
Why This Matters
Without these checks:
With this fix, the error is explicit, descriptive, and actionable.
Checklist: