Skip to content

Commit 5a447e4

Browse files
Handle starting the srv topo watch and add unit tests
Signed-off-by: siddharth16396 <[email protected]>
1 parent 5915807 commit 5a447e4

File tree

2 files changed

+298
-4
lines changed

2 files changed

+298
-4
lines changed

go/vt/vttablet/tabletserver/querythrottler/query_throttler.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,14 @@ func (qt *QueryThrottler) Shutdown() {
9090
qt.mu.Lock()
9191
defer qt.mu.Unlock()
9292

93+
// Cancel the watch context to stop the watch goroutine
94+
if qt.cancelWatchContext != nil {
95+
qt.cancelWatchContext()
96+
}
97+
98+
// Reset the watch started flag to allow restarting the watch if needed
99+
qt.watchStarted.Store(false)
100+
93101
// Stop the current strategy to clean up any background processes
94102
if qt.strategyHandlerInstance != nil {
95103
qt.strategyHandlerInstance.Stop()
@@ -116,7 +124,7 @@ func (qt *QueryThrottler) InitDBConfig(keyspace string) {
116124
log.Infof("QueryThrottler: initialized with keyspace=%s", keyspace)
117125

118126
// Start the topo server watch post the keyspace is set.
119-
//qt.startSrvKeyspaceWatch()
127+
qt.startSrvKeyspaceWatch()
120128
}
121129

122130
// Throttle checks if the tablet is under heavy load

go/vt/vttablet/tabletserver/querythrottler/query_throttler_test.go

Lines changed: 289 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,16 @@ import (
2020
"context"
2121
"errors"
2222
"fmt"
23+
"sync"
2324
"testing"
2425
"time"
2526

2627
"vitess.io/vitess/go/vt/log"
2728
querypb "vitess.io/vitess/go/vt/proto/query"
2829
topodatapb "vitess.io/vitess/go/vt/proto/topodata"
2930
"vitess.io/vitess/go/vt/sqlparser"
30-
"vitess.io/vitess/go/vt/srvtopo/fakesrvtopo"
31+
"vitess.io/vitess/go/vt/srvtopo"
32+
"vitess.io/vitess/go/vt/srvtopo/srvtopotest"
3133
"vitess.io/vitess/go/vt/topo"
3234

3335
"vitess.io/vitess/go/vt/vttablet/tabletserver/querythrottler/registry"
@@ -80,7 +82,9 @@ func TestQueryThrottler_StrategyLifecycleManagement(t *testing.T) {
8082
}
8183
env := tabletenv.NewEnv(vtenv.NewTestEnv(), config, "TestThrottler")
8284

83-
iqt := NewQueryThrottler(ctx, throttler, env, &topodatapb.TabletAlias{Cell: "test-cell", Uid: uint32(123)}, &fakesrvtopo.FakeSrvTopo{})
85+
srvTopoServer := srvtopotest.NewPassthroughSrvTopoServer()
86+
87+
iqt := NewQueryThrottler(ctx, throttler, env, &topodatapb.TabletAlias{Cell: "test-cell", Uid: uint32(123)}, srvTopoServer)
8488

8589
// Verify initial strategy was started (NoOpStrategy in this case)
8690
require.NotNil(t, iqt.strategyHandlerInstance)
@@ -104,7 +108,9 @@ func TestQueryThrottler_Shutdown(t *testing.T) {
104108
env := tabletenv.NewEnv(vtenv.NewTestEnv(), config, "TestThrottler")
105109

106110
throttler := &throttle.Throttler{}
107-
iqt := NewQueryThrottler(ctx, throttler, env, &topodatapb.TabletAlias{Cell: "test-cell", Uid: uint32(123)}, &fakesrvtopo.FakeSrvTopo{})
111+
srvTopoServer := srvtopotest.NewPassthroughSrvTopoServer()
112+
113+
iqt := NewQueryThrottler(ctx, throttler, env, &topodatapb.TabletAlias{Cell: "test-cell", Uid: uint32(123)}, srvTopoServer)
108114

109115
// Should not panic when called multiple times
110116
iqt.Shutdown()
@@ -724,3 +730,283 @@ func TestIsConfigUpdateRequired(t *testing.T) {
724730
})
725731
}
726732
}
733+
734+
// TestQueryThrottler_startSrvKeyspaceWatch_InitialLoad tests that initial configuration is loaded successfully when GetSrvKeyspace succeeds.
735+
func TestQueryThrottler_startSrvKeyspaceWatch_InitialLoad(t *testing.T) {
736+
ctx, cancel := context.WithCancel(context.Background())
737+
defer cancel()
738+
739+
env := tabletenv.NewEnv(vtenv.NewTestEnv(), &tabletenv.TabletConfig{}, "TestThrottler")
740+
741+
srvTopoServer := srvtopotest.NewPassthroughSrvTopoServer()
742+
srvTopoServer.SrvKeyspace = createTestSrvKeyspace(true, registry.ThrottlingStrategyTabletThrottler, false)
743+
srvTopoServer.SrvKeyspaceError = nil
744+
745+
throttler := &throttle.Throttler{}
746+
qt := NewQueryThrottler(ctx, throttler, env, &topodatapb.TabletAlias{Cell: "test-cell", Uid: uint32(123)}, srvTopoServer)
747+
748+
qt.InitDBConfig("test_keyspace")
749+
750+
// Verify watch was started
751+
require.Eventually(t, func() bool {
752+
return qt.watchStarted.Load()
753+
}, 2*time.Second, 10*time.Millisecond, "Watch should have been started")
754+
755+
// Verify that the configuration was loaded correctly
756+
require.Eventually(t, func() bool {
757+
qt.mu.RLock()
758+
defer qt.mu.RUnlock()
759+
return qt.cfg.Enabled &&
760+
qt.cfg.StrategyName == registry.ThrottlingStrategyTabletThrottler &&
761+
!qt.cfg.DryRun
762+
}, 2*time.Second, 10*time.Millisecond, "Config should be loaded correctly: enabled=true, strategy=TabletThrottler, dryRun=false")
763+
764+
require.Equal(t, "test_keyspace", qt.keyspace, "Keyspace should be set correctly")
765+
}
766+
767+
// TestQueryThrottler_startSrvKeyspaceWatch_InitialLoadFailure tests that watch starts even when initial GetSrvKeyspace fails.
768+
func TestQueryThrottler_startSrvKeyspaceWatch_InitialLoadFailure(t *testing.T) {
769+
ctx, cancel := context.WithCancel(context.Background())
770+
defer cancel()
771+
772+
env := tabletenv.NewEnv(vtenv.NewTestEnv(), &tabletenv.TabletConfig{}, "TestThrottler")
773+
774+
// Configure PassthroughSrvTopoServer to return an error on GetSrvKeyspace
775+
srvTopoServer := srvtopotest.NewPassthroughSrvTopoServer()
776+
srvTopoServer.SrvKeyspace = nil
777+
srvTopoServer.SrvKeyspaceError = fmt.Errorf("failed to fetch keyspace")
778+
779+
throttler := &throttle.Throttler{}
780+
qt := NewQueryThrottler(ctx, throttler, env, &topodatapb.TabletAlias{Cell: "test-cell", Uid: uint32(123)}, srvTopoServer)
781+
782+
// Initialize with keyspace to trigger startSrvKeyspaceWatch
783+
qt.InitDBConfig("test_keyspace")
784+
785+
// Verify watch was started despite initial load failure
786+
require.Eventually(t, func() bool {
787+
return qt.watchStarted.Load()
788+
}, 2*time.Second, 10*time.Millisecond, "Watch should be started even if initial load fails")
789+
790+
require.Equal(t, "test_keyspace", qt.keyspace, "Keyspace should be set correctly")
791+
792+
// Configuration should remain at default (NoOpStrategy) due to failure
793+
require.Eventually(t, func() bool {
794+
qt.mu.RLock()
795+
defer qt.mu.RUnlock()
796+
return !qt.cfg.Enabled
797+
}, 2*time.Second, 10*time.Millisecond, "Config should remain disabled after initial load failure")
798+
}
799+
800+
// TestQueryThrottler_startSrvKeyspaceWatch_OnlyStartsOnce tests that watch only starts once even with concurrent calls (atomic flag protection).
801+
func TestQueryThrottler_startSrvKeyspaceWatch_OnlyStartsOnce(t *testing.T) {
802+
ctx, cancel := context.WithCancel(context.Background())
803+
defer cancel()
804+
805+
env := tabletenv.NewEnv(vtenv.NewTestEnv(), &tabletenv.TabletConfig{}, "TestThrottler")
806+
807+
srvTopoServer := srvtopotest.NewPassthroughSrvTopoServer()
808+
srvTopoServer.SrvKeyspace = createTestSrvKeyspace(true, registry.ThrottlingStrategyTabletThrottler, false)
809+
srvTopoServer.SrvKeyspaceError = nil
810+
811+
throttler := &throttle.Throttler{}
812+
qt := NewQueryThrottler(ctx, throttler, env, &topodatapb.TabletAlias{Cell: "test-cell", Uid: uint32(123)}, srvTopoServer)
813+
814+
qt.InitDBConfig("test_keyspace")
815+
816+
// Attempt to start the watch multiple times concurrently
817+
const numGoroutines = 10
818+
startedCount := 0
819+
var wg sync.WaitGroup
820+
var mu sync.Mutex
821+
822+
for i := 0; i < numGoroutines; i++ {
823+
wg.Add(1)
824+
go func() {
825+
defer wg.Done()
826+
// Each goroutine tries to start the watch
827+
qt.startSrvKeyspaceWatch()
828+
mu.Lock()
829+
startedCount++
830+
mu.Unlock()
831+
}()
832+
}
833+
834+
// Wait for all goroutines to complete
835+
wg.Wait()
836+
837+
// Verify that the watch was started exactly once (atomic flag prevents multiple starts)
838+
require.Eventually(t, func() bool {
839+
return qt.watchStarted.Load()
840+
}, 2*time.Second, 10*time.Millisecond, "Watch should have been started")
841+
842+
require.Equal(t, numGoroutines, startedCount, "All goroutines should have called startSrvKeyspaceWatch")
843+
}
844+
845+
// TestQueryThrottler_startSrvKeyspaceWatch_RequiredFieldsValidation tests that watch doesn't start when required fields are missing.
846+
func TestQueryThrottler_startSrvKeyspaceWatch_RequiredFieldsValidation(t *testing.T) {
847+
tests := []struct {
848+
name string
849+
srvTopoServer srvtopo.Server
850+
keyspace string
851+
expectedWatchFlag bool
852+
description string
853+
}{
854+
{
855+
name: "Nil srvTopoServer prevents watch start",
856+
srvTopoServer: nil,
857+
keyspace: "test_keyspace",
858+
expectedWatchFlag: false,
859+
description: "Watch should not start when srvTopoServer is nil",
860+
},
861+
{
862+
name: "Empty keyspace prevents watch start",
863+
srvTopoServer: srvtopotest.NewPassthroughSrvTopoServer(),
864+
keyspace: "",
865+
expectedWatchFlag: false,
866+
description: "Watch should not start when keyspace is empty",
867+
},
868+
{
869+
name: "Valid fields allow watch to start",
870+
srvTopoServer: srvtopotest.NewPassthroughSrvTopoServer(),
871+
keyspace: "test_keyspace",
872+
expectedWatchFlag: true,
873+
description: "Watch should start when all required fields are valid",
874+
},
875+
}
876+
877+
for _, tt := range tests {
878+
t.Run(tt.name, func(t *testing.T) {
879+
ctx, cancel := context.WithCancel(context.Background())
880+
defer cancel()
881+
882+
env := tabletenv.NewEnv(vtenv.NewTestEnv(), &tabletenv.TabletConfig{}, "TestThrottler")
883+
884+
throttler := &throttle.Throttler{}
885+
qt := NewQueryThrottler(ctx, throttler, env, &topodatapb.TabletAlias{Cell: "test-cell", Uid: uint32(123)}, tt.srvTopoServer)
886+
887+
qt.InitDBConfig(tt.keyspace)
888+
889+
qt.startSrvKeyspaceWatch()
890+
891+
if tt.expectedWatchFlag {
892+
require.Eventually(t, func() bool {
893+
return qt.watchStarted.Load()
894+
}, 2*time.Second, 10*time.Millisecond, tt.description)
895+
} else {
896+
// For negative cases, ensure the watch doesn't start within a reasonable time
897+
require.Never(t, func() bool {
898+
return qt.watchStarted.Load()
899+
}, 500*time.Millisecond, 10*time.Millisecond, tt.description)
900+
}
901+
})
902+
}
903+
}
904+
905+
// TestQueryThrottler_startSrvKeyspaceWatch_WatchCallback tests that WatchSrvKeyspace callback receives config updates and HandleConfigUpdate is invoked correctly.
906+
func TestQueryThrottler_startSrvKeyspaceWatch_WatchCallback(t *testing.T) {
907+
tests := []struct {
908+
name string
909+
enabled bool
910+
strategy registry.ThrottlingStrategy
911+
dryRun bool
912+
expectedEnabled bool
913+
expectedStrategy registry.ThrottlingStrategy
914+
expectedDryRun bool
915+
}{
916+
{
917+
name: "TabletThrottler strategy with enabled and no dry-run",
918+
enabled: true,
919+
strategy: registry.ThrottlingStrategyTabletThrottler,
920+
dryRun: false,
921+
expectedEnabled: true,
922+
expectedStrategy: registry.ThrottlingStrategyTabletThrottler,
923+
expectedDryRun: false,
924+
},
925+
{
926+
name: "TabletThrottler disabled with dry-run",
927+
enabled: false,
928+
strategy: registry.ThrottlingStrategyTabletThrottler,
929+
dryRun: true,
930+
expectedEnabled: false,
931+
expectedStrategy: registry.ThrottlingStrategyTabletThrottler,
932+
expectedDryRun: true,
933+
},
934+
}
935+
936+
for _, tt := range tests {
937+
t.Run(tt.name, func(t *testing.T) {
938+
ctx, cancel := context.WithCancel(context.Background())
939+
defer cancel()
940+
941+
env := tabletenv.NewEnv(vtenv.NewTestEnv(), &tabletenv.TabletConfig{}, "TestThrottler")
942+
943+
srvTopoServer := srvtopotest.NewPassthroughSrvTopoServer()
944+
srvTopoServer.SrvKeyspace = createTestSrvKeyspace(tt.enabled, tt.strategy, tt.dryRun)
945+
srvTopoServer.SrvKeyspaceError = nil
946+
947+
throttler := &throttle.Throttler{}
948+
qt := NewQueryThrottler(ctx, throttler, env, &topodatapb.TabletAlias{Cell: "test-cell", Uid: uint32(123)}, srvTopoServer)
949+
950+
qt.InitDBConfig("test_keyspace")
951+
952+
// Verify watch was started
953+
require.Eventually(t, func() bool {
954+
return qt.watchStarted.Load()
955+
}, 2*time.Second, 10*time.Millisecond, "Watch should have been started")
956+
957+
// Verify that HandleConfigUpdate was called by checking if the config was updated
958+
require.Eventually(t, func() bool {
959+
qt.mu.RLock()
960+
defer qt.mu.RUnlock()
961+
return qt.cfg.Enabled == tt.expectedEnabled &&
962+
qt.cfg.StrategyName == tt.expectedStrategy &&
963+
qt.cfg.DryRun == tt.expectedDryRun
964+
}, 2*time.Second, 10*time.Millisecond, "Config should be updated correctly after callback is invoked")
965+
966+
})
967+
}
968+
}
969+
970+
// TestQueryThrottler_startSrvKeyspaceWatch_ShutdownStopsWatch tests that Shutdown properly cancels the watch context and stops the watch goroutine.
971+
func TestQueryThrottler_startSrvKeyspaceWatch_ShutdownStopsWatch(t *testing.T) {
972+
ctx, cancel := context.WithCancel(context.Background())
973+
defer cancel()
974+
975+
env := tabletenv.NewEnv(vtenv.NewTestEnv(), &tabletenv.TabletConfig{}, "TestThrottler")
976+
977+
srvTopoServer := srvtopotest.NewPassthroughSrvTopoServer()
978+
srvTopoServer.SrvKeyspace = createTestSrvKeyspace(true, registry.ThrottlingStrategyTabletThrottler, false)
979+
srvTopoServer.SrvKeyspaceError = nil
980+
981+
throttler := &throttle.Throttler{}
982+
qt := NewQueryThrottler(ctx, throttler, env, &topodatapb.TabletAlias{Cell: "test-cell", Uid: uint32(123)}, srvTopoServer)
983+
984+
qt.InitDBConfig("test_keyspace")
985+
986+
// Verify watch was started
987+
require.Eventually(t, func() bool {
988+
return qt.watchStarted.Load()
989+
}, 2*time.Second, 10*time.Millisecond, "Watch should have been started before shutdown")
990+
991+
require.NotNil(t, qt.cancelWatchContext, "Cancel function should be set before shutdown")
992+
993+
// Call Shutdown to stop the watch
994+
qt.Shutdown()
995+
996+
// Verify that the watch started flag is reset
997+
require.Eventually(t, func() bool {
998+
return !qt.watchStarted.Load()
999+
}, 2*time.Second, 10*time.Millisecond, "Watch should be marked as not started after shutdown")
1000+
1001+
// Verify that the strategy was stopped
1002+
qt.mu.RLock()
1003+
strategyInstance := qt.strategyHandlerInstance
1004+
qt.mu.RUnlock()
1005+
require.NotNil(t, strategyInstance, "Strategy instance should still exist after shutdown")
1006+
1007+
// Call Shutdown again to ensure it doesn't panic
1008+
qt.Shutdown()
1009+
1010+
// Verify the watch flag remains false
1011+
require.False(t, qt.watchStarted.Load(), "Watch should remain not started after multiple shutdowns")
1012+
}

0 commit comments

Comments
 (0)