Skip to content

Commit 3e31880

Browse files
authored
Add a component to assign polling shards (#3775)
* Add a component to assign polling shards
1 parent aa76719 commit 3e31880

File tree

3 files changed

+180
-0
lines changed

3 files changed

+180
-0
lines changed

common/cluster/metadata.go

+1
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ type (
104104
InitialFailoverVersion int64 `yaml:"initialFailoverVersion"`
105105
// Address indicate the remote service address(Host:Port). Host can be DNS name.
106106
RPCAddress string `yaml:"rpcAddress"`
107+
ShardCount int32
107108
// private field to track cluster information updates
108109
version int64
109110
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
// The MIT License
2+
//
3+
// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved.
4+
//
5+
// Copyright (c) 2020 Uber Technologies, Inc.
6+
//
7+
// Permission is hereby granted, free of charge, to any person obtaining a copy
8+
// of this software and associated documentation files (the "Software"), to deal
9+
// in the Software without restriction, including without limitation the rights
10+
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11+
// copies of the Software, and to permit persons to whom the Software is
12+
// furnished to do so, subject to the following conditions:
13+
//
14+
// The above copyright notice and this permission notice shall be included in
15+
// all copies or substantial portions of the Software.
16+
//
17+
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18+
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19+
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20+
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21+
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22+
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23+
// THE SOFTWARE.
24+
25+
package replication
26+
27+
import (
28+
"fmt"
29+
30+
"go.temporal.io/server/common/cluster"
31+
)
32+
33+
type (
34+
pollerManagerImpl struct {
35+
currentShardId int32
36+
clusterMetadata cluster.Metadata
37+
}
38+
)
39+
40+
func newPollerManager(
41+
currentShardId int32,
42+
clusterMetadata cluster.Metadata,
43+
) *pollerManagerImpl {
44+
return &pollerManagerImpl{
45+
currentShardId: currentShardId,
46+
clusterMetadata: clusterMetadata,
47+
}
48+
}
49+
50+
func (p pollerManagerImpl) getPollingShardIDs(remoteClusterName string) []int32 {
51+
currentCluster := p.clusterMetadata.GetCurrentClusterName()
52+
allClusters := p.clusterMetadata.GetAllClusterInfo()
53+
currentClusterInfo, ok := allClusters[currentCluster]
54+
if !ok {
55+
panic("Cannot get current cluster info from cluster metadata cache")
56+
}
57+
remoteClusterInfo, ok := allClusters[remoteClusterName]
58+
if !ok {
59+
panic(fmt.Sprintf("Cannot get remote cluster %s info from cluster metadata cache", remoteClusterName))
60+
}
61+
return generatePollingShardIDs(p.currentShardId, currentClusterInfo.ShardCount, remoteClusterInfo.ShardCount)
62+
}
63+
64+
func generatePollingShardIDs(localShardId int32, localShardCount int32, remoteShardCount int32) []int32 {
65+
var pollingShards []int32
66+
if remoteShardCount <= localShardCount {
67+
if localShardId <= remoteShardCount {
68+
pollingShards = append(pollingShards, localShardId)
69+
}
70+
return pollingShards
71+
}
72+
73+
// remoteShardCount > localShardCount, replication poller will poll from multiple remote shard.
74+
// The remote shard count and local shard count must be multiples.
75+
if remoteShardCount%localShardCount != 0 {
76+
panic(fmt.Sprintf("Remote shard count %d and local shard count %d are not multiples.", remoteShardCount, localShardCount))
77+
}
78+
for i := localShardId; i <= remoteShardCount; i += localShardCount {
79+
pollingShards = append(pollingShards, i)
80+
}
81+
return pollingShards
82+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
// The MIT License
2+
//
3+
// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved.
4+
//
5+
// Copyright (c) 2020 Uber Technologies, Inc.
6+
//
7+
// Permission is hereby granted, free of charge, to any person obtaining a copy
8+
// of this software and associated documentation files (the "Software"), to deal
9+
// in the Software without restriction, including without limitation the rights
10+
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11+
// copies of the Software, and to permit persons to whom the Software is
12+
// furnished to do so, subject to the following conditions:
13+
//
14+
// The above copyright notice and this permission notice shall be included in
15+
// all copies or substantial portions of the Software.
16+
//
17+
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18+
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19+
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20+
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21+
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22+
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23+
// THE SOFTWARE.
24+
25+
package replication
26+
27+
import (
28+
"fmt"
29+
"testing"
30+
31+
"github.com/stretchr/testify/assert"
32+
)
33+
34+
func TestGetPollingShardIds(t *testing.T) {
35+
testCases := []struct {
36+
shardID int32
37+
remoteShardCount int32
38+
localShardCount int32
39+
expectedPanic bool
40+
expectedShardIDs []int32
41+
}{
42+
{
43+
1,
44+
4,
45+
4,
46+
false,
47+
[]int32{1},
48+
},
49+
{
50+
1,
51+
2,
52+
4,
53+
false,
54+
[]int32{1},
55+
},
56+
{
57+
3,
58+
2,
59+
4,
60+
false,
61+
[]int32{},
62+
},
63+
{
64+
1,
65+
16,
66+
4,
67+
false,
68+
[]int32{1, 5, 9, 13},
69+
},
70+
{
71+
4,
72+
16,
73+
4,
74+
false,
75+
[]int32{4, 8, 12, 16},
76+
},
77+
{
78+
4,
79+
17,
80+
4,
81+
true,
82+
[]int32{},
83+
},
84+
}
85+
for idx, tt := range testCases {
86+
t.Run(fmt.Sprintf("Testcase %d", idx), func(t *testing.T) {
87+
t.Parallel()
88+
defer func() {
89+
if r := recover(); tt.expectedPanic && r == nil {
90+
t.Errorf("The code did not panic")
91+
}
92+
}()
93+
shardIDs := generatePollingShardIDs(tt.shardID, tt.localShardCount, tt.remoteShardCount)
94+
assert.Equal(t, tt.expectedShardIDs, shardIDs)
95+
})
96+
}
97+
}

0 commit comments

Comments
 (0)