Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 54 additions & 13 deletions pkg/controller/jobs/statefulset/statefulset_reconciler.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"time"

"github.com/go-logr/logr"
"golang.org/x/sync/errgroup"
appsv1 "k8s.io/api/apps/v1"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
Expand All @@ -29,6 +30,7 @@ import (
"k8s.io/client-go/tools/record"
"k8s.io/client-go/util/workqueue"
"k8s.io/klog/v2"
"k8s.io/utils/ptr"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/controller/controllerutil"
Expand All @@ -37,6 +39,7 @@ import (
"sigs.k8s.io/controller-runtime/pkg/predicate"
"sigs.k8s.io/controller-runtime/pkg/reconcile"

kueue "sigs.k8s.io/kueue/apis/kueue/v1beta2"
"sigs.k8s.io/kueue/pkg/controller/jobframework"
podcontroller "sigs.k8s.io/kueue/pkg/controller/jobs/pod/constants"
clientutil "sigs.k8s.io/kueue/pkg/util/client"
Expand Down Expand Up @@ -65,34 +68,35 @@ func (r *Reconciler) Reconcile(ctx context.Context, req reconcile.Request) (reco
log := ctrl.LoggerFrom(ctx)
log.V(2).Info("Reconcile StatefulSet")

err := r.fetchAndFinalizePods(ctx, req)
return ctrl.Result{}, err
}

func (r *Reconciler) fetchAndFinalizePods(ctx context.Context, req reconcile.Request) error {
podList := &corev1.PodList{}
if err := r.client.List(ctx, podList, client.InNamespace(req.Namespace), client.MatchingLabels{
podcontroller.GroupNameLabel: GetWorkloadName(req.Name),
}); err != nil {
return err
}

// If no Pods are found, there's nothing to do.
if len(podList.Items) == 0 {
return nil
return ctrl.Result{}, err
}

sts := &appsv1.StatefulSet{}
err := r.client.Get(ctx, req.NamespacedName, sts)
if client.IgnoreNotFound(err) != nil {
return err
return ctrl.Result{}, err
}

if err != nil {
sts = nil
}

return r.finalizePods(ctx, sts, podList.Items)
eg, ctx := errgroup.WithContext(ctx)

eg.Go(func() error {
return r.finalizePods(ctx, sts, podList.Items)
})

eg.Go(func() error {
return r.reconcileWorkload(ctx, sts)
})

err = eg.Wait()
return ctrl.Result{}, err
}

func (r *Reconciler) finalizePods(ctx context.Context, sts *appsv1.StatefulSet, pods []corev1.Pod) error {
Expand Down Expand Up @@ -139,6 +143,43 @@ func shouldFinalize(sts *appsv1.StatefulSet, pod *corev1.Pod) bool {
return shouldUngate(sts, pod) || utilpod.IsTerminated(pod)
}

func (r *Reconciler) reconcileWorkload(ctx context.Context, sts *appsv1.StatefulSet) error {
if sts == nil {
return nil
}

wl := &kueue.Workload{}
err := r.client.Get(ctx, client.ObjectKey{Namespace: sts.Namespace, Name: GetWorkloadName(sts.Name)}, wl)
if err != nil {
return client.IgnoreNotFound(err)
}

hasOwnerReference, err := controllerutil.HasOwnerReference(wl.OwnerReferences, sts, r.client.Scheme())
if err != nil {
return err
}

var (
shouldUpdate = false
replicas = ptr.Deref(sts.Spec.Replicas, 1)
)

switch {
case hasOwnerReference && replicas == 0:
shouldUpdate = true
err = controllerutil.RemoveOwnerReference(sts, wl, r.client.Scheme())
Comment on lines +169 to +170
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you describe why we remove the owner reference from the workload on scaling down the StatefulSet?

IIUC this will result in deleting the workload, right? It seems safer to just update the count in the workload object, similarly as we would do for Jobs. For example, when we suspend a Job we don't delete the workload IIRC.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, when we suspend a Job we don't delete the workload IIRC.

It is a very interesting behavior in Job. I tried to suspend the job manually, but it automatically unsuspends itself. The only way to suspend the job is to deactivate the workload.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you describe why we remove the owner reference from the workload on scaling down the StatefulSet?

The problem is that after activation, the replicas can change (replicas=3 → replicas=0 → replicas=5). Our workload is admitted, and we can't change the PodSet count since this field is immutable. I think it's much easier to just delete it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC this will result in deleting the workload, right?

Yeah, you right. It should delete workload.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, for Job when the count is updated, then we update the count in workload.

So, my natural preference is to also update the count in the workload for STS, not sure it is much more complex. I'm not sure myself - considering the options.

The problem is that after activation, the replicas can change (replicas=3 → replicas=0 → replicas=5). Our workload is admitted, and we can't change the PodSet count since this field is immutable. I think it's much easier to just delete it.

I see, but to me it suggests we should also evict the workload for STS in that case, to be consistent with Job.
Can you try to prototype, maybe in a separate PR so that we can compare. I don't think we need to rush with decisions. I would prefer to do it well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the investigation and answers. The PR looks good assuming we take the approach of deleting the workload, but I would prefer to explore keeping and updating the workload, as we do for Jobs.

case !hasOwnerReference && replicas > 0:
shouldUpdate = true
err = controllerutil.SetOwnerReference(sts, wl, r.client.Scheme())
}
if err != nil || !shouldUpdate {
return err
}

err = r.client.Update(ctx, wl)
return err
}

func (r *Reconciler) SetupWithManager(mgr ctrl.Manager) error {
ctrl.Log.V(3).Info("Setting up StatefulSet reconciler")
return ctrl.NewControllerManagedBy(mgr).
Expand Down
81 changes: 80 additions & 1 deletion pkg/controller/jobs/statefulset/statefulset_reconciler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ import (
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/reconcile"

kueue "sigs.k8s.io/kueue/apis/kueue/v1beta2"
podconstants "sigs.k8s.io/kueue/pkg/controller/jobs/pod/constants"
utiltesting "sigs.k8s.io/kueue/pkg/util/testing"
utiltestingapi "sigs.k8s.io/kueue/pkg/util/testing/v1beta2"
testingjobspod "sigs.k8s.io/kueue/pkg/util/testingjobs/pod"
statefulsettesting "sigs.k8s.io/kueue/pkg/util/testingjobs/statefulset"
)
Expand All @@ -45,8 +47,10 @@ func TestReconciler(t *testing.T) {
stsKey client.ObjectKey
statefulSet *appsv1.StatefulSet
pods []corev1.Pod
workloads []kueue.Workload
wantStatefulSet *appsv1.StatefulSet
wantPods []corev1.Pod
wantWorkloads []kueue.Workload
wantErr error
}{
"statefulset not found": {
Expand Down Expand Up @@ -170,14 +174,76 @@ func TestReconciler(t *testing.T) {
Obj(),
},
},
"should add StatefulSet to Workload owner references if replicas > 0": {
stsKey: client.ObjectKey{Name: "sts", Namespace: "ns"},
statefulSet: statefulsettesting.MakeStatefulSet("sts", "ns").
UID("sts-uid").
Queue("lq").
Obj(),
workloads: []kueue.Workload{
*utiltestingapi.MakeWorkload(GetWorkloadName("sts"), "ns").
Obj(),
},
wantStatefulSet: statefulsettesting.MakeStatefulSet("sts", "ns").
UID("sts-uid").
Queue("lq").
DeepCopy(),
wantWorkloads: []kueue.Workload{
*utiltestingapi.MakeWorkload(GetWorkloadName("sts"), "ns").
OwnerReference(gvk, "sts", "sts-uid").
Obj(),
},
},
"shouldn't add StatefulSet to Workload owner references if replicas = 0": {
stsKey: client.ObjectKey{Name: "sts", Namespace: "ns"},
statefulSet: statefulsettesting.MakeStatefulSet("sts", "ns").
UID("sts-uid").
Queue("lq").
Obj(),
workloads: []kueue.Workload{
*utiltestingapi.MakeWorkload(GetWorkloadName("sts"), "ns").
Obj(),
},
wantStatefulSet: statefulsettesting.MakeStatefulSet("sts", "ns").
UID("sts-uid").
Queue("lq").
DeepCopy(),
wantWorkloads: []kueue.Workload{
*utiltestingapi.MakeWorkload(GetWorkloadName("sts"), "ns").
OwnerReference(gvk, "sts", "sts-uid").
Obj(),
},
},
"should remove StatefulSet from Workload owner references if replicas = 0": {
stsKey: client.ObjectKey{Name: "sts", Namespace: "ns"},
statefulSet: statefulsettesting.MakeStatefulSet("sts", "ns").
UID("sts-uid").
Queue("lq").
Replicas(0).
Obj(),
workloads: []kueue.Workload{
*utiltestingapi.MakeWorkload(GetWorkloadName("sts"), "ns").
OwnerReference(gvk, "sts", "sts-uid").
Obj(),
},
wantStatefulSet: statefulsettesting.MakeStatefulSet("sts", "ns").
UID("sts-uid").
Queue("lq").
Replicas(0).
DeepCopy(),
wantWorkloads: []kueue.Workload{
*utiltestingapi.MakeWorkload(GetWorkloadName("sts"), "ns").
Obj(),
},
},
}
for name, tc := range cases {
t.Run(name, func(t *testing.T) {
ctx, _ := utiltesting.ContextWithLog(t)
clientBuilder := utiltesting.NewClientBuilder()
indexer := utiltesting.AsIndexer(clientBuilder)

objs := make([]client.Object, 0, len(tc.pods)+1)
objs := make([]client.Object, 0, len(tc.pods)+len(tc.workloads)+1)
if tc.statefulSet != nil {
objs = append(objs, tc.statefulSet)
}
Expand All @@ -186,6 +252,10 @@ func TestReconciler(t *testing.T) {
objs = append(objs, p.DeepCopy())
}

for _, wl := range tc.workloads {
objs = append(objs, wl.DeepCopy())
}

kClient := clientBuilder.WithObjects(objs...).Build()

reconciler, err := NewReconciler(ctx, kClient, indexer, nil)
Expand Down Expand Up @@ -219,6 +289,15 @@ func TestReconciler(t *testing.T) {
if diff := cmp.Diff(tc.wantPods, gotPodList.Items, baseCmpOpts...); diff != "" {
t.Errorf("Pods after reconcile (-want,+got):\n%s", diff)
}

gotWorkloadList := &kueue.WorkloadList{}
if err := kClient.List(ctx, gotWorkloadList); err != nil {
t.Fatalf("Could not get WorkloadList after reconcile: %v", err)
}

if diff := cmp.Diff(tc.wantWorkloads, gotWorkloadList.Items, baseCmpOpts...); diff != "" {
t.Errorf("Pods after reconcile (-want,+got):\n%s", diff)
}
})
}
}
108 changes: 108 additions & 0 deletions test/e2e/singlecluster/statefulset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,114 @@ var _ = ginkgo.Describe("StatefulSet integration", func() {
g.Expect(createdHighPriorityWl.Status.Conditions).To(utiltesting.HaveConditionStatusTrue(kueue.WorkloadAdmitted))
}, util.Timeout, util.Interval).Should(gomega.Succeed())
})

ginkgo.By("Check the low priority Workload is preempted", func() {
gomega.Eventually(func(g gomega.Gomega) {
g.Expect(k8sClient.Get(ctx, lowPriorityWlKey, createdLowPriorityWl)).To(gomega.Succeed())
g.Expect(createdLowPriorityWl.Status.Conditions).To(utiltesting.HaveConditionStatusFalse(kueue.WorkloadAdmitted))
g.Expect(createdLowPriorityWl.Status.Conditions).To(utiltesting.HaveConditionStatusTrue(kueue.WorkloadPreempted))
}, util.Timeout, util.Interval).Should(gomega.Succeed())
})
})
})

