Skip to content

Commit 6ac0aec

Browse files
authored
Merge branch 'master' into master-mc-bug-fixes
2 parents 2fb0ac1 + 3cc40a3 commit 6ac0aec

File tree

7 files changed

+131
-16
lines changed

7 files changed

+131
-16
lines changed
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
name: Sagemaker PR Checks (Master-v2)
2+
on:
3+
pull_request_target:
4+
branches:
5+
- "master-v2"
6+
7+
concurrency:
8+
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.head_ref }}
9+
cancel-in-progress: true
10+
11+
permissions:
12+
id-token: write
13+
14+
jobs:
15+
collab-check:
16+
runs-on: ubuntu-latest
17+
outputs:
18+
approval-env: ${{ steps.collab-check.outputs.result }}
19+
steps:
20+
- name: Collaborator Check
21+
uses: actions/github-script@v7
22+
id: collab-check
23+
with:
24+
github-token: ${{ secrets.COLLAB_CHECK_TOKEN }}
25+
result-encoding: string
26+
script: |
27+
try {
28+
const res = await github.rest.repos.checkCollaborator({
29+
owner: context.repo.owner,
30+
repo: context.repo.repo,
31+
username: "${{ github.event.pull_request.user.login }}",
32+
});
33+
console.log("Verifed ${{ github.event.pull_request.user.login }} is a repo collaborator. Auto Approving PR Checks.")
34+
return res.status == "204" ? "auto-approve" : "manual-approval"
35+
} catch (error) {
36+
console.log("${{ github.event.pull_request.user.login }} is not a collaborator. Requiring Manual Approval to run PR Checks.")
37+
return "manual-approval"
38+
}
39+
wait-for-approval:
40+
runs-on: ubuntu-latest
41+
needs: [collab-check]
42+
environment: ${{ needs.collab-check.outputs.approval-env }}
43+
steps:
44+
- run: echo "Workflow Approved! Starting PR Checks."
45+
codestyle-doc-tests:
46+
runs-on: ubuntu-latest
47+
needs: [wait-for-approval]
48+
steps:
49+
- name: Configure AWS Credentials
50+
uses: aws-actions/configure-aws-credentials@v4
51+
with:
52+
role-to-assume: ${{ secrets.CI_AWS_ROLE_ARN }}
53+
aws-region: us-west-2
54+
role-duration-seconds: 10800
55+
- name: Run Codestyle & Doc Tests
56+
uses: aws-actions/aws-codebuild-run-build@v1
57+
with:
58+
project-name: ${{ github.event.repository.name }}-ci-codestyle-doc-tests
59+
source-version-override: 'refs/pull/${{ github.event.pull_request.number }}/head^{${{ github.event.pull_request.head.sha }}}'
60+
unit-tests:
61+
runs-on: ubuntu-latest
62+
needs: [wait-for-approval]
63+
strategy:
64+
fail-fast: false
65+
matrix:
66+
python-version: ["py39","py310","py311","py312"]
67+
steps:
68+
- name: Configure AWS Credentials
69+
uses: aws-actions/configure-aws-credentials@v4
70+
with:
71+
role-to-assume: ${{ secrets.CI_AWS_ROLE_ARN }}
72+
aws-region: us-west-2
73+
role-duration-seconds: 10800
74+
- name: Run Unit Tests
75+
uses: aws-actions/aws-codebuild-run-build@v1
76+
with:
77+
project-name: ${{ github.event.repository.name }}-ci-unit-tests
78+
source-version-override: 'refs/pull/${{ github.event.pull_request.number }}/head^{${{ github.event.pull_request.head.sha }}}'
79+
env-vars-for-codebuild: |
80+
PY_VERSION
81+
env:
82+
PY_VERSION: ${{ matrix.python-version }}
83+
integ-tests:
84+
runs-on: ubuntu-latest
85+
needs: [wait-for-approval]
86+
steps:
87+
- name: Configure AWS Credentials
88+
uses: aws-actions/configure-aws-credentials@v4
89+
with:
90+
role-to-assume: ${{ secrets.CI_AWS_ROLE_ARN }}
91+
aws-region: us-west-2
92+
role-duration-seconds: 10800
93+
- name: Run Integ Tests
94+
uses: aws-actions/aws-codebuild-run-build@v1
95+
with:
96+
project-name: ${{ github.event.repository.name }}-ci-integ-tests
97+
source-version-override: 'refs/pull/${{ github.event.pull_request.number }}/head^{${{ github.event.pull_request.head.sha }}}'

