Skip to content

Commit 1359496

Browse files
authored
Merge a2b61a5 into 149e582
2 parents 149e582 + a2b61a5 commit 1359496

File tree

5 files changed

+164
-322
lines changed

5 files changed

+164
-322
lines changed

src/transport/SecureSession.h

+16-27
Original file line numberDiff line numberDiff line change
@@ -49,43 +49,31 @@ static constexpr uint32_t kUndefinedMessageIndex = UINT32_MAX;
4949
class SecureSession
5050
{
5151
public:
52-
SecureSession() : mPeerAddress(PeerAddress::Uninitialized()) {}
53-
SecureSession(const PeerAddress & addr) : mPeerAddress(addr) {}
54-
SecureSession(PeerAddress && addr) : mPeerAddress(addr) {}
52+
SecureSession(uint16_t localSessionId, NodeId peerNodeId, uint16_t peerSessionId, FabricIndex fabric, uint64_t time) :
53+
mPeerNodeId(peerNodeId), mLocalSessionId(localSessionId), mPeerSessionId(peerSessionId), mFabric(fabric)
54+
{
55+
SetLastActivityTimeMs(time);
56+
}
5557

56-
SecureSession(SecureSession &&) = default;
57-
SecureSession(const SecureSession &) = default;
58-
SecureSession & operator=(const SecureSession &) = default;
59-
SecureSession & operator=(SecureSession &&) = default;
58+
SecureSession(SecureSession &&) = delete;
59+
SecureSession(const SecureSession &) = delete;
60+
SecureSession & operator=(const SecureSession &) = delete;
61+
SecureSession & operator=(SecureSession &&) = delete;
6062

6163
const PeerAddress & GetPeerAddress() const { return mPeerAddress; }
6264
PeerAddress & GetPeerAddress() { return mPeerAddress; }
6365
void SetPeerAddress(const PeerAddress & address) { mPeerAddress = address; }
6466

6567
NodeId GetPeerNodeId() const { return mPeerNodeId; }
66-
void SetPeerNodeId(NodeId peerNodeId) { mPeerNodeId = peerNodeId; }
67-
68-
uint16_t GetPeerSessionId() const { return mPeerSessionId; }
69-
void SetPeerSessionId(uint16_t id) { mPeerSessionId = id; }
70-
71-
// TODO: Rename KeyID to SessionID
7268
uint16_t GetLocalSessionId() const { return mLocalSessionId; }
73-
void SetLocalSessionId(uint16_t id) { mLocalSessionId = id; }
69+
uint16_t GetPeerSessionId() const { return mPeerSessionId; }
70+
FabricIndex GetFabricIndex() const { return mFabric; }
7471

7572
uint64_t GetLastActivityTimeMs() const { return mLastActivityTimeMs; }
7673
void SetLastActivityTimeMs(uint64_t value) { mLastActivityTimeMs = value; }
7774

7875
CryptoContext & GetCryptoContext() { return mCryptoContext; }
7976

80-
FabricIndex GetFabricIndex() const { return mFabric; }
81-
void SetFabricIndex(FabricIndex fabricIndex) { mFabric = fabricIndex; }
82-
83-
bool IsInitialized()
84-
{
85-
return (mPeerAddress.IsInitialized() || mPeerNodeId != kUndefinedNodeId || mPeerSessionId != UINT16_MAX ||
86-
mLocalSessionId != UINT16_MAX);
87-
}
88-
8977
CHIP_ERROR EncryptBeforeSend(const uint8_t * input, size_t input_length, uint8_t * output, PacketHeader & header,
9078
MessageAuthenticationCode & mac) const
9179
{
@@ -101,14 +89,15 @@ class SecureSession
10189
SessionMessageCounter & GetSessionMessageCounter() { return mSessionMessageCounter; }
10290

10391
private:
92+
const NodeId mPeerNodeId;
93+
const uint16_t mLocalSessionId;
94+
const uint16_t mPeerSessionId;
95+
const FabricIndex mFabric;
96+
10497
PeerAddress mPeerAddress;
105-
NodeId mPeerNodeId = kUndefinedNodeId;
106-
uint16_t mPeerSessionId = UINT16_MAX;
107-
uint16_t mLocalSessionId = UINT16_MAX;
10898
uint64_t mLastActivityTimeMs = 0;
10999
CryptoContext mCryptoContext;
110100
SessionMessageCounter mSessionMessageCounter;
111-
FabricIndex mFabric = kUndefinedFabricIndex;
112101
};
113102

114103
} // namespace Transport

