diff --git a/Cargo.lock b/Cargo.lock index f71b9b74f9..f5c6fa5c6d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -820,6 +820,7 @@ dependencies = [ "linkerd-stack", "pin-project", "tokio", + "tokio-util", "tower", "tracing", ] @@ -2170,9 +2171,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.6.3" +version = "0.6.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ebb7cb2f00c5ae8df755b252306272cd1790d39728363936e01827e11f0b017b" +checksum = "5143d049e85af7fbc36f5454d990e62c2df705b3589f123b71f441b6b59f443f" dependencies = [ "bytes", "futures-core", diff --git a/linkerd/concurrency-limit/Cargo.toml b/linkerd/concurrency-limit/Cargo.toml index f2ce70ad5b..2f43b1c101 100644 --- a/linkerd/concurrency-limit/Cargo.toml +++ b/linkerd/concurrency-limit/Cargo.toml @@ -11,6 +11,7 @@ publish = false futures = "0.3.9" linkerd-stack = { path = "../stack" } tokio = { version = "1", features = ["sync"] } +tokio-util = "0.6.5" tower = { version = "0.4.5", default-features = false } tracing = "0.1.23" pin-project = "1" diff --git a/linkerd/concurrency-limit/src/lib.rs b/linkerd/concurrency-limit/src/lib.rs index 0c13fc963a..bdd50822ec 100644 --- a/linkerd/concurrency-limit/src/lib.rs +++ b/linkerd/concurrency-limit/src/lib.rs @@ -9,14 +9,13 @@ use linkerd_stack::layer; use pin_project::pin_project; use std::{ - fmt, future::Future, - mem, pin::Pin, sync::Arc, task::{Context, Poll}, }; -use tokio::sync::{OwnedSemaphorePermit, Semaphore}; +use tokio::sync::{OwnedSemaphorePermit as Permit, Semaphore}; +use tokio_util::sync::PollSemaphore; use tower::Service; use tracing::trace; @@ -24,14 +23,8 @@ use tracing::trace; #[derive(Debug)] pub struct ConcurrencyLimit { inner: T, - semaphore: Arc, - state: State, -} - -enum State { - Waiting(Pin + Send + Sync + 'static>>), - Ready(OwnedSemaphorePermit), - Empty, + semaphore: PollSemaphore, + permit: Option, } /// Future for the `ConcurrencyLimit` service. @@ -41,7 +34,7 @@ pub struct ResponseFuture { #[pin] inner: T, // The permit is held until the future becomes ready. - permit: Option, + permit: Option, } impl ConcurrencyLimit { @@ -54,8 +47,8 @@ impl ConcurrencyLimit { fn new(inner: S, semaphore: Arc) -> Self { ConcurrencyLimit { inner, - semaphore, - state: State::Empty, + semaphore: PollSemaphore::new(semaphore), + permit: None, } } } @@ -69,39 +62,30 @@ where type Future = ResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - trace!(available = %self.semaphore.available_permits(), "acquiring permit"); + trace!( + // available = %self.semaphore.available_permits(), + "acquiring permit" + ); loop { - self.state = match self.state { - State::Ready(_) => { - trace!(available = %self.semaphore.available_permits(), "permit acquired"); - return self.inner.poll_ready(cx); - } - State::Waiting(ref mut fut) => { - tokio::pin!(fut); - let permit = futures::ready!(fut.poll(cx)); - State::Ready(permit) - } - State::Empty => { - let semaphore = self.semaphore.clone(); - State::Waiting(Box::pin(async move { - semaphore - .acquire_owned() - .await - .expect("Semaphore cannot close") - })) - } - }; + if self.permit.is_some() { + trace!("permit already acquired; polling service"); + return self.inner.poll_ready(cx); + } + + let permit = + futures::ready!(self.semaphore.poll_acquire(cx)).expect("Semaphore must not close"); + self.permit = Some(permit); + trace!("permit acquired"); } } fn call(&mut self, request: Request) -> Self::Future { // Make sure a permit has been acquired - let permit = match mem::replace(&mut self.state, State::Empty) { - // Take the permit. - State::Ready(permit) => Some(permit), - // whoopsie! - _ => panic!("max requests in-flight; poll_ready must be called first"), - }; + let permit = self.permit.take(); + assert!( + permit.is_some(), + "max requests in-flight; poll_ready must be called first" + ); // Call the inner service let inner = self.inner.call(request); @@ -118,7 +102,7 @@ where ConcurrencyLimit { inner: self.inner.clone(), semaphore: self.semaphore.clone(), - state: State::Empty, + permit: None, } } } @@ -141,16 +125,3 @@ where Poll::Ready(res) } } - -impl fmt::Debug for State { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - State::Waiting(_) => f - .debug_tuple("State::Waiting") - .field(&format_args!("...")) - .finish(), - State::Ready(ref r) => f.debug_tuple("State::Ready").field(&r).finish(), - State::Empty => f.debug_tuple("State::Empty").finish(), - } - } -}