@@ -28,6 +28,7 @@ use crate::util::WakeList;
28
28
29
29
use crate :: loom:: sync:: atomic:: AtomicU64 ;
30
30
use std:: fmt;
31
+ use std:: sync:: RwLock ;
31
32
use std:: { num:: NonZeroU64 , ptr:: NonNull } ;
32
33
33
34
struct AtomicOptionNonZeroU64 ( AtomicU64 ) ;
@@ -115,7 +116,10 @@ struct Inner {
115
116
next_wake : AtomicOptionNonZeroU64 ,
116
117
117
118
/// Sharded Timer wheels.
118
- wheels : Box < [ Mutex < wheel:: Wheel > ] > ,
119
+ wheels : RwLock < ShardedWheel > ,
120
+
121
+ /// Number of entries in the sharded timer wheels.
122
+ wheels_len : u32 ,
119
123
120
124
/// True if the driver is being shutdown.
121
125
pub ( super ) is_shutdown : AtomicBool ,
@@ -130,6 +134,9 @@ struct Inner {
130
134
did_wake : AtomicBool ,
131
135
}
132
136
137
+ /// Wrapper around the sharded timer wheels.
138
+ struct ShardedWheel ( Box < [ Mutex < wheel:: Wheel > ] > ) ;
139
+
133
140
// ===== impl Driver =====
134
141
135
142
impl Driver {
@@ -149,7 +156,8 @@ impl Driver {
149
156
time_source,
150
157
inner : Inner {
151
158
next_wake : AtomicOptionNonZeroU64 :: new ( None ) ,
152
- wheels : wheels. into_boxed_slice ( ) ,
159
+ wheels : RwLock :: new ( ShardedWheel ( wheels. into_boxed_slice ( ) ) ) ,
160
+ wheels_len : shards,
153
161
is_shutdown : AtomicBool :: new ( false ) ,
154
162
#[ cfg( feature = "test-util" ) ]
155
163
did_wake : AtomicBool :: new ( false ) ,
@@ -190,23 +198,27 @@ impl Driver {
190
198
assert ! ( !handle. is_shutdown( ) ) ;
191
199
192
200
// Finds out the min expiration time to park.
193
- let locks = ( 0 ..rt_handle. time ( ) . inner . get_shard_size ( ) )
194
- . map ( |id| rt_handle. time ( ) . inner . lock_sharded_wheel ( id) )
195
- . collect :: < Vec < _ > > ( ) ;
196
-
197
- let expiration_time = locks
198
- . iter ( )
199
- . filter_map ( |lock| lock. next_expiration_time ( ) )
200
- . min ( ) ;
201
-
202
- rt_handle
203
- . time ( )
204
- . inner
205
- . next_wake
206
- . store ( next_wake_time ( expiration_time) ) ;
207
-
208
- // Safety: After updating the `next_wake`, we drop all the locks.
209
- drop ( locks) ;
201
+ let expiration_time = {
202
+ let mut wheels_lock = rt_handle
203
+ . time ( )
204
+ . inner
205
+ . wheels
206
+ . write ( )
207
+ . expect ( "Timer wheel shards poisoned" ) ;
208
+ let expiration_time = wheels_lock
209
+ . 0
210
+ . iter_mut ( )
211
+ . filter_map ( |wheel| wheel. get_mut ( ) . next_expiration_time ( ) )
212
+ . min ( ) ;
213
+
214
+ rt_handle
215
+ . time ( )
216
+ . inner
217
+ . next_wake
218
+ . store ( next_wake_time ( expiration_time) ) ;
219
+
220
+ expiration_time
221
+ } ;
210
222
211
223
match expiration_time {
212
224
Some ( when) => {
@@ -312,7 +324,12 @@ impl Handle {
312
324
// Returns the next wakeup time of this shard.
313
325
pub ( self ) fn process_at_sharded_time ( & self , id : u32 , mut now : u64 ) -> Option < u64 > {
314
326
let mut waker_list = WakeList :: new ( ) ;
315
- let mut lock = self . inner . lock_sharded_wheel ( id) ;
327
+ let mut wheels_lock = self
328
+ . inner
329
+ . wheels
330
+ . read ( )
331
+ . expect ( "Timer wheel shards poisoned" ) ;
332
+ let mut lock = wheels_lock. lock_sharded_wheel ( id) ;
316
333
317
334
if now < lock. elapsed ( ) {
318
335
// Time went backwards! This normally shouldn't happen as the Rust language
@@ -334,15 +351,22 @@ impl Handle {
334
351
if !waker_list. can_push ( ) {
335
352
// Wake a batch of wakers. To avoid deadlock, we must do this with the lock temporarily dropped.
336
353
drop ( lock) ;
354
+ drop ( wheels_lock) ;
337
355
338
356
waker_list. wake_all ( ) ;
339
357
340
- lock = self . inner . lock_sharded_wheel ( id) ;
358
+ wheels_lock = self
359
+ . inner
360
+ . wheels
361
+ . read ( )
362
+ . expect ( "Timer wheel shards poisoned" ) ;
363
+ lock = wheels_lock. lock_sharded_wheel ( id) ;
341
364
}
342
365
}
343
366
}
344
367
let next_wake_up = lock. poll_at ( ) ;
345
368
drop ( lock) ;
369
+ drop ( wheels_lock) ;
346
370
347
371
waker_list. wake_all ( ) ;
348
372
next_wake_up
@@ -360,7 +384,12 @@ impl Handle {
360
384
/// `add_entry` must not be called concurrently.
361
385
pub ( self ) unsafe fn clear_entry ( & self , entry : NonNull < TimerShared > ) {
362
386
unsafe {
363
- let mut lock = self . inner . lock_sharded_wheel ( entry. as_ref ( ) . shard_id ( ) ) ;
387
+ let wheels_lock = self
388
+ . inner
389
+ . wheels
390
+ . read ( )
391
+ . expect ( "Timer wheel shards poisoned" ) ;
392
+ let mut lock = wheels_lock. lock_sharded_wheel ( entry. as_ref ( ) . shard_id ( ) ) ;
364
393
365
394
if entry. as_ref ( ) . might_be_registered ( ) {
366
395
lock. remove ( entry) ;
@@ -383,7 +412,13 @@ impl Handle {
383
412
entry : NonNull < TimerShared > ,
384
413
) {
385
414
let waker = unsafe {
386
- let mut lock = self . inner . lock_sharded_wheel ( entry. as_ref ( ) . shard_id ( ) ) ;
415
+ let wheels_lock = self
416
+ . inner
417
+ . wheels
418
+ . read ( )
419
+ . expect ( "Timer wheel shards poisoned" ) ;
420
+
421
+ let mut lock = wheels_lock. lock_sharded_wheel ( entry. as_ref ( ) . shard_id ( ) ) ;
387
422
388
423
// We may have raced with a firing/deregistration, so check before
389
424
// deregistering.
@@ -443,24 +478,14 @@ impl Handle {
443
478
// ===== impl Inner =====
444
479
445
480
impl Inner {
446
- /// Locks the driver's sharded wheel structure.
447
- pub ( super ) fn lock_sharded_wheel (
448
- & self ,
449
- shard_id : u32 ,
450
- ) -> crate :: loom:: sync:: MutexGuard < ' _ , Wheel > {
451
- let index = shard_id % ( self . wheels . len ( ) as u32 ) ;
452
- // Safety: This modulo operation ensures that the index is not out of bounds.
453
- unsafe { self . wheels . get_unchecked ( index as usize ) . lock ( ) }
454
- }
455
-
456
481
// Check whether the driver has been shutdown
457
482
pub ( super ) fn is_shutdown ( & self ) -> bool {
458
483
self . is_shutdown . load ( Ordering :: SeqCst )
459
484
}
460
485
461
486
// Gets the number of shards.
462
487
fn get_shard_size ( & self ) -> u32 {
463
- self . wheels . len ( ) as u32
488
+ self . wheels_len
464
489
}
465
490
}
466
491
@@ -470,5 +495,19 @@ impl fmt::Debug for Inner {
470
495
}
471
496
}
472
497
498
+ // ===== impl ShardedWheel =====
499
+
500
+ impl ShardedWheel {
501
+ /// Locks the driver's sharded wheel structure.
502
+ pub ( super ) fn lock_sharded_wheel (
503
+ & self ,
504
+ shard_id : u32 ,
505
+ ) -> crate :: loom:: sync:: MutexGuard < ' _ , Wheel > {
506
+ let index = shard_id % ( self . 0 . len ( ) as u32 ) ;
507
+ // Safety: This modulo operation ensures that the index is not out of bounds.
508
+ unsafe { self . 0 . get_unchecked ( index as usize ) } . lock ( )
509
+ }
510
+ }
511
+
473
512
#[ cfg( test) ]
474
513
mod tests;
0 commit comments