From 0709230ffccc85496d0611dd0fa6d35d83c8a141 Mon Sep 17 00:00:00 2001 From: Eliza Weisman Date: Wed, 7 Apr 2021 10:18:55 -0700 Subject: [PATCH] concurrency-limit: use `tokio-util`'s `PollSemaphore` Currently, our concurrency limit middleware uses `tokio::sync::Semaphore` to implement the concurrency limit. Because acquiring permits from the `Semaphore` type is only possible via the `acquire` and `acquire_owned` methods, which are `async fn`s returning unameable futures that are `!Unpin`, we cannot simply poll the semaphore in `poll_ready`. Instead, to allow polling the concurrency limit service to acquire a permit, we box the `acquire_owned` future in `poll_ready`. This means that every time we try to acquire a new concurrency limit permit, we must allocate a new box. The `tokio-util` crate has a [`PollSemaphore`][1] type which allows wrapping a `tokio::sync::Semaphore` to add a `poll_acquire` method to poll the semaphore to acquire a permit. Unlike our implementation, `PollSemaphore` uses the [`tokio_util::sync::ReusableBoxFuture`][2] type for boxing the `acquire_owned` future returned by the `Semaphore`. `ReusableBoxFuture` is a safe abstraction around unsafe code that allows *one* allocation to be reused for storing multiple type-erased `Box::pin`ned futures. This means creating a single allocation for each clone of a `ConcurrencyLimit` service, rather than having each clone of the concurrency limit service allocate once *every* time a new permit is acquired and drop that allocation once the permit is acquired. This should reduce the overhead of the concurrency limit layer a bit. It also makes the code much simpler! :) [1]: https://docs.rs/tokio-util/0.6.5/tokio_util/sync/struct.PollSemaphore.html [2]: https://docs.rs/tokio-util/0.6.5/tokio_util/sync/struct.ReusableBoxFuture.html Signed-off-by: Eliza Weisman --- Cargo.lock | 5 +- linkerd/concurrency-limit/Cargo.toml | 1 + linkerd/concurrency-limit/src/lib.rs | 81 +++++++++------------------- 3 files changed, 30 insertions(+), 57 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f71b9b74f9..f5c6fa5c6d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -820,6 +820,7 @@ dependencies = [ "linkerd-stack", "pin-project", "tokio", + "tokio-util", "tower", "tracing", ] @@ -2170,9 +2171,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.6.3" +version = "0.6.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ebb7cb2f00c5ae8df755b252306272cd1790d39728363936e01827e11f0b017b" +checksum = "5143d049e85af7fbc36f5454d990e62c2df705b3589f123b71f441b6b59f443f" dependencies = [ "bytes", "futures-core", diff --git a/linkerd/concurrency-limit/Cargo.toml b/linkerd/concurrency-limit/Cargo.toml index f2ce70ad5b..2f43b1c101 100644 --- a/linkerd/concurrency-limit/Cargo.toml +++ b/linkerd/concurrency-limit/Cargo.toml @@ -11,6 +11,7 @@ publish = false futures = "0.3.9" linkerd-stack = { path = "../stack" } tokio = { version = "1", features = ["sync"] } +tokio-util = "0.6.5" tower = { version = "0.4.5", default-features = false } tracing = "0.1.23" pin-project = "1" diff --git a/linkerd/concurrency-limit/src/lib.rs b/linkerd/concurrency-limit/src/lib.rs index 0c13fc963a..bdd50822ec 100644 --- a/linkerd/concurrency-limit/src/lib.rs +++ b/linkerd/concurrency-limit/src/lib.rs @@ -9,14 +9,13 @@ use linkerd_stack::layer; use pin_project::pin_project; use std::{ - fmt, future::Future, - mem, pin::Pin, sync::Arc, task::{Context, Poll}, }; -use tokio::sync::{OwnedSemaphorePermit, Semaphore}; +use tokio::sync::{OwnedSemaphorePermit as Permit, Semaphore}; +use tokio_util::sync::PollSemaphore; use tower::Service; use tracing::trace; @@ -24,14 +23,8 @@ use tracing::trace; #[derive(Debug)] pub struct ConcurrencyLimit { inner: T, - semaphore: Arc, - state: State, -} - -enum State { - Waiting(Pin + Send + Sync + 'static>>), - Ready(OwnedSemaphorePermit), - Empty, + semaphore: PollSemaphore, + permit: Option, } /// Future for the `ConcurrencyLimit` service. @@ -41,7 +34,7 @@ pub struct ResponseFuture { #[pin] inner: T, // The permit is held until the future becomes ready. - permit: Option, + permit: Option, } impl ConcurrencyLimit { @@ -54,8 +47,8 @@ impl ConcurrencyLimit { fn new(inner: S, semaphore: Arc) -> Self { ConcurrencyLimit { inner, - semaphore, - state: State::Empty, + semaphore: PollSemaphore::new(semaphore), + permit: None, } } } @@ -69,39 +62,30 @@ where type Future = ResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - trace!(available = %self.semaphore.available_permits(), "acquiring permit"); + trace!( + // available = %self.semaphore.available_permits(), + "acquiring permit" + ); loop { - self.state = match self.state { - State::Ready(_) => { - trace!(available = %self.semaphore.available_permits(), "permit acquired"); - return self.inner.poll_ready(cx); - } - State::Waiting(ref mut fut) => { - tokio::pin!(fut); - let permit = futures::ready!(fut.poll(cx)); - State::Ready(permit) - } - State::Empty => { - let semaphore = self.semaphore.clone(); - State::Waiting(Box::pin(async move { - semaphore - .acquire_owned() - .await - .expect("Semaphore cannot close") - })) - } - }; + if self.permit.is_some() { + trace!("permit already acquired; polling service"); + return self.inner.poll_ready(cx); + } + + let permit = + futures::ready!(self.semaphore.poll_acquire(cx)).expect("Semaphore must not close"); + self.permit = Some(permit); + trace!("permit acquired"); } } fn call(&mut self, request: Request) -> Self::Future { // Make sure a permit has been acquired - let permit = match mem::replace(&mut self.state, State::Empty) { - // Take the permit. - State::Ready(permit) => Some(permit), - // whoopsie! - _ => panic!("max requests in-flight; poll_ready must be called first"), - }; + let permit = self.permit.take(); + assert!( + permit.is_some(), + "max requests in-flight; poll_ready must be called first" + ); // Call the inner service let inner = self.inner.call(request); @@ -118,7 +102,7 @@ where ConcurrencyLimit { inner: self.inner.clone(), semaphore: self.semaphore.clone(), - state: State::Empty, + permit: None, } } } @@ -141,16 +125,3 @@ where Poll::Ready(res) } } - -impl fmt::Debug for State { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - State::Waiting(_) => f - .debug_tuple("State::Waiting") - .field(&format_args!("...")) - .finish(), - State::Ready(ref r) => f.debug_tuple("State::Ready").field(&r).finish(), - State::Empty => f.debug_tuple("State::Empty").finish(), - } - } -}