-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcoordinator.go
100 lines (82 loc) · 2.04 KB
/
coordinator.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
// Package coordinator implements a generic coordinator functionality for helium nodes.
// The coordinator orchestrates the execution of the MHE-based MPC protocol by executing
// its sub-protocols and routines.
package coordinator
import (
"context"
"fmt"
"sync"
"github.com/ChristianMct/helium/sessions"
)
type EventType any
type Log[T EventType] []T
type Channel[T EventType] struct {
Incoming <-chan T
Outgoing chan<- T
}
type channel[T EventType] struct {
incoming chan T
outgoing chan T
}
func (c *channel[T]) Channel() *Channel[T] {
return &Channel[T]{Incoming: c.incoming, Outgoing: c.outgoing}
}
type Coordinator[T EventType] interface {
Register(ctx context.Context) (evChan *Channel[T], present int, err error)
}
type TestCoordinator[T EventType] struct {
hid sessions.NodeID
log Log[T]
closed bool
c channel[T]
clients []chan T
l sync.Mutex
}
func NewTestCoordinator[T EventType](hid sessions.NodeID) *TestCoordinator[T] {
tc := &TestCoordinator[T]{hid: hid,
log: make([]T, 0),
c: channel[T]{incoming: make(chan T), outgoing: make(chan T)},
clients: make([]chan T, 0)}
go func() {
for ev := range tc.c.outgoing {
tc.l.Lock()
tc.log = append(tc.log, ev)
for _, cli := range tc.clients {
cli <- ev
}
tc.l.Unlock()
}
tc.l.Lock()
tc.closed = true
for _, cli := range tc.clients {
close(cli)
}
tc.l.Unlock()
}()
return tc
}
func (tc *TestCoordinator[T]) Close() {
close(tc.c.incoming)
}
func (tc *TestCoordinator[T]) Register(ctx context.Context) (evChan *Channel[T], present int, err error) {
tc.l.Lock()
defer tc.l.Unlock()
nid, has := sessions.NodeIDFromContext(ctx)
if !has {
return nil, 0, fmt.Errorf("no node id found in context")
}
if nid == tc.hid {
return tc.c.Channel(), 0, nil
}
p := len(tc.log)
cliC := channel[T]{incoming: make(chan T, p), outgoing: make(chan T)}
for _, ev := range tc.log {
cliC.incoming <- ev
}
if tc.closed {
close(cliC.incoming)
} else {
tc.clients = append(tc.clients, cliC.incoming)
}
return cliC.Channel(), p, nil
}