@@ -16,9 +16,9 @@ use std::{num::NonZeroUsize, sync::RwLock as StdRwLock};
16
16
17
17
use async_trait:: async_trait;
18
18
use matrix_sdk_common:: ring_buffer:: RingBuffer ;
19
- use ruma:: { MxcUri , OwnedMxcUri } ;
19
+ use ruma:: { time :: SystemTime , MxcUri , OwnedMxcUri } ;
20
20
21
- use super :: { EventCacheStore , EventCacheStoreError , Result } ;
21
+ use super :: { EventCacheStore , EventCacheStoreError , MediaRetentionPolicy , Result } ;
22
22
use crate :: media:: { MediaRequest , UniqueKey as _} ;
23
23
24
24
/// In-memory, non-persistent implementation of the `EventCacheStore`.
@@ -27,15 +27,41 @@ use crate::media::{MediaRequest, UniqueKey as _};
27
27
#[ allow( clippy:: type_complexity) ]
28
28
#[ derive( Debug ) ]
29
29
pub struct MemoryStore {
30
- media : StdRwLock < RingBuffer < ( OwnedMxcUri , String /* unique key */ , Vec < u8 > ) > > ,
30
+ inner : StdRwLock < MemoryStoreInner > ,
31
+ }
32
+
33
+ #[ derive( Debug ) ]
34
+ struct MemoryStoreInner {
35
+ /// The media retention policy to use on cleanups.
36
+ media_retention_policy : Option < MediaRetentionPolicy > ,
37
+ /// Media content.
38
+ media : RingBuffer < MediaContent > ,
31
39
}
32
40
33
41
// SAFETY: `new_unchecked` is safe because 20 is not zero.
34
42
const NUMBER_OF_MEDIAS : NonZeroUsize = unsafe { NonZeroUsize :: new_unchecked ( 20 ) } ;
35
43
44
+ /// A media content.
45
+ #[ derive( Debug , Clone ) ]
46
+ struct MediaContent {
47
+ /// The Matrix URI of the media.
48
+ uri : OwnedMxcUri ,
49
+ /// The unique key of the media request.
50
+ key : String ,
51
+ /// The content of the media.
52
+ data : Vec < u8 > ,
53
+ /// The last access time of the media.
54
+ last_access : SystemTime ,
55
+ }
56
+
36
57
impl Default for MemoryStore {
37
58
fn default ( ) -> Self {
38
- Self { media : StdRwLock :: new ( RingBuffer :: new ( NUMBER_OF_MEDIAS ) ) }
59
+ let inner = MemoryStoreInner {
60
+ media_retention_policy : Default :: default ( ) ,
61
+ media : RingBuffer :: new ( NUMBER_OF_MEDIAS ) ,
62
+ } ;
63
+
64
+ Self { inner : StdRwLock :: new ( inner) }
39
65
}
40
66
}
41
67
@@ -51,53 +77,178 @@ impl MemoryStore {
51
77
impl EventCacheStore for MemoryStore {
52
78
type Error = EventCacheStoreError ;
53
79
54
- async fn add_media_content ( & self , request : & MediaRequest , data : Vec < u8 > ) -> Result < ( ) > {
80
+ async fn media_retention_policy ( & self ) -> Result < Option < MediaRetentionPolicy > , Self :: Error > {
81
+ Ok ( self . inner . read ( ) . unwrap ( ) . media_retention_policy )
82
+ }
83
+
84
+ async fn set_media_retention_policy (
85
+ & self ,
86
+ policy : MediaRetentionPolicy ,
87
+ ) -> Result < ( ) , Self :: Error > {
88
+ let mut inner = self . inner . write ( ) . unwrap ( ) ;
89
+ inner. media_retention_policy = Some ( policy) ;
90
+
91
+ Ok ( ( ) )
92
+ }
93
+
94
+ async fn add_media_content (
95
+ & self ,
96
+ request : & MediaRequest ,
97
+ data : Vec < u8 > ,
98
+ current_time : SystemTime ,
99
+ policy : MediaRetentionPolicy ,
100
+ ) -> Result < ( ) > {
55
101
// Avoid duplication. Let's try to remove it first.
56
102
self . remove_media_content ( request) . await ?;
103
+
104
+ if policy. exceeds_max_file_size ( data. len ( ) ) {
105
+ // The content is too big to be cached.
106
+ return Ok ( ( ) ) ;
107
+ }
108
+
57
109
// Now, let's add it.
58
- self . media . write ( ) . unwrap ( ) . push ( ( request. uri ( ) . to_owned ( ) , request. unique_key ( ) , data) ) ;
110
+ let content = MediaContent {
111
+ uri : request. uri ( ) . to_owned ( ) ,
112
+ key : request. unique_key ( ) ,
113
+ data,
114
+ last_access : current_time,
115
+ } ;
116
+ self . inner . write ( ) . unwrap ( ) . media . push ( content) ;
59
117
60
118
Ok ( ( ) )
61
119
}
62
120
63
- async fn get_media_content ( & self , request : & MediaRequest ) -> Result < Option < Vec < u8 > > > {
64
- let media = self . media . read ( ) . unwrap ( ) ;
121
+ async fn get_media_content (
122
+ & self ,
123
+ request : & MediaRequest ,
124
+ current_time : SystemTime ,
125
+ ) -> Result < Option < Vec < u8 > > > {
126
+ let mut inner = self . inner . write ( ) . unwrap ( ) ;
65
127
let expected_key = request. unique_key ( ) ;
66
128
67
- Ok ( media. iter ( ) . find_map ( |( _media_uri, media_key, media_content) | {
68
- ( media_key == & expected_key) . then ( || media_content. to_owned ( ) )
69
- } ) )
129
+ // First get the content out of the buffer.
130
+ let Some ( index) = inner. media . iter ( ) . position ( |media| media. key == expected_key) else {
131
+ return Ok ( None ) ;
132
+ } ;
133
+ let Some ( mut content) = inner. media . remove ( index) else {
134
+ return Ok ( None ) ;
135
+ } ;
136
+
137
+ // Clone the data.
138
+ let data = content. data . clone ( ) ;
139
+
140
+ // Update the last access time.
141
+ content. last_access = current_time;
142
+
143
+ // Put it back in the buffer.
144
+ inner. media . push ( content) ;
145
+
146
+ Ok ( Some ( data) )
70
147
}
71
148
72
149
async fn remove_media_content ( & self , request : & MediaRequest ) -> Result < ( ) > {
73
- let mut media = self . media . write ( ) . unwrap ( ) ;
150
+ let mut inner = self . inner . write ( ) . unwrap ( ) ;
151
+
74
152
let expected_key = request. unique_key ( ) ;
75
- let Some ( index) = media
76
- . iter ( )
77
- . position ( |( _media_uri, media_key, _media_content) | media_key == & expected_key)
78
- else {
153
+ let Some ( index) = inner. media . iter ( ) . position ( |media| media. key == expected_key) else {
79
154
return Ok ( ( ) ) ;
80
155
} ;
81
156
82
- media. remove ( index) ;
157
+ inner . media . remove ( index) ;
83
158
84
159
Ok ( ( ) )
85
160
}
86
161
87
162
async fn remove_media_content_for_uri ( & self , uri : & MxcUri ) -> Result < ( ) > {
88
- let mut media = self . media . write ( ) . unwrap ( ) ;
89
- let expected_key = uri . to_owned ( ) ;
90
- let positions = media
163
+ let mut inner = self . inner . write ( ) . unwrap ( ) ;
164
+ let positions = inner
165
+ . media
91
166
. iter ( )
92
167
. enumerate ( )
93
- . filter_map ( |( position, ( media_uri, _media_key, _media_content) ) | {
94
- ( media_uri == & expected_key) . then_some ( position)
95
- } )
168
+ . filter_map ( |( position, media) | ( media. uri == uri) . then_some ( position) )
96
169
. collect :: < Vec < _ > > ( ) ;
97
170
98
171
// Iterate in reverse-order so that positions stay valid after first removals.
99
172
for position in positions. into_iter ( ) . rev ( ) {
100
- media. remove ( position) ;
173
+ inner. media . remove ( position) ;
174
+ }
175
+
176
+ Ok ( ( ) )
177
+ }
178
+
179
+ async fn clean_up_media_cache (
180
+ & self ,
181
+ policy : MediaRetentionPolicy ,
182
+ current_time : SystemTime ,
183
+ ) -> Result < ( ) , Self :: Error > {
184
+ if !policy. has_limitations ( ) {
185
+ // We can safely skip all the checks.
186
+ return Ok ( ( ) ) ;
187
+ }
188
+
189
+ let mut inner = self . inner . write ( ) . unwrap ( ) ;
190
+
191
+ // First, check media content that exceed the max filesize.
192
+ if policy. max_file_size . is_some ( ) || policy. max_cache_size . is_some ( ) {
193
+ inner. media . retain ( |content| !policy. exceeds_max_file_size ( content. data . len ( ) ) ) ;
194
+ }
195
+
196
+ // Then, clean up expired media content.
197
+ if policy. last_access_expiry . is_some ( ) {
198
+ inner
199
+ . media
200
+ . retain ( |content| !policy. has_content_expired ( current_time, content. last_access ) ) ;
201
+ }
202
+
203
+ // Finally, if the cache size is too big, remove old items until it fits.
204
+ if let Some ( max_cache_size) = policy. max_cache_size {
205
+ // Reverse the iterator because in case the cache size is overflowing, we want
206
+ // to count the number of old items to remove, and old items are at
207
+ // the start.
208
+ let ( cache_size, overflowing_count) = inner. media . iter ( ) . rev ( ) . fold (
209
+ ( 0usize , 0u8 ) ,
210
+ |( cache_size, overflowing_count) , content| {
211
+ if overflowing_count > 0 {
212
+ // Assume that all data is overflowing now. Overflowing count cannot
213
+ // overflow because the number of items is limited to 20.
214
+ ( cache_size, overflowing_count + 1 )
215
+ } else {
216
+ match cache_size. checked_add ( content. data . len ( ) ) {
217
+ Some ( cache_size) => ( cache_size, 0 ) ,
218
+ // The cache size is overflowing, let's count the number of overflowing
219
+ // items to be able to remove them, since the max cache size cannot be
220
+ // bigger than usize::MAX.
221
+ None => ( cache_size, 1 ) ,
222
+ }
223
+ }
224
+ } ,
225
+ ) ;
226
+
227
+ // If the cache size is overflowing, remove the number of old items we counted.
228
+ for _position in 0 ..overflowing_count {
229
+ inner. media . pop ( ) ;
230
+ }
231
+
232
+ if cache_size > max_cache_size {
233
+ let difference = cache_size - max_cache_size;
234
+
235
+ // Count the number of old items to remove to reach the difference.
236
+ let mut accumulated_items_size = 0usize ;
237
+ let mut remove_items_count = 0u8 ;
238
+ for content in inner. media . iter ( ) {
239
+ remove_items_count += 1 ;
240
+ // Cannot overflow since we already removed overflowing items.
241
+ accumulated_items_size += content. data . len ( ) ;
242
+
243
+ if accumulated_items_size >= difference {
244
+ break ;
245
+ }
246
+ }
247
+
248
+ for _position in 0 ..remove_items_count {
249
+ inner. media . pop ( ) ;
250
+ }
251
+ }
101
252
}
102
253
103
254
Ok ( ( ) )
@@ -112,5 +263,5 @@ mod tests {
112
263
Ok ( MemoryStore :: new ( ) )
113
264
}
114
265
115
- event_cache_store_integration_tests ! ( ) ;
266
+ event_cache_store_integration_tests ! ( with_media_size_tests ) ;
116
267
}
0 commit comments