Skip to content

Commit 51fd1dd

Browse files
authored
Allow start many replication pollers in one shard (#3790)
* Allow start many replication poller in one shard
1 parent 0bf1137 commit 51fd1dd

6 files changed

+68
-42
lines changed

service/history/replication/poller_manager.go

+11-5
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,18 @@ import (
3131
)
3232

3333
type (
34+
pollerManager interface {
35+
getSourceClusterShardIDs(sourceClusterName string) []int32
36+
}
37+
3438
pollerManagerImpl struct {
3539
currentShardId int32
3640
clusterMetadata cluster.Metadata
3741
}
3842
)
3943

44+
var _ pollerManager = (*pollerManagerImpl)(nil)
45+
4046
func newPollerManager(
4147
currentShardId int32,
4248
clusterMetadata cluster.Metadata,
@@ -47,21 +53,21 @@ func newPollerManager(
4753
}
4854
}
4955

50-
func (p pollerManagerImpl) getPollingShardIDs(remoteClusterName string) []int32 {
56+
func (p pollerManagerImpl) getSourceClusterShardIDs(sourceClusterName string) []int32 {
5157
currentCluster := p.clusterMetadata.GetCurrentClusterName()
5258
allClusters := p.clusterMetadata.GetAllClusterInfo()
5359
currentClusterInfo, ok := allClusters[currentCluster]
5460
if !ok {
5561
panic("Cannot get current cluster info from cluster metadata cache")
5662
}
57-
remoteClusterInfo, ok := allClusters[remoteClusterName]
63+
remoteClusterInfo, ok := allClusters[sourceClusterName]
5864
if !ok {
59-
panic(fmt.Sprintf("Cannot get remote cluster %s info from cluster metadata cache", remoteClusterName))
65+
panic(fmt.Sprintf("Cannot get source cluster %s info from cluster metadata cache", sourceClusterName))
6066
}
61-
return generatePollingShardIDs(p.currentShardId, currentClusterInfo.ShardCount, remoteClusterInfo.ShardCount)
67+
return generateShardIDs(p.currentShardId, currentClusterInfo.ShardCount, remoteClusterInfo.ShardCount)
6268
}
6369

64-
func generatePollingShardIDs(localShardId int32, localShardCount int32, remoteShardCount int32) []int32 {
70+
func generateShardIDs(localShardId int32, localShardCount int32, remoteShardCount int32) []int32 {
6571
var pollingShards []int32
6672
if remoteShardCount <= localShardCount {
6773
if localShardId <= remoteShardCount {

service/history/replication/poller_manager_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ func TestGetPollingShardIds(t *testing.T) {
9090
t.Errorf("The code did not panic")
9191
}
9292
}()
93-
shardIDs := generatePollingShardIDs(tt.shardID, tt.localShardCount, tt.remoteShardCount)
93+
shardIDs := generateShardIDs(tt.shardID, tt.localShardCount, tt.remoteShardCount)
9494
assert.Equal(t, tt.expectedShardIDs, shardIDs)
9595
})
9696
}

service/history/replication/task_processor.go

+11-6
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,10 @@ type (
7474

7575
// taskProcessorImpl is responsible for processing replication tasks for a shard.
7676
taskProcessorImpl struct {
77-
currentCluster string
77+
status int32
78+
7879
sourceCluster string
79-
status int32
80+
sourceShardID int32
8081
shard shard.Context
8182
historyEngine shard.Engine
8283
historySerializer serialization.Serializer
@@ -109,6 +110,7 @@ type (
109110

110111
// NewTaskProcessor creates a new replication task processor.
111112
func NewTaskProcessor(
113+
sourceShardID int32,
112114
shard shard.Context,
113115
historyEngine shard.Engine,
114116
config *configs.Config,
@@ -132,9 +134,9 @@ func NewTaskProcessor(
132134
WithExpirationInterval(config.ReplicationTaskProcessorErrorRetryExpiration(shardID))
133135

134136
return &taskProcessorImpl{
135-
currentCluster: shard.GetClusterMetadata().GetCurrentClusterName(),
136-
sourceCluster: replicationTaskFetcher.getSourceCluster(),
137137
status: common.DaemonStatusInitialized,
138+
sourceShardID: sourceShardID,
139+
sourceCluster: replicationTaskFetcher.getSourceCluster(),
138140
shard: shard,
139141
historyEngine: historyEngine,
140142
historySerializer: eventSerializer,
@@ -383,6 +385,7 @@ func (p *taskProcessorImpl) convertTaskToDLQTask(
383385
switch replicationTask.TaskType {
384386
case enumsspb.REPLICATION_TASK_TYPE_SYNC_ACTIVITY_TASK:
385387
taskAttributes := replicationTask.GetSyncActivityTaskAttributes()
388+
// TODO: GetShardID will break GetDLQReplicationMessages we need to handle DLQ for cross shard replication.
386389
return &persistence.PutReplicationTaskToDLQRequest{
387390
ShardID: p.shard.GetShardID(),
388391
SourceClusterName: p.sourceCluster,
@@ -414,6 +417,7 @@ func (p *taskProcessorImpl) convertTaskToDLQTask(
414417
// NOTE: last event vs next event, next event ID is exclusive
415418
nextEventID := lastEvent.GetEventId() + 1
416419

420+
// TODO: GetShardID will break GetDLQReplicationMessages we need to handle DLQ for cross shard replication.
417421
return &persistence.PutReplicationTaskToDLQRequest{
418422
ShardID: p.shard.GetShardID(),
419423
SourceClusterName: p.sourceCluster,
@@ -442,6 +446,7 @@ func (p *taskProcessorImpl) convertTaskToDLQTask(
442446
return nil, err
443447
}
444448

449+
// TODO: GetShardID will break GetDLQReplicationMessages we need to handle DLQ for cross shard replication.
445450
return &persistence.PutReplicationTaskToDLQRequest{
446451
ShardID: p.shard.GetShardID(),
447452
SourceClusterName: p.sourceCluster,
@@ -464,7 +469,7 @@ func (p *taskProcessorImpl) paginationFn(_ []byte) ([]interface{}, []byte, error
464469
respChan := make(chan *replicationspb.ReplicationMessages, 1)
465470
p.requestChan <- &replicationTaskRequest{
466471
token: &replicationspb.ReplicationToken{
467-
ShardId: p.shard.GetShardID(),
472+
ShardId: p.sourceShardID,
468473
LastProcessedMessageId: p.maxRxProcessedTaskID,
469474
LastProcessedVisibilityTime: &p.maxRxProcessedTimestamp,
470475
LastRetrievedMessageId: p.maxRxReceivedTaskID,
@@ -499,7 +504,7 @@ func (p *taskProcessorImpl) paginationFn(_ []byte) ([]interface{}, []byte, error
499504
if resp.GetHasMore() {
500505
p.rxTaskBackoff = time.Duration(0)
501506
} else {
502-
p.rxTaskBackoff = p.config.ReplicationTaskProcessorNoTaskRetryWait(p.shard.GetShardID())
507+
p.rxTaskBackoff = p.config.ReplicationTaskProcessorNoTaskRetryWait(p.sourceShardID)
503508
}
504509
return tasks, nil, nil
505510

service/history/replication/task_processor_manager.go

+40-30
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ package replication
2626

2727
import (
2828
"context"
29+
"fmt"
2930
"sync"
3031
"sync/atomic"
3132
"time"
@@ -49,6 +50,10 @@ import (
4950
wcache "go.temporal.io/server/service/history/workflow/cache"
5051
)
5152

53+
const (
54+
clusterCallbackKey = "%s-%d" // <cluster name>-<polling shard id>
55+
)
56+
5257
type (
5358
// taskProcessorManagerImpl is to manage replication task processors
5459
taskProcessorManagerImpl struct {
@@ -62,6 +67,7 @@ type (
6267
workflowCache wcache.Cache
6368
resender xdc.NDCHistoryResender
6469
taskExecutorProvider TaskExecutorProvider
70+
taskPollerManager pollerManager
6571
metricsHandler metrics.Handler
6672
logger log.Logger
6773

@@ -110,6 +116,7 @@ func NewTaskProcessorManager(
110116
metricsHandler: shard.GetMetricsHandler(),
111117
taskProcessors: make(map[string]TaskProcessor),
112118
taskExecutorProvider: taskExecutorProvider,
119+
taskPollerManager: newPollerManager(shard.GetShardID(), shard.GetClusterMetadata()),
113120
minTxAckedTaskID: persistence.EmptyQueueMessageID,
114121
shutdownChan: make(chan struct{}),
115122
}
@@ -167,36 +174,39 @@ func (r *taskProcessorManagerImpl) handleClusterMetadataUpdate(
167174
if clusterName == currentClusterName {
168175
continue
169176
}
170-
// The metadata triggers a update when the following fields update: 1. Enabled 2. Initial Failover Version 3. Cluster address
171-
// The callback covers three cases:
172-
// Case 1: Remove a cluster Case 2: Add a new cluster Case 3: Refresh cluster metadata.
173-
174-
if processor, ok := r.taskProcessors[clusterName]; ok {
175-
// Case 1 and Case 3
176-
processor.Stop()
177-
delete(r.taskProcessors, clusterName)
178-
}
179-
180-
if clusterInfo := newClusterMetadata[clusterName]; clusterInfo != nil && clusterInfo.Enabled {
181-
// Case 2 and Case 3
182-
fetcher := r.replicationTaskFetcherFactory.GetOrCreateFetcher(clusterName)
183-
replicationTaskProcessor := NewTaskProcessor(
184-
r.shard,
185-
r.engine,
186-
r.config,
187-
r.shard.GetMetricsHandler(),
188-
fetcher,
189-
r.taskExecutorProvider(TaskExecutorParams{
190-
RemoteCluster: clusterName,
191-
Shard: r.shard,
192-
HistoryResender: r.resender,
193-
DeleteManager: r.deleteMgr,
194-
WorkflowCache: r.workflowCache,
195-
}),
196-
r.eventSerializer,
197-
)
198-
replicationTaskProcessor.Start()
199-
r.taskProcessors[clusterName] = replicationTaskProcessor
177+
sourceShardIds := r.taskPollerManager.getSourceClusterShardIDs(clusterName)
178+
for _, sourceShardId := range sourceShardIds {
179+
perShardTaskProcessorKey := fmt.Sprintf(clusterCallbackKey, clusterName, sourceShardId)
180+
// The metadata triggers an update when the following fields update: 1. Enabled 2. Initial Failover Version 3. Cluster address
181+
// The callback covers three cases:
182+
// Case 1: Remove a cluster Case 2: Add a new cluster Case 3: Refresh cluster metadata.
183+
if processor, ok := r.taskProcessors[perShardTaskProcessorKey]; ok {
184+
// Case 1 and Case 3
185+
processor.Stop()
186+
delete(r.taskProcessors, perShardTaskProcessorKey)
187+
}
188+
if clusterInfo := newClusterMetadata[clusterName]; clusterInfo != nil && clusterInfo.Enabled {
189+
// Case 2 and Case 3
190+
fetcher := r.replicationTaskFetcherFactory.GetOrCreateFetcher(clusterName)
191+
replicationTaskProcessor := NewTaskProcessor(
192+
sourceShardId,
193+
r.shard,
194+
r.engine,
195+
r.config,
196+
r.shard.GetMetricsHandler(),
197+
fetcher,
198+
r.taskExecutorProvider(TaskExecutorParams{
199+
RemoteCluster: clusterName,
200+
Shard: r.shard,
201+
HistoryResender: r.resender,
202+
DeleteManager: r.deleteMgr,
203+
WorkflowCache: r.workflowCache,
204+
}),
205+
r.eventSerializer,
206+
)
207+
replicationTaskProcessor.Start()
208+
r.taskProcessors[perShardTaskProcessorKey] = replicationTaskProcessor
209+
}
200210
}
201211
}
202212
}

service/history/replication/task_processor_test.go

+1
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ func (s *taskProcessorSuite) SetupTest() {
148148
metricsClient := metrics.NoopMetricsHandler
149149

150150
s.replicationTaskProcessor = NewTaskProcessor(
151+
s.shardID,
151152
s.mockShard,
152153
s.mockEngine,
153154
s.config,

tests/test_cluster.go

+4
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,10 @@ func NewCluster(options *TestClusterConfig, logger log.Logger) (*TestCluster, er
169169
}
170170
}
171171

172+
clusterInfoMap := make(map[string]cluster.ClusterInformation)
172173
for clusterName, clusterInfo := range clusterMetadataConfig.ClusterInformation {
174+
clusterInfo.ShardCount = options.HistoryConfig.NumHistoryShards
175+
clusterInfoMap[clusterName] = clusterInfo
173176
_, err := testBase.ClusterMetadataManager.SaveClusterMetadata(context.Background(), &persistence.SaveClusterMetadataRequest{
174177
ClusterMetadata: persistencespb.ClusterMetadata{
175178
HistoryShardCount: options.HistoryConfig.NumHistoryShards,
@@ -185,6 +188,7 @@ func NewCluster(options *TestClusterConfig, logger log.Logger) (*TestCluster, er
185188
return nil, err
186189
}
187190
}
191+
clusterMetadataConfig.ClusterInformation = clusterInfoMap
188192

189193
// This will save custom test search attributes to cluster metadata.
190194
// Actual Elasticsearch fields are created from index template (testdata/es_v7_index_template.json).

0 commit comments

Comments
 (0)