Skip to content

Commit a2f72dd

Browse files
yunhanw-googleshgutte
authored andcommittedSep 10, 2024
Fix multiple check-in/peer nodeId handling in icd client side (project-chip#35304)
1 parent ce4845f commit a2f72dd

14 files changed

+119
-47
lines changed
 

‎examples/chip-tool/commands/clusters/ClusterCommand.h

+21-6
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class ClusterCommand : public InteractionModelCommands, public ModelCommand, pub
6060
const chip::app::Clusters::IcdManagement::Commands::UnregisterClient::Type & value)
6161
{
6262
ReturnErrorOnFailure(InteractionModelCommands::SendCommand(device, endpointId, clusterId, commandId, value));
63-
mScopedNodeId = chip::ScopedNodeId(value.checkInNodeID, device->GetSecureSession().Value()->GetFabricIndex());
63+
mPeerNodeId = chip::ScopedNodeId(device->GetDeviceId(), device->GetSecureSession().Value()->GetFabricIndex());
6464
return CHIP_NO_ERROR;
6565
}
6666

@@ -69,7 +69,8 @@ class ClusterCommand : public InteractionModelCommands, public ModelCommand, pub
6969
const chip::app::Clusters::IcdManagement::Commands::RegisterClient::Type & value)
7070
{
7171
ReturnErrorOnFailure(InteractionModelCommands::SendCommand(device, endpointId, clusterId, commandId, value));
72-
mScopedNodeId = chip::ScopedNodeId(value.checkInNodeID, device->GetSecureSession().Value()->GetFabricIndex());
72+
mPeerNodeId = chip::ScopedNodeId(device->GetDeviceId(), device->GetSecureSession().Value()->GetFabricIndex());
73+
mCheckInNodeId = chip::ScopedNodeId(value.checkInNodeID, device->GetSecureSession().Value()->GetFabricIndex());
7374
mMonitoredSubject = value.monitoredSubject;
7475
mClientType = value.clientType;
7576
memcpy(mICDSymmetricKey, value.key.data(), value.key.size());
@@ -147,7 +148,9 @@ class ClusterCommand : public InteractionModelCommands, public ModelCommand, pub
147148
return;
148149
}
149150
chip::app::ICDClientInfo clientInfo;
150-
clientInfo.peer_node = mScopedNodeId;
151+
152+
clientInfo.peer_node = mPeerNodeId;
153+
clientInfo.check_in_node = mCheckInNodeId;
151154
clientInfo.monitored_subject = mMonitoredSubject;
152155
clientInfo.start_icd_counter = value.ICDCounter;
153156
clientInfo.client_type = mClientType;
@@ -159,7 +162,7 @@ class ClusterCommand : public InteractionModelCommands, public ModelCommand, pub
159162
if ((path.mEndpointId == chip::kRootEndpointId) && (path.mClusterId == chip::app::Clusters::IcdManagement::Id) &&
160163
(path.mCommandId == chip::app::Clusters::IcdManagement::Commands::UnregisterClient::Id))
161164
{
162-
ClearICDEntry(mScopedNodeId);
165+
ClearICDEntry(mPeerNodeId);
163166
}
164167
}
165168

@@ -260,9 +263,21 @@ class ClusterCommand : public InteractionModelCommands, public ModelCommand, pub
260263
private:
261264
chip::ClusterId mClusterId;
262265
chip::CommandId mCommandId;
263-
chip::ScopedNodeId mScopedNodeId;
264-
uint64_t mMonitoredSubject = static_cast<uint64_t>(0);
266+
// The scoped node ID to which RegisterClient and UnregisterClient command will be sent. Not set for other commands.
267+
chip::ScopedNodeId mPeerNodeId;
268+
// The scoped node ID to which a Check-In message will be sent. Only set for the RegisterClient command.
269+
chip::ScopedNodeId mCheckInNodeId;
270+
271+
// Used to determine if a particular client has an active subscription for the given entry.
272+
// The MonitoredSubject, when it is a NodeID, MAY be the same as the CheckInNodeID.
273+
// The MonitoredSubject gives the registering client the flexibility of having a different
274+
// CheckInNodeID from the MonitoredSubject.
275+
uint64_t mMonitoredSubject = static_cast<uint64_t>(0);
276+
277+
// Client type of the client registering
265278
chip::app::Clusters::IcdManagement::ClientTypeEnum mClientType = chip::app::Clusters::IcdManagement::ClientTypeEnum::kPermanent;
279+
280+
// Shared secret between the client and the ICD to encrypt the Check-In message.
266281
uint8_t mICDSymmetricKey[chip::Crypto::kAES_CCM128_Key_Length];
267282

268283
CHIP_ERROR mError = CHIP_NO_ERROR;

‎examples/chip-tool/commands/icd/ICDCommand.cpp

+6-3
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,17 @@ CHIP_ERROR ICDListCommand::RunCommand()
4040
fprintf(stderr, " +------------------------------------------------------------------------------------------+\n");
4141
fprintf(stderr, " | %-88s |\n", "Known ICDs:");
4242
fprintf(stderr, " +------------------------------------------------------------------------------------------+\n");
43-
fprintf(stderr, " | %20s | %15s | %15s | %16s | %10s |\n", "Fabric Index:Node ID", "Start Counter", "Counter Offset",
44-
"MonitoredSubject", "ClientType");
43+
fprintf(stderr, " | %20s | %20s | %15s | %15s | %16s | %10s |\n", "Fabric Index:Peer Node ID", "Fabric Index:CheckIn Node ID",
44+
"Start Counter", "Counter Offset", "MonitoredSubject", "ClientType");
4545

4646
while (iter->Next(info))
4747
{
4848
fprintf(stderr, " +------------------------------------------------------------------------------------------+\n");
49-
fprintf(stderr, " | %3" PRIu32 ":" ChipLogFormatX64 " | %15" PRIu32 " | %15" PRIu32 " | " ChipLogFormatX64 " | %10u |\n",
49+
fprintf(stderr,
50+
" | %3" PRIu32 ":" ChipLogFormatX64 " | %3" PRIu32 ":" ChipLogFormatX64 " | %15" PRIu32 " | %15" PRIu32
51+
" | " ChipLogFormatX64 " | %10u |\n",
5052
static_cast<uint32_t>(info.peer_node.GetFabricIndex()), ChipLogValueX64(info.peer_node.GetNodeId()),
53+
static_cast<uint32_t>(info.check_in_node.GetFabricIndex()), ChipLogValueX64(info.check_in_node.GetNodeId()),
5154
info.start_icd_counter, info.offset, ChipLogValueX64(info.monitored_subject),
5255
static_cast<uint8_t>(info.client_type));
5356

‎examples/chip-tool/commands/pairing/PairingCommand.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,8 @@ void PairingCommand::OnICDRegistrationComplete(ScopedNodeId nodeId, uint32_t icd
469469
sizeof(icdSymmetricKeyHex), chip::Encoding::HexFlags::kNullTerminate);
470470

471471
app::ICDClientInfo clientInfo;
472-
clientInfo.peer_node = chip::ScopedNodeId(mICDCheckInNodeId.Value(), nodeId.GetFabricIndex());
472+
clientInfo.check_in_node = chip::ScopedNodeId(mICDCheckInNodeId.Value(), nodeId.GetFabricIndex());
473+
clientInfo.peer_node = nodeId;
473474
clientInfo.monitored_subject = mICDMonitoredSubject.Value();
474475
clientInfo.start_icd_counter = icdCounter;
475476

‎src/app/icd/client/DefaultICDClientStorage.cpp

+8-1
Original file line numberDiff line numberDiff line change
@@ -235,16 +235,22 @@ CHIP_ERROR DefaultICDClientStorage::Load(FabricIndex fabricIndex, std::vector<IC
235235
ICDClientInfo clientInfo;
236236
TLV::TLVType ICDClientInfoType;
237237
NodeId nodeId;
238+
NodeId checkInNodeId;
238239
FabricIndex fabric;
239240
ReturnErrorOnFailure(reader.EnterContainer(ICDClientInfoType));
240241
// Peer Node ID
241242
ReturnErrorOnFailure(reader.Next(TLV::ContextTag(ClientInfoTag::kPeerNodeId)));
242243
ReturnErrorOnFailure(reader.Get(nodeId));
243244

245+
ReturnErrorOnFailure(reader.Next(TLV::ContextTag(ClientInfoTag::kCheckInNodeId)));
246+
ReturnErrorOnFailure(reader.Get(checkInNodeId));
247+
244248
// Fabric Index
245249
ReturnErrorOnFailure(reader.Next(TLV::ContextTag(ClientInfoTag::kFabricIndex)));
246250
ReturnErrorOnFailure(reader.Get(fabric));
247-
clientInfo.peer_node = ScopedNodeId(nodeId, fabric);
251+
252+
clientInfo.peer_node = ScopedNodeId(nodeId, fabric);
253+
clientInfo.check_in_node = ScopedNodeId(checkInNodeId, fabric);
248254

249255
// Start ICD Counter
250256
ReturnErrorOnFailure(reader.Next(TLV::ContextTag(ClientInfoTag::kStartICDCounter)));
@@ -323,6 +329,7 @@ CHIP_ERROR DefaultICDClientStorage::SerializeToTlv(TLV::TLVWriter & writer, cons
323329
TLV::TLVType ICDClientInfoContainerType;
324330
ReturnErrorOnFailure(writer.StartContainer(TLV::AnonymousTag(), TLV::kTLVType_Structure, ICDClientInfoContainerType));
325331
ReturnErrorOnFailure(writer.Put(TLV::ContextTag(ClientInfoTag::kPeerNodeId), clientInfo.peer_node.GetNodeId()));
332+
ReturnErrorOnFailure(writer.Put(TLV::ContextTag(ClientInfoTag::kCheckInNodeId), clientInfo.check_in_node.GetNodeId()));
326333
ReturnErrorOnFailure(writer.Put(TLV::ContextTag(ClientInfoTag::kFabricIndex), clientInfo.peer_node.GetFabricIndex()));
327334
ReturnErrorOnFailure(writer.Put(TLV::ContextTag(ClientInfoTag::kStartICDCounter), clientInfo.start_icd_counter));
328335
ReturnErrorOnFailure(writer.Put(TLV::ContextTag(ClientInfoTag::kOffset), clientInfo.offset));

‎src/app/icd/client/DefaultICDClientStorage.h

+11-9
Original file line numberDiff line numberDiff line change
@@ -124,13 +124,14 @@ class DefaultICDClientStorage : public ICDClientStorage
124124
enum class ClientInfoTag : uint8_t
125125
{
126126
kPeerNodeId = 1,
127-
kFabricIndex = 2,
128-
kStartICDCounter = 3,
129-
kOffset = 4,
130-
kMonitoredSubject = 5,
131-
kAesKeyHandle = 6,
132-
kHmacKeyHandle = 7,
133-
kClientType = 8,
127+
kCheckInNodeId = 2,
128+
kFabricIndex = 3,
129+
kStartICDCounter = 4,
130+
kOffset = 5,
131+
kMonitoredSubject = 6,
132+
kAesKeyHandle = 7,
133+
kHmacKeyHandle = 8,
134+
kClientType = 9,
134135
};
135136

136137
enum class CounterTag : uint8_t
@@ -158,8 +159,9 @@ class DefaultICDClientStorage : public ICDClientStorage
158159
{
159160
// All the fields added together
160161
return TLV::EstimateStructOverhead(
161-
sizeof(NodeId), sizeof(FabricIndex), sizeof(uint32_t) /*start_icd_counter*/, sizeof(uint32_t) /*offset*/,
162-
sizeof(uint64_t) /*monitored_subject*/, sizeof(Crypto::Symmetric128BitsKeyByteArray) /*aes_key_handle*/,
162+
sizeof(NodeId), sizeof(NodeId), sizeof(FabricIndex), sizeof(uint32_t) /*start_icd_counter*/,
163+
sizeof(uint32_t) /*offset*/, sizeof(uint64_t) /*monitored_subject*/,
164+
sizeof(Crypto::Symmetric128BitsKeyByteArray) /*aes_key_handle*/,
163165
sizeof(Crypto::Symmetric128BitsKeyByteArray) /*hmac_key_handle*/, sizeof(uint8_t) /*client_type*/);
164166
}
165167

‎src/app/icd/client/ICDClientInfo.h

+2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ namespace app {
3131
struct ICDClientInfo
3232
{
3333
ScopedNodeId peer_node;
34+
ScopedNodeId check_in_node;
3435
uint32_t start_icd_counter = 0;
3536
uint32_t offset = 0;
3637
Clusters::IcdManagement::ClientTypeEnum client_type = Clusters::IcdManagement::ClientTypeEnum::kPermanent;
@@ -44,6 +45,7 @@ struct ICDClientInfo
4445
ICDClientInfo & operator=(const ICDClientInfo & other)
4546
{
4647
peer_node = other.peer_node;
48+
check_in_node = other.check_in_node;
4749
start_icd_counter = other.start_icd_counter;
4850
offset = other.offset;
4951
client_type = other.client_type;

‎src/app/icd/client/RefreshKeySender.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ CHIP_ERROR RefreshKeySender::RegisterClientWithNewKey(Messaging::ExchangeManager
7777
EndpointId endpointId = 0;
7878

7979
Clusters::IcdManagement::Commands::RegisterClient::Type registerClientCommand;
80-
registerClientCommand.checkInNodeID = mICDClientInfo.peer_node.GetNodeId();
80+
registerClientCommand.checkInNodeID = mICDClientInfo.check_in_node.GetNodeId();
8181
registerClientCommand.monitoredSubject = mICDClientInfo.monitored_subject;
8282
registerClientCommand.key = mNewKey.Span();
8383
return Controller::InvokeCommandRequest(&exchangeMgr, sessionHandle, endpointId, registerClientCommand, onSuccess, onFailure);

‎src/controller/java/AndroidCheckInDelegate.cpp

+18-13
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@
2323
#include <lib/support/JniReferences.h>
2424
#include <lib/support/logging/CHIPLogging.h>
2525

26-
#define PARSE_CLIENT_INFO(_clientInfo, _peerNodeId, _startCounter, _offset, _monitoredSubject, _jniICDAesKey, _jniICDHmacKey) \
26+
#define PARSE_CLIENT_INFO(_clientInfo, _peerNodeId, _checkInNodeId, _startCounter, _offset, _monitoredSubject, _jniICDAesKey, \
27+
_jniICDHmacKey) \
2728
jlong _peerNodeId = static_cast<jlong>(_clientInfo.peer_node.GetNodeId()); \
29+
jlong _checkInNodeId = static_cast<jlong>(_clientInfo.check_in_node.GetNodeId()); \
2830
jlong _startCounter = static_cast<jlong>(_clientInfo.start_icd_counter); \
2931
jlong _offset = static_cast<jlong>(_clientInfo.offset); \
3032
jlong _monitoredSubject = static_cast<jlong>(_clientInfo.monitored_subject); \
@@ -53,24 +55,26 @@ CHIP_ERROR AndroidCheckInDelegate::SetDelegate(jobject checkInDelegateObj)
5355

5456
void AndroidCheckInDelegate::OnCheckInComplete(const ICDClientInfo & clientInfo)
5557
{
56-
ChipLogProgress(
57-
ICD, "Check In Message processing complete: start_counter=%" PRIu32 " offset=%" PRIu32 " nodeid=" ChipLogFormatScopedNodeId,
58-
clientInfo.start_icd_counter, clientInfo.offset, ChipLogValueScopedNodeId(clientInfo.peer_node));
58+
ChipLogProgress(ICD,
59+
"Check In Message processing complete: start_counter=%" PRIu32 " offset=%" PRIu32
60+
" peernodeid=" ChipLogFormatScopedNodeId " checkinnodeid=" ChipLogFormatScopedNodeId,
61+
clientInfo.start_icd_counter, clientInfo.offset, ChipLogValueScopedNodeId(clientInfo.peer_node),
62+
ChipLogValueScopedNodeId(clientInfo.check_in_node));
5963

6064
VerifyOrReturn(mCheckInDelegate.HasValidObjectRef(), ChipLogProgress(ICD, "check-in delegate is not implemented!"));
6165

6266
JNIEnv * env = chip::JniReferences::GetInstance().GetEnvForCurrentThread();
6367
VerifyOrReturn(env != nullptr, ChipLogError(Controller, "JNIEnv is null!"));
64-
PARSE_CLIENT_INFO(clientInfo, peerNodeId, startCounter, offset, monitoredSubject, jniICDAesKey, jniICDHmacKey)
68+
PARSE_CLIENT_INFO(clientInfo, peerNodeId, checkInNodeId, startCounter, offset, monitoredSubject, jniICDAesKey, jniICDHmacKey)
6569

6670
jmethodID onCheckInCompleteMethodID = nullptr;
6771
CHIP_ERROR err = chip::JniReferences::GetInstance().FindMethod(env, mCheckInDelegate.ObjectRef(), "onCheckInComplete",
68-
"(JJJJ[B[B)V", &onCheckInCompleteMethodID);
72+
"(JJJJJ[B[B)V", &onCheckInCompleteMethodID);
6973
VerifyOrReturn(err == CHIP_NO_ERROR,
7074
ChipLogProgress(ICD, "onCheckInComplete - FindMethod is failed! : %" CHIP_ERROR_FORMAT, err.Format()));
7175

72-
env->CallVoidMethod(mCheckInDelegate.ObjectRef(), onCheckInCompleteMethodID, peerNodeId, startCounter, offset, monitoredSubject,
73-
jniICDAesKey.jniValue(), jniICDHmacKey.jniValue());
76+
env->CallVoidMethod(mCheckInDelegate.ObjectRef(), onCheckInCompleteMethodID, peerNodeId, checkInNodeId, startCounter, offset,
77+
monitoredSubject, jniICDAesKey.jniValue(), jniICDHmacKey.jniValue());
7478
}
7579

7680
RefreshKeySender * AndroidCheckInDelegate::OnKeyRefreshNeeded(ICDClientInfo & clientInfo, ICDClientStorage * clientStorage)
@@ -84,17 +88,18 @@ RefreshKeySender * AndroidCheckInDelegate::OnKeyRefreshNeeded(ICDClientInfo & cl
8488
JNIEnv * env = chip::JniReferences::GetInstance().GetEnvForCurrentThread();
8589
VerifyOrReturnValue(env != nullptr, nullptr, ChipLogError(Controller, "JNIEnv is null!"));
8690

87-
PARSE_CLIENT_INFO(clientInfo, peerNodeId, startCounter, offset, monitoredSubject, jniICDAesKey, jniICDHmacKey)
91+
PARSE_CLIENT_INFO(clientInfo, peerNodeId, checkInNodeId, startCounter, offset, monitoredSubject, jniICDAesKey,
92+
jniICDHmacKey)
8893

8994
jmethodID onKeyRefreshNeededMethodID = nullptr;
90-
err = chip::JniReferences::GetInstance().FindMethod(env, mCheckInDelegate.ObjectRef(), "onKeyRefreshNeeded", "(JJJJ[B[B)V",
95+
err = chip::JniReferences::GetInstance().FindMethod(env, mCheckInDelegate.ObjectRef(), "onKeyRefreshNeeded", "(JJJJJ[B[B)V",
9196
&onKeyRefreshNeededMethodID);
9297
VerifyOrReturnValue(err == CHIP_NO_ERROR, nullptr,
9398
ChipLogProgress(ICD, "onKeyRefreshNeeded - FindMethod is failed! : %" CHIP_ERROR_FORMAT, err.Format()));
9499

95-
jbyteArray key = static_cast<jbyteArray>(env->CallObjectMethod(mCheckInDelegate.ObjectRef(), onKeyRefreshNeededMethodID,
96-
peerNodeId, startCounter, offset, monitoredSubject,
97-
jniICDAesKey.jniValue(), jniICDHmacKey.jniValue()));
100+
jbyteArray key = static_cast<jbyteArray>(
101+
env->CallObjectMethod(mCheckInDelegate.ObjectRef(), onKeyRefreshNeededMethodID, peerNodeId, checkInNodeId, startCounter,
102+
offset, monitoredSubject, jniICDAesKey.jniValue(), jniICDHmacKey.jniValue()));
98103

99104
if (key != nullptr)
100105
{

‎src/controller/java/AndroidDeviceControllerWrapper.cpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -1015,6 +1015,8 @@ void AndroidDeviceControllerWrapper::OnICDRegistrationComplete(chip::ScopedNodeI
10151015
CHIP_ERROR err = CHIP_NO_ERROR;
10161016
chip::app::ICDClientInfo clientInfo;
10171017
clientInfo.peer_node = icdNodeId;
1018+
clientInfo.check_in_node = chip::ScopedNodeId(mAutoCommissioner.GetCommissioningParameters().GetICDCheckInNodeId().Value(),
1019+
icdNodeId.GetFabricIndex());
10181020
clientInfo.monitored_subject = mAutoCommissioner.GetCommissioningParameters().GetICDMonitoredSubject().Value();
10191021
clientInfo.start_icd_counter = icdCounter;
10201022

@@ -1056,7 +1058,7 @@ void AndroidDeviceControllerWrapper::OnICDRegistrationComplete(chip::ScopedNodeI
10561058
methodErr = chip::JniReferences::GetInstance().GetLocalClassRef(env, "chip/devicecontroller/ICDDeviceInfo", icdDeviceInfoClass);
10571059
VerifyOrReturn(methodErr == CHIP_NO_ERROR, ChipLogError(Controller, "Could not find class ICDDeviceInfo"));
10581060

1059-
icdDeviceInfoStructCtor = env->GetMethodID(icdDeviceInfoClass, "<init>", "([BILjava/lang/String;JJIJJJJI)V");
1061+
icdDeviceInfoStructCtor = env->GetMethodID(icdDeviceInfoClass, "<init>", "([BILjava/lang/String;JJIJJJJJI)V");
10601062
VerifyOrReturn(icdDeviceInfoStructCtor != nullptr, ChipLogError(Controller, "Could not find ICDDeviceInfo constructor"));
10611063

10621064
methodErr =
@@ -1069,7 +1071,9 @@ void AndroidDeviceControllerWrapper::OnICDRegistrationComplete(chip::ScopedNodeI
10691071
icdDeviceInfoObj = env->NewObject(
10701072
icdDeviceInfoClass, icdDeviceInfoStructCtor, jSymmetricKey, static_cast<jint>(mUserActiveModeTriggerHint.Raw()),
10711073
jUserActiveModeTriggerInstruction, static_cast<jlong>(mIdleModeDuration), static_cast<jlong>(mActiveModeDuration),
1072-
static_cast<jint>(mActiveModeThreshold), static_cast<jlong>(icdNodeId.GetNodeId()), static_cast<jlong>(icdCounter),
1074+
static_cast<jint>(mActiveModeThreshold), static_cast<jlong>(icdNodeId.GetNodeId()),
1075+
static_cast<jlong>(mAutoCommissioner.GetCommissioningParameters().GetICDCheckInNodeId().Value()),
1076+
static_cast<jlong>(icdCounter),
10731077
static_cast<jlong>(mAutoCommissioner.GetCommissioningParameters().GetICDMonitoredSubject().Value()),
10741078
static_cast<jlong>(Controller()->GetFabricId()), static_cast<jint>(Controller()->GetFabricIndex()));
10751079

0 commit comments

Comments
 (0)