Skip to content

Commit 6a6c309

Browse files
authored
RUST-1442 On-demand Azure KMS credentials (#872)
1 parent bf3489d commit 6a6c309

17 files changed

+315
-70
lines changed

.evergreen/config.yml

+14
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ functions:
9191
export TOPOLOGY=${TOPOLOGY}
9292
export MONGODB_VERSION=${MONGODB_VERSION}
9393
94+
export AZURE_IMDS_MOCK_PORT=44175
95+
9496
if [ "Windows_NT" != "$OS" ]; then
9597
ulimit -n 64000
9698
fi
@@ -488,6 +490,16 @@ functions:
488490
export TLS_FEATURE=${TLS_FEATURE}
489491
.evergreen/run-csfle-kmip-servers.sh
490492
493+
"run mock azure imds server":
494+
- command: shell.exec
495+
params:
496+
shell: bash
497+
working_dir: "src"
498+
background: true
499+
script: |
500+
${PREPARE_SHELL}
501+
.evergreen/run-csfle-mock-azure-imds.sh
502+
491503
"build csfle expansions":
492504
- command: shell.exec
493505
params:
@@ -1214,6 +1226,7 @@ tasks:
12141226
- func: "install junit dependencies"
12151227
- func: "bootstrap mongo-orchestration"
12161228
- func: "run kmip server"
1229+
- func: "run mock azure imds server"
12171230
- func: "build csfle expansions"
12181231
- func: "run csfle tests"
12191232

@@ -1229,6 +1242,7 @@ tasks:
12291242
- func: "install junit dependencies"
12301243
- func: "install libmongocrypt"
12311244
- func: "run kmip server"
1245+
- func: "run mock azure imds server"
12321246
- func: "build csfle expansions"
12331247
- func: "run csfle serverless tests"
12341248

.evergreen/feature-combinations.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ export NO_FEATURES=''
55
# async-std-related features that conflict with the library's default features.
66
export ASYNC_STD_FEATURES='--no-default-features --features async-std-runtime,sync'
77
# All additional features that do not conflict with the default features. New features added to the library should also be added to this list.
8-
export ADDITIONAL_FEATURES='--features tokio-sync,zstd-compression,snappy-compression,zlib-compression,openssl-tls,aws-auth,tracing-unstable,in-use-encryption-unstable'
8+
export ADDITIONAL_FEATURES='--features tokio-sync,zstd-compression,snappy-compression,zlib-compression,openssl-tls,aws-auth,tracing-unstable,in-use-encryption-unstable,azure-kms'
99

1010

1111
# Array of feature combinations that, in total, provides complete coverage of the driver.
+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#!/bin/bash
2+
3+
. ${DRIVERS_TOOLS}/.evergreen/find-python3.sh
4+
PYTHON=$(find_python3)
5+
6+
function prepend() { while read line; do echo "${1}${line}"; done; }
7+
8+
cd ${DRIVERS_TOOLS}/.evergreen/csfle
9+
${PYTHON} bottle.py fake_azure:imds -b localhost:${AZURE_IMDS_MOCK_PORT} 2>&1 | prepend "[MOCK AZURE IMDS] "

.evergreen/run-csfle-tests.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ source ./.evergreen/env.sh
77

88
set -o xtrace
99

10-
FEATURE_FLAGS="in-use-encryption-unstable,aws-auth,${TLS_FEATURE}"
10+
FEATURE_FLAGS="in-use-encryption-unstable,aws-auth,azure-kms,${TLS_FEATURE}"
1111
OPTIONS="-- -Z unstable-options --format json --report-time"
1212

1313
if [ "$SINGLE_THREAD" = true ]; then

Cargo.toml

+4
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ bson-uuid-1 = ["bson/uuid-1"]
6565
# This can only be used with the tokio-runtime feature flag.
6666
aws-auth = ["reqwest"]
6767

68+
# Enable support for on-demand Azure KMS credentials.
69+
# This can only be used with the tokio-runtime feature flag.
70+
azure-kms = ["reqwest"]
71+
6872
zstd-compression = ["zstd"]
6973
zlib-compression = ["flate2"]
7074
snappy-compression = ["snap"]

src/client/auth/mod.rs

+9-4
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ use crate::{
2323
client::options::ServerApi,
2424
cmap::{Command, Connection, StreamDescription},
2525
error::{Error, ErrorKind, Result},
26-
runtime::HttpClient,
2726
};
2827

2928
const SCRAM_SHA_1_STR: &str = "SCRAM-SHA-1";
@@ -253,7 +252,7 @@ impl AuthMechanism {
253252
stream: &mut Connection,
254253
credential: &Credential,
255254
server_api: Option<&ServerApi>,
256-
#[cfg_attr(not(feature = "aws-auth"), allow(unused))] http_client: &HttpClient,
255+
#[cfg(feature = "aws-auth")] http_client: &crate::runtime::HttpClient,
257256
) -> Result<()> {
258257
self.validate_credential(credential)?;
259258

@@ -398,9 +397,9 @@ impl Credential {
398397
pub(crate) async fn authenticate_stream(
399398
&self,
400399
conn: &mut Connection,
401-
http_client: &HttpClient,
402400
server_api: Option<&ServerApi>,
403401
first_round: Option<FirstRound>,
402+
#[cfg(feature = "aws-auth")] http_client: &crate::runtime::HttpClient,
404403
) -> Result<()> {
405404
let stream_description = conn.stream_description()?;
406405

@@ -431,7 +430,13 @@ impl Credential {
431430

432431
// Authenticate according to the chosen mechanism.
433432
mechanism
434-
.authenticate_stream(conn, self, server_api, http_client)
433+
.authenticate_stream(
434+
conn,
435+
self,
436+
server_api,
437+
#[cfg(feature = "aws-auth")]
438+
http_client,
439+
)
435440
.await
436441
}
437442

src/client/csfle.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
pub(crate) mod client_builder;
22
pub mod client_encryption;
33
pub mod options;
4-
mod state_machine;
4+
pub(crate) mod state_machine;
55

66
use std::{path::Path, time::Duration};
77

src/client/csfle/state_machine.rs

+153-4
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ pub(crate) struct CryptExecutor {
3535
mongocryptd: Option<Mongocryptd>,
3636
mongocryptd_client: Option<Client>,
3737
metadata_client: Option<WeakClient>,
38+
#[cfg(feature = "azure-kms")]
39+
azure: azure::ExecutorState,
3840
}
3941

4042
impl CryptExecutor {
@@ -56,6 +58,8 @@ impl CryptExecutor {
5658
mongocryptd: None,
5759
mongocryptd_client: None,
5860
metadata_client: None,
61+
#[cfg(feature = "azure-kms")]
62+
azure: azure::ExecutorState::new()?,
5963
})
6064
}
6165

@@ -211,11 +215,10 @@ impl CryptExecutor {
211215
let ctx = result_mut(&mut ctx)?;
212216
#[allow(unused_mut)]
213217
let mut out = rawdoc! {};
214-
if self
215-
.kms_providers
216-
.credentials()
218+
let credentials = self.kms_providers.credentials();
219+
if credentials
217220
.get(&KmsProvider::Aws)
218-
.map_or(false, |d| d.is_empty())
221+
.map_or(false, Document::is_empty)
219222
{
220223
#[cfg(feature = "aws-auth")]
221224
{
@@ -240,6 +243,21 @@ impl CryptExecutor {
240243
));
241244
}
242245
}
246+
if credentials
247+
.get(&KmsProvider::Azure)
248+
.map_or(false, Document::is_empty)
249+
{
250+
#[cfg(feature = "azure-kms")]
251+
{
252+
out.append("azure", self.azure.get_token().await?);
253+
}
254+
#[cfg(not(feature = "azure-kms"))]
255+
{
256+
return Err(Error::invalid_argument(
257+
"On-demand Azure KMS credentials require the `azure-kms` feature.",
258+
));
259+
}
260+
}
243261
ctx.provide_kms_providers(&out)?;
244262
}
245263
State::Ready => {
@@ -346,3 +364,134 @@ fn raw_to_doc(raw: &RawDocument) -> Result<Document> {
346364
raw.try_into()
347365
.map_err(|e| Error::internal(format!("could not parse raw document: {}", e)))
348366
}
367+
368+
#[cfg(feature = "azure-kms")]
369+
pub(crate) mod azure {
370+
use bson::{rawdoc, RawDocumentBuf};
371+
use serde::Deserialize;
372+
use std::time::{Duration, Instant};
373+
use tokio::sync::Mutex;
374+
375+
use crate::{
376+
error::{Error, Result},
377+
runtime::HttpClient,
378+
};
379+
380+
#[derive(Debug)]
381+
pub(crate) struct ExecutorState {
382+
cached_access_token: Mutex<Option<CachedAccessToken>>,
383+
http: HttpClient,
384+
#[cfg(test)]
385+
pub(crate) test_host: Option<(&'static str, u16)>,
386+
#[cfg(test)]
387+
pub(crate) test_param: Option<&'static str>,
388+
}
389+
390+
impl ExecutorState {
391+
pub(crate) fn new() -> Result<Self> {
392+
const AZURE_IMDS_TIMEOUT: Duration = Duration::from_secs(10);
393+
Ok(Self {
394+
cached_access_token: Mutex::new(None),
395+
http: HttpClient::with_timeout(AZURE_IMDS_TIMEOUT)?,
396+
#[cfg(test)]
397+
test_host: None,
398+
#[cfg(test)]
399+
test_param: None,
400+
})
401+
}
402+
403+
pub(crate) async fn get_token(&self) -> Result<RawDocumentBuf> {
404+
let mut cached_token = self.cached_access_token.lock().await;
405+
if let Some(cached) = &*cached_token {
406+
if cached.expire_time.saturating_duration_since(Instant::now())
407+
> Duration::from_secs(60)
408+
{
409+
return Ok(cached.token_doc.clone());
410+
}
411+
}
412+
let token = self.fetch_new_token().await?;
413+
let out = token.token_doc.clone();
414+
*cached_token = Some(token);
415+
Ok(out)
416+
}
417+
418+
async fn fetch_new_token(&self) -> Result<CachedAccessToken> {
419+
let now = Instant::now();
420+
let server_response: ServerResponse = self
421+
.http
422+
.get_and_deserialize_json(self.make_url()?, &self.make_headers())
423+
.await
424+
.map_err(|e| Error::authentication_error("azure imds", &format!("{}", e)))?;
425+
let expires_in_secs: u64 = server_response.expires_in.parse().map_err(|e| {
426+
Error::authentication_error(
427+
"azure imds",
428+
&format!("invalid `expires_in` response field: {}", e),
429+
)
430+
})?;
431+
#[allow(clippy::redundant_clone)]
432+
Ok(CachedAccessToken {
433+
token_doc: rawdoc! { "accessToken": server_response.access_token.clone() },
434+
expire_time: now + Duration::from_secs(expires_in_secs),
435+
#[cfg(test)]
436+
server_response,
437+
})
438+
}
439+
440+
fn make_url(&self) -> Result<reqwest::Url> {
441+
let url = reqwest::Url::parse_with_params(
442+
"http://169.254.169.254/metadata/identity/oauth2/token",
443+
&[
444+
("api-version", "2018-02-01"),
445+
("resource", "https://vault.azure.net"),
446+
],
447+
)
448+
.map_err(|e| Error::internal(format!("invalid Azure IMDS URL: {}", e)))?;
449+
#[cfg(test)]
450+
let url = {
451+
let mut url = url;
452+
if let Some((host, port)) = self.test_host {
453+
url.set_host(Some(host))
454+
.map_err(|e| Error::internal(format!("invalid test host: {}", e)))?;
455+
url.set_port(Some(port))
456+
.map_err(|()| Error::internal(format!("invalid test port {}", port)))?;
457+
}
458+
url
459+
};
460+
Ok(url)
461+
}
462+
463+
fn make_headers(&self) -> Vec<(&'static str, &'static str)> {
464+
let headers = vec![("Metadata", "true"), ("Accept", "application/json")];
465+
#[cfg(test)]
466+
let headers = {
467+
let mut headers = headers;
468+
if let Some(p) = self.test_param {
469+
headers.push(("X-MongoDB-HTTP-TestParams", p));
470+
}
471+
headers
472+
};
473+
headers
474+
}
475+
476+
#[cfg(test)]
477+
pub(crate) async fn take_cached(&self) -> Option<CachedAccessToken> {
478+
self.cached_access_token.lock().await.take()
479+
}
480+
}
481+
482+
#[derive(Debug, Deserialize)]
483+
pub(crate) struct ServerResponse {
484+
pub(crate) access_token: String,
485+
pub(crate) expires_in: String,
486+
#[allow(unused)]
487+
pub(crate) resource: String,
488+
}
489+
490+
#[derive(Debug)]
491+
pub(crate) struct CachedAccessToken {
492+
pub(crate) token_doc: RawDocumentBuf,
493+
pub(crate) expire_time: Instant,
494+
#[cfg(test)]
495+
pub(crate) server_response: ServerResponse,
496+
}
497+
}

src/cmap/establish/handshake/mod.rs

+8-6
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ use crate::{
1313
error::Result,
1414
hello::{hello_command, run_hello, HelloReply},
1515
options::{AuthMechanism, Credential, DriverInfo, ServerApi},
16-
runtime::HttpClient,
1716
};
1817

1918
#[cfg(all(feature = "tokio-runtime", not(feature = "tokio-sync")))]
@@ -323,16 +322,17 @@ pub(crate) struct Handshaker {
323322
#[allow(dead_code)]
324323
compressors: Option<Vec<Compressor>>,
325324

326-
http_client: HttpClient,
327-
328325
server_api: Option<ServerApi>,
329326

330327
metadata: ClientMetadata,
328+
329+
#[cfg(feature = "aws-auth")]
330+
http_client: crate::runtime::HttpClient,
331331
}
332332

333333
impl Handshaker {
334334
/// Creates a new Handshaker.
335-
pub(crate) fn new(http_client: HttpClient, options: HandshakerOptions) -> Self {
335+
pub(crate) fn new(options: HandshakerOptions) -> Self {
336336
let mut metadata = BASE_CLIENT_METADATA.clone();
337337
let compressors = options.compressors;
338338

@@ -383,11 +383,12 @@ impl Handshaker {
383383
command.body.insert("client", metadata.clone());
384384

385385
Self {
386-
http_client,
387386
command,
388387
compressors,
389388
server_api: options.server_api,
390389
metadata,
390+
#[cfg(feature = "aws-auth")]
391+
http_client: crate::runtime::HttpClient::default(),
391392
}
392393
}
393394

@@ -457,9 +458,10 @@ impl Handshaker {
457458
credential
458459
.authenticate_stream(
459460
conn,
460-
&self.http_client,
461461
self.server_api.as_ref(),
462462
first_round,
463+
#[cfg(feature = "aws-auth")]
464+
&self.http_client,
463465
)
464466
.await?
465467
}

0 commit comments

Comments
 (0)