ginkgo.When("Workload deactivated", func() {
ginkgo.It("shouldn't delete deactivated Workload", func() {
statefulSet := statefulsettesting.MakeStatefulSet("sts", ns.Name).
Image(util.GetAgnHostImage(), util.BehaviorWaitForDeletion).
RequestAndLimit(corev1.ResourceCPU, "200m").
TerminationGracePeriod(1).
Replicas(3).
Queue(lq.Name).
Obj()

ginkgo.By("Create StatefulSet", func() {
gomega.Expect(k8sClient.Create(ctx, statefulSet)).To(gomega.Succeed())
})

wlLookupKey := types.NamespacedName{Name: statefulset.GetWorkloadName(statefulSet.Name), Namespace: ns.Name}
createdWorkload := &kueue.Workload{}
ginkgo.By("Check workload is created", func() {
gomega.Eventually(func(g gomega.Gomega) {
g.Expect(k8sClient.Get(ctx, wlLookupKey, createdWorkload)).To(gomega.Succeed())
g.Expect(createdWorkload.Status.Conditions).To(utiltesting.HaveConditionStatusTrue(kueue.WorkloadAdmitted))
}, util.Timeout, util.Interval).Should(gomega.Succeed())
})

createdWorkloadUID := createdWorkload.UID

ginkgo.By("Waiting for all replicas to be ready", func() {
gomega.Eventually(func(g gomega.Gomega) {
createdStatefulSet := &appsv1.StatefulSet{}
g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed())
g.Expect(createdStatefulSet.Status.ReadyReplicas).To(gomega.Equal(int32(3)))
}, util.LongTimeout, util.Interval).Should(gomega.Succeed())
})

ginkgo.By("Deactivate the workload", func() {
gomega.Eventually(func(g gomega.Gomega) {
g.Expect(k8sClient.Get(ctx, wlLookupKey, createdWorkload)).To(gomega.Succeed())
createdWorkload.Spec.Active = ptr.To(false)
g.Expect(k8sClient.Update(ctx, createdWorkload)).To(gomega.Succeed())
}, util.Timeout, util.Interval).Should(gomega.Succeed())
})

ginkgo.By("Check workload is deactivated", func() {
gomega.Eventually(func(g gomega.Gomega) {
g.Expect(k8sClient.Get(ctx, wlLookupKey, createdWorkload)).To(gomega.Succeed())
g.Expect(createdWorkload.UID).Should(gomega.Equal(createdWorkloadUID))
g.Expect(createdWorkload.Spec.Active).Should(gomega.Equal(ptr.To(false)))
g.Expect(createdWorkload.Status.Conditions).Should(utiltesting.HaveConditionStatusTrue(kueue.WorkloadEvicted))
}, util.LongTimeout, util.Interval).Should(gomega.Succeed())
})

ginkgo.By("Waiting for all replicas to be not ready", func() {
gomega.Eventually(func(g gomega.Gomega) {
createdStatefulSet := &appsv1.StatefulSet{}
g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed())
g.Expect(createdStatefulSet.Status.ReadyReplicas).To(gomega.Equal(int32(0)))
}, util.LongTimeout, util.Interval).Should(gomega.Succeed())
})

ginkgo.By("Re-activate the workload", func() {
gomega.Eventually(func(g gomega.Gomega) {
g.Expect(k8sClient.Get(ctx, wlLookupKey, createdWorkload)).To(gomega.Succeed())
createdWorkload.Spec.Active = ptr.To(true)
g.Expect(k8sClient.Update(ctx, createdWorkload)).To(gomega.Succeed())
}, util.Timeout, util.Interval).Should(gomega.Succeed())
})

ginkgo.By("Check workload is re-admitted", func() {
gomega.Eventually(func(g gomega.Gomega) {
g.Expect(k8sClient.Get(ctx, wlLookupKey, createdWorkload)).To(gomega.Succeed())
g.Expect(createdWorkload.UID).Should(gomega.Equal(createdWorkloadUID))
g.Expect(createdWorkload.Status.Conditions).To(utiltesting.HaveConditionStatusTrue(kueue.WorkloadAdmitted))
}, util.LongTimeout, util.Interval).Should(gomega.Succeed())
})

ginkgo.By("Waiting for all replicas to be ready", func() {
gomega.Eventually(func(g gomega.Gomega) {
createdStatefulSet := &appsv1.StatefulSet{}
g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed())
g.Expect(createdStatefulSet.Status.ReadyReplicas).To(gomega.Equal(int32(3)))
}, util.LongTimeout, util.Interval).Should(gomega.Succeed())
})

ginkgo.By("Delete StatefulSet", func() {
gomega.Expect(k8sClient.Delete(ctx, statefulSet)).To(gomega.Succeed())
})

ginkgo.By("Check all pods are deleted", func() {
pods := &corev1.PodList{}
gomega.Eventually(func(g gomega.Gomega) {
g.Expect(k8sClient.List(ctx, pods, client.InNamespace(ns.Name))).To(gomega.Succeed())
g.Expect(pods.Items).Should(gomega.BeEmpty())
}, util.LongTimeout, util.Interval).Should(gomega.Succeed())
})

ginkgo.By("Check workload is deleted", func() {
util.ExpectObjectToBeDeletedWithTimeout(ctx, k8sClient, createdWorkload, false, util.LongTimeout)
})
})
})
})