Skip to content

Commit 1ff7a2a

Browse files
committed
base: Add media retention policy to EventCacheStore
Signed-off-by: Kévin Commaille <zecakeh@tedomum.fr>
1 parent 06fc220 commit 1ff7a2a

File tree

9 files changed

+1188
-102
lines changed

9 files changed

+1188
-102
lines changed

crates/matrix-sdk-base/src/client.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ use crate::RoomMemberships;
7171
use crate::{
7272
deserialized_responses::{RawAnySyncOrStrippedTimelineEvent, SyncTimelineEvent},
7373
error::{Error, Result},
74-
event_cache_store::DynEventCacheStore,
74+
event_cache_store::EventCacheStoreWrapper,
7575
rooms::{
7676
normal::{RoomInfoNotableUpdate, RoomInfoNotableUpdateReasons},
7777
Room, RoomInfo, RoomState,
@@ -93,7 +93,7 @@ pub struct BaseClient {
9393
/// Database
9494
pub(crate) store: Store,
9595
/// The store used by the event cache.
96-
event_cache_store: Arc<DynEventCacheStore>,
96+
event_cache_store: EventCacheStoreWrapper,
9797
/// The store used for encryption.
9898
///
9999
/// This field is only meant to be used for `OlmMachine` initialization.
@@ -147,7 +147,7 @@ impl BaseClient {
147147

148148
BaseClient {
149149
store: Store::new(config.state_store),
150-
event_cache_store: config.event_cache_store,
150+
event_cache_store: EventCacheStoreWrapper::new(config.event_cache_store),
151151
#[cfg(feature = "e2e-encryption")]
152152
crypto_store: config.crypto_store,
153153
#[cfg(feature = "e2e-encryption")]
@@ -222,8 +222,8 @@ impl BaseClient {
222222
}
223223

224224
/// Get a reference to the event cache store.
225-
pub fn event_cache_store(&self) -> &DynEventCacheStore {
226-
&*self.event_cache_store
225+
pub fn event_cache_store(&self) -> &EventCacheStoreWrapper {
226+
&self.event_cache_store
227227
}
228228

229229
/// Is the client logged in.

crates/matrix-sdk-base/src/event_cache_store/integration_tests.rs

+468-29
Large diffs are not rendered by default.

crates/matrix-sdk-base/src/event_cache_store/memory_store.rs

+176-25
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ use std::{num::NonZeroUsize, sync::RwLock as StdRwLock};
1616

1717
use async_trait::async_trait;
1818
use matrix_sdk_common::ring_buffer::RingBuffer;
19-
use ruma::{MxcUri, OwnedMxcUri};
19+
use ruma::{time::SystemTime, MxcUri, OwnedMxcUri};
2020

21-
use super::{EventCacheStore, EventCacheStoreError, Result};
21+
use super::{EventCacheStore, EventCacheStoreError, MediaRetentionPolicy, Result};
2222
use crate::media::{MediaRequest, UniqueKey as _};
2323

2424
/// In-memory, non-persistent implementation of the `EventCacheStore`.
@@ -27,15 +27,41 @@ use crate::media::{MediaRequest, UniqueKey as _};
2727
#[allow(clippy::type_complexity)]
2828
#[derive(Debug)]
2929
pub struct MemoryStore {
30-
media: StdRwLock<RingBuffer<(OwnedMxcUri, String /* unique key */, Vec<u8>)>>,
30+
inner: StdRwLock<MemoryStoreInner>,
31+
}
32+
33+
#[derive(Debug)]
34+
struct MemoryStoreInner {
35+
/// The media retention policy to use on cleanups.
36+
media_retention_policy: Option<MediaRetentionPolicy>,
37+
/// Media content.
38+
media: RingBuffer<MediaContent>,
3139
}
3240

3341
// SAFETY: `new_unchecked` is safe because 20 is not zero.
3442
const NUMBER_OF_MEDIAS: NonZeroUsize = unsafe { NonZeroUsize::new_unchecked(20) };
3543

44+
/// A media content.
45+
#[derive(Debug, Clone)]
46+
struct MediaContent {
47+
/// The Matrix URI of the media.
48+
uri: OwnedMxcUri,
49+
/// The unique key of the media request.
50+
key: String,
51+
/// The content of the media.
52+
data: Vec<u8>,
53+
/// The last access time of the media.
54+
last_access: SystemTime,
55+
}
56+
3657
impl Default for MemoryStore {
3758
fn default() -> Self {
38-
Self { media: StdRwLock::new(RingBuffer::new(NUMBER_OF_MEDIAS)) }
59+
let inner = MemoryStoreInner {
60+
media_retention_policy: Default::default(),
61+
media: RingBuffer::new(NUMBER_OF_MEDIAS),
62+
};
63+
64+
Self { inner: StdRwLock::new(inner) }
3965
}
4066
}
4167

@@ -51,53 +77,178 @@ impl MemoryStore {
5177
impl EventCacheStore for MemoryStore {
5278
type Error = EventCacheStoreError;
5379

54-
async fn add_media_content(&self, request: &MediaRequest, data: Vec<u8>) -> Result<()> {
80+
async fn media_retention_policy(&self) -> Result<Option<MediaRetentionPolicy>, Self::Error> {
81+
Ok(self.inner.read().unwrap().media_retention_policy)
82+
}
83+
84+
async fn set_media_retention_policy(
85+
&self,
86+
policy: MediaRetentionPolicy,
87+
) -> Result<(), Self::Error> {
88+
let mut inner = self.inner.write().unwrap();
89+
inner.media_retention_policy = Some(policy);
90+
91+
Ok(())
92+
}
93+
94+
async fn add_media_content(
95+
&self,
96+
request: &MediaRequest,
97+
data: Vec<u8>,
98+
current_time: SystemTime,
99+
policy: MediaRetentionPolicy,
100+
) -> Result<()> {
55101
// Avoid duplication. Let's try to remove it first.
56102
self.remove_media_content(request).await?;
103+
104+
if policy.exceeds_max_file_size(data.len()) {
105+
// The content is too big to be cached.
106+
return Ok(());
107+
}
108+
57109
// Now, let's add it.
58-
self.media.write().unwrap().push((request.uri().to_owned(), request.unique_key(), data));
110+
let content = MediaContent {
111+
uri: request.uri().to_owned(),
112+
key: request.unique_key(),
113+
data,
114+
last_access: current_time,
115+
};
116+
self.inner.write().unwrap().media.push(content);
59117

60118
Ok(())
61119
}
62120

63-
async fn get_media_content(&self, request: &MediaRequest) -> Result<Option<Vec<u8>>> {
64-
let media = self.media.read().unwrap();
121+
async fn get_media_content(
122+
&self,
123+
request: &MediaRequest,
124+
current_time: SystemTime,
125+
) -> Result<Option<Vec<u8>>> {
126+
let mut inner = self.inner.write().unwrap();
65127
let expected_key = request.unique_key();
66128

67-
Ok(media.iter().find_map(|(_media_uri, media_key, media_content)| {
68-
(media_key == &expected_key).then(|| media_content.to_owned())
69-
}))
129+
// First get the content out of the buffer.
130+
let Some(index) = inner.media.iter().position(|media| media.key == expected_key) else {
131+
return Ok(None);
132+
};
133+
let Some(mut content) = inner.media.remove(index) else {
134+
return Ok(None);
135+
};
136+
137+
// Clone the data.
138+
let data = content.data.clone();
139+
140+
// Update the last access time.
141+
content.last_access = current_time;
142+
143+
// Put it back in the buffer.
144+
inner.media.push(content);
145+
146+
Ok(Some(data))
70147
}
71148

72149
async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> {
73-
let mut media = self.media.write().unwrap();
150+
let mut inner = self.inner.write().unwrap();
151+
74152
let expected_key = request.unique_key();
75-
let Some(index) = media
76-
.iter()
77-
.position(|(_media_uri, media_key, _media_content)| media_key == &expected_key)
78-
else {
153+
let Some(index) = inner.media.iter().position(|media| media.key == expected_key) else {
79154
return Ok(());
80155
};
81156

82-
media.remove(index);
157+
inner.media.remove(index);
83158

84159
Ok(())
85160
}
86161

87162
async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> {
88-
let mut media = self.media.write().unwrap();
89-
let expected_key = uri.to_owned();
90-
let positions = media
163+
let mut inner = self.inner.write().unwrap();
164+
let positions = inner
165+
.media
91166
.iter()
92167
.enumerate()
93-
.filter_map(|(position, (media_uri, _media_key, _media_content))| {
94-
(media_uri == &expected_key).then_some(position)
95-
})
168+
.filter_map(|(position, media)| (media.uri == uri).then_some(position))
96169
.collect::<Vec<_>>();
97170

98171
// Iterate in reverse-order so that positions stay valid after first removals.
99172
for position in positions.into_iter().rev() {
100-
media.remove(position);
173+
inner.media.remove(position);
174+
}
175+
176+
Ok(())
177+
}
178+
179+
async fn clean_up_media_cache(
180+
&self,
181+
policy: MediaRetentionPolicy,
182+
current_time: SystemTime,
183+
) -> Result<(), Self::Error> {
184+
if !policy.has_limitations() {
185+
// We can safely skip all the checks.
186+
return Ok(());
187+
}
188+
189+
let mut inner = self.inner.write().unwrap();
190+
191+
// First, check media content that exceed the max filesize.
192+
if policy.max_file_size.is_some() || policy.max_cache_size.is_some() {
193+
inner.media.retain(|content| !policy.exceeds_max_file_size(content.data.len()));
194+
}
195+
196+
// Then, clean up expired media content.
197+
if policy.last_access_expiry.is_some() {
198+
inner
199+
.media
200+
.retain(|content| !policy.has_content_expired(current_time, content.last_access));
201+
}
202+
203+
// Finally, if the cache size is too big, remove old items until it fits.
204+
if let Some(max_cache_size) = policy.max_cache_size {
205+
// Reverse the iterator because in case the cache size is overflowing, we want
206+
// to count the number of old items to remove, and old items are at
207+
// the start.
208+
let (cache_size, overflowing_count) = inner.media.iter().rev().fold(
209+
(0usize, 0u8),
210+
|(cache_size, overflowing_count), content| {
211+
if overflowing_count > 0 {
212+
// Assume that all data is overflowing now. Overflowing count cannot
213+
// overflow because the number of items is limited to 20.
214+
(cache_size, overflowing_count + 1)
215+
} else {
216+
match cache_size.checked_add(content.data.len()) {
217+
Some(cache_size) => (cache_size, 0),
218+
// The cache size is overflowing, let's count the number of overflowing
219+
// items to be able to remove them, since the max cache size cannot be
220+
// bigger than usize::MAX.
221+
None => (cache_size, 1),
222+
}
223+
}
224+
},
225+
);
226+
227+
// If the cache size is overflowing, remove the number of old items we counted.
228+
for _position in 0..overflowing_count {
229+
inner.media.pop();
230+
}
231+
232+
if cache_size > max_cache_size {
233+
let difference = cache_size - max_cache_size;
234+
235+
// Count the number of old items to remove to reach the difference.
236+
let mut accumulated_items_size = 0usize;
237+
let mut remove_items_count = 0u8;
238+
for content in inner.media.iter() {
239+
remove_items_count += 1;
240+
// Cannot overflow since we already removed overflowing items.
241+
accumulated_items_size += content.data.len();
242+
243+
if accumulated_items_size >= difference {
244+
break;
245+
}
246+
}
247+
248+
for _position in 0..remove_items_count {
249+
inner.media.pop();
250+
}
251+
}
101252
}
102253

103254
Ok(())
@@ -112,5 +263,5 @@ mod tests {
112263
Ok(MemoryStore::new())
113264
}
114265

115-
event_cache_store_integration_tests!();
266+
event_cache_store_integration_tests!(with_media_size_tests);
116267
}

0 commit comments

Comments
 (0)