src/transport/SecureSessionTable.h

+27-129
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
#include <lib/core/CHIPError.h>
2020
#include <lib/support/CodeUtils.h>
21+
#include <lib/support/Pool.h>
2122
#include <system/TimeSource.h>
2223
#include <transport/SecureSession.h>
2324

@@ -43,9 +44,10 @@ class SecureSessionTable
4344
/**
4445
* Allocates a new secure session out of the internal resource pool.
4546
*
46-
* @param peerNode represents peer Node's ID
47-
* @param peerSessionId represents the encryption key ID assigned by peer node
4847
* @param localSessionId represents the encryption key ID assigned by local node
48+
* @param peerNodeId represents peer Node's ID
49+
* @param peerSessionId represents the encryption key ID assigned by peer node
50+
* @param fabric represents fabric ID for the session
4951
* @param state [out] will contain the session if one was available. May be null if no return value is desired.
5052
*
5153
* @note the newly created state will have an 'active' time set based on the current time source.
@@ -54,70 +56,17 @@ class SecureSessionTable
5456
* has been reached (with CHIP_ERROR_NO_MEMORY).
5557
*/
5658
CHECK_RETURN_VALUE
57-
CHIP_ERROR CreateNewSecureSession(NodeId peerNode, uint16_t peerSessionId, uint16_t localSessionId, SecureSession ** state)
59+
SecureSession * CreateNewSecureSession(uint16_t localSessionId, NodeId peerNodeId, uint16_t peerSessionId, FabricIndex fabric)
5860
{
59-
CHIP_ERROR err = CHIP_ERROR_NO_MEMORY;
60-
61-
if (state)
62-
{
63-
*state = nullptr;
64-
}
65-
66-
for (size_t i = 0; i < kMaxSessionCount; i++)
67-
{
68-
if (!mStates[i].IsInitialized())
69-
{
70-
mStates[i] = SecureSession();
71-
mStates[i].SetPeerNodeId(peerNode);
72-
mStates[i].SetPeerSessionId(peerSessionId);
73-
mStates[i].SetLocalSessionId(localSessionId);
74-
mStates[i].SetLastActivityTimeMs(mTimeSource.GetCurrentMonotonicTimeMs());
75-
76-
if (state)
77-
{
78-
*state = &mStates[i];
79-
}
80-
81-
err = CHIP_NO_ERROR;
82-
break;
83-
}
84-
}
85-
86-
return err;
61+
return mEntries.CreateObject(localSessionId, peerNodeId, peerSessionId, fabric, mTimeSource.GetCurrentMonotonicTimeMs());
8762
}
8863

