Skip to content

Conversation

@aaghaazkhan
Copy link

@aaghaazkhan aaghaazkhan commented Dec 1, 2025

Summary

This PR adds defensive validation to ensure GRPO's required grouping
invariants are not silently violated during rollout or log-probability
computation.

GRPO requires that for each prompt, the policy generates G completions
(num_generations). These G completions must stay grouped together
throughout rollout, reward computation, and log-prob computation.

If the configured micro-batch sizes are smaller than num_generations,
the framework silently reshapes rewards/logprobs incorrectly, mixing
generations from different prompts. This leads to invalid advantage
computations and results in the model failing to learn (e.g., 0%
accuracy for all metrics).

This PR introduces explicit validation checks to stop training early if
incorrect batching configurations are used.


What This Fix Does

Adds two safety checks:

  1. Validate that rollout micro-batch size is large enough:
if self._rollout_micro_batch_size < self.algo_config.num_generations:
    raise ValueError(...)
  1. Validate that log-prob computation micro-batch size is also large enough:
if self._compute_logps_micro_batch_size < self.algo_config.num_generations:
    raise ValueError(...)

These checks prevent GRPO from running with invalid batching
configuration that silently corrupts reward grouping.

Why This Matters

Without these checks:

  • Batches are split incorrectly
  • Rewards are grouped incorrectly
  • Advantages are computed with wrong (prompt, generation) alignment
  • Training converges to 0 accuracy even for simple tasks
  • Debugging is very difficult because failure is something very unexpected

With this fix, the error is explicit, descriptive, and actionable.


Checklist:

  • I have added all the necessary unit tests for my change.
  • I have verified that my change does not break existing code and all unit tests pass.
  • I have added all appropriate doc-strings/documentation.
  • My PR is based on the latest changes of the main branch (if unsure, rebase the code).
  • I have signed the Contributor License Agreement.
  • I have followed Contribution Guidelines.

@aaghaazkhan aaghaazkhan changed the title Fix: enforce GRPO micro-batch grouping validation to prevent incorrect advantage computation Add GRPO micro-batch size validation to prevent incorrect grouping and advantage computation Dec 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant