Skip to content

Commit 7215a32

Browse files
authored
Validate shard id in shard controller (#3776)
1 parent 32b5d91 commit 7215a32

File tree

2 files changed

+45
-14
lines changed

2 files changed

+45
-14
lines changed

service/history/shard/controller_impl.go

+28-8
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import (
3333
"time"
3434

3535
"go.opentelemetry.io/otel/trace"
36+
"go.temporal.io/api/serviceerror"
3637

3738
"go.temporal.io/server/api/historyservice/v1"
3839
"go.temporal.io/server/client"
@@ -58,6 +59,11 @@ const (
5859
shardControllerMembershipUpdateListenerName = "ShardController"
5960
)
6061

62+
var (
63+
invalidShardIdLowerBound = serviceerror.NewInvalidArgument("shard Id cannot be equal or lower than zero")
64+
invalidShardIdUpperBound = serviceerror.NewInvalidArgument("shard Id cannot be larger than max shard count")
65+
)
66+
6167
type (
6268
ControllerImpl struct {
6369
membershipUpdateCh chan *membership.ChangedEvent
@@ -211,6 +217,17 @@ func (c *ControllerImpl) CloseShardByID(shardID int32) {
211217
}
212218
}
213219

220+
func (c *ControllerImpl) ShardIDs() []int32 {
221+
c.RLock()
222+
defer c.RUnlock()
223+
224+
ids := make([]int32, 0, len(c.historyShards))
225+
for id := range c.historyShards {
226+
ids = append(ids, id)
227+
}
228+
return ids
229+
}
230+
214231
func (c *ControllerImpl) shardClosedCallback(shard *ContextImpl) {
215232
startTime := time.Now().UTC()
216233
defer func() {
@@ -231,6 +248,10 @@ func (c *ControllerImpl) shardClosedCallback(shard *ContextImpl) {
231248
// if necessary. If a shard context is created, it will initialize in the background.
232249
// This function won't block on rangeid lease acquisition.
233250
func (c *ControllerImpl) getOrCreateShardContext(shardID int32) (*ContextImpl, error) {
251+
err := c.validateShardId(shardID)
252+
if err != nil {
253+
return nil, err
254+
}
234255
c.RLock()
235256
if shard, ok := c.historyShards[shardID]; ok {
236257
if shard.isValid() {
@@ -443,15 +464,14 @@ func (c *ControllerImpl) doShutdown() {
443464
c.historyShards = nil
444465
}
445466

446-
func (c *ControllerImpl) ShardIDs() []int32 {
447-
c.RLock()
448-
defer c.RUnlock()
449-
450-
ids := make([]int32, 0, len(c.historyShards))
451-
for id := range c.historyShards {
452-
ids = append(ids, id)
467+
func (c *ControllerImpl) validateShardId(shardID int32) error {
468+
if shardID <= 0 {
469+
return invalidShardIdLowerBound
453470
}
454-
return ids
471+
if shardID > c.config.NumberOfShards {
472+
return invalidShardIdUpperBound
473+
}
474+
return nil
455475
}
456476

457477
func IsShardOwnershipLostError(err error) bool {

service/history/shard/controller_test.go

+17-6
Original file line numberDiff line numberDiff line change
@@ -594,9 +594,9 @@ func (s *controllerSuite) TestShardExplicitUnload() {
594594
s.mockClusterMetadata.EXPECT().GetAllClusterInfo().Return(cluster.TestSingleDCClusterInfo).AnyTimes()
595595
mockEngine := NewMockEngine(s.controller)
596596
mockEngine.EXPECT().Stop().AnyTimes()
597-
s.setupMocksForAcquireShard(0, mockEngine, 5, 6, false)
597+
s.setupMocksForAcquireShard(1, mockEngine, 5, 6, false)
598598

599-
shard, err := s.shardController.getOrCreateShardContext(0)
599+
shard, err := s.shardController.getOrCreateShardContext(1)
600600
s.NoError(err)
601601
s.Equal(1, len(s.shardController.ShardIDs()))
602602

@@ -618,7 +618,7 @@ func (s *controllerSuite) TestShardExplicitUnloadCancelGetOrCreate() {
618618
mockEngine := NewMockEngine(s.controller)
619619
mockEngine.EXPECT().Stop().AnyTimes()
620620

621-
shardID := int32(0)
621+
shardID := int32(1)
622622
s.mockServiceResolver.EXPECT().Lookup(convert.Int32ToString(shardID)).Return(s.hostInfo, nil)
623623

624624
ready := make(chan struct{})
@@ -638,7 +638,7 @@ func (s *controllerSuite) TestShardExplicitUnloadCancelGetOrCreate() {
638638
})
639639

640640
// get shard, will start initializing in background
641-
shard, err := s.shardController.getOrCreateShardContext(0)
641+
shard, err := s.shardController.getOrCreateShardContext(1)
642642
s.NoError(err)
643643

644644
<-ready
@@ -659,7 +659,7 @@ func (s *controllerSuite) TestShardExplicitUnloadCancelAcquire() {
659659
mockEngine := NewMockEngine(s.controller)
660660
mockEngine.EXPECT().Stop().AnyTimes()
661661

662-
shardID := int32(0)
662+
shardID := int32(1)
663663
s.mockServiceResolver.EXPECT().Lookup(convert.Int32ToString(shardID)).Return(s.hostInfo, nil)
664664
// return success from GetOrCreateShard
665665
s.mockShardManager.EXPECT().GetOrCreateShard(gomock.Any(), getOrCreateShardRequestMatcher(shardID)).Return(
@@ -691,7 +691,7 @@ func (s *controllerSuite) TestShardExplicitUnloadCancelAcquire() {
691691
})
692692

693693
// get shard, will start initializing in background
694-
shard, err := s.shardController.getOrCreateShardContext(0)
694+
shard, err := s.shardController.getOrCreateShardContext(1)
695695
s.NoError(err)
696696

697697
<-ready
@@ -834,6 +834,17 @@ func (s *controllerSuite) TestShardControllerFuzz() {
834834
}, 1*time.Second, 50*time.Millisecond, "engine start/stop")
835835
}
836836

837+
func (s *controllerSuite) Test_GetOrCreateShard_InvalidShardID() {
838+
numShards := int32(2)
839+
s.config.NumberOfShards = numShards
840+
841+
_, err := s.shardController.getOrCreateShardContext(0)
842+
s.ErrorIs(err, invalidShardIdLowerBound)
843+
844+
_, err = s.shardController.getOrCreateShardContext(3)
845+
s.ErrorIs(err, invalidShardIdUpperBound)
846+
}
847+
837848
func (s *controllerSuite) setupMocksForAcquireShard(
838849
shardID int32,
839850
mockEngine *MockEngine,

0 commit comments

Comments
 (0)