Skip to content

Commit c0aafa1

Browse files
authored
feat: batch computation of ECChain Keys in chainexchange (#918)
* feat: batch computation of ECChain Keys in chainexchange Adjust the for loop not to repeat the full chain The EqualExportedValues is required as we never call explicit Key these values, leaving cached key empty. Signed-off-by: Jakub Sztandera <oss@kubuxu.com> * fix: make fake MarshalPayloadForSigning use real signing It conflicts with partial message validation Signed-off-by: Jakub Sztandera <oss@kubuxu.com> * Address review Signed-off-by: Jakub Sztandera <oss@kubuxu.com> --------- Signed-off-by: Jakub Sztandera <oss@kubuxu.com>
1 parent 9e81679 commit c0aafa1

File tree

11 files changed

+249
-57
lines changed

11 files changed

+249
-57
lines changed

chainexchange/options.go

+13
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"time"
66

77
"github.com/filecoin-project/go-f3/gpbft"
8+
"github.com/filecoin-project/go-f3/internal/clock"
89
"github.com/filecoin-project/go-f3/internal/psutil"
910
"github.com/filecoin-project/go-f3/manifest"
1011
pubsub "github.com/libp2p/go-libp2p-pubsub"
@@ -25,6 +26,7 @@ type options struct {
2526
listener Listener
2627
maxTimestampAge time.Duration
2728
compression bool
29+
clk clock.Clock
2830
}
2931

3032
func newOptions(o ...Option) (*options, error) {
@@ -35,6 +37,7 @@ func newOptions(o ...Option) (*options, error) {
3537
maxInstanceLookahead: manifest.DefaultCommitteeLookback,
3638
maxDiscoveredChainsPerInstance: 1000,
3739
maxWantedChainsPerInstance: 1000,
40+
clk: clock.RealClock,
3841
}
3942
for _, apply := range o {
4043
if err := apply(opts); err != nil {
@@ -163,3 +166,13 @@ func WithMaxTimestampAge(max time.Duration) Option {
163166
return nil
164167
}
165168
}
169+
170+
func WithClock(clk clock.Clock) Option {
171+
return func(o *options) error {
172+
if clk == nil {
173+
return errors.New("clock cannot be nil")
174+
}
175+
o.clk = clk
176+
return nil
177+
}
178+
}

chainexchange/pubsub.go

+10-9
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ func (p *PubSubChainExchange) validatePubSubMessage(ctx context.Context, _ peer.
245245
return pubsub.ValidationReject
246246
}
247247
}
248-
now := time.Now().UnixMilli()
248+
now := p.clk.Now().UnixMilli()
249249
lowerBound := now - p.maxTimestampAge.Milliseconds()
250250
if lowerBound > cmsg.Timestamp || cmsg.Timestamp > now {
251251
// The timestamp is too old or too far ahead. Ignore the message to avoid
@@ -263,11 +263,11 @@ func (p *PubSubChainExchange) cacheAsDiscoveredChain(ctx context.Context, cmsg M
263263
wanted := p.getChainsDiscoveredAt(ctx, cmsg.Instance)
264264
discovered := p.getChainsDiscoveredAt(ctx, cmsg.Instance)
265265

266-
for offset := cmsg.Chain.Len(); offset >= 0 && ctx.Err() == nil; offset-- {
267-
// TODO: Expose internals of merkle.go so that keys can be generated
268-
// cumulatively for a more efficient prefix chain key generation.
269-
prefix := cmsg.Chain.Prefix(offset)
266+
allPrefixes := cmsg.Chain.AllPrefixes()
267+
for i := len(allPrefixes) - 1; i >= 0 && ctx.Err() == nil; i-- {
268+
prefix := allPrefixes[i]
270269
key := prefix.Key()
270+
271271
if portion, found := wanted.Peek(key); !found {
272272
// Not a wanted key; add it to discovered chains if they are not there already,
273273
// i.e. without modifying the recent-ness of any of the discovered values.
@@ -329,11 +329,12 @@ type discovery struct {
329329
func (p *PubSubChainExchange) cacheAsWantedChain(ctx context.Context, cmsg Message) {
330330
var notifications []discovery
331331
wanted := p.getChainsWantedAt(ctx, cmsg.Instance)
332-
for offset := cmsg.Chain.Len(); offset >= 0 && ctx.Err() == nil; offset-- {
333-
// TODO: Expose internals of merkle.go so that keys can be generated
334-
// cumulatively for a more efficient prefix chain key generation.
335-
prefix := cmsg.Chain.Prefix(offset)
332+
333+
allPrefixes := cmsg.Chain.AllPrefixes()
334+
for i := len(allPrefixes) - 1; i >= 0 && ctx.Err() == nil; i-- {
335+
prefix := allPrefixes[i]
336336
key := prefix.Key()
337+
337338
if portion, found := wanted.Peek(key); !found || portion.IsPlaceholder() {
338339
wanted.Add(key, &chainPortion{
339340
chain: prefix,

chainexchange/pubsub_test.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -88,25 +88,25 @@ func TestPubSubChainExchange_Broadcast(t *testing.T) {
8888
chain, found = subject.GetChainByInstance(ctx, instance, key)
8989
return found
9090
}, time.Second, 100*time.Millisecond)
91-
require.Equal(t, ecChain, chain)
91+
require.EqualExportedValues(t, ecChain, chain)
9292

9393
baseChain := ecChain.BaseChain()
9494
baseKey := baseChain.Key()
9595
require.Eventually(t, func() bool {
9696
chain, found = subject.GetChainByInstance(ctx, instance, baseKey)
9797
return found
9898
}, time.Second, 100*time.Millisecond)
99-
require.Equal(t, baseChain, chain)
99+
require.EqualExportedValues(t, baseChain, chain)
100100

101101
// Assert that we have received 2 notifications, because ecChain has 2 tipsets.
102102
// First should be the ecChain, second should be the baseChain.
103103

104104
notifications := testListener.getNotifications()
105105
require.Len(t, notifications, 2)
106106
require.Equal(t, instance, notifications[1].instance)
107-
require.Equal(t, baseChain, notifications[1].chain)
107+
require.EqualExportedValues(t, baseChain, notifications[1].chain)
108108
require.Equal(t, instance, notifications[0].instance)
109-
require.Equal(t, ecChain, notifications[0].chain)
109+
require.EqualExportedValues(t, ecChain, notifications[0].chain)
110110

111111
require.NoError(t, subject.Shutdown(ctx))
112112
})

gpbft/chain.go

+45
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,51 @@ func (c *ECChain) Key() ECChainKey {
425425
return c.key
426426
}
427427

428+
// KeysForPrefixes return batch of keys for all prefixes
429+
// the indexes to this array correspond to Prefix(i) calls
430+
func (c *ECChain) KeysForPrefixes() []ECChainKey {
431+
if c.IsZero() {
432+
return nil
433+
}
434+
values := make([][]byte, c.Len())
435+
for i, ts := range c.TipSets {
436+
values[i] = ts.MarshalForSigning()
437+
}
438+
439+
batch := merkle.BatchTree(values)
440+
res := make([]ECChainKey, c.Len())
441+
for i := range c.Len() {
442+
res[i] = batch[i]
443+
}
444+
return res
445+
}
446+
447+
// AllPrefixes returns an array of all prefix chain including the c itself.
448+
// It precomputes keys for them as well, populating the key cache.
449+
func (c *ECChain) AllPrefixes() []*ECChain {
450+
if c.IsZero() {
451+
return nil
452+
}
453+
values := make([][]byte, c.Len())
454+
for i, ts := range c.TipSets {
455+
values[i] = ts.MarshalForSigning()
456+
}
457+
batch := merkle.BatchTree(values)
458+
459+
res := make([]*ECChain, len(c.TipSets))
460+
for i := range len(c.TipSets) {
461+
var prefix ECChain
462+
prefix.TipSets = c.TipSets[: i+1 : i+1]
463+
464+
copy(prefix.key[:], batch[i][:]) // populate the key cache
465+
prefix.keyLazyLoader.Do(func() {})
466+
467+
res[i] = &prefix
468+
}
469+
470+
return res
471+
}
472+
428473
func (c *ECChain) String() string {
429474
if c.IsZero() {
430475
return "丄"

gpbft/chain_test.go

+49
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,55 @@ func TestECChain(t *testing.T) {
189189
})
190190
}
191191

192+
func TestECChainKeysForPrefixes(t *testing.T) {
193+
t.Parallel()
194+
195+
pt1Cid := gpbft.MakeCid([]byte("pt1"))
196+
pt2Cid := gpbft.MakeCid([]byte("pt2"))
197+
pt3Cid := gpbft.MakeCid([]byte("pt3"))
198+
tipSets := []*gpbft.TipSet{
199+
{Epoch: 0, Key: []byte{0}, PowerTable: pt1Cid},
200+
{Epoch: 1, Key: []byte{1}, PowerTable: pt2Cid},
201+
{Epoch: 2, Key: []byte{2}, PowerTable: pt3Cid},
202+
}
203+
204+
chain := &gpbft.ECChain{TipSets: tipSets}
205+
keys := chain.KeysForPrefixes()
206+
207+
require.Equal(t, len(keys), len(tipSets), "KeysForPrefixes should return a key for each prefix")
208+
209+
for i, key := range keys {
210+
prefixChain := chain.Prefix(i)
211+
require.Equal(t, key, prefixChain.Key(), "Key for prefix %d should match", i)
212+
}
213+
}
214+
215+
func TestECChainAllPrefixes(t *testing.T) {
216+
t.Parallel()
217+
218+
pt1Cid := gpbft.MakeCid([]byte("pt1"))
219+
pt2Cid := gpbft.MakeCid([]byte("pt2"))
220+
pt3Cid := gpbft.MakeCid([]byte("pt3"))
221+
tipSets := []*gpbft.TipSet{
222+
{Epoch: 0, Key: []byte{0}, PowerTable: pt1Cid},
223+
{Epoch: 1, Key: []byte{1}, PowerTable: pt2Cid},
224+
{Epoch: 2, Key: []byte{2}, PowerTable: pt3Cid},
225+
}
226+
227+
chain := &gpbft.ECChain{TipSets: tipSets}
228+
prefixes := chain.AllPrefixes()
229+
230+
require.Equal(t, len(prefixes), len(tipSets), "AllPrefixes should return a prefix for each tipset")
231+
232+
for i, prefix := range prefixes {
233+
require.Equal(t, prefix.Len(), i+1, "Prefix %d should have length %d", i, i+1)
234+
expected := chain.Prefix(i)
235+
require.True(t, expected.Eq(prefix))
236+
require.Equal(t, expected.Key(), prefix.Key())
237+
require.NotSame(t, expected, prefix)
238+
}
239+
}
240+
192241
func TestECChain_Eq(t *testing.T) {
193242
t.Parallel()
194243
var (

host.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ func newRunner(
147147
runner.msgEncoding = encoding.NewCBOR[*PartialGMessage]()
148148
}
149149

150-
runner.pmm, err = newPartialMessageManager(runner.Progress, ps, m)
150+
runner.pmm, err = newPartialMessageManager(runner.Progress, ps, m, runner.clock)
151151
if err != nil {
152152
return nil, fmt.Errorf("creating partial message manager: %w", err)
153153
}

internal/clock/clock.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@ func WithMockClock(ctx context.Context) (context.Context, *Mock) {
2626
return context.WithValue(ctx, clockKey, (Clock)(clk)), clk
2727
}
2828

29-
var realClock = clock.New()
29+
var RealClock = clock.New()
3030

3131
// GetClock either retrieves a mock clock from the context or returns a realtime clock.
3232
func GetClock(ctx context.Context) Clock {
3333
clk := ctx.Value(clockKey)
3434
if clk == nil {
35-
return realClock
35+
return RealClock
3636
}
3737
return clk.(Clock)
3838
}

merkle/merkle.go

+30-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ func TreeWithProofs(values [][]byte) (Digest, [][]Digest) {
2929

3030
// Tree returns a the root of the merkle-tree of the given values.
3131
func Tree(values [][]byte) Digest {
32-
return buildTree(bits.Len(uint(len(values))-1), values, nil)
32+
return buildTree(depth(values), values, nil)
3333
}
3434

3535
// VerifyProof verifies that the given value maps to the given index in the merkle-tree with the
@@ -123,3 +123,32 @@ func buildTree(depth int, values [][]byte, proofs [][]Digest) Digest {
123123

124124
return internalHash(leftHash, rightHash)
125125
}
126+
127+
// BatchTree creates a batch of prefixes of values: [[0], [0, 1], [0, 1, 2], ...]
128+
// and provides digests for all of them.
129+
func BatchTree(values [][]byte) []Digest {
130+
n := len(values)
131+
roots := make([]Digest, n+1) // roots[0] is empty
132+
133+
for k := 1; k <= n; k++ {
134+
if k == 1 {
135+
roots[k] = leafHash(values[0])
136+
continue
137+
}
138+
139+
depth := bits.Len(uint(k - 1))
140+
split := 1 << (depth - 1)
141+
142+
// Reuse left subtree root from previous computation
143+
leftRoot := roots[split]
144+
145+
// compute right subtree root
146+
// this could be made more optimal but leads to more messy code
147+
rightValues := values[split:k]
148+
rightRoot := buildTree(depth-1, rightValues, nil)
149+
150+
roots[k] = internalHash(leftRoot, rightRoot)
151+
}
152+
153+
return roots[1:]
154+
}

merkle/merkle_test.go

+86
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
package merkle
22

33
import (
4+
"encoding/hex"
45
"fmt"
6+
"runtime"
57
"testing"
68

79
"github.com/stretchr/testify/assert"
@@ -65,3 +67,87 @@ func TestHashZero(t *testing.T) {
6567
valid, _ := VerifyProof(root, 0, nil, nil)
6668
assert.False(t, valid)
6769
}
70+
71+
func generateInputs(dst string, N int) [][]byte {
72+
var res [][]byte
73+
for i := 0; i < N; i++ {
74+
res = append(res, []byte(dst+fmt.Sprintf("%09d", i)))
75+
}
76+
return res
77+
}
78+
79+
func TestHashBatch(t *testing.T) {
80+
for N := 1; N < 300; {
81+
inputs := generateInputs(fmt.Sprintf("batch-%d", N), N)
82+
expected := make([]Digest, 0, N-1)
83+
84+
for i := 0; i < N; i++ {
85+
expected = append(expected, Tree(inputs[0:i+1]))
86+
}
87+
88+
batchRes := BatchTree(inputs)
89+
require.Equal(t, expected, batchRes, "failed at size %d", N)
90+
if N < 32 {
91+
N++
92+
} else {
93+
N += 13
94+
}
95+
}
96+
}
97+
98+
func TestHashTreeGolden(t *testing.T) {
99+
expectedHex := []string{
100+
"3d4395573ce4d2acbce4fe8a4be67ca5e7cdfb8ee2e85b2f6733c16b24c3b175",
101+
"91b7c899421ca7f3228e10265c6970a03bc2ccba44367b1d44a9d8597b20a32e",
102+
"69abe78dc2390b4666b60d0582e1799e73e48766f6e502c515e79d6cd2ae3c45",
103+
"bc4ce8dbf993eb2e87c02bbf19cd4faeb3a0672188bc6be6c8d867cef9b08917",
104+
"538cfd0c1f6b7ab4c3d20466d4e01b438972212fe5257eae213ae0a040da977f",
105+
"e28aa108b0263820dfe2c7f051ddc8794ab48ebd3c1813db28bf9f06bedc52f3",
106+
"875cb1d5027522b344b8adc62cd6bd110d97eaedd40a35bcb2fe142a9cb4612b",
107+
"63804e8b6cb16993d5d43d9d7faf17ba967365dac141a4afbce1d794157a1b8e",
108+
"07105bd8716bebc90036c8ebfe23a92bd775c09664b076ffa1d9a29d30647f91",
109+
"960b7eb6440789f76f5d53965e8b208e34777bc4aab78edf6827d71c7eea4933",
110+
"d55e07222c786722e1ad1b5bcc2ebaf04b2b4e92c07f3f7b61b0fbf0fd78fb9b",
111+
"ee5a34dfae748e088a1b99386274158266f44ceeb2c5190f4e9bbc39cd8a4d26",
112+
"15def4fc077ccfb0e48b32bc07ea3b91acecc5b73ed9caf13b10adf17052c371",
113+
"07cfe4ec2efa9075763f921e9f29794ec6b945694e41cc19911101270d8b1087",
114+
"84cdf541cbb3b9b3f26dbdeb9ca3a2721d15447a8b47074c3b08b560f79e5d85",
115+
"af8e9fc2f15aaedadb96da1afb339b93e3174661327dcc6aad70ea67e183369d",
116+
}
117+
var results []string
118+
N := 16
119+
for i := 0; i < N; i++ {
120+
inputs := generateInputs("golden", i+1)
121+
res := Tree(inputs)
122+
resHex := hex.EncodeToString(res[:])
123+
assert.Equal(t, expectedHex[i], resHex)
124+
results = append(results, resHex)
125+
}
126+
t.Logf("results: %#v", results)
127+
128+
batchRes := BatchTree(generateInputs("golden", N))
129+
batchResHash := make([]string, N)
130+
for i := 0; i < N; i++ {
131+
batchResHash[i] = hex.EncodeToString(batchRes[i][:])
132+
}
133+
assert.Equal(t, expectedHex, batchResHash)
134+
}
135+
136+
func BenchmarkPrefixes(b *testing.B) {
137+
K := 128
138+
inputs := generateInputs("golden", K)
139+
b.Run("IndividualPrefix", func(b *testing.B) {
140+
b.ReportAllocs()
141+
for range b.N {
142+
for i := range len(inputs) {
143+
runtime.KeepAlive(Tree(inputs[:i+1]))
144+
}
145+
}
146+
})
147+
b.Run("BatchPrefix", func(b *testing.B) {
148+
b.ReportAllocs()
149+
for range b.N {
150+
runtime.KeepAlive(BatchTree(inputs))
151+
}
152+
})
153+
}

0 commit comments

Comments
 (0)