.github/workflows/submodule-codebuild-ci.yml renamed to .github/workflows/pr-checks-master.yml

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
name: Sagemaker PR Checks
1+
name: Sagemaker PR Checks (Master)
22
on:
33
pull_request_target:
44
branches:
5-
- "master*"
5+
- "master"
66
paths:
77
- 'sagemaker-train/**'
88
- 'sagemaker-serve/**'
@@ -56,19 +56,18 @@ jobs:
5656
- uses: actions/checkout@v3
5757
with:
5858
fetch-depth: 0
59-
token: ${{ secrets.GH_PAT }} # or use appropriate token
60-
ref: ${{ github.event.pull_request.base.ref }} # Target branch (master-v3)
59+
token: ${{ secrets.GH_PAT }}
60+
ref: ${{ github.event.pull_request.base.ref }}
6161
- name: Detect Changes
6262
id: check-changes
6363
run: |
64-
set -e # Exit on error
64+
set -e
6565
66-
# Debug information
6766
echo "Target Branch: ${{ github.event.pull_request.base.ref }}"
6867
echo "Current Target SHA: $(git rev-parse HEAD)"
6968
echo "PR Number: ${{ github.event.pull_request.number }}"
7069
echo "PR Latest SHA: ${{ github.event.pull_request.head.sha }}"
71-
# Fetch PR without creating a branch
70+
7271
git fetch origin pull/${{ github.event.pull_request.number }}/head
7372
CHANGES=$(git diff --name-only HEAD FETCH_HEAD)
7473

sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def _calculate_transition_duration(trans) -> Tuple[str, str]:
157157
def wait(
158158
training_job: TrainingJob,
159159
poll: int = 5,
160-
timeout: Optional[int] = None
160+
timeout: Optional[int] = 3000
161161
) -> None:
162162
"""Wait for training job to complete with progress tracking.
163163
@@ -192,8 +192,10 @@ def wait(
192192
iteration = 0
193193
while True:
194194
iteration += 1
195-
time.sleep(poll)
196-
training_job.refresh()
195+
time.sleep(1)
196+
if iteration == poll:
197+
training_job.refresh()
198+
iteration = 0
197199
clear_output(wait=True)
198200

199201
status = training_job.training_job_status
@@ -302,7 +304,7 @@ def wait(
302304
raise FailedStatusError(resource_type="TrainingJob", status=status, reason=failure_reason)
303305

304306
if timeout and elapsed >= timeout:
305-
raise TimeoutExceededError(resouce_type="TrainingJob", status=status)
307+
raise TimeoutExceededError(resource_type="TrainingJob", status=status)
306308

307309
else:
308310
print(f"\nTrainingJob Name: {training_job.training_job_name}")
@@ -363,7 +365,7 @@ def wait(
363365
raise FailedStatusError(resource_type="TrainingJob", status=status, reason=failure_reason)
364366

365367
if timeout and elapsed >= timeout:
366-
raise TimeoutExceededError(resouce_type="TrainingJob", status=status)
368+
raise TimeoutExceededError(resource_type="TrainingJob", status=status)
367369

368370

369371
except (FailedStatusError, TimeoutExceededError):

sagemaker-train/src/sagemaker/train/dpo_trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,11 @@ def train(self,
261261

262262
if wait:
263263
from sagemaker.train.common_utils.trainer_wait import wait as _wait
264-
_wait(training_job)
264+
from sagemaker.core.utils.exceptions import TimeoutExceededError
265+
try :
266+
_wait(training_job)
267+
except TimeoutExceededError as e:
268+
logger.error("Error: %s", e)
265269

266270
self.latest_training_job = training_job
267271
return training_job

sagemaker-train/src/sagemaker/train/rlaif_trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,11 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
268268

269269
if wait:
270270
from sagemaker.train.common_utils.trainer_wait import wait as _wait
271-
_wait(training_job)
271+
from sagemaker.core.utils.exceptions import TimeoutExceededError
272+
try :
273+
_wait(training_job)
274+
except TimeoutExceededError as e:
275+
logger.error("Error: %s", e)
272276

273277
self.latest_training_job = training_job
274278
return training_job

sagemaker-train/src/sagemaker/train/rlvr_trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,11 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None,
268268

269269
if wait:
270270
from sagemaker.train.common_utils.trainer_wait import wait as _wait
271-
_wait(training_job)
271+
from sagemaker.core.utils.exceptions import TimeoutExceededError
272+
try:
273+
_wait(training_job)
274+
except TimeoutExceededError as e:
275+
logger.error("Error: %s", e)
272276

273277
self.latest_training_job = training_job
274278
return training_job

sagemaker-train/src/sagemaker/train/sft_trainer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from logging import exception
12
from typing import Optional, Union
23
import logging
34
from sagemaker.train.base_trainer import BaseTrainer
@@ -261,7 +262,11 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
261262

262263
if wait:
263264
from sagemaker.train.common_utils.trainer_wait import wait as _wait
264-
_wait(training_job)
265+
from sagemaker.core.utils.exceptions import TimeoutExceededError
266+
try :
267+
_wait(training_job)
268+
except TimeoutExceededError as e:
269+
logger.error("Error: %s", e)
265270

266271
self.latest_training_job = training_job
267272
return training_job

0 commit comments

Comments
 (0)