@@ -82,18 +82,45 @@ type impl struct {
82
82
receivers []Receiver
83
83
}
84
84
85
+ // interfaceWrapper is concrete type that wraps an interface. Necessary because
86
+ // atomic.Value needs the same type and can not Store(nil). This indirection
87
+ // allows us to store nil.
88
+ type interfaceWrapper [T any ] struct {
89
+ t T
90
+ }
91
+ type atomicInterface [T any ] struct {
92
+ iface atomic.Value
93
+ }
94
+
95
+ func (a * atomicInterface [T ]) Load () T {
96
+ var v T
97
+ x := a .iface .Load ()
98
+ if x != nil {
99
+ return x .(interfaceWrapper [T ]).t
100
+ }
101
+ return v
102
+ }
103
+
104
+ func (a * atomicInterface [T ]) Store (v T ) {
105
+ a .iface .Store (interfaceWrapper [T ]{v })
106
+ }
107
+
85
108
type streamMessageSender struct {
86
- to peer.ID
87
- stream network.Stream
88
- connected bool
89
- bsnet * impl
90
- opts * MessageSenderOpts
109
+ to peer.ID
110
+ stream atomicInterface [network.Stream ]
111
+ bsnet * impl
112
+ opts * MessageSenderOpts
113
+ }
114
+
115
+ type HasContext interface {
116
+ Context () context.Context
91
117
}
92
118
93
119
// Open a stream to the remote peer
94
120
func (s * streamMessageSender ) Connect (ctx context.Context ) (network.Stream , error ) {
95
- if s .connected {
96
- return s .stream , nil
121
+ stream := s .stream .Load ()
122
+ if stream != nil {
123
+ return stream , nil
97
124
}
98
125
99
126
tctx , cancel := context .WithTimeout (ctx , s .opts .SendTimeout )
@@ -107,30 +134,45 @@ func (s *streamMessageSender) Connect(ctx context.Context) (network.Stream, erro
107
134
if err != nil {
108
135
return nil , err
109
136
}
137
+ if withCtx , ok := stream .Conn ().(HasContext ); ok {
138
+ context .AfterFunc (withCtx .Context (), func () {
139
+ s .stream .Store (nil )
140
+ })
141
+ }
110
142
111
- s .stream = stream
112
- s .connected = true
113
- return s .stream , nil
143
+ s .stream .Store (stream )
144
+ return stream , nil
114
145
}
115
146
116
147
// Reset the stream
117
148
func (s * streamMessageSender ) Reset () error {
118
- if s .stream != nil {
119
- err := s .stream .Reset ()
120
- s .connected = false
149
+ stream := s .stream .Load ()
150
+ if stream != nil {
151
+ err := stream .Reset ()
152
+ s .stream .Store (nil )
121
153
return err
122
154
}
123
155
return nil
124
156
}
125
157
126
158
// Close the stream
127
159
func (s * streamMessageSender ) Close () error {
128
- return s .stream .Close ()
160
+ stream := s .stream .Load ()
161
+ if stream != nil {
162
+ err := stream .Close ()
163
+ s .stream .Store (nil )
164
+ return err
165
+ }
166
+ return nil
129
167
}
130
168
131
169
// Indicates whether the peer supports HAVE / DONT_HAVE messages
132
170
func (s * streamMessageSender ) SupportsHave () bool {
133
- return s .bsnet .SupportsHave (s .stream .Protocol ())
171
+ stream := s .stream .Load ()
172
+ if stream == nil {
173
+ return false
174
+ }
175
+ return s .bsnet .SupportsHave (stream .Protocol ())
134
176
}
135
177
136
178
// Send a message to the peer, attempting multiple times
0 commit comments