Skip to content

Commit d9bac92

Browse files
authored
Return error if cluster metadata is invalid (#3879)
* Return error if cluster metadata is invalid
1 parent af00719 commit d9bac92

File tree

3 files changed

+74
-62
lines changed

3 files changed

+74
-62
lines changed

service/history/replication/poller_manager.go

+15-10
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,15 @@
2525
package replication
2626

2727
import (
28+
"errors"
2829
"fmt"
2930

3031
"go.temporal.io/server/common/cluster"
3132
)
3233

3334
type (
3435
pollerManager interface {
35-
getSourceClusterShardIDs(sourceClusterName string) []int32
36+
getSourceClusterShardIDs(sourceClusterName string) ([]int32, error)
3637
}
3738

3839
pollerManagerImpl struct {
@@ -53,18 +54,27 @@ func newPollerManager(
5354
}
5455
}
5556

56-
func (p pollerManagerImpl) getSourceClusterShardIDs(sourceClusterName string) []int32 {
57+
func (p pollerManagerImpl) getSourceClusterShardIDs(sourceClusterName string) ([]int32, error) {
5758
currentCluster := p.clusterMetadata.GetCurrentClusterName()
5859
allClusters := p.clusterMetadata.GetAllClusterInfo()
5960
currentClusterInfo, ok := allClusters[currentCluster]
6061
if !ok {
61-
panic("Cannot get current cluster info from cluster metadata cache")
62+
return nil, errors.New("cannot get current cluster info from cluster metadata cache")
6263
}
6364
remoteClusterInfo, ok := allClusters[sourceClusterName]
6465
if !ok {
65-
panic(fmt.Sprintf("Cannot get source cluster %s info from cluster metadata cache", sourceClusterName))
66+
return nil, errors.New(fmt.Sprintf("cannot get source cluster %s info from cluster metadata cache", sourceClusterName))
6667
}
67-
return generateShardIDs(p.currentShardId, currentClusterInfo.ShardCount, remoteClusterInfo.ShardCount)
68+
69+
// The remote shard count and local shard count must be multiples.
70+
large, small := remoteClusterInfo.ShardCount, currentClusterInfo.ShardCount
71+
if small > large {
72+
large, small = small, large
73+
}
74+
if large%small != 0 {
75+
return nil, errors.New(fmt.Sprintf("remote shard count %d and local shard count %d are not multiples.", remoteClusterInfo.ShardCount, currentClusterInfo.ShardCount))
76+
}
77+
return generateShardIDs(p.currentShardId, currentClusterInfo.ShardCount, remoteClusterInfo.ShardCount), nil
6878
}
6979

7080
func generateShardIDs(localShardId int32, localShardCount int32, remoteShardCount int32) []int32 {
@@ -75,12 +85,7 @@ func generateShardIDs(localShardId int32, localShardCount int32, remoteShardCoun
7585
}
7686
return pollingShards
7787
}
78-
7988
// remoteShardCount > localShardCount, replication poller will poll from multiple remote shard.
80-
// The remote shard count and local shard count must be multiples.
81-
if remoteShardCount%localShardCount != 0 {
82-
panic(fmt.Sprintf("Remote shard count %d and local shard count %d are not multiples.", remoteShardCount, localShardCount))
83-
}
8489
for i := localShardId; i <= remoteShardCount; i += localShardCount {
8590
pollingShards = append(pollingShards, i)
8691
}

service/history/replication/poller_manager_test.go

+8-15
Original file line numberDiff line numberDiff line change
@@ -36,60 +36,53 @@ func TestGetPollingShardIds(t *testing.T) {
3636
shardID int32
3737
remoteShardCount int32
3838
localShardCount int32
39-
expectedPanic bool
4039
expectedShardIDs []int32
4140
}{
4241
{
4342
1,
4443
4,
4544
4,
46-
false,
4745
[]int32{1},
4846
},
4947
{
5048
1,
5149
2,
5250
4,
53-
false,
5451
[]int32{1},
5552
},
5653
{
5754
3,
5855
2,
5956
4,
60-
false,
61-
[]int32{},
57+
nil,
6258
},
6359
{
6460
1,
6561
16,
6662
4,
67-
false,
6863
[]int32{1, 5, 9, 13},
6964
},
7065
{
7166
4,
7267
16,
7368
4,
74-
false,
7569
[]int32{4, 8, 12, 16},
7670
},
7771
{
7872
4,
7973
17,
8074
4,
81-
true,
82-
[]int32{},
75+
[]int32{4, 8, 12, 16},
76+
},
77+
{
78+
1,
79+
17,
80+
4,
81+
[]int32{1, 5, 9, 13, 17},
8382
},
8483
}
8584
for idx, tt := range testCases {
8685
t.Run(fmt.Sprintf("Testcase %d", idx), func(t *testing.T) {
87-
t.Parallel()
88-
defer func() {
89-
if r := recover(); tt.expectedPanic && r == nil {
90-
t.Errorf("The code did not panic")
91-
}
92-
}()
9386
shardIDs := generateShardIDs(tt.shardID, tt.localShardCount, tt.remoteShardCount)
9487
assert.Equal(t, tt.expectedShardIDs, shardIDs)
9588
})

service/history/replication/task_processor_manager.go

+51-37
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ package replication
2626

2727
import (
2828
"context"
29-
"fmt"
3029
"sync"
3130
"sync/atomic"
3231
"time"
@@ -72,7 +71,7 @@ type (
7271
logger log.Logger
7372

7473
taskProcessorLock sync.RWMutex
75-
taskProcessors map[string]TaskProcessor
74+
taskProcessors map[string][]TaskProcessor // cluster name - processor
7675
minTxAckedTaskID int64
7776
shutdownChan chan struct{}
7877
}
@@ -114,7 +113,7 @@ func NewTaskProcessorManager(
114113
),
115114
logger: shard.GetLogger(),
116115
metricsHandler: shard.GetMetricsHandler(),
117-
taskProcessors: make(map[string]TaskProcessor),
116+
taskProcessors: make(map[string][]TaskProcessor),
118117
taskExecutorProvider: taskExecutorProvider,
119118
taskPollerManager: newPollerManager(shard.GetShardID(), shard.GetClusterMetadata()),
120119
minTxAckedTaskID: persistence.EmptyQueueMessageID,
@@ -149,8 +148,10 @@ func (r *taskProcessorManagerImpl) Stop() {
149148

150149
r.shard.GetClusterMetadata().UnRegisterMetadataChangeCallback(r)
151150
r.taskProcessorLock.Lock()
152-
for _, replicationTaskProcessor := range r.taskProcessors {
153-
replicationTaskProcessor.Stop()
151+
for _, taskProcessors := range r.taskProcessors {
152+
for _, processor := range taskProcessors {
153+
processor.Stop()
154+
}
154155
}
155156
r.taskProcessorLock.Unlock()
156157
}
@@ -170,44 +171,57 @@ func (r *taskProcessorManagerImpl) handleClusterMetadataUpdate(
170171
r.taskProcessorLock.Lock()
171172
defer r.taskProcessorLock.Unlock()
172173
currentClusterName := r.shard.GetClusterMetadata().GetCurrentClusterName()
174+
// The metadata triggers an update when the following fields update: 1. Enabled 2. Initial Failover Version 3. Cluster address
175+
// The callback covers three cases:
176+
// Case 1: Remove a cluster Case 2: Add a new cluster Case 3: Refresh cluster metadata(1 + 2).
177+
178+
// Case 1 and Case 3
173179
for clusterName := range oldClusterMetadata {
174180
if clusterName == currentClusterName {
175181
continue
176182
}
177-
sourceShardIds := r.taskPollerManager.getSourceClusterShardIDs(clusterName)
183+
for _, processor := range r.taskProcessors[clusterName] {
184+
processor.Stop()
185+
delete(r.taskProcessors, clusterName)
186+
}
187+
}
188+
189+
// Case 2 and Case 3
190+
for clusterName := range newClusterMetadata {
191+
if clusterName == currentClusterName {
192+
continue
193+
}
194+
if clusterInfo := newClusterMetadata[clusterName]; clusterInfo == nil || !clusterInfo.Enabled {
195+
continue
196+
}
197+
sourceShardIds, err := r.taskPollerManager.getSourceClusterShardIDs(clusterName)
198+
if err != nil {
199+
r.logger.Error("Failed to get source shard id list", tag.Error(err), tag.ClusterName(clusterName))
200+
continue
201+
}
202+
var processors []TaskProcessor
178203
for _, sourceShardId := range sourceShardIds {
179-
perShardTaskProcessorKey := fmt.Sprintf(clusterCallbackKey, clusterName, sourceShardId)
180-
// The metadata triggers an update when the following fields update: 1. Enabled 2. Initial Failover Version 3. Cluster address
181-
// The callback covers three cases:
182-
// Case 1: Remove a cluster Case 2: Add a new cluster Case 3: Refresh cluster metadata.
183-
if processor, ok := r.taskProcessors[perShardTaskProcessorKey]; ok {
184-
// Case 1 and Case 3
185-
processor.Stop()
186-
delete(r.taskProcessors, perShardTaskProcessorKey)
187-
}
188-
if clusterInfo := newClusterMetadata[clusterName]; clusterInfo != nil && clusterInfo.Enabled {
189-
// Case 2 and Case 3
190-
fetcher := r.replicationTaskFetcherFactory.GetOrCreateFetcher(clusterName)
191-
replicationTaskProcessor := NewTaskProcessor(
192-
sourceShardId,
193-
r.shard,
194-
r.engine,
195-
r.config,
196-
r.shard.GetMetricsHandler(),
197-
fetcher,
198-
r.taskExecutorProvider(TaskExecutorParams{
199-
RemoteCluster: clusterName,
200-
Shard: r.shard,
201-
HistoryResender: r.resender,
202-
DeleteManager: r.deleteMgr,
203-
WorkflowCache: r.workflowCache,
204-
}),
205-
r.eventSerializer,
206-
)
207-
replicationTaskProcessor.Start()
208-
r.taskProcessors[perShardTaskProcessorKey] = replicationTaskProcessor
209-
}
204+
fetcher := r.replicationTaskFetcherFactory.GetOrCreateFetcher(clusterName)
205+
replicationTaskProcessor := NewTaskProcessor(
206+
sourceShardId,
207+
r.shard,
208+
r.engine,
209+
r.config,
210+
r.shard.GetMetricsHandler(),
211+
fetcher,
212+
r.taskExecutorProvider(TaskExecutorParams{
213+
RemoteCluster: clusterName,
214+
Shard: r.shard,
215+
HistoryResender: r.resender,
216+
DeleteManager: r.deleteMgr,
217+
WorkflowCache: r.workflowCache,
218+
}),
219+
r.eventSerializer,
220+
)
221+
replicationTaskProcessor.Start()
222+
processors = append(processors, replicationTaskProcessor)
210223
}
224+
r.taskProcessors[clusterName] = processors
211225
}
212226
}
213227

0 commit comments

Comments
 (0)