Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RUST-1222 Cancel in-progress operations when SDAM heartbeats time out #1249

Merged
merged 6 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/client/auth/aws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ async fn authenticate_stream_inner(
);
let client_first = sasl_start.into_command();

let server_first_response = conn.send_command(client_first, None).await?;
let server_first_response = conn.send_message(client_first).await?;

let server_first = ServerFirst::parse(server_first_response.auth_response_body(MECH_NAME)?)?;
server_first.validate(&nonce)?;
Expand Down Expand Up @@ -135,7 +135,7 @@ async fn authenticate_stream_inner(

let client_second = sasl_continue.into_command();

let server_second_response = conn.send_command(client_second, None).await?;
let server_second_response = conn.send_message(client_second).await?;
let server_second = SaslResponse::parse(
MECH_NAME,
server_second_response.auth_response_body(MECH_NAME)?,
Expand Down
2 changes: 1 addition & 1 deletion src/client/auth/oidc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,7 @@ async fn send_sasl_command(
conn: &mut Connection,
command: crate::cmap::Command,
) -> Result<SaslResponse> {
let response = conn.send_command(command, None).await?;
let response = conn.send_message(command).await?;
SaslResponse::parse(
MONGODB_OIDC_STR,
response.auth_response_body(MONGODB_OIDC_STR)?,
Expand Down
2 changes: 1 addition & 1 deletion src/client/auth/plain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub(crate) async fn authenticate_stream(
)
.into_command();

let response = conn.send_command(sasl_start, None).await?;
let response = conn.send_message(sasl_start).await?;
let sasl_response = SaslResponse::parse("PLAIN", response.auth_response_body("PLAIN")?)?;

if !sasl_response.done {
Expand Down
6 changes: 3 additions & 3 deletions src/client/auth/scram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ impl ScramVersion {

let command = client_first.to_command(self);

let server_first = conn.send_command(command, None).await?;
let server_first = conn.send_message(command).await?;

Ok(FirstRound {
client_first,
Expand Down Expand Up @@ -215,7 +215,7 @@ impl ScramVersion {

let command = client_final.to_command();

let server_final_response = conn.send_command(command, None).await?;
let server_final_response = conn.send_message(command).await?;
let server_final = ServerFinal::parse(server_final_response.auth_response_body("SCRAM")?)?;
server_final.validate(salted_password.as_slice(), &client_final, self)?;

Expand All @@ -231,7 +231,7 @@ impl ScramVersion {
);
let command = noop.into_command();

let server_noop_response = conn.send_command(command, None).await?;
let server_noop_response = conn.send_message(command).await?;
let server_noop_response_document: Document =
server_noop_response.auth_response_body("SCRAM")?;

Expand Down
2 changes: 1 addition & 1 deletion src/client/auth/x509.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ pub(crate) async fn send_client_first(
) -> Result<RawCommandResponse> {
let command = build_client_first(credential, server_api);

conn.send_command(command, None).await
conn.send_message(command).await
}

/// Performs X.509 authentication for a given stream.
Expand Down
7 changes: 3 additions & 4 deletions src/client/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -614,13 +614,12 @@ impl Client {
}

let should_redact = cmd.should_redact();
let should_compress = cmd.should_compress();

let cmd_name = cmd.name.clone();
let target_db = cmd.target_db.clone();

#[allow(unused_mut)]
let mut message = Message::from_command(cmd, Some(request_id))?;
let mut message = Message::try_from(cmd)?;
message.request_id = Some(request_id);
#[cfg(feature = "in-use-encryption")]
{
let guard = self.inner.csfle.read().await;
Expand Down Expand Up @@ -652,7 +651,7 @@ impl Client {
.await;

let start_time = Instant::now();
let command_result = match connection.send_message(message, should_compress).await {
let command_result = match connection.send_message(message).await {
Ok(response) => {
async fn handle_response<T: Operation>(
client: &Client,
Expand Down
60 changes: 40 additions & 20 deletions src/cmap/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@ use derive_where::derive_where;
use serde::Serialize;
use tokio::{
io::BufStream,
sync::{mpsc, Mutex},
sync::{
broadcast::{self, error::RecvError},
mpsc,
Mutex,
},
};

use self::wire::{Message, MessageFlags};
Expand Down Expand Up @@ -171,12 +175,42 @@ impl Connection {
self.error.is_some()
}

pub(crate) async fn send_message_with_cancellation(
&mut self,
message: impl TryInto<Message, Error = impl Into<Error>>,
cancellation_receiver: &mut broadcast::Receiver<()>,
) -> Result<RawCommandResponse> {
tokio::select! {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: should this have a biased; clause to make error behavior deterministic?

// A lagged error indicates that more heartbeats failed than the channel's capacity
// between checking out this connection and executing the operation. If this occurs,
// then proceed with cancelling the operation. RecvError::Closed can be ignored, as
// the sender (and by extension the connection pool) dropping does not indicate that
// the operation should be cancelled.
Comment on lines +186 to +190
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The lagged scenario I outlined here will probably never actually happen. The lifetime of a receiver is intentionally very short so that only relevant messages are received. The following would need to occur:

  • Connection is checked out for an operation, leading to a new receiver to be constructed. This new receiver will only receive messages sent after its construction.
  • Execution path proceeds with the rest of the steps between checkout and actually sending the message, which is primarily building the command. In the meantime...
  • SDAM heartbeat times out, leading to a pool clear and a message to be stored in the channel.
  • Another SDAM heartbeat times out after waiting the full heartbeat interval and another cancellation message is sent out.
  • Slow command construction finishes and send_message_with_cancellation is called; recv below immediately returns a lagged error because the receiver has two unseen messages from the two heartbeat timeouts which exceeds the channel's capacity. In this case we still want to proceed with cancellation.

These receivers are kind of acting like oneshots in that they're created fresh for each checked-out connection and only call recv once (i.e. on the below line), so the important thing here is to determine whether something was sent during their lifetime.

Ok(_) | Err(RecvError::Lagged(_)) = cancellation_receiver.recv() => {
let error: Error = ErrorKind::ConnectionPoolCleared {
message: format!(
"Connection to {} interrupted due to server monitor timeout",
self.address,
)
}.into();
self.error = Some(error.clone());
Err(error)
}
// This future is not cancellation safe because it contains calls to methods that are
// not cancellation safe (e.g. AsyncReadExt::read_exact). However, in the case that
// this future is cancelled because a cancellation message was received, this
// connection will be closed upon being returned to the pool, so any data loss on its
// underlying stream is not an issue.
result = self.send_message(message) => result,
}
}

pub(crate) async fn send_message(
&mut self,
message: Message,
// This value is only read if a compression feature flag is enabled.
#[allow(unused_variables)] can_compress: bool,
message: impl TryInto<Message, Error = impl Into<Error>>,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small refactor here to avoid needing to add an equivalent send_command_with_cancellation for pending connections

) -> Result<RawCommandResponse> {
let message = message.try_into().map_err(Into::into)?;

if self.more_to_come {
return Err(Error::internal(format!(
"attempted to send a new message to {} but moreToCome bit was set",
Expand All @@ -192,7 +226,7 @@ impl Connection {
feature = "snappy-compression"
))]
let write_result = match self.compressor {
Some(ref compressor) if can_compress => {
Some(ref compressor) if message.should_compress => {
message
.write_op_compressed_to(&mut self.stream, compressor)
.await
Expand Down Expand Up @@ -232,21 +266,6 @@ impl Connection {
))
}

/// Executes a `Command` and returns a `CommandResponse` containing the result from the server.
///
/// An `Ok(...)` result simply means the server received the command and that the driver
/// driver received the response; it does not imply anything about the success of the command
/// itself.
pub(crate) async fn send_command(
&mut self,
command: Command,
request_id: impl Into<Option<i32>>,
) -> Result<RawCommandResponse> {
let to_compress = command.should_compress();
let message = Message::from_command(command, request_id.into())?;
self.send_message(message, to_compress).await
}

/// Receive the next message from the connection.
/// This will return an error if the previous response on this connection did not include the
/// moreToCome flag.
Expand Down Expand Up @@ -378,6 +397,7 @@ pub(crate) struct PendingConnection {
pub(crate) generation: PoolGeneration,
pub(crate) event_emitter: CmapEventEmitter,
pub(crate) time_created: Instant,
pub(crate) cancellation_receiver: Option<broadcast::Receiver<()>>,
}

impl PendingConnection {
Expand Down
65 changes: 51 additions & 14 deletions src/cmap/conn/pooled.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,18 @@ use std::{
};

use derive_where::derive_where;
use tokio::sync::{mpsc, Mutex};
use tokio::sync::{broadcast, mpsc, Mutex};

use super::{
CmapEventEmitter,
Connection,
ConnectionGeneration,
ConnectionInfo,
Message,
PendingConnection,
PinnedConnectionHandle,
PoolManager,
RawCommandResponse,
};
use crate::{
bson::oid::ObjectId,
Expand Down Expand Up @@ -50,7 +52,7 @@ pub(crate) struct PooledConnection {
}

/// The state of a pooled connection.
#[derive(Clone, Debug)]
#[derive(Debug)]
enum PooledConnectionState {
/// The state associated with a connection checked into the connection pool.
CheckedIn { available_time: Instant },
Expand All @@ -59,6 +61,10 @@ enum PooledConnectionState {
CheckedOut {
/// The manager used to check this connection back into the pool.
pool_manager: PoolManager,

/// The receiver to receive a cancellation notice. Only present on non-load-balanced
/// connections.
cancellation_receiver: Option<broadcast::Receiver<()>>,
},

/// The state associated with a pinned connection.
Expand Down Expand Up @@ -140,6 +146,24 @@ impl PooledConnection {
.and_then(|sd| sd.service_id)
}

/// Sends a message on this connection.
pub(crate) async fn send_message(
&mut self,
message: impl TryInto<Message, Error = impl Into<Error>>,
) -> Result<RawCommandResponse> {
match self.state {
PooledConnectionState::CheckedOut {
cancellation_receiver: Some(ref mut cancellation_receiver),
..
} => {
self.connection
.send_message_with_cancellation(message, cancellation_receiver)
.await
}
_ => self.connection.send_message(message).await,
}
}

/// Updates the state of the connection to indicate that it is checked into the pool.
pub(crate) fn mark_checked_in(&mut self) {
if !matches!(self.state, PooledConnectionState::CheckedIn { .. }) {
Expand All @@ -155,8 +179,15 @@ impl PooledConnection {
}

/// Updates the state of the connection to indicate that it is checked out of the pool.
pub(crate) fn mark_checked_out(&mut self, pool_manager: PoolManager) {
self.state = PooledConnectionState::CheckedOut { pool_manager };
pub(crate) fn mark_checked_out(
&mut self,
pool_manager: PoolManager,
cancellation_receiver: Option<broadcast::Receiver<()>>,
) {
self.state = PooledConnectionState::CheckedOut {
pool_manager,
cancellation_receiver,
};
}

/// Whether this connection is idle.
Expand All @@ -175,15 +206,14 @@ impl PooledConnection {
Instant::now().duration_since(available_time) >= max_idle_time
}

/// Nullifies the internal state of this connection and returns it in a new [PooledConnection].
/// If a state is provided, then the new connection will contain that state; otherwise, this
/// connection's state will be cloned.
fn take(&mut self, state: impl Into<Option<PooledConnectionState>>) -> Self {
/// Nullifies the internal state of this connection and returns it in a new [PooledConnection]
/// with the given state.
fn take(&mut self, new_state: PooledConnectionState) -> Self {
Self {
connection: self.connection.take(),
generation: self.generation,
event_emitter: self.event_emitter.clone(),
state: state.into().unwrap_or_else(|| self.state.clone()),
state: new_state,
}
}

Expand All @@ -196,7 +226,9 @@ impl PooledConnection {
self.id
)))
}
PooledConnectionState::CheckedOut { ref pool_manager } => {
PooledConnectionState::CheckedOut {
ref pool_manager, ..
} => {
let (tx, rx) = mpsc::channel(1);
self.state = PooledConnectionState::Pinned {
// Mark the connection as in-use while the operation currently using the
Expand Down Expand Up @@ -286,10 +318,11 @@ impl Drop for PooledConnection {
// Nothing needs to be done when a checked-in connection is dropped.
PooledConnectionState::CheckedIn { .. } => Ok(()),
// A checked-out connection should be sent back to the connection pool.
PooledConnectionState::CheckedOut { pool_manager } => {
PooledConnectionState::CheckedOut { pool_manager, .. } => {
let pool_manager = pool_manager.clone();
let mut dropped_connection = self.take(None);
dropped_connection.mark_checked_in();
let dropped_connection = self.take(PooledConnectionState::CheckedIn {
available_time: Instant::now(),
});
pool_manager.check_in(dropped_connection)
}
// A pinned connection should be returned to its pinner or to the connection pool.
Expand Down Expand Up @@ -339,7 +372,11 @@ impl Drop for PooledConnection {
}
// The pinner of this connection has been dropped while the connection was
// sitting in its channel, so the connection should be returned to the pool.
PinnedState::Returned { .. } => pool_manager.check_in(self.take(None)),
PinnedState::Returned { .. } => {
pool_manager.check_in(self.take(PooledConnectionState::CheckedIn {
available_time: Instant::now(),
}))
}
}
}
};
Expand Down
Loading