Skip to content

Commit 38ba60b

Browse files
authored
RUST-1384 Implement auto encryption (#717)
1 parent 9314247 commit 38ba60b

File tree

32 files changed

+618
-125
lines changed

32 files changed

+618
-125
lines changed

.evergreen/feature-combinations.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,5 @@
66
export FEATURE_COMBINATIONS=(
77
'' # default features
88
'--no-default-features --features async-std-runtime,sync' # features that conflict w/ default features
9-
'--features tokio-sync,zstd-compression,snappy-compression,zlib-compression,openssl-tls,aws-auth' # additive features
9+
'--features tokio-sync,zstd-compression,snappy-compression,zlib-compression,openssl-tls,aws-auth,csfle' # additive features
1010
)

Cargo.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ zlib-compression = ["flate2"]
6666
snappy-compression = ["snap"]
6767

6868
# DO NOT USE; see https://jira.mongodb.org/browse/RUST-569 for the status of CSFLE support in the Rust driver.
69-
csfle = ["mongocrypt", "which"]
69+
csfle = ["mongocrypt", "which", "rayon"]
7070

7171
[dependencies]
7272
async-trait = "0.1.42"
@@ -89,6 +89,7 @@ openssl-probe = { version = "0.1.5", optional = true }
8989
os_info = { version = "3.0.1", default-features = false }
9090
percent-encoding = "2.0.0"
9191
rand = { version = "0.8.3", features = ["small_rng"] }
92+
rayon = { version = "1.5.3", optional = true }
9293
rustc_version_runtime = "0.2.1"
9394
rustls-pemfile = "0.3.0"
9495
serde_with = "1.3.1"

src/client/csfle.rs

+12-7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
pub mod options;
2+
mod state_machine;
23

34
use std::{
45
path::Path,
@@ -7,6 +8,7 @@ use std::{
78

89
use derivative::Derivative;
910
use mongocrypt::Crypt;
11+
use rayon::ThreadPool;
1012

1113
use crate::{
1214
error::{Error, Result},
@@ -30,34 +32,37 @@ use super::WeakClient;
3032
#[derivative(Debug)]
3133
pub(super) struct ClientState {
3234
#[derivative(Debug = "ignore")]
33-
#[allow(dead_code)]
34-
crypt: Crypt,
35+
pub(crate) crypt: Crypt,
3536
mongocryptd_client: Option<Client>,
3637
aux_clients: AuxClients,
3738
opts: AutoEncryptionOptions,
39+
crypto_threads: ThreadPool,
3840
}
3941

4042
#[derive(Debug)]
4143
struct AuxClients {
42-
#[allow(dead_code)]
4344
key_vault_client: WeakClient,
44-
#[allow(dead_code)]
4545
metadata_client: Option<WeakClient>,
46-
#[allow(dead_code)]
47-
internal_client: Option<Client>,
46+
_internal_client: Option<Client>,
4847
}
4948

5049
impl ClientState {
5150
pub(super) async fn new(client: &Client, opts: AutoEncryptionOptions) -> Result<Self> {
5251
let crypt = Self::make_crypt(&opts)?;
5352
let mongocryptd_client = Self::spawn_mongocryptd_if_needed(&opts, &crypt).await?;
5453
let aux_clients = Self::make_aux_clients(client, &opts)?;
54+
let num_cpus = std::thread::available_parallelism()?.get();
55+
let crypto_threads = rayon::ThreadPoolBuilder::new()
56+
.num_threads(num_cpus)
57+
.build()
58+
.map_err(|e| Error::internal(format!("could not initialize thread pool: {}", e)))?;
5559

5660
Ok(Self {
5761
crypt,
5862
mongocryptd_client,
5963
aux_clients,
6064
opts,
65+
crypto_threads,
6166
})
6267
}
6368

@@ -171,7 +176,7 @@ impl ClientState {
171176
Ok(AuxClients {
172177
key_vault_client,
173178
metadata_client,
174-
internal_client,
179+
_internal_client: internal_client,
175180
})
176181
}
177182
}

src/client/csfle/state_machine.rs

+171
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
use std::convert::TryInto;
2+
3+
use bson::{Document, RawDocument, RawDocumentBuf};
4+
use futures_util::{stream, TryStreamExt};
5+
use mongocrypt::ctx::{Ctx, State};
6+
use tokio::{
7+
io::{AsyncReadExt, AsyncWriteExt},
8+
sync::oneshot,
9+
};
10+
11+
use crate::{
12+
client::options::ServerAddress,
13+
cmap::options::StreamOptions,
14+
error::{Error, Result},
15+
operation::{RawOutput, RunCommand},
16+
runtime::AsyncStream,
17+
Client,
18+
};
19+
20+
impl Client {
21+
pub(crate) async fn run_mongocrypt_ctx(
22+
&self,
23+
ctx: Ctx,
24+
db: Option<&str>,
25+
) -> Result<RawDocumentBuf> {
26+
let guard = self.inner.csfle.read().await;
27+
let csfle = match guard.as_ref() {
28+
Some(csfle) => csfle,
29+
None => return Err(Error::internal("no csfle state for mongocrypt ctx")),
30+
};
31+
let mut result = None;
32+
// This needs to be a `Result` so that the `Ctx` can be temporarily owned by the processing
33+
// thread for crypto finalization. An `Option` would also work here, but `Result` means we
34+
// can return a helpful error if things get into a broken state rather than panicing.
35+
let mut ctx = Ok(ctx);
36+
loop {
37+
let state = result_ref(&ctx)?.state()?;
38+
match state {
39+
State::NeedMongoCollinfo => {
40+
let ctx = result_mut(&mut ctx)?;
41+
let filter = raw_to_doc(ctx.mongo_op()?)?;
42+
let metadata_client = csfle
43+
.aux_clients
44+
.metadata_client
45+
.as_ref()
46+
.and_then(|w| w.upgrade())
47+
.ok_or_else(|| {
48+
Error::internal("metadata_client required for NeedMongoCollinfo state")
49+
})?;
50+
let db = metadata_client.database(db.as_ref().ok_or_else(|| {
51+
Error::internal("db required for NeedMongoCollinfo state")
52+
})?);
53+
let mut cursor = db.list_collections(filter, None).await?;
54+
if cursor.advance().await? {
55+
ctx.mongo_feed(cursor.current())?;
56+
}
57+
ctx.mongo_done()?;
58+
}
59+
State::NeedMongoMarkings => {
60+
let ctx = result_mut(&mut ctx)?;
61+
let command = ctx.mongo_op()?.to_raw_document_buf();
62+
let db = db.as_ref().ok_or_else(|| {
63+
Error::internal("db required for NeedMongoMarkings state")
64+
})?;
65+
let op = RawOutput(RunCommand::new_raw(db.to_string(), command, None, None)?);
66+
let mongocryptd_client = csfle
67+
.mongocryptd_client
68+
.as_ref()
69+
.ok_or_else(|| Error::internal("mongocryptd client not found"))?;
70+
let response = mongocryptd_client.execute_operation(op, None).await?;
71+
ctx.mongo_feed(response.raw_body())?;
72+
ctx.mongo_done()?;
73+
}
74+
State::NeedMongoKeys => {
75+
let ctx = result_mut(&mut ctx)?;
76+
let filter = raw_to_doc(ctx.mongo_op()?)?;
77+
let kv_ns = &csfle.opts.key_vault_namespace;
78+
let kv_client = csfle
79+
.aux_clients
80+
.key_vault_client
81+
.upgrade()
82+
.ok_or_else(|| Error::internal("key vault client dropped"))?;
83+
let kv_coll = kv_client
84+
.database(&kv_ns.db)
85+
.collection::<RawDocumentBuf>(&kv_ns.coll);
86+
let mut cursor = kv_coll.find(filter, None).await?;
87+
while cursor.advance().await? {
88+
ctx.mongo_feed(cursor.current())?;
89+
}
90+
ctx.mongo_done()?;
91+
}
92+
State::NeedKms => {
93+
let ctx = result_mut(&mut ctx)?;
94+
let scope = ctx.kms_scope();
95+
let mut kms_ctxen: Vec<Result<_>> = vec![];
96+
while let Some(kms_ctx) = scope.next_kms_ctx() {
97+
kms_ctxen.push(Ok(kms_ctx));
98+
}
99+
stream::iter(kms_ctxen)
100+
.try_for_each_concurrent(None, |mut kms_ctx| async move {
101+
let endpoint = kms_ctx.endpoint()?;
102+
let addr = ServerAddress::parse(endpoint)?;
103+
let provider = kms_ctx.kms_provider()?;
104+
let tls_options = csfle
105+
.opts()
106+
.tls_options
107+
.as_ref()
108+
.and_then(|tls| tls.get(&provider))
109+
.cloned()
110+
.unwrap_or_default();
111+
let mut stream = AsyncStream::connect(
112+
StreamOptions::builder()
113+
.address(addr)
114+
.tls_options(tls_options)
115+
.build(),
116+
)
117+
.await?;
118+
stream.write_all(kms_ctx.message()?).await?;
119+
let mut buf = vec![];
120+
while kms_ctx.bytes_needed() > 0 {
121+
let buf_size = kms_ctx.bytes_needed().try_into().map_err(|e| {
122+
Error::internal(format!("buffer size overflow: {}", e))
123+
})?;
124+
buf.resize(buf_size, 0);
125+
let count = stream.read(&mut buf).await?;
126+
kms_ctx.feed(&buf[0..count])?;
127+
}
128+
Ok(())
129+
})
130+
.await?;
131+
}
132+
State::NeedKmsCredentials => todo!("RUST-1314"),
133+
State::Ready => {
134+
let (tx, rx) = oneshot::channel();
135+
let mut thread_ctx = std::mem::replace(
136+
&mut ctx,
137+
Err(Error::internal("crypto context not present")),
138+
)?;
139+
csfle.crypto_threads.spawn(move || {
140+
let result = thread_ctx.finalize().map(|doc| doc.to_owned());
141+
let _ = tx.send((thread_ctx, result));
142+
});
143+
let (ctx_again, output) = rx
144+
.await
145+
.map_err(|_| Error::internal("crypto thread dropped"))?;
146+
ctx = Ok(ctx_again);
147+
result = Some(output?);
148+
}
149+
State::Done => break,
150+
s => return Err(Error::internal(format!("unhandled state {:?}", s))),
151+
}
152+
}
153+
match result {
154+
Some(doc) => Ok(doc),
155+
None => Err(Error::internal("libmongocrypt terminated without output")),
156+
}
157+
}
158+
}
159+
160+
fn result_ref<T>(r: &Result<T>) -> Result<&T> {
161+
r.as_ref().map_err(Error::clone)
162+
}
163+
164+
fn result_mut<T>(r: &mut Result<T>) -> Result<&mut T> {
165+
r.as_mut().map_err(|e| e.clone())
166+
}
167+
168+
fn raw_to_doc(raw: &RawDocument) -> Result<Document> {
169+
raw.try_into()
170+
.map_err(|e| Error::internal(format!("could not parse raw document: {}", e)))
171+
}

src/client/executor.rs

+62
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1+
#[cfg(feature = "csfle")]
2+
use bson::RawDocumentBuf;
13
use bson::{doc, RawBsonRef, RawDocument, Timestamp};
4+
#[cfg(feature = "csfle")]
5+
use futures_core::future::BoxFuture;
26
use lazy_static::lazy_static;
37
use serde::de::DeserializeOwned;
48

@@ -598,6 +602,21 @@ impl Client {
598602
let target_db = cmd.target_db.clone();
599603

600604
let serialized = op.serialize_command(cmd)?;
605+
#[cfg(feature = "csfle")]
606+
let serialized = {
607+
let guard = self.inner.csfle.read().await;
608+
if let Some(ref csfle) = *guard {
609+
if csfle.opts().bypass_auto_encryption != Some(true) {
610+
self.auto_encrypt(csfle, RawDocument::from_bytes(&serialized)?, &target_db)
611+
.await?
612+
.into_bytes()
613+
} else {
614+
serialized
615+
}
616+
} else {
617+
serialized
618+
}
619+
};
601620
let raw_cmd = RawCommand {
602621
name: cmd_name.clone(),
603622
target_db,
@@ -750,6 +769,21 @@ impl Client {
750769
handler.handle_command_succeeded_event(command_succeeded_event);
751770
});
752771

772+
#[cfg(feature = "csfle")]
773+
let response = {
774+
let guard = self.inner.csfle.read().await;
775+
if let Some(ref csfle) = *guard {
776+
if csfle.opts().bypass_auto_encryption != Some(true) {
777+
let new_body = self.auto_decrypt(csfle, response.raw_body()).await?;
778+
RawCommandResponse::new_raw(response.source, new_body)
779+
} else {
780+
response
781+
}
782+
} else {
783+
response
784+
}
785+
};
786+
753787
match op.handle_response(response, connection.stream_description()?) {
754788
Ok(response) => Ok(response),
755789
Err(mut err) => {
@@ -765,6 +799,34 @@ impl Client {
765799
}
766800
}
767801

802+
#[cfg(feature = "csfle")]
803+
fn auto_encrypt<'a>(
804+
&'a self,
805+
csfle: &'a super::csfle::ClientState,
806+
command: &'a RawDocument,
807+
target_db: &'a str,
808+
) -> BoxFuture<'a, Result<RawDocumentBuf>> {
809+
Box::pin(async move {
810+
let ctx = csfle
811+
.crypt
812+
.ctx_builder()
813+
.build_encrypt(target_db, command)?;
814+
self.run_mongocrypt_ctx(ctx, Some(target_db)).await
815+
})
816+
}
817+
818+
#[cfg(feature = "csfle")]
819+
fn auto_decrypt<'a>(
820+
&'a self,
821+
csfle: &'a super::csfle::ClientState,
822+
response: &'a RawDocument,
823+
) -> BoxFuture<'a, Result<RawDocumentBuf>> {
824+
Box::pin(async move {
825+
let ctx = csfle.crypt.ctx_builder().build_decrypt(response)?;
826+
self.run_mongocrypt_ctx(ctx, None).await
827+
})
828+
}
829+
768830
/// Start an implicit session if the operation and write concern are compatible with sessions.
769831
async fn start_implicit_session<T: Operation>(&self, op: &T) -> Result<Option<ClientSession>> {
770832
match self.get_session_support_status().await? {

src/cmap/conn/command.rs

+5-4
Original file line numberDiff line numberDiff line change
@@ -203,10 +203,11 @@ impl RawCommandResponse {
203203

204204
pub(crate) fn new(source: ServerAddress, message: Message) -> Result<Self> {
205205
let raw = message.single_document_response()?;
206-
Ok(Self {
207-
source,
208-
raw: RawDocumentBuf::from_bytes(raw)?,
209-
})
206+
Ok(Self::new_raw(source, RawDocumentBuf::from_bytes(raw)?))
207+
}
208+
209+
pub(crate) fn new_raw(source: ServerAddress, raw: RawDocumentBuf) -> Self {
210+
Self { source, raw }
210211
}
211212

212213
pub(crate) fn body<'a, T: Deserialize<'a>>(&'a self) -> Result<T> {

src/coll/mod.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -1189,6 +1189,10 @@ where
11891189
}
11901190

11911191
let ordered = options.as_ref().and_then(|o| o.ordered).unwrap_or(true);
1192+
#[cfg(feature = "csfle")]
1193+
let encrypted = self.client().auto_encryption_opts().await.is_some();
1194+
#[cfg(not(feature = "csfle"))]
1195+
let encrypted = false;
11921196

11931197
let mut cumulative_failure: Option<BulkWriteFailure> = None;
11941198
let mut error_labels: HashSet<String> = Default::default();
@@ -1198,7 +1202,7 @@ where
11981202

11991203
while n_attempted < ds.len() {
12001204
let docs: Vec<&T> = ds.iter().skip(n_attempted).map(Borrow::borrow).collect();
1201-
let insert = Insert::new(self.namespace(), docs, options.clone());
1205+
let insert = Insert::new_encrypted(self.namespace(), docs, options.clone(), encrypted);
12021206

12031207
match self
12041208
.client()

0 commit comments

Comments
 (0)