Skip to content

Commit 367ac58

Browse files
committed
Add StatefulSet as owner of Workload.
1 parent a469237 commit 367ac58

File tree

3 files changed

+242
-14
lines changed

3 files changed

+242
-14
lines changed

pkg/controller/jobs/statefulset/statefulset_reconciler.go

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"time"
2222

2323
"github.com/go-logr/logr"
24+
"golang.org/x/sync/errgroup"
2425
appsv1 "k8s.io/api/apps/v1"
2526
corev1 "k8s.io/api/core/v1"
2627
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
@@ -29,6 +30,7 @@ import (
2930
"k8s.io/client-go/tools/record"
3031
"k8s.io/client-go/util/workqueue"
3132
"k8s.io/klog/v2"
33+
"k8s.io/utils/ptr"
3234
ctrl "sigs.k8s.io/controller-runtime"
3335
"sigs.k8s.io/controller-runtime/pkg/client"
3436
"sigs.k8s.io/controller-runtime/pkg/controller/controllerutil"
@@ -37,6 +39,7 @@ import (
3739
"sigs.k8s.io/controller-runtime/pkg/predicate"
3840
"sigs.k8s.io/controller-runtime/pkg/reconcile"
3941

42+
kueue "sigs.k8s.io/kueue/apis/kueue/v1beta2"
4043
"sigs.k8s.io/kueue/pkg/controller/jobframework"
4144
podcontroller "sigs.k8s.io/kueue/pkg/controller/jobs/pod/constants"
4245
clientutil "sigs.k8s.io/kueue/pkg/util/client"
@@ -65,34 +68,35 @@ func (r *Reconciler) Reconcile(ctx context.Context, req reconcile.Request) (reco
6568
log := ctrl.LoggerFrom(ctx)
6669
log.V(2).Info("Reconcile StatefulSet")
6770

68-
err := r.fetchAndFinalizePods(ctx, req)
69-
return ctrl.Result{}, err
70-
}
71-
72-
func (r *Reconciler) fetchAndFinalizePods(ctx context.Context, req reconcile.Request) error {
7371
podList := &corev1.PodList{}
7472
if err := r.client.List(ctx, podList, client.InNamespace(req.Namespace), client.MatchingLabels{
7573
podcontroller.GroupNameLabel: GetWorkloadName(req.Name),
7674
}); err != nil {
77-
return err
78-
}
79-
80-
// If no Pods are found, there's nothing to do.
81-
if len(podList.Items) == 0 {
82-
return nil
75+
return ctrl.Result{}, err
8376
}
8477

8578
sts := &appsv1.StatefulSet{}
8679
err := r.client.Get(ctx, req.NamespacedName, sts)
8780
if client.IgnoreNotFound(err) != nil {
88-
return err
81+
return ctrl.Result{}, err
8982
}
9083

9184
if err != nil {
9285
sts = nil
9386
}
9487

95-
return r.finalizePods(ctx, sts, podList.Items)
88+
eg, ctx := errgroup.WithContext(ctx)
89+
90+
eg.Go(func() error {
91+
return r.finalizePods(ctx, sts, podList.Items)
92+
})
93+
94+
eg.Go(func() error {
95+
return r.reconcileWorkload(ctx, sts)
96+
})
97+
98+
err = eg.Wait()
99+
return ctrl.Result{}, err
96100
}
97101

98102
func (r *Reconciler) finalizePods(ctx context.Context, sts *appsv1.StatefulSet, pods []corev1.Pod) error {
@@ -139,6 +143,43 @@ func shouldFinalize(sts *appsv1.StatefulSet, pod *corev1.Pod) bool {
139143
return shouldUngate(sts, pod) || utilpod.IsTerminated(pod)
140144
}
141145

146+
func (r *Reconciler) reconcileWorkload(ctx context.Context, sts *appsv1.StatefulSet) error {
147+
if sts == nil {
148+
return nil
149+
}
150+
151+
wl := &kueue.Workload{}
152+
err := r.client.Get(ctx, client.ObjectKey{Namespace: sts.Namespace, Name: GetWorkloadName(sts.Name)}, wl)
153+
if err != nil {
154+
return client.IgnoreNotFound(err)
155+
}
156+
157+
hasOwnerReference, err := controllerutil.HasOwnerReference(wl.OwnerReferences, sts, r.client.Scheme())
158+
if err != nil {
159+
return err
160+
}
161+
162+
var (
163+
shouldUpdate = false
164+
replicas = ptr.Deref(sts.Spec.Replicas, 1)
165+
)
166+
167+
switch {
168+
case hasOwnerReference && replicas == 0:
169+
shouldUpdate = true
170+
err = controllerutil.RemoveOwnerReference(sts, wl, r.client.Scheme())
171+
case !hasOwnerReference && replicas > 0:
172+
shouldUpdate = true
173+
err = controllerutil.SetOwnerReference(sts, wl, r.client.Scheme())
174+
}
175+
if err != nil || !shouldUpdate {
176+
return err
177+
}
178+
179+
err = r.client.Update(ctx, wl)
180+
return err
181+
}
182+
142183
func (r *Reconciler) SetupWithManager(mgr ctrl.Manager) error {
143184
ctrl.Log.V(3).Info("Setting up StatefulSet reconciler")
144185
return ctrl.NewControllerManagedBy(mgr).

pkg/controller/jobs/statefulset/statefulset_reconciler_test.go

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@ import (
2727
"sigs.k8s.io/controller-runtime/pkg/client"
2828
"sigs.k8s.io/controller-runtime/pkg/reconcile"
2929

30+
kueue "sigs.k8s.io/kueue/apis/kueue/v1beta2"
3031
podconstants "sigs.k8s.io/kueue/pkg/controller/jobs/pod/constants"
3132
utiltesting "sigs.k8s.io/kueue/pkg/util/testing"
33+
utiltestingapi "sigs.k8s.io/kueue/pkg/util/testing/v1beta2"
3234
testingjobspod "sigs.k8s.io/kueue/pkg/util/testingjobs/pod"
3335
statefulsettesting "sigs.k8s.io/kueue/pkg/util/testingjobs/statefulset"
3436
)
@@ -45,8 +47,10 @@ func TestReconciler(t *testing.T) {
4547
stsKey client.ObjectKey
4648
statefulSet *appsv1.StatefulSet
4749
pods []corev1.Pod
50+
workloads []kueue.Workload
4851
wantStatefulSet *appsv1.StatefulSet
4952
wantPods []corev1.Pod
53+
wantWorkloads []kueue.Workload
5054
wantErr error
5155
}{
5256
"statefulset not found": {
@@ -170,14 +174,76 @@ func TestReconciler(t *testing.T) {
170174
Obj(),
171175
},
172176
},
177+
"should add StatefulSet to Workload owner references if replicas > 0": {
178+
stsKey: client.ObjectKey{Name: "sts", Namespace: "ns"},
179+
statefulSet: statefulsettesting.MakeStatefulSet("sts", "ns").
180+
UID("sts-uid").
181+
Queue("lq").
182+
Obj(),
183+
workloads: []kueue.Workload{
184+
*utiltestingapi.MakeWorkload(GetWorkloadName("sts"), "ns").
185+
Obj(),
186+
},
187+
wantStatefulSet: statefulsettesting.MakeStatefulSet("sts", "ns").
188+
UID("sts-uid").
189+
Queue("lq").
190+
DeepCopy(),
191+
wantWorkloads: []kueue.Workload{
192+
*utiltestingapi.MakeWorkload(GetWorkloadName("sts"), "ns").
193+
OwnerReference(gvk, "sts", "sts-uid").
194+
Obj(),
195+
},
196+
},
197+
"shouldn't add StatefulSet to Workload owner references if replicas = 0": {
198+
stsKey: client.ObjectKey{Name: "sts", Namespace: "ns"},
199+
statefulSet: statefulsettesting.MakeStatefulSet("sts", "ns").
200+
UID("sts-uid").
201+
Queue("lq").
202+
Obj(),
203+
workloads: []kueue.Workload{
204+
*utiltestingapi.MakeWorkload(GetWorkloadName("sts"), "ns").
205+
Obj(),
206+
},
207+
wantStatefulSet: statefulsettesting.MakeStatefulSet("sts", "ns").
208+
UID("sts-uid").
209+
Queue("lq").
210+
DeepCopy(),
211+
wantWorkloads: []kueue.Workload{
212+
*utiltestingapi.MakeWorkload(GetWorkloadName("sts"), "ns").
213+
OwnerReference(gvk, "sts", "sts-uid").
214+
Obj(),
215+
},
216+
},
217+
"should remove StatefulSet from Workload owner references if replicas = 0": {
218+
stsKey: client.ObjectKey{Name: "sts", Namespace: "ns"},
219+
statefulSet: statefulsettesting.MakeStatefulSet("sts", "ns").
220+
UID("sts-uid").
221+
Queue("lq").
222+
Replicas(0).
223+
Obj(),
224+
workloads: []kueue.Workload{
225+
*utiltestingapi.MakeWorkload(GetWorkloadName("sts"), "ns").
226+
OwnerReference(gvk, "sts", "sts-uid").
227+
Obj(),
228+
},
229+
wantStatefulSet: statefulsettesting.MakeStatefulSet("sts", "ns").
230+
UID("sts-uid").
231+
Queue("lq").
232+
Replicas(0).
233+
DeepCopy(),
234+
wantWorkloads: []kueue.Workload{
235+
*utiltestingapi.MakeWorkload(GetWorkloadName("sts"), "ns").
236+
Obj(),
237+
},
238+
},
173239
}
174240
for name, tc := range cases {
175241
t.Run(name, func(t *testing.T) {
176242
ctx, _ := utiltesting.ContextWithLog(t)
177243
clientBuilder := utiltesting.NewClientBuilder()
178244
indexer := utiltesting.AsIndexer(clientBuilder)
179245

180-
objs := make([]client.Object, 0, len(tc.pods)+1)
246+
objs := make([]client.Object, 0, len(tc.pods)+len(tc.workloads)+1)
181247
if tc.statefulSet != nil {
182248
objs = append(objs, tc.statefulSet)
183249
}
@@ -186,6 +252,10 @@ func TestReconciler(t *testing.T) {
186252
objs = append(objs, p.DeepCopy())
187253
}
188254

255+
for _, wl := range tc.workloads {
256+
objs = append(objs, wl.DeepCopy())
257+
}
258+
189259
kClient := clientBuilder.WithObjects(objs...).Build()
190260

191261
reconciler, err := NewReconciler(ctx, kClient, indexer, nil)
@@ -219,6 +289,15 @@ func TestReconciler(t *testing.T) {
219289
if diff := cmp.Diff(tc.wantPods, gotPodList.Items, baseCmpOpts...); diff != "" {
220290
t.Errorf("Pods after reconcile (-want,+got):\n%s", diff)
221291
}
292+
293+
gotWorkloadList := &kueue.WorkloadList{}
294+
if err := kClient.List(ctx, gotWorkloadList); err != nil {
295+
t.Fatalf("Could not get WorkloadList after reconcile: %v", err)
296+
}
297+
298+
if diff := cmp.Diff(tc.wantWorkloads, gotWorkloadList.Items, baseCmpOpts...); diff != "" {
299+
t.Errorf("Pods after reconcile (-want,+got):\n%s", diff)
300+
}
222301
})
223302
}
224303
}

test/e2e/singlecluster/statefulset_test.go

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,114 @@ var _ = ginkgo.Describe("StatefulSet integration", func() {
535535
g.Expect(createdHighPriorityWl.Status.Conditions).To(utiltesting.HaveConditionStatusTrue(kueue.WorkloadAdmitted))
536536
}, util.Timeout, util.Interval).Should(gomega.Succeed())
537537
})
538+
539+
ginkgo.By("Check the low priority Workload is preempted", func() {
540+
gomega.Eventually(func(g gomega.Gomega) {
541+
g.Expect(k8sClient.Get(ctx, lowPriorityWlKey, createdLowPriorityWl)).To(gomega.Succeed())
542+
g.Expect(createdLowPriorityWl.Status.Conditions).To(utiltesting.HaveConditionStatusFalse(kueue.WorkloadAdmitted))
543+
g.Expect(createdLowPriorityWl.Status.Conditions).To(utiltesting.HaveConditionStatusTrue(kueue.WorkloadPreempted))
544+
}, util.Timeout, util.Interval).Should(gomega.Succeed())
545+
})
546+
})
547+
})
548+
549+
ginkgo.When("Workload deactivated", func() {
550+
ginkgo.It("shouldn't delete deactivated Workload", func() {
551+
statefulSet := statefulsettesting.MakeStatefulSet("sts", ns.Name).
552+
Image(util.GetAgnHostImage(), util.BehaviorWaitForDeletion).
553+
RequestAndLimit(corev1.ResourceCPU, "200m").
554+
TerminationGracePeriod(1).
555+
Replicas(3).
556+
Queue(lq.Name).
557+
Obj()
558+
559+
ginkgo.By("Create StatefulSet", func() {
560+
gomega.Expect(k8sClient.Create(ctx, statefulSet)).To(gomega.Succeed())
561+
})
562+
563+
wlLookupKey := types.NamespacedName{Name: statefulset.GetWorkloadName(statefulSet.Name), Namespace: ns.Name}
564+
createdWorkload := &kueue.Workload{}
565+
ginkgo.By("Check workload is created", func() {
566+
gomega.Eventually(func(g gomega.Gomega) {
567+
g.Expect(k8sClient.Get(ctx, wlLookupKey, createdWorkload)).To(gomega.Succeed())
568+
g.Expect(createdWorkload.Status.Conditions).To(utiltesting.HaveConditionStatusTrue(kueue.WorkloadAdmitted))
569+
}, util.Timeout, util.Interval).Should(gomega.Succeed())
570+
})
571+
572+
createdWorkloadUID := createdWorkload.UID
573+
574+
ginkgo.By("Waiting for all replicas to be ready", func() {
575+
gomega.Eventually(func(g gomega.Gomega) {
576+
createdStatefulSet := &appsv1.StatefulSet{}
577+
g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed())
578+
g.Expect(createdStatefulSet.Status.ReadyReplicas).To(gomega.Equal(int32(3)))
579+
}, util.LongTimeout, util.Interval).Should(gomega.Succeed())
580+
})
581+
582+
ginkgo.By("Deactivate the workload", func() {
583+
gomega.Eventually(func(g gomega.Gomega) {
584+
g.Expect(k8sClient.Get(ctx, wlLookupKey, createdWorkload)).To(gomega.Succeed())
585+
createdWorkload.Spec.Active = ptr.To(false)
586+
g.Expect(k8sClient.Update(ctx, createdWorkload)).To(gomega.Succeed())
587+
}, util.Timeout, util.Interval).Should(gomega.Succeed())
588+
})
589+
590+
ginkgo.By("Check workload is deactivated", func() {
591+
gomega.Eventually(func(g gomega.Gomega) {
592+
g.Expect(k8sClient.Get(ctx, wlLookupKey, createdWorkload)).To(gomega.Succeed())
593+
g.Expect(createdWorkload.UID).Should(gomega.Equal(createdWorkloadUID))
594+
g.Expect(createdWorkload.Spec.Active).Should(gomega.Equal(ptr.To(false)))
595+
g.Expect(createdWorkload.Status.Conditions).Should(utiltesting.HaveConditionStatusTrue(kueue.WorkloadEvicted))
596+
}, util.LongTimeout, util.Interval).Should(gomega.Succeed())
597+
})
598+
599+
ginkgo.By("Waiting for all replicas to be not ready", func() {
600+
gomega.Eventually(func(g gomega.Gomega) {
601+
createdStatefulSet := &appsv1.StatefulSet{}
602+
g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed())
603+
g.Expect(createdStatefulSet.Status.ReadyReplicas).To(gomega.Equal(int32(0)))
604+
}, util.LongTimeout, util.Interval).Should(gomega.Succeed())
605+
})
606+
607+
ginkgo.By("Re-activate the workload", func() {
608+
gomega.Eventually(func(g gomega.Gomega) {
609+
g.Expect(k8sClient.Get(ctx, wlLookupKey, createdWorkload)).To(gomega.Succeed())
610+
createdWorkload.Spec.Active = ptr.To(true)
611+
g.Expect(k8sClient.Update(ctx, createdWorkload)).To(gomega.Succeed())
612+
}, util.Timeout, util.Interval).Should(gomega.Succeed())
613+
})
614+
615+
ginkgo.By("Check workload is re-admitted", func() {
616+
gomega.Eventually(func(g gomega.Gomega) {
617+
g.Expect(k8sClient.Get(ctx, wlLookupKey, createdWorkload)).To(gomega.Succeed())
618+
g.Expect(createdWorkload.UID).Should(gomega.Equal(createdWorkloadUID))
619+
g.Expect(createdWorkload.Status.Conditions).To(utiltesting.HaveConditionStatusTrue(kueue.WorkloadAdmitted))
620+
}, util.LongTimeout, util.Interval).Should(gomega.Succeed())
621+
})
622+
623+
ginkgo.By("Waiting for all replicas to be ready", func() {
624+
gomega.Eventually(func(g gomega.Gomega) {
625+
createdStatefulSet := &appsv1.StatefulSet{}
626+
g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed())
627+
g.Expect(createdStatefulSet.Status.ReadyReplicas).To(gomega.Equal(int32(3)))
628+
}, util.LongTimeout, util.Interval).Should(gomega.Succeed())
629+
})
630+
631+
ginkgo.By("Delete StatefulSet", func() {
632+
gomega.Expect(k8sClient.Delete(ctx, statefulSet)).To(gomega.Succeed())
633+
})
634+
635+
ginkgo.By("Check all pods are deleted", func() {
636+
pods := &corev1.PodList{}
637+
gomega.Eventually(func(g gomega.Gomega) {
638+
g.Expect(k8sClient.List(ctx, pods, client.InNamespace(ns.Name))).To(gomega.Succeed())
639+
g.Expect(pods.Items).Should(gomega.BeEmpty())
640+
}, util.LongTimeout, util.Interval).Should(gomega.Succeed())
641+
})
642+
643+
ginkgo.By("Check workload is deleted", func() {
644+
util.ExpectObjectToBeDeletedWithTimeout(ctx, k8sClient, createdWorkload, false, util.LongTimeout)
645+
})
538646
})
539647
})
540648
})

0 commit comments

Comments
 (0)