4
4
5
5
using System ;
6
6
using System . Collections . Concurrent ;
7
- using System . Security ;
7
+ using System . Linq ;
8
+ using System . Runtime . Caching ;
9
+ using System . Security . Cryptography ;
10
+ using System . Text ;
8
11
using System . Threading ;
9
12
using System . Threading . Tasks ;
10
13
using Azure . Core ;
@@ -24,6 +27,8 @@ public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationPro
24
27
/// </summary>
25
28
private static ConcurrentDictionary < PublicClientAppKey , IPublicClientApplication > s_pcaMap
26
29
= new ConcurrentDictionary < PublicClientAppKey , IPublicClientApplication > ( ) ;
30
+ private static readonly MemoryCache s_accountPwCache = new ( nameof ( ActiveDirectoryAuthenticationProvider ) ) ;
31
+ private static readonly int s_accountPwCacheTtlInHours = 2 ;
27
32
private static readonly string s_nativeClientRedirectUri = "https://login.microsoftonline.com/common/oauth2/nativeclient" ;
28
33
private static readonly string s_defaultScopeSuffix = "/.default" ;
29
34
private readonly string _type = typeof ( ActiveDirectoryAuthenticationProvider ) . Name ;
@@ -172,7 +177,7 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
172
177
return new SqlAuthenticationToken ( accessToken . Token , accessToken . ExpiresOn ) ;
173
178
}
174
179
175
- AuthenticationResult result ;
180
+ AuthenticationResult result = null ;
176
181
if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryServicePrincipal )
177
182
{
178
183
AccessToken accessToken = await new ClientSecretCredential ( audience , parameters . UserId , parameters . Password , tokenCredentialOptions ) . GetTokenAsync ( tokenRequestContext , cts . Token ) . ConfigureAwait ( false ) ;
@@ -208,82 +213,82 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
208
213
209
214
if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryIntegrated )
210
215
{
211
- if ( ! string . IsNullOrEmpty ( parameters . UserId ) )
212
- {
213
- result = await app . AcquireTokenByIntegratedWindowsAuth ( scopes )
214
- . WithCorrelationId ( parameters . ConnectionId )
215
- . WithUsername ( parameters . UserId )
216
- . ExecuteAsync ( cancellationToken : cts . Token )
217
- . ConfigureAwait ( false ) ;
218
- }
219
- else
220
- {
221
- result = await app . AcquireTokenByIntegratedWindowsAuth ( scopes )
222
- . WithCorrelationId ( parameters . ConnectionId )
223
- . ExecuteAsync ( cancellationToken : cts . Token )
224
- . ConfigureAwait ( false ) ;
225
- }
226
- SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}" , result ? . ExpiresOn ) ;
227
- }
228
- else if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryPassword )
229
- {
230
- result = await app . AcquireTokenByUsernamePassword ( scopes , parameters . UserId , parameters . Password )
231
- . WithCorrelationId ( parameters . ConnectionId )
232
- . ExecuteAsync ( cancellationToken : cts . Token )
233
- . ConfigureAwait ( false ) ;
234
-
235
- SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}" , result ? . ExpiresOn ) ;
236
- }
237
- else if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryInteractive ||
238
- parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryDeviceCodeFlow )
239
- {
240
- // Fetch available accounts from 'app' instance
241
- System . Collections . Generic . IEnumerator < IAccount > accounts = ( await app . GetAccountsAsync ( ) . ConfigureAwait ( false ) ) . GetEnumerator ( ) ;
216
+ result = await TryAcquireTokenSilent ( app , parameters , scopes , cts ) . ConfigureAwait ( false ) ;
242
217
243
- IAccount account = default ;
244
- if ( accounts . MoveNext ( ) )
218
+ if ( null == result )
245
219
{
246
220
if ( ! string . IsNullOrEmpty ( parameters . UserId ) )
247
221
{
248
- do
249
- {
250
- IAccount currentVal = accounts . Current ;
251
- if ( string . Compare ( parameters . UserId , currentVal . Username , StringComparison . InvariantCultureIgnoreCase ) == 0 )
252
- {
253
- account = currentVal ;
254
- break ;
255
- }
256
- }
257
- while ( accounts . MoveNext ( ) ) ;
222
+ result = await app . AcquireTokenByIntegratedWindowsAuth ( scopes )
223
+ . WithCorrelationId ( parameters . ConnectionId )
224
+ . WithUsername ( parameters . UserId )
225
+ . ExecuteAsync ( cancellationToken : cts . Token )
226
+ . ConfigureAwait ( false ) ;
258
227
}
259
228
else
260
229
{
261
- account = accounts . Current ;
230
+ result = await app . AcquireTokenByIntegratedWindowsAuth ( scopes )
231
+ . WithCorrelationId ( parameters . ConnectionId )
232
+ . ExecuteAsync ( cancellationToken : cts . Token )
233
+ . ConfigureAwait ( false ) ;
262
234
}
235
+ SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}" , result ? . ExpiresOn ) ;
236
+ }
237
+ }
238
+ else if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryPassword )
239
+ {
240
+ string pwCacheKey = GetAccountPwCacheKey ( parameters ) ;
241
+ object previousPw = s_accountPwCache . Get ( pwCacheKey ) ;
242
+ byte [ ] currPwHash = GetHash ( parameters . Password ) ;
243
+
244
+ if ( null != previousPw &&
245
+ previousPw is byte [ ] previousPwBytes &&
246
+ // Only get the cached token if the current password hash matches the previously used password hash
247
+ currPwHash . SequenceEqual ( previousPwBytes ) )
248
+ {
249
+ result = await TryAcquireTokenSilent ( app , parameters , scopes , cts ) . ConfigureAwait ( false ) ;
263
250
}
264
251
265
- if ( null != account )
252
+ if ( null == result )
266
253
{
267
- try
268
- {
269
- // If 'account' is available in 'app', we use the same to acquire token silently.
270
- // Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent
271
- result = await app . AcquireTokenSilent ( scopes , account ) . ExecuteAsync ( cancellationToken : cts . Token ) . ConfigureAwait ( false ) ;
272
- SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}" , parameters . AuthenticationMethod , result ? . ExpiresOn ) ;
273
- }
274
- catch ( MsalUiRequiredException )
254
+ result = await app . AcquireTokenByUsernamePassword ( scopes , parameters . UserId , parameters . Password )
255
+ . WithCorrelationId ( parameters . ConnectionId )
256
+ . ExecuteAsync ( cancellationToken : cts . Token )
257
+ . ConfigureAwait ( false ) ;
258
+
259
+ // We cache the password hash to ensure future connection requests include a validated password
260
+ // when we check for a cached MSAL account. Otherwise, a connection request with the same username
261
+ // against the same tenant could succeed with an invalid password when we re-use the cached token.
262
+ if ( ! s_accountPwCache . Add ( pwCacheKey , GetHash ( parameters . Password ) , DateTime . UtcNow . AddHours ( s_accountPwCacheTtlInHours ) ) )
275
263
{
276
- // An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application,
277
- // for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired),
278
- // or the user needs to perform two factor authentication.
279
- result = await AcquireTokenInteractiveDeviceFlowAsync ( app , scopes , parameters . ConnectionId , parameters . UserId , parameters . AuthenticationMethod , cts ) . ConfigureAwait ( false ) ;
280
- SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}" , parameters . AuthenticationMethod , result ? . ExpiresOn ) ;
264
+ s_accountPwCache . Remove ( pwCacheKey ) ;
265
+ s_accountPwCache . Add ( pwCacheKey , GetHash ( parameters . Password ) , DateTime . UtcNow . AddHours ( s_accountPwCacheTtlInHours ) ) ;
281
266
}
267
+
268
+ SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}" , result ? . ExpiresOn ) ;
282
269
}
283
- else
270
+ }
271
+ else if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryInteractive ||
272
+ parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryDeviceCodeFlow )
273
+ {
274
+ try
275
+ {
276
+ result = await TryAcquireTokenSilent ( app , parameters , scopes , cts ) . ConfigureAwait ( false ) ;
277
+ SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}" , parameters . AuthenticationMethod , result ? . ExpiresOn ) ;
278
+ }
279
+ catch ( MsalUiRequiredException )
280
+ {
281
+ // An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application,
282
+ // for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired),
283
+ // or the user needs to perform two factor authentication.
284
+ result = await AcquireTokenInteractiveDeviceFlowAsync ( app , scopes , parameters . ConnectionId , parameters . UserId , parameters . AuthenticationMethod , cts , _customWebUI , _deviceCodeFlowCallback ) . ConfigureAwait ( false ) ;
285
+ SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}" , parameters . AuthenticationMethod , result ? . ExpiresOn ) ;
286
+ }
287
+
288
+ if ( null == result )
284
289
{
285
290
// If no existing 'account' is found, we request user to sign in interactively.
286
- result = await AcquireTokenInteractiveDeviceFlowAsync ( app , scopes , parameters . ConnectionId , parameters . UserId , parameters . AuthenticationMethod , cts ) . ConfigureAwait ( false ) ;
291
+ result = await AcquireTokenInteractiveDeviceFlowAsync ( app , scopes , parameters . ConnectionId , parameters . UserId , parameters . AuthenticationMethod , cts , _customWebUI , _deviceCodeFlowCallback ) . ConfigureAwait ( false ) ;
287
292
SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}" , parameters . AuthenticationMethod , result ? . ExpiresOn ) ;
288
293
}
289
294
}
@@ -296,8 +301,49 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
296
301
return new SqlAuthenticationToken ( result . AccessToken , result . ExpiresOn ) ;
297
302
}
298
303
299
- private async Task < AuthenticationResult > AcquireTokenInteractiveDeviceFlowAsync ( IPublicClientApplication app , string [ ] scopes , Guid connectionId , string userId ,
300
- SqlAuthenticationMethod authenticationMethod , CancellationTokenSource cts )
304
+ private static async Task < AuthenticationResult > TryAcquireTokenSilent ( IPublicClientApplication app , SqlAuthenticationParameters parameters ,
305
+ string [ ] scopes , CancellationTokenSource cts )
306
+ {
307
+ AuthenticationResult result = null ;
308
+
309
+ // Fetch available accounts from 'app' instance
310
+ System . Collections . Generic . IEnumerator < IAccount > accounts = ( await app . GetAccountsAsync ( ) . ConfigureAwait ( false ) ) . GetEnumerator ( ) ;
311
+
312
+ IAccount account = default ;
313
+ if ( accounts . MoveNext ( ) )
314
+ {
315
+ if ( ! string . IsNullOrEmpty ( parameters . UserId ) )
316
+ {
317
+ do
318
+ {
319
+ IAccount currentVal = accounts . Current ;
320
+ if ( string . Compare ( parameters . UserId , currentVal . Username , StringComparison . InvariantCultureIgnoreCase ) == 0 )
321
+ {
322
+ account = currentVal ;
323
+ break ;
324
+ }
325
+ }
326
+ while ( accounts . MoveNext ( ) ) ;
327
+ }
328
+ else
329
+ {
330
+ account = accounts . Current ;
331
+ }
332
+ }
333
+
334
+ if ( null != account )
335
+ {
336
+ // If 'account' is available in 'app', we use the same to acquire token silently.
337
+ // Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent
338
+ result = await app . AcquireTokenSilent ( scopes , account ) . ExecuteAsync ( cancellationToken : cts . Token ) . ConfigureAwait ( false ) ;
339
+ SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}" , parameters . AuthenticationMethod , result ? . ExpiresOn ) ;
340
+ }
341
+
342
+ return result ;
343
+ }
344
+
345
+ private static async Task < AuthenticationResult > AcquireTokenInteractiveDeviceFlowAsync ( IPublicClientApplication app , string [ ] scopes , Guid connectionId , string userId ,
346
+ SqlAuthenticationMethod authenticationMethod , CancellationTokenSource cts , ICustomWebUi customWebUI , Func < DeviceCodeResult , Task > deviceCodeFlowCallback )
301
347
{
302
348
try
303
349
{
@@ -316,11 +362,11 @@ private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(
316
362
*/
317
363
ctsInteractive . CancelAfter ( 180000 ) ;
318
364
#endif
319
- if ( _customWebUI != null )
365
+ if ( customWebUI != null )
320
366
{
321
367
return await app . AcquireTokenInteractive ( scopes )
322
368
. WithCorrelationId ( connectionId )
323
- . WithCustomWebUi ( _customWebUI )
369
+ . WithCustomWebUi ( customWebUI )
324
370
. WithLoginHint ( userId )
325
371
. ExecuteAsync ( ctsInteractive . Token )
326
372
. ConfigureAwait ( false ) ;
@@ -354,7 +400,7 @@ private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(
354
400
else
355
401
{
356
402
AuthenticationResult result = await app . AcquireTokenWithDeviceCode ( scopes ,
357
- deviceCodeResult => _deviceCodeFlowCallback ( deviceCodeResult ) )
403
+ deviceCodeResult => deviceCodeFlowCallback ( deviceCodeResult ) )
358
404
. WithCorrelationId ( connectionId )
359
405
. ExecuteAsync ( cancellationToken : cts . Token )
360
406
. ConfigureAwait ( false ) ;
@@ -407,6 +453,19 @@ private IPublicClientApplication GetPublicClientAppInstance(PublicClientAppKey p
407
453
return clientApplicationInstance ;
408
454
}
409
455
456
+ private static string GetAccountPwCacheKey ( SqlAuthenticationParameters parameters )
457
+ {
458
+ return parameters . Authority + "+" + parameters . UserId ;
459
+ }
460
+
461
+ private static byte [ ] GetHash ( string input )
462
+ {
463
+ byte [ ] unhashedBytes = Encoding . Unicode . GetBytes ( input ) ;
464
+ SHA256 sha256 = SHA256 . Create ( ) ;
465
+ byte [ ] hashedBytes = sha256 . ComputeHash ( unhashedBytes ) ;
466
+ return hashedBytes ;
467
+ }
468
+
410
469
private IPublicClientApplication CreateClientAppInstance ( PublicClientAppKey publicClientAppKey )
411
470
{
412
471
IPublicClientApplication publicClientApplication ;
0 commit comments