@@ -35,6 +35,8 @@ pub(crate) struct CryptExecutor {
35
35
mongocryptd : Option < Mongocryptd > ,
36
36
mongocryptd_client : Option < Client > ,
37
37
metadata_client : Option < WeakClient > ,
38
+ #[ cfg( feature = "azure-kms" ) ]
39
+ azure : azure:: ExecutorState ,
38
40
}
39
41
40
42
impl CryptExecutor {
@@ -56,6 +58,8 @@ impl CryptExecutor {
56
58
mongocryptd : None ,
57
59
mongocryptd_client : None ,
58
60
metadata_client : None ,
61
+ #[ cfg( feature = "azure-kms" ) ]
62
+ azure : azure:: ExecutorState :: new ( ) ?,
59
63
} )
60
64
}
61
65
@@ -211,11 +215,10 @@ impl CryptExecutor {
211
215
let ctx = result_mut ( & mut ctx) ?;
212
216
#[ allow( unused_mut) ]
213
217
let mut out = rawdoc ! { } ;
214
- if self
215
- . kms_providers
216
- . credentials ( )
218
+ let credentials = self . kms_providers . credentials ( ) ;
219
+ if credentials
217
220
. get ( & KmsProvider :: Aws )
218
- . map_or ( false , |d| d . is_empty ( ) )
221
+ . map_or ( false , Document :: is_empty)
219
222
{
220
223
#[ cfg( feature = "aws-auth" ) ]
221
224
{
@@ -240,6 +243,21 @@ impl CryptExecutor {
240
243
) ) ;
241
244
}
242
245
}
246
+ if credentials
247
+ . get ( & KmsProvider :: Azure )
248
+ . map_or ( false , Document :: is_empty)
249
+ {
250
+ #[ cfg( feature = "azure-kms" ) ]
251
+ {
252
+ out. append ( "azure" , self . azure . get_token ( ) . await ?) ;
253
+ }
254
+ #[ cfg( not( feature = "azure-kms" ) ) ]
255
+ {
256
+ return Err ( Error :: invalid_argument (
257
+ "On-demand Azure KMS credentials require the `azure-kms` feature." ,
258
+ ) ) ;
259
+ }
260
+ }
243
261
ctx. provide_kms_providers ( & out) ?;
244
262
}
245
263
State :: Ready => {
@@ -346,3 +364,134 @@ fn raw_to_doc(raw: &RawDocument) -> Result<Document> {
346
364
raw. try_into ( )
347
365
. map_err ( |e| Error :: internal ( format ! ( "could not parse raw document: {}" , e) ) )
348
366
}
367
+
368
+ #[ cfg( feature = "azure-kms" ) ]
369
+ pub ( crate ) mod azure {
370
+ use bson:: { rawdoc, RawDocumentBuf } ;
371
+ use serde:: Deserialize ;
372
+ use std:: time:: { Duration , Instant } ;
373
+ use tokio:: sync:: Mutex ;
374
+
375
+ use crate :: {
376
+ error:: { Error , Result } ,
377
+ runtime:: HttpClient ,
378
+ } ;
379
+
380
+ #[ derive( Debug ) ]
381
+ pub ( crate ) struct ExecutorState {
382
+ cached_access_token : Mutex < Option < CachedAccessToken > > ,
383
+ http : HttpClient ,
384
+ #[ cfg( test) ]
385
+ pub ( crate ) test_host : Option < ( & ' static str , u16 ) > ,
386
+ #[ cfg( test) ]
387
+ pub ( crate ) test_param : Option < & ' static str > ,
388
+ }
389
+
390
+ impl ExecutorState {
391
+ pub ( crate ) fn new ( ) -> Result < Self > {
392
+ const AZURE_IMDS_TIMEOUT : Duration = Duration :: from_secs ( 10 ) ;
393
+ Ok ( Self {
394
+ cached_access_token : Mutex :: new ( None ) ,
395
+ http : HttpClient :: with_timeout ( AZURE_IMDS_TIMEOUT ) ?,
396
+ #[ cfg( test) ]
397
+ test_host : None ,
398
+ #[ cfg( test) ]
399
+ test_param : None ,
400
+ } )
401
+ }
402
+
403
+ pub ( crate ) async fn get_token ( & self ) -> Result < RawDocumentBuf > {
404
+ let mut cached_token = self . cached_access_token . lock ( ) . await ;
405
+ if let Some ( cached) = & * cached_token {
406
+ if cached. expire_time . saturating_duration_since ( Instant :: now ( ) )
407
+ > Duration :: from_secs ( 60 )
408
+ {
409
+ return Ok ( cached. token_doc . clone ( ) ) ;
410
+ }
411
+ }
412
+ let token = self . fetch_new_token ( ) . await ?;
413
+ let out = token. token_doc . clone ( ) ;
414
+ * cached_token = Some ( token) ;
415
+ Ok ( out)
416
+ }
417
+
418
+ async fn fetch_new_token ( & self ) -> Result < CachedAccessToken > {
419
+ let now = Instant :: now ( ) ;
420
+ let server_response: ServerResponse = self
421
+ . http
422
+ . get_and_deserialize_json ( self . make_url ( ) ?, & self . make_headers ( ) )
423
+ . await
424
+ . map_err ( |e| Error :: authentication_error ( "azure imds" , & format ! ( "{}" , e) ) ) ?;
425
+ let expires_in_secs: u64 = server_response. expires_in . parse ( ) . map_err ( |e| {
426
+ Error :: authentication_error (
427
+ "azure imds" ,
428
+ & format ! ( "invalid `expires_in` response field: {}" , e) ,
429
+ )
430
+ } ) ?;
431
+ #[ allow( clippy:: redundant_clone) ]
432
+ Ok ( CachedAccessToken {
433
+ token_doc : rawdoc ! { "accessToken" : server_response. access_token. clone( ) } ,
434
+ expire_time : now + Duration :: from_secs ( expires_in_secs) ,
435
+ #[ cfg( test) ]
436
+ server_response,
437
+ } )
438
+ }
439
+
440
+ fn make_url ( & self ) -> Result < reqwest:: Url > {
441
+ let url = reqwest:: Url :: parse_with_params (
442
+ "http://169.254.169.254/metadata/identity/oauth2/token" ,
443
+ & [
444
+ ( "api-version" , "2018-02-01" ) ,
445
+ ( "resource" , "https://vault.azure.net" ) ,
446
+ ] ,
447
+ )
448
+ . map_err ( |e| Error :: internal ( format ! ( "invalid Azure IMDS URL: {}" , e) ) ) ?;
449
+ #[ cfg( test) ]
450
+ let url = {
451
+ let mut url = url;
452
+ if let Some ( ( host, port) ) = self . test_host {
453
+ url. set_host ( Some ( host) )
454
+ . map_err ( |e| Error :: internal ( format ! ( "invalid test host: {}" , e) ) ) ?;
455
+ url. set_port ( Some ( port) )
456
+ . map_err ( |( ) | Error :: internal ( format ! ( "invalid test port {}" , port) ) ) ?;
457
+ }
458
+ url
459
+ } ;
460
+ Ok ( url)
461
+ }
462
+
463
+ fn make_headers ( & self ) -> Vec < ( & ' static str , & ' static str ) > {
464
+ let headers = vec ! [ ( "Metadata" , "true" ) , ( "Accept" , "application/json" ) ] ;
465
+ #[ cfg( test) ]
466
+ let headers = {
467
+ let mut headers = headers;
468
+ if let Some ( p) = self . test_param {
469
+ headers. push ( ( "X-MongoDB-HTTP-TestParams" , p) ) ;
470
+ }
471
+ headers
472
+ } ;
473
+ headers
474
+ }
475
+
476
+ #[ cfg( test) ]
477
+ pub ( crate ) async fn take_cached ( & self ) -> Option < CachedAccessToken > {
478
+ self . cached_access_token . lock ( ) . await . take ( )
479
+ }
480
+ }
481
+
482
+ #[ derive( Debug , Deserialize ) ]
483
+ pub ( crate ) struct ServerResponse {
484
+ pub ( crate ) access_token : String ,
485
+ pub ( crate ) expires_in : String ,
486
+ #[ allow( unused) ]
487
+ pub ( crate ) resource : String ,
488
+ }
489
+
490
+ #[ derive( Debug ) ]
491
+ pub ( crate ) struct CachedAccessToken {
492
+ pub ( crate ) token_doc : RawDocumentBuf ,
493
+ pub ( crate ) expire_time : Instant ,
494
+ #[ cfg( test) ]
495
+ pub ( crate ) server_response : ServerResponse ,
496
+ }
497
+ }
0 commit comments