Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify CASEClient initialization code #24079

Merged
merged 3 commits into from
Dec 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions src/app/CASEClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,20 @@

namespace chip {

CASEClient::CASEClient(const CASEClientInitParams & params) : mInitParams(params) {}

void CASEClient::SetRemoteMRPIntervals(const ReliableMessageProtocolConfig & remoteMRPConfig)
{
mCASESession.SetRemoteMRPConfig(remoteMRPConfig);
}

CHIP_ERROR CASEClient::EstablishSession(const ScopedNodeId & peer, const Transport::PeerAddress & peerAddress,
CHIP_ERROR CASEClient::EstablishSession(const CASEClientInitParams & params, const ScopedNodeId & peer,
const Transport::PeerAddress & peerAddress,
const ReliableMessageProtocolConfig & remoteMRPConfig,
SessionEstablishmentDelegate * delegate)
{
VerifyOrReturnError(mInitParams.fabricTable != nullptr, CHIP_ERROR_INVALID_ARGUMENT);
VerifyOrReturnError(params.fabricTable != nullptr, CHIP_ERROR_INVALID_ARGUMENT);

// Create a UnauthenticatedSession for CASE pairing.
Optional<SessionHandle> session = mInitParams.sessionManager->CreateUnauthenticatedSession(peerAddress, remoteMRPConfig);
Optional<SessionHandle> session = params.sessionManager->CreateUnauthenticatedSession(peerAddress, remoteMRPConfig);
VerifyOrReturnError(session.HasValue(), CHIP_ERROR_NO_MEMORY);

// Allocate the exchange immediately before calling CASESession::EstablishSession.
Expand All @@ -42,13 +41,13 @@ CHIP_ERROR CASEClient::EstablishSession(const ScopedNodeId & peer, const Transpo
// free it on error, but can only do this if it is actually called.
// Allocating the exchange context right before calling EstablishSession
// ensures that if allocation succeeds, CASESession has taken ownership.
Messaging::ExchangeContext * exchange = mInitParams.exchangeMgr->NewContext(session.Value(), &mCASESession);
Messaging::ExchangeContext * exchange = params.exchangeMgr->NewContext(session.Value(), &mCASESession);
VerifyOrReturnError(exchange != nullptr, CHIP_ERROR_INTERNAL);

mCASESession.SetGroupDataProvider(mInitParams.groupDataProvider);
ReturnErrorOnFailure(mCASESession.EstablishSession(*mInitParams.sessionManager, mInitParams.fabricTable, peer, exchange,
mInitParams.sessionResumptionStorage, mInitParams.certificateValidityPolicy,
delegate, mInitParams.mrpLocalConfig));
mCASESession.SetGroupDataProvider(params.groupDataProvider);
ReturnErrorOnFailure(mCASESession.EstablishSession(*params.sessionManager, params.fabricTable, peer, exchange,
params.sessionResumptionStorage, params.certificateValidityPolicy, delegate,
params.mrpLocalConfig));

return CHIP_NO_ERROR;
}
Expand Down
22 changes: 15 additions & 7 deletions src/app/CASEClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,31 @@ struct CASEClientInitParams
Messaging::ExchangeManager * exchangeMgr = nullptr;
FabricTable * fabricTable = nullptr;
Credentials::GroupDataProvider * groupDataProvider = nullptr;
Optional<ReliableMessageProtocolConfig> mrpLocalConfig = Optional<ReliableMessageProtocolConfig>::Missing();

Optional<ReliableMessageProtocolConfig> mrpLocalConfig = Optional<ReliableMessageProtocolConfig>::Missing();
CHIP_ERROR Validate() const
{
// sessionResumptionStorage can be nullptr when resumption is disabled.
// certificateValidityPolicy is optional, too.
ReturnErrorCodeIf(sessionManager == nullptr, CHIP_ERROR_INCORRECT_STATE);
ReturnErrorCodeIf(exchangeMgr == nullptr, CHIP_ERROR_INCORRECT_STATE);
ReturnErrorCodeIf(fabricTable == nullptr, CHIP_ERROR_INCORRECT_STATE);
ReturnErrorCodeIf(groupDataProvider == nullptr, CHIP_ERROR_INCORRECT_STATE);

return CHIP_NO_ERROR;
}
};

class DLL_EXPORT CASEClient
{
public:
CASEClient(const CASEClientInitParams & params);

void SetRemoteMRPIntervals(const ReliableMessageProtocolConfig & remoteMRPConfig);

CHIP_ERROR EstablishSession(const ScopedNodeId & peer, const Transport::PeerAddress & peerAddress,
const ReliableMessageProtocolConfig & remoteMRPConfig, SessionEstablishmentDelegate * delegate);
CHIP_ERROR EstablishSession(const CASEClientInitParams & params, const ScopedNodeId & peer,
const Transport::PeerAddress & peerAddress, const ReliableMessageProtocolConfig & remoteMRPConfig,
SessionEstablishmentDelegate * delegate);

private:
CASEClientInitParams mInitParams;

CASESession mCASESession;
};

Expand Down
4 changes: 2 additions & 2 deletions src/app/CASEClientPool.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace chip {
class CASEClientPoolDelegate
{
public:
virtual CASEClient * Allocate(CASEClientInitParams params) = 0;
virtual CASEClient * Allocate() = 0;

virtual void Release(CASEClient * client) = 0;

Expand All @@ -38,7 +38,7 @@ class CASEClientPool : public CASEClientPoolDelegate
public:
~CASEClientPool() override { mClientPool.ReleaseAll(); }

CASEClient * Allocate(CASEClientInitParams params) override { return mClientPool.CreateObject(params); }
CASEClient * Allocate() override { return mClientPool.CreateObject(); }

void Release(CASEClient * client) override { mClientPool.ReleaseObject(client); }

Expand Down
4 changes: 2 additions & 2 deletions src/app/CASESessionManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ void CASESessionManager::FindOrEstablishSession(const ScopedNodeId & peerId, Cal
{
ChipLogDetail(CASESessionManager, "FindOrEstablishSession: No existing OperationalSessionSetup instance found");

session = mConfig.sessionSetupPool->Allocate(mConfig.sessionInitParams, peerId, this);
session = mConfig.sessionSetupPool->Allocate(mConfig.sessionInitParams, mConfig.clientPool, peerId, this);

if (session == nullptr)
{
Expand Down Expand Up @@ -83,7 +83,7 @@ void CASESessionManager::UpdatePeerAddress(ScopedNodeId peerId)
{
ChipLogDetail(CASESessionManager, "UpdatePeerAddress: No existing OperationalSessionSetup instance found");

session = mConfig.sessionSetupPool->Allocate(mConfig.sessionInitParams, peerId, this);
session = mConfig.sessionSetupPool->Allocate(mConfig.sessionInitParams, mConfig.clientPool, peerId, this);
if (session == nullptr)
{
ChipLogDetail(CASESessionManager, "UpdatePeerAddress: Failed to allocate OperationalSessionSetup instance");
Expand Down
3 changes: 2 additions & 1 deletion src/app/CASESessionManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ class OperationalSessionSetupPoolDelegate;

struct CASESessionManagerConfig
{
DeviceProxyInitParams sessionInitParams;
CASEClientInitParams sessionInitParams;
CASEClientPoolDelegate * clientPool = nullptr;
OperationalSessionSetupPoolDelegate * sessionSetupPool = nullptr;
};

Expand Down
12 changes: 5 additions & 7 deletions src/app/OperationalSessionSetup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,10 @@ void OperationalSessionSetup::UpdateDeviceData(const Transport::PeerAddress & ad

CHIP_ERROR OperationalSessionSetup::EstablishConnection(const ReliableMessageProtocolConfig & config)
{
mCASEClient = mInitParams.clientPool->Allocate(CASEClientInitParams{
mInitParams.sessionManager, mInitParams.sessionResumptionStorage, mInitParams.certificateValidityPolicy,
mInitParams.exchangeMgr, mFabricTable, mInitParams.groupDataProvider, mInitParams.mrpLocalConfig });
mCASEClient = mClientPool->Allocate();
ReturnErrorCodeIf(mCASEClient == nullptr, CHIP_ERROR_NO_MEMORY);

CHIP_ERROR err = mCASEClient->EstablishSession(mPeerId, mDeviceAddress, config, this);
CHIP_ERROR err = mCASEClient->EstablishSession(mInitParams, mPeerId, mDeviceAddress, config, this);
if (err != CHIP_NO_ERROR)
{
CleanupCASEClient();
Expand Down Expand Up @@ -330,7 +328,7 @@ void OperationalSessionSetup::CleanupCASEClient()
{
if (mCASEClient)
{
mInitParams.clientPool->Release(mCASEClient);
mClientPool->Release(mCASEClient);
mCASEClient = nullptr;
}
}
Expand Down Expand Up @@ -364,7 +362,7 @@ OperationalSessionSetup::~OperationalSessionSetup()
if (mCASEClient)
{
// Make sure we don't leak it.
mInitParams.clientPool->Release(mCASEClient);
mClientPool->Release(mCASEClient);
}
}

Expand All @@ -382,7 +380,7 @@ CHIP_ERROR OperationalSessionSetup::LookupPeerAddress()
return CHIP_NO_ERROR;
}

auto const * fabricInfo = mFabricTable->FindFabricWithIndex(mPeerId.GetFabricIndex());
auto const * fabricInfo = mInitParams.fabricTable->FindFabricWithIndex(mPeerId.GetFabricIndex());
VerifyOrReturnError(fabricInfo != nullptr, CHIP_ERROR_INVALID_FABRIC_INDEX);

PeerId peerId(fabricInfo->GetCompressedFabricId(), mPeerId.GetNodeId());
Expand Down
35 changes: 5 additions & 30 deletions src/app/OperationalSessionSetup.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,31 +45,6 @@

namespace chip {

struct DeviceProxyInitParams
{
SessionManager * sessionManager = nullptr;
SessionResumptionStorage * sessionResumptionStorage = nullptr;
Credentials::CertificateValidityPolicy * certificateValidityPolicy = nullptr;
Messaging::ExchangeManager * exchangeMgr = nullptr;
FabricTable * fabricTable = nullptr;
CASEClientPoolDelegate * clientPool = nullptr;
Credentials::GroupDataProvider * groupDataProvider = nullptr;

Optional<ReliableMessageProtocolConfig> mrpLocalConfig = Optional<ReliableMessageProtocolConfig>::Missing();

CHIP_ERROR Validate() const
{
ReturnErrorCodeIf(sessionManager == nullptr, CHIP_ERROR_INCORRECT_STATE);
// sessionResumptionStorage can be nullptr when resumption is disabled
ReturnErrorCodeIf(exchangeMgr == nullptr, CHIP_ERROR_INCORRECT_STATE);
ReturnErrorCodeIf(fabricTable == nullptr, CHIP_ERROR_INCORRECT_STATE);
ReturnErrorCodeIf(groupDataProvider == nullptr, CHIP_ERROR_INCORRECT_STATE);
ReturnErrorCodeIf(clientPool == nullptr, CHIP_ERROR_INCORRECT_STATE);

return CHIP_NO_ERROR;
}
};

class OperationalSessionSetup;

/**
Expand Down Expand Up @@ -171,20 +146,20 @@ class DLL_EXPORT OperationalSessionSetup : public SessionDelegate,
public:
~OperationalSessionSetup() override;

OperationalSessionSetup(DeviceProxyInitParams & params, ScopedNodeId peerId,
OperationalSessionSetup(const CASEClientInitParams & params, CASEClientPoolDelegate * clientPool, ScopedNodeId peerId,
OperationalSessionReleaseDelegate * releaseDelegate) :
mSecureSession(*this)
{
mInitParams = params;
if (params.Validate() != CHIP_NO_ERROR || releaseDelegate == nullptr)
if (params.Validate() != CHIP_NO_ERROR || clientPool == nullptr || releaseDelegate == nullptr)
{
mState = State::Uninitialized;
return;
}

mClientPool = clientPool;
mSystemLayer = params.exchangeMgr->GetSessionManager()->SystemLayer();
mPeerId = peerId;
mFabricTable = params.fabricTable;
mReleaseDelegate = releaseDelegate;
mState = State::NeedsAddress;
mAddressLookupHandle.SetListener(this);
Expand Down Expand Up @@ -260,8 +235,8 @@ class DLL_EXPORT OperationalSessionSetup : public SessionDelegate,
SecureConnected, // CASE session established.
};

DeviceProxyInitParams mInitParams;
FabricTable * mFabricTable = nullptr;
CASEClientInitParams mInitParams;
CASEClientPoolDelegate * mClientPool = nullptr;
System::Layer * mSystemLayer;

// mCASEClient is only non-null if we are in State::Connecting or just
Expand Down
10 changes: 5 additions & 5 deletions src/app/OperationalSessionSetupPool.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ namespace chip {
class OperationalSessionSetupPoolDelegate
{
public:
virtual OperationalSessionSetup * Allocate(DeviceProxyInitParams & params, ScopedNodeId peerId,
OperationalSessionReleaseDelegate * releaseDelegate) = 0;
virtual OperationalSessionSetup * Allocate(const CASEClientInitParams & params, CASEClientPoolDelegate * clientPool,
ScopedNodeId peerId, OperationalSessionReleaseDelegate * releaseDelegate) = 0;

virtual void Release(OperationalSessionSetup * device) = 0;

Expand All @@ -47,10 +47,10 @@ class OperationalSessionSetupPool : public OperationalSessionSetupPoolDelegate
public:
~OperationalSessionSetupPool() override { mSessionSetupPool.ReleaseAll(); }

OperationalSessionSetup * Allocate(DeviceProxyInitParams & params, ScopedNodeId peerId,
OperationalSessionReleaseDelegate * releaseDelegate) override
OperationalSessionSetup * Allocate(const CASEClientInitParams & params, CASEClientPoolDelegate * clientPool,
ScopedNodeId peerId, OperationalSessionReleaseDelegate * releaseDelegate) override
{
return mSessionSetupPool.CreateObject(params, peerId, releaseDelegate);
return mSessionSetupPool.CreateObject(params, clientPool, peerId, releaseDelegate);
}

void Release(OperationalSessionSetup * device) override { mSessionSetupPool.ReleaseObject(device); }
Expand Down
4 changes: 2 additions & 2 deletions src/app/server/Server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,11 +290,11 @@ CHIP_ERROR Server::Init(const ServerInitParams & initParams)
.certificateValidityPolicy = mCertificateValidityPolicy,
.exchangeMgr = &mExchangeMgr,
.fabricTable = &mFabrics,
.clientPool = &mCASEClientPool,
.groupDataProvider = mGroupsProvider,
.mrpLocalConfig = GetLocalMRPConfig(),
},
.sessionSetupPool = &mSessionSetupPool,
.clientPool = &mCASEClientPool,
.sessionSetupPool = &mSessionSetupPool,
};

err = mCASESessionManager.Init(&DeviceLayer::SystemLayer(), caseSessionManagerConfig);
Expand Down
2 changes: 1 addition & 1 deletion src/app/tests/TestOperationalDeviceProxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ void TestOperationalDeviceProxy_EstablishSessionDirectly(nlTestSuite * inSuite,
VerifyOrDie(groupDataProvider.Init() == CHIP_NO_ERROR);
// TODO: Set IPK in groupDataProvider

DeviceProxyInitParams params = {
CASEClientInitParams params = {
.sessionManager = &sessionManager,
.sessionResumptionStorage = &sessionResumptionStorage,
.exchangeMgr = &exchangeMgr,
Expand Down
6 changes: 3 additions & 3 deletions src/controller/CHIPDeviceControllerFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,18 +245,18 @@ CHIP_ERROR DeviceControllerFactory::InitSystemState(FactoryInitParams params)
stateParams.sessionSetupPool = Platform::New<DeviceControllerSystemStateParams::SessionSetupPool>();
stateParams.caseClientPool = Platform::New<DeviceControllerSystemStateParams::CASEClientPool>();

DeviceProxyInitParams deviceInitParams = {
CASEClientInitParams sessionInitParams = {
.sessionManager = stateParams.sessionMgr,
.sessionResumptionStorage = stateParams.sessionResumptionStorage.get(),
.exchangeMgr = stateParams.exchangeMgr,
.fabricTable = stateParams.fabricTable,
.clientPool = stateParams.caseClientPool,
.groupDataProvider = stateParams.groupDataProvider,
.mrpLocalConfig = GetLocalMRPConfig(),
};

CASESessionManagerConfig sessionManagerConfig = {
.sessionInitParams = deviceInitParams,
.sessionInitParams = sessionInitParams,
.clientPool = stateParams.caseClientPool,
.sessionSetupPool = stateParams.sessionSetupPool,
};

Expand Down
2 changes: 1 addition & 1 deletion src/protocols/secure_channel/CASEServer.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class CASEServer : public SessionEstablishmentDelegate,
void OnResponseTimeout(Messaging::ExchangeContext * ec) override {}
Messaging::ExchangeMessageDispatch & GetMessageDispatch() override { return GetSession().GetMessageDispatch(); }

virtual CASESession & GetSession() { return mPairingSession; }
CASESession & GetSession() { return mPairingSession; }

private:
Messaging::ExchangeManager * mExchangeManager = nullptr;
Expand Down
11 changes: 1 addition & 10 deletions src/protocols/secure_channel/tests/TestCASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,15 +124,6 @@ class TestCASESecurePairingDelegate : public SessionEstablishmentDelegate
uint32_t mNumPairingComplete = 0;
};

class CASEServerForTest : public CASEServer
{
public:
CASESession & GetSession() override { return mCaseSession; }

private:
CASESession mCaseSession;
};

class TestOperationalKeystore : public chip::Crypto::OperationalKeystore
{
public:
Expand Down Expand Up @@ -469,7 +460,7 @@ void TestCASESession::SecurePairingHandshakeTest(nlTestSuite * inSuite, void * i
SecurePairingHandshakeTestCommon(inSuite, inContext, sessionManager, pairingCommissioner, delegateCommissioner);
}

CASEServerForTest gPairingServer;
CASEServer gPairingServer;

void TestCASESession::SecurePairingHandshakeServerTest(nlTestSuite * inSuite, void * inContext)
{
Expand Down