-
Notifications
You must be signed in to change notification settings - Fork 446
Implement Warmup-Stable-Decay (WSD) Learning Rate Schedule #2883
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1103,16 +1103,25 @@ def create_device_mesh(config, devices=None): | |
|
|
||
|
|
||
| def create_learning_rate_schedule(config): | ||
| """Creates a warmup and cosine decay learning rate schedule: | ||
| We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2 | ||
| Learning rate schedule has either two or three parts: | ||
| """Creates a learning rate schedule with warmup and decay. | ||
|
|
||
| Supports two schedule types: | ||
| - Cosine: Inspired by Llama2's learning rate schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2 | ||
| - WSD (Warmup-Stable-Decay): Maintains constant learning rate for most of training before final decay | ||
|
|
||
| Schedule structure: | ||
| 1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction] | ||
| 2) Cosine from [learning_rate] to [learning_rate * cosine_learning_rate_final_fraction] until learning_rate_schedule_steps | ||
| 2) Decay from [learning_rate] to a final value until learning_rate_schedule_steps | ||
| - Cosine: decays to [learning_rate * cosine_learning_rate_final_fraction] | ||
| - WSD: maintains [learning_rate] for a stable phase, then decays to [learning_rate * wsd_learning_rate_final_fraction] | ||
| using either linear or cosine decay based on wsd_decay_style | ||
| 3) Constant learning rate of 0 from learning_rate_schedule_steps to steps. | ||
| The zero learning rate section can be used to more accurately measure the fully trained model's performance. | ||
| """ | ||
|
|
||
| def make_cos_schedule(init_lr, final_lr, len_steps): | ||
| """Creates a cosine decay schedule from init_lr to final_lr over len_steps.""" | ||
|
|
||
| def schedule(step): | ||
| pct = (step) / len_steps | ||
| a = 0.5 * (jnp.cos(jnp.pi * pct) + 1) | ||
|
|
@@ -1122,25 +1131,50 @@ def schedule(step): | |
| return schedule | ||
|
|
||
| lr = config.learning_rate | ||
| cos_final_lr = lr * config.cosine_learning_rate_final_fraction | ||
|
|
||
| warmup_steps = int(config.learning_rate_schedule_steps * config.warmup_steps_fraction) | ||
| cos_steps = config.learning_rate_schedule_steps - warmup_steps | ||
| constant_zero_steps = config.steps - config.learning_rate_schedule_steps | ||
|
|
||
| warmup_schedule = optax.linear_schedule(init_value=0.0, end_value=lr, transition_steps=warmup_steps) | ||
| cos_schedule = make_cos_schedule(lr, cos_final_lr, cos_steps) | ||
| constant_schedule = optax.constant_schedule(0.0) | ||
|
|
||
| pieces = [warmup_schedule, cos_schedule] | ||
| boundaries = [ | ||
| warmup_steps, | ||
| warmup_steps + cos_steps, | ||
| ] | ||
| if config.lr_schedule_type == "cosine": | ||
| cos_final_lr = lr * config.cosine_learning_rate_final_fraction | ||
| cos_steps = config.learning_rate_schedule_steps - warmup_steps | ||
| cos_schedule = make_cos_schedule(lr, cos_final_lr, cos_steps) | ||
|
|
||
| pieces = [warmup_schedule, cos_schedule] | ||
| boundaries = [warmup_steps, warmup_steps + cos_steps] | ||
|
|
||
| elif config.lr_schedule_type == "wsd": | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you use types.LearningRateScheduleType.WSD here? |
||
| wsd_final_lr = lr * config.wsd_learning_rate_final_fraction | ||
| decay_steps = int(config.learning_rate_schedule_steps * config.wsd_decay_steps_fraction) | ||
| stable_steps = config.learning_rate_schedule_steps - warmup_steps - decay_steps | ||
|
|
||
| if stable_steps < 0: | ||
| raise ValueError( | ||
| f"Invalid WSD schedule: warmup_steps_fraction ({config.warmup_steps_fraction}) + " | ||
| f"wsd_decay_steps_fraction ({config.wsd_decay_steps_fraction}) must not exceed 1.0. " | ||
| f"Current sum: {config.warmup_steps_fraction + config.wsd_decay_steps_fraction}" | ||
| ) | ||
|
|
||
| stable_schedule = optax.constant_schedule(lr) | ||
|
|
||
| # Create decay schedule based on wsd_decay_style | ||
| if config.wsd_decay_style == "linear": | ||
| decay_schedule = optax.linear_schedule(init_value=lr, end_value=wsd_final_lr, transition_steps=decay_steps) | ||
| elif config.wsd_decay_style == "cosine": | ||
|
Comment on lines
+1161
to
+1163
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please use types.WsdDecayStyle.LINEAR and types.WsdDecayStyle.COSINE instead of "linear" and "cosine" |
||
| decay_schedule = make_cos_schedule(lr, wsd_final_lr, decay_steps) | ||
| else: | ||
| raise ValueError(f"Invalid wsd_decay_style: {config.wsd_decay_style}. " "Must be either 'linear' or 'cosine'.") | ||
|
|
||
| pieces = [warmup_schedule, stable_schedule, decay_schedule] | ||
| boundaries = [warmup_steps, warmup_steps + stable_steps] | ||
|
|
||
| else: | ||
| raise ValueError(f"Invalid lr_schedule_type: {config.lr_schedule_type}. " "Must be either 'cosine' or 'wsd'.") | ||
|
|
||
| if constant_zero_steps > 0: | ||
| constant_schedule = optax.constant_schedule(0.0) | ||
| pieces.append(constant_schedule) | ||
| boundaries.append(warmup_steps + cos_steps + constant_zero_steps) | ||
| boundaries.append(boundaries[-1] + constant_zero_steps) | ||
|
|
||
| return optax.join_schedules(pieces, boundaries) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you use types.LearningRateScheduleType.COSINE instead of string literal "cosine"?