Skip to content

Commit 1038256

Browse files
kghostpull[bot]
authored andcommitted
Add a simple variant implementation similiar to std::variant (#6624)
1 parent ac169b6 commit 1038256

File tree

6 files changed

+423
-54
lines changed

6 files changed

+423
-54
lines changed

src/channel/ChannelContext.cpp

+37-36
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,17 @@ void ChannelContext::Start(const ChannelBuilder & builder)
3838
ExchangeContext * ChannelContext::NewExchange(ExchangeDelegate * delegate)
3939
{
4040
assert(GetState() == ChannelState::kReady);
41-
return mExchangeManager->NewContext(mStateVars.mReady.mSession, delegate);
41+
return mExchangeManager->NewContext(GetReadyVars().mSession, delegate);
4242
}
4343

4444
bool ChannelContext::MatchNodeId(NodeId nodeId)
4545
{
4646
switch (mState)
4747
{
4848
case ChannelState::kPreparing:
49-
return nodeId == mStateVars.mPreparing.mBuilder.GetPeerNodeId();
49+
return nodeId == GetPrepareVars().mBuilder.GetPeerNodeId();
5050
case ChannelState::kReady: {
51-
auto state = mExchangeManager->GetSessionMgr()->GetPeerConnectionState(mStateVars.mReady.mSession);
51+
auto state = mExchangeManager->GetSessionMgr()->GetPeerConnectionState(GetReadyVars().mSession);
5252
if (state == nullptr)
5353
return false;
5454
return nodeId == state->GetPeerNodeId();
@@ -63,7 +63,7 @@ bool ChannelContext::MatchTransport(Transport::Type transport)
6363
switch (mState)
6464
{
6565
case ChannelState::kPreparing:
66-
switch (mStateVars.mPreparing.mBuilder.GetTransportPreference())
66+
switch (GetPrepareVars().mBuilder.GetTransportPreference())
6767
{
6868
case ChannelBuilder::TransportPreference::kPreferConnectionOriented:
6969
case ChannelBuilder::TransportPreference::kConnectionOriented:
@@ -73,7 +73,7 @@ bool ChannelContext::MatchTransport(Transport::Type transport)
7373
}
7474
return false;
7575
case ChannelState::kReady: {
76-
auto state = mExchangeManager->GetSessionMgr()->GetPeerConnectionState(mStateVars.mReady.mSession);
76+
auto state = mExchangeManager->GetSessionMgr()->GetPeerConnectionState(GetReadyVars().mSession);
7777
if (state == nullptr)
7878
return false;
7979
return transport == state->GetPeerAddress().GetTransportType();
@@ -118,36 +118,38 @@ bool ChannelContext::MatchesBuilder(const ChannelBuilder & builder)
118118

119119
bool ChannelContext::IsCasePairing()
120120
{
121-
return mState == ChannelState::kPreparing && mStateVars.mPreparing.mState == PrepareState::kCasePairing;
121+
return mState == ChannelState::kPreparing && GetPrepareVars().mState == PrepareState::kCasePairing;
122122
}
123123

124124
bool ChannelContext::MatchesSession(SecureSessionHandle session, SecureSessionMgr * ssm)
125125
{
126126
switch (mState)
127127
{
128128
case ChannelState::kPreparing: {
129-
switch (mStateVars.mPreparing.mState)
129+
switch (GetPrepareVars().mState)
130130
{
131131
case PrepareState::kCasePairing: {
132132
auto state = ssm->GetPeerConnectionState(session);
133-
return (state->GetPeerNodeId() == mStateVars.mPreparing.mBuilder.GetPeerNodeId() &&
134-
state->GetPeerKeyID() == mStateVars.mPreparing.mBuilder.GetPeerKeyID());
133+
return (state->GetPeerNodeId() == GetPrepareVars().mBuilder.GetPeerNodeId() &&
134+
state->GetPeerKeyID() == GetPrepareVars().mBuilder.GetPeerKeyID());
135135
}
136136
default:
137137
return false;
138138
}
139139
}
140140
case ChannelState::kReady:
141-
return mStateVars.mReady.mSession == session;
141+
return GetReadyVars().mSession == session;
142142
default:
143143
return false;
144144
}
145145
}
146146

147147
void ChannelContext::EnterPreparingState(const ChannelBuilder & builder)
148148
{
149-
mState = ChannelState::kPreparing;
150-
mStateVars.mPreparing.mBuilder = builder;
149+
mState = ChannelState::kPreparing;
150+
151+
mStateVars.Set<PrepareVars>();
152+
GetPrepareVars().mBuilder = builder;
151153

152154
EnterAddressResolve();
153155
}
@@ -157,14 +159,14 @@ void ChannelContext::ExitPreparingState() {}
157159
// Address resolve
158160
void ChannelContext::EnterAddressResolve()
159161
{
160-
mStateVars.mPreparing.mState = PrepareState::kAddressResolving;
162+
GetPrepareVars().mState = PrepareState::kAddressResolving;
161163

162164
// Skip address resolve if the address is provided
163165
{
164-
auto addr = mStateVars.mPreparing.mBuilder.GetForcePeerAddress();
166+
auto addr = GetPrepareVars().mBuilder.GetForcePeerAddress();
165167
if (addr.HasValue())
166168
{
167-
mStateVars.mPreparing.mAddress = addr.Value();
169+
GetPrepareVars().mAddress = addr.Value();
168170
ExitAddressResolve();
169171
// Only CASE session is supported
170172
EnterCasePairingState();
@@ -174,10 +176,10 @@ void ChannelContext::EnterAddressResolve()
174176

175177
// TODO: call mDNS Scanner::SubscribeNode after PR #4459 is ready
176178
// Scanner::RegisterScannerDelegate(this)
177-
// Scanner::SubscribeNode(mStateVars.mPreparing.mBuilder.GetPeerNodeId())
179+
// Scanner::SubscribeNode(GetPrepareVars().mBuilder.GetPeerNodeId())
178180

179181
// The HandleNodeIdResolve may already have been called, recheck the state here before set up the timer
180-
if (mState == ChannelState::kPreparing && mStateVars.mPreparing.mState == PrepareState::kAddressResolving)
182+
if (mState == ChannelState::kPreparing && GetPrepareVars().mState == PrepareState::kAddressResolving)
181183
{
182184
System::Layer * layer = mExchangeManager->GetSessionMgr()->SystemLayer();
183185
layer->StartTimer(CHIP_CONFIG_NODE_ADDRESS_RESOLVE_TIMEOUT_MSECS, AddressResolveTimeout, this);
@@ -196,7 +198,7 @@ void ChannelContext::AddressResolveTimeout()
196198
{
197199
if (mState != ChannelState::kPreparing)
198200
return;
199-
if (mStateVars.mPreparing.mState != PrepareState::kAddressResolving)
201+
if (GetPrepareVars().mState != PrepareState::kAddressResolving)
200202
return;
201203

202204
ExitAddressResolve();
@@ -219,7 +221,7 @@ void ChannelContext::HandleNodeIdResolve(CHIP_ERROR error, uint64_t nodeId, cons
219221
return;
220222
}
221223
case ChannelState::kPreparing: {
222-
switch (mStateVars.mPreparing.mState)
224+
switch (GetPrepareVars().mState)
223225
{
224226
case PrepareState::kAddressResolving: {
225227
if (error != CHIP_NO_ERROR)
@@ -232,8 +234,8 @@ void ChannelContext::HandleNodeIdResolve(CHIP_ERROR error, uint64_t nodeId, cons
232234

233235
if (!address.mAddress.HasValue())
234236
return;
235-
mStateVars.mPreparing.mAddressType = address.mAddressType;
236-
mStateVars.mPreparing.mAddress = address.mAddress.Value();
237+
GetPrepareVars().mAddressType = address.mAddressType;
238+
GetPrepareVars().mAddress = address.mAddress.Value();
237239
ExitAddressResolve();
238240
EnterCasePairingState();
239241
return;
@@ -253,18 +255,18 @@ void ChannelContext::HandleNodeIdResolve(CHIP_ERROR error, uint64_t nodeId, cons
253255

254256
void ChannelContext::EnterCasePairingState()
255257
{
256-
mStateVars.mPreparing.mState = PrepareState::kCasePairing;
257-
mStateVars.mPreparing.mCasePairingSession = Platform::New<CASESession>();
258+
auto & prepare = GetPrepareVars();
259+
prepare.mCasePairingSession = Platform::New<CASESession>();
258260

259-
ExchangeContext * ctxt = mExchangeManager->NewContext(SecureSessionHandle(), mStateVars.mPreparing.mCasePairingSession);
261+
ExchangeContext * ctxt = mExchangeManager->NewContext(SecureSessionHandle(), prepare.mCasePairingSession);
260262
VerifyOrReturn(ctxt != nullptr);
261263

262264
// TODO: currently only supports IP/UDP paring
263265
Transport::PeerAddress addr;
264-
addr.SetTransportType(Transport::Type::kUdp).SetIPAddress(mStateVars.mPreparing.mAddress);
265-
CHIP_ERROR err = mStateVars.mPreparing.mCasePairingSession->EstablishSession(
266-
addr, &mStateVars.mPreparing.mBuilder.GetOperationalCredentialSet(), mStateVars.mPreparing.mBuilder.GetPeerNodeId(),
267-
mExchangeManager->GetNextKeyId(), ctxt, this);
266+
addr.SetTransportType(Transport::Type::kUdp).SetIPAddress(prepare.mAddress);
267+
CHIP_ERROR err = prepare.mCasePairingSession->EstablishSession(addr, &prepare.mBuilder.GetOperationalCredentialSet(),
268+
prepare.mBuilder.GetPeerNodeId(),
269+
mExchangeManager->GetNextKeyId(), ctxt, this);
268270
if (err != CHIP_NO_ERROR)
269271
{
270272
ExitCasePairingState();
@@ -275,14 +277,14 @@ void ChannelContext::EnterCasePairingState()
275277

276278
void ChannelContext::ExitCasePairingState()
277279
{
278-
Platform::Delete(mStateVars.mPreparing.mCasePairingSession);
280+
Platform::Delete(GetPrepareVars().mCasePairingSession);
279281
}
280282

281283
void ChannelContext::OnSessionEstablishmentError(CHIP_ERROR error)
282284
{
283285
if (mState != ChannelState::kPreparing)
284286
return;
285-
switch (mStateVars.mPreparing.mState)
287+
switch (GetPrepareVars().mState)
286288
{
287289
case PrepareState::kCasePairing:
288290
ExitCasePairingState();
@@ -298,11 +300,11 @@ void ChannelContext::OnSessionEstablished()
298300
{
299301
if (mState != ChannelState::kPreparing)
300302
return;
301-
switch (mStateVars.mPreparing.mState)
303+
switch (GetPrepareVars().mState)
302304
{
303305
case PrepareState::kCasePairing:
304306
ExitCasePairingState();
305-
mStateVars.mPreparing.mState = PrepareState::kCasePairingDone;
307+
GetPrepareVars().mState = PrepareState::kCasePairingDone;
306308
// TODO: current CASE paring session API doesn't show how to derive a secure session
307309
return;
308310
default:
@@ -314,7 +316,7 @@ void ChannelContext::OnNewConnection(SecureSessionHandle session)
314316
{
315317
if (mState != ChannelState::kPreparing)
316318
return;
317-
if (mStateVars.mPreparing.mState != PrepareState::kCasePairingDone)
319+
if (GetPrepareVars().mState != PrepareState::kCasePairingDone)
318320
return;
319321

320322
ExitPreparingState();
@@ -324,8 +326,7 @@ void ChannelContext::OnNewConnection(SecureSessionHandle session)
324326
void ChannelContext::EnterReadyState(SecureSessionHandle session)
325327
{
326328
mState = ChannelState::kReady;
327-
328-
mStateVars.mReady.mSession = session;
329+
mStateVars.Set<ReadyVars>(session);
329330
mChannelManager->NotifyChannelEvent(this, [](ChannelDelegate * delegate) { delegate->OnEstablished(); });
330331
}
331332

@@ -344,7 +345,7 @@ void ChannelContext::ExitReadyState()
344345
// Currently SecureSessionManager doesn't provide an interface to close a session
345346

346347
// TODO: call mDNS Scanner::UnubscribeNode after PR #4459 is ready
347-
// Scanner::UnsubscribeNode(mStateVars.mPreparing.mBuilder.GetPeerNodeId())
348+
// Scanner::UnsubscribeNode(GetPrepareVars().mBuilder.GetPeerNodeId())
348349
}
349350

350351
void ChannelContext::EnterFailedState(CHIP_ERROR error)

src/channel/ChannelContext.h

+22-18
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <channel/Channel.h>
2828
#include <lib/core/ReferenceCounted.h>
2929
#include <lib/mdns/platform/Mdns.h>
30+
#include <lib/support/Variant.h>
3031
#include <protocols/secure_channel/CASESession.h>
3132
#include <transport/PeerConnectionState.h>
3233
#include <transport/SecureSessionMgr.h>
@@ -129,25 +130,28 @@ class ChannelContext : public ReferenceCounted<ChannelContext, ChannelContextDel
129130
kCasePairingDone,
130131
};
131132

132-
union StateVars
133+
// mPreparing is pretty big, consider move it outside
134+
struct PrepareVars
133135
{
134-
StateVars() {}
135-
136-
// mPreparing is pretty big, consider move it outside
137-
struct PrepareVars
138-
{
139-
PrepareState mState;
140-
Inet::IPAddressType mAddressType;
141-
Inet::IPAddress mAddress;
142-
CASESession * mCasePairingSession;
143-
ChannelBuilder mBuilder;
144-
} mPreparing;
145-
146-
struct ReadyVars
147-
{
148-
SecureSessionHandle mSession;
149-
} mReady;
150-
} mStateVars;
136+
static constexpr const size_t VariantId = 1;
137+
PrepareState mState;
138+
Inet::IPAddressType mAddressType;
139+
Inet::IPAddress mAddress;
140+
CASESession * mCasePairingSession;
141+
ChannelBuilder mBuilder;
142+
};
143+
144+
struct ReadyVars
145+
{
146+
static constexpr const size_t VariantId = 2;
147+
ReadyVars(SecureSessionHandle session) : mSession(session) {}
148+
const SecureSessionHandle mSession;
149+
};
150+
151+
Variant<PrepareVars, ReadyVars> mStateVars;
152+
153+
PrepareVars & GetPrepareVars() { return mStateVars.Get<PrepareVars>(); }
154+
ReadyVars & GetReadyVars() { return mStateVars.Get<ReadyVars>(); }
151155

152156
// State machine functions
153157
void EnterPreparingState(const ChannelBuilder & builder);

src/lib/support/BUILD.gn

+1
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ static_library("support") {
9292
"TimeUtils.h",
9393
"UnitTestRegistration.cpp",
9494
"UnitTestRegistration.h",
95+
"Variant.h",
9596
"logging/CHIPLogging.cpp",
9697
"logging/CHIPLogging.h",
9798
"verhoeff/Verhoeff.cpp",

0 commit comments

Comments
 (0)