89-
/**
90-
* Get a secure session given a Node Id.
91-
*
92-
* @param nodeId is the session to find (based on nodeId).
93-
* @param begin If a member of the pool, will start search from the next item. Can be nullptr to search from start.
94-
*
95-
* @return the state found, nullptr if not found
96-
*/
97-
CHECK_RETURN_VALUE
98-
SecureSession * FindSecureSession(NodeId nodeId, SecureSession * begin)
99-
{
100-
SecureSession * state = nullptr;
101-
SecureSession * iter = &mStates[0];
102-
103-
if (begin >= iter && begin < &mStates[kMaxSessionCount])
104-
{
105-
iter = begin + 1;
106-
}
64+
void ReleaseSession(SecureSession * session) { mEntries.ReleaseObject(session); }
10765

108-
for (; iter < &mStates[kMaxSessionCount]; iter++)
109-
{
110-
if (!iter->IsInitialized())
111-
{
112-
continue;
113-
}
114-
if (iter->GetPeerNodeId() == nodeId)
115-
{
116-
state = iter;
117-
break;
118-
}
119-
}
120-
return state;
66+
template <typename Function>
67+
bool ForEachSession(Function && function)
68+
{
69+
return mEntries.ForEachActiveObject(std::forward<Function>(function));
12170
}
12271

12372
/**
@@ -129,66 +78,23 @@ class SecureSessionTable
12978
* @return the state found, nullptr if not found
13079
*/
13180
CHECK_RETURN_VALUE
132-
SecureSession * FindSecureSessionByLocalKey(uint16_t localSessionId, SecureSession * begin)
133-
{
134-
SecureSession * state = nullptr;
135-
SecureSession * iter = &mStates[0];
136-
137-
if (begin >= iter && begin < &mStates[kMaxSessionCount])
138-
{
139-
iter = begin + 1;
140-
}
141-
142-
for (; iter < &mStates[kMaxSessionCount]; iter++)
143-
{
144-
if (!iter->IsInitialized())
145-
{
146-
continue;
147-
}
148-
if (iter->GetLocalSessionId() == localSessionId)
149-
{
150-
state = iter;
151-
break;
152-
}
153-
}
154-
return state;
155-
}
156-
157-
/**
158-
* Get the first session that matches the given fabric index.
159-
*
160-
* @param fabric The fabric index to match
161-
*
162-
* @return the session found, nullptr if not found
163-
*/
164-
CHECK_RETURN_VALUE
165-
SecureSession * FindSecureSessionByFabric(FabricIndex fabric)
81+
SecureSession * FindSecureSessionByLocalKey(uint16_t localSessionId)
16682
{
167-
for (auto & state : mStates)
168-
{
169-
if (!state.IsInitialized())
170-
{
171-
continue;
172-
}
173-
if (state.GetFabricIndex() == fabric)
83+
SecureSession * result = nullptr;
84+
mEntries.ForEachActiveObject([&](auto session) {
85+
if (session->GetLocalSessionId() == localSessionId)
17486
{
175-
return &state;
87+
result = session;
88+
return false;
17689
}
177-
}
178-
return nullptr;
90+
return true;
91+
});
92+
return result;
17993
}
18094

18195
/// Convenience method to mark a session as active
18296
void MarkSessionActive(SecureSession * state) { state->SetLastActivityTimeMs(mTimeSource.GetCurrentMonotonicTimeMs()); }
18397

184-
/// Convenience method to expired a session and fired the related callback
185-
template <typename Callback>
186-
void MarkSessionExpired(SecureSession * state, Callback callback)
187-
{
188-
callback(*state);
189-
*state = SecureSession(PeerAddress::Uninitialized());
190-
}
191-
19298
/**
19399
* Iterates through all active sessions and expires any sessions with an idle time
194100
* larger than the given amount.
@@ -199,30 +105,22 @@ class SecureSessionTable
199105
void ExpireInactiveSessions(uint64_t maxIdleTimeMs, Callback callback)
200106
{
201107
const uint64_t currentTime = mTimeSource.GetCurrentMonotonicTimeMs();
202-
203-
for (size_t i = 0; i < kMaxSessionCount; i++)
204-
{
205-
if (!mStates[i].IsInitialized())
206-
{
207-
continue; // not an active session
208-
}
209-
210-
uint64_t sessionActiveTime = mStates[i].GetLastActivityTimeMs();
211-
if (sessionActiveTime + maxIdleTimeMs >= currentTime)
108+
mEntries.ForEachActiveObject([&](auto session) {
109+
if (session->GetLastActivityTimeMs() + maxIdleTimeMs < currentTime)
212110
{
213-
continue; // not expired
111+
callback(*session);
112+
ReleaseSession(session);
214113
}
215-
216-
MarkSessionExpired(&mStates[i], callback);
217-
}
114+
return true;
115+
});
218116
}
219117

220118
/// Allows access to the underlying time source used for keeping track of session active time
221119
Time::TimeSource<kTimeSource> & GetTimeSource() { return mTimeSource; }
222120

223121
private:
224122
Time::TimeSource<kTimeSource> mTimeSource;
225-
SecureSession mStates[kMaxSessionCount];
123+
BitMapObjectPool<SecureSession, kMaxSessionCount> mEntries;
226124
};
227125

228126
} // namespace Transport

0 commit comments

Comments
 (0)