Skip to content

Commit f70b9c6

Browse files
authored
[5.1.1] Fix | Throttling of token requests by calling AcquireTokenSilent (#1966)
1 parent daa1a74 commit f70b9c6

File tree

1 file changed

+126
-67
lines changed

1 file changed

+126
-67
lines changed

src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs

+126-67
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44

55
using System;
66
using System.Collections.Concurrent;
7-
using System.Security;
7+
using System.Linq;
8+
using System.Runtime.Caching;
9+
using System.Security.Cryptography;
10+
using System.Text;
811
using System.Threading;
912
using System.Threading.Tasks;
1013
using Azure.Core;
@@ -24,6 +27,8 @@ public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationPro
2427
/// </summary>
2528
private static ConcurrentDictionary<PublicClientAppKey, IPublicClientApplication> s_pcaMap
2629
= new ConcurrentDictionary<PublicClientAppKey, IPublicClientApplication>();
30+
private static readonly MemoryCache s_accountPwCache = new(nameof(ActiveDirectoryAuthenticationProvider));
31+
private static readonly int s_accountPwCacheTtlInHours = 2;
2732
private static readonly string s_nativeClientRedirectUri = "https://login.microsoftonline.com/common/oauth2/nativeclient";
2833
private static readonly string s_defaultScopeSuffix = "/.default";
2934
private readonly string _type = typeof(ActiveDirectoryAuthenticationProvider).Name;
@@ -172,7 +177,7 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
172177
return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn);
173178
}
174179

175-
AuthenticationResult result;
180+
AuthenticationResult result = null;
176181
if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryServicePrincipal)
177182
{
178183
AccessToken accessToken = await new ClientSecretCredential(audience, parameters.UserId, parameters.Password, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false);
@@ -208,82 +213,82 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
208213

209214
if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryIntegrated)
210215
{
211-
if (!string.IsNullOrEmpty(parameters.UserId))
212-
{
213-
result = await app.AcquireTokenByIntegratedWindowsAuth(scopes)
214-
.WithCorrelationId(parameters.ConnectionId)
215-
.WithUsername(parameters.UserId)
216-
.ExecuteAsync(cancellationToken: cts.Token)
217-
.ConfigureAwait(false);
218-
}
219-
else
220-
{
221-
result = await app.AcquireTokenByIntegratedWindowsAuth(scopes)
222-
.WithCorrelationId(parameters.ConnectionId)
223-
.ExecuteAsync(cancellationToken: cts.Token)
224-
.ConfigureAwait(false);
225-
}
226-
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}", result?.ExpiresOn);
227-
}
228-
else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryPassword)
229-
{
230-
result = await app.AcquireTokenByUsernamePassword(scopes, parameters.UserId, parameters.Password)
231-
.WithCorrelationId(parameters.ConnectionId)
232-
.ExecuteAsync(cancellationToken: cts.Token)
233-
.ConfigureAwait(false);
234-
235-
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}", result?.ExpiresOn);
236-
}
237-
else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive ||
238-
parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow)
239-
{
240-
// Fetch available accounts from 'app' instance
241-
System.Collections.Generic.IEnumerator<IAccount> accounts = (await app.GetAccountsAsync().ConfigureAwait(false)).GetEnumerator();
216+
result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false);
242217

243-
IAccount account = default;
244-
if (accounts.MoveNext())
218+
if (null == result)
245219
{
246220
if (!string.IsNullOrEmpty(parameters.UserId))
247221
{
248-
do
249-
{
250-
IAccount currentVal = accounts.Current;
251-
if (string.Compare(parameters.UserId, currentVal.Username, StringComparison.InvariantCultureIgnoreCase) == 0)
252-
{
253-
account = currentVal;
254-
break;
255-
}
256-
}
257-
while (accounts.MoveNext());
222+
result = await app.AcquireTokenByIntegratedWindowsAuth(scopes)
223+
.WithCorrelationId(parameters.ConnectionId)
224+
.WithUsername(parameters.UserId)
225+
.ExecuteAsync(cancellationToken: cts.Token)
226+
.ConfigureAwait(false);
258227
}
259228
else
260229
{
261-
account = accounts.Current;
230+
result = await app.AcquireTokenByIntegratedWindowsAuth(scopes)
231+
.WithCorrelationId(parameters.ConnectionId)
232+
.ExecuteAsync(cancellationToken: cts.Token)
233+
.ConfigureAwait(false);
262234
}
235+
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}", result?.ExpiresOn);
236+
}
237+
}
238+
else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryPassword)
239+
{
240+
string pwCacheKey = GetAccountPwCacheKey(parameters);
241+
object previousPw = s_accountPwCache.Get(pwCacheKey);
242+
byte[] currPwHash = GetHash(parameters.Password);
243+
244+
if (null != previousPw &&
245+
previousPw is byte[] previousPwBytes &&
246+
// Only get the cached token if the current password hash matches the previously used password hash
247+
currPwHash.SequenceEqual(previousPwBytes))
248+
{
249+
result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false);
263250
}
264251

265-
if (null != account)
252+
if (null == result)
266253
{
267-
try
268-
{
269-
// If 'account' is available in 'app', we use the same to acquire token silently.
270-
// Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent
271-
result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync(cancellationToken: cts.Token).ConfigureAwait(false);
272-
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
273-
}
274-
catch (MsalUiRequiredException)
254+
result = await app.AcquireTokenByUsernamePassword(scopes, parameters.UserId, parameters.Password)
255+
.WithCorrelationId(parameters.ConnectionId)
256+
.ExecuteAsync(cancellationToken: cts.Token)
257+
.ConfigureAwait(false);
258+
259+
// We cache the password hash to ensure future connection requests include a validated password
260+
// when we check for a cached MSAL account. Otherwise, a connection request with the same username
261+
// against the same tenant could succeed with an invalid password when we re-use the cached token.
262+
if (!s_accountPwCache.Add(pwCacheKey, GetHash(parameters.Password), DateTime.UtcNow.AddHours(s_accountPwCacheTtlInHours)))
275263
{
276-
// An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application,
277-
// for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired),
278-
// or the user needs to perform two factor authentication.
279-
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts).ConfigureAwait(false);
280-
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
264+
s_accountPwCache.Remove(pwCacheKey);
265+
s_accountPwCache.Add(pwCacheKey, GetHash(parameters.Password), DateTime.UtcNow.AddHours(s_accountPwCacheTtlInHours));
281266
}
267+
268+
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}", result?.ExpiresOn);
282269
}
283-
else
270+
}
271+
else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive ||
272+
parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow)
273+
{
274+
try
275+
{
276+
result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false);
277+
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
278+
}
279+
catch (MsalUiRequiredException)
280+
{
281+
// An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application,
282+
// for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired),
283+
// or the user needs to perform two factor authentication.
284+
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts, _customWebUI, _deviceCodeFlowCallback).ConfigureAwait(false);
285+
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
286+
}
287+
288+
if (null == result)
284289
{
285290
// If no existing 'account' is found, we request user to sign in interactively.
286-
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts).ConfigureAwait(false);
291+
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts, _customWebUI, _deviceCodeFlowCallback).ConfigureAwait(false);
287292
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
288293
}
289294
}
@@ -296,8 +301,49 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
296301
return new SqlAuthenticationToken(result.AccessToken, result.ExpiresOn);
297302
}
298303

299-
private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(IPublicClientApplication app, string[] scopes, Guid connectionId, string userId,
300-
SqlAuthenticationMethod authenticationMethod, CancellationTokenSource cts)
304+
private static async Task<AuthenticationResult> TryAcquireTokenSilent(IPublicClientApplication app, SqlAuthenticationParameters parameters,
305+
string[] scopes, CancellationTokenSource cts)
306+
{
307+
AuthenticationResult result = null;
308+
309+
// Fetch available accounts from 'app' instance
310+
System.Collections.Generic.IEnumerator<IAccount> accounts = (await app.GetAccountsAsync().ConfigureAwait(false)).GetEnumerator();
311+
312+
IAccount account = default;
313+
if (accounts.MoveNext())
314+
{
315+
if (!string.IsNullOrEmpty(parameters.UserId))
316+
{
317+
do
318+
{
319+
IAccount currentVal = accounts.Current;
320+
if (string.Compare(parameters.UserId, currentVal.Username, StringComparison.InvariantCultureIgnoreCase) == 0)
321+
{
322+
account = currentVal;
323+
break;
324+
}
325+
}
326+
while (accounts.MoveNext());
327+
}
328+
else
329+
{
330+
account = accounts.Current;
331+
}
332+
}
333+
334+
if (null != account)
335+
{
336+
// If 'account' is available in 'app', we use the same to acquire token silently.
337+
// Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent
338+
result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync(cancellationToken: cts.Token).ConfigureAwait(false);
339+
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
340+
}
341+
342+
return result;
343+
}
344+
345+
private static async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(IPublicClientApplication app, string[] scopes, Guid connectionId, string userId,
346+
SqlAuthenticationMethod authenticationMethod, CancellationTokenSource cts, ICustomWebUi customWebUI, Func<DeviceCodeResult, Task> deviceCodeFlowCallback)
301347
{
302348
try
303349
{
@@ -316,11 +362,11 @@ private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(
316362
*/
317363
ctsInteractive.CancelAfter(180000);
318364
#endif
319-
if (_customWebUI != null)
365+
if (customWebUI != null)
320366
{
321367
return await app.AcquireTokenInteractive(scopes)
322368
.WithCorrelationId(connectionId)
323-
.WithCustomWebUi(_customWebUI)
369+
.WithCustomWebUi(customWebUI)
324370
.WithLoginHint(userId)
325371
.ExecuteAsync(ctsInteractive.Token)
326372
.ConfigureAwait(false);
@@ -354,7 +400,7 @@ private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(
354400
else
355401
{
356402
AuthenticationResult result = await app.AcquireTokenWithDeviceCode(scopes,
357-
deviceCodeResult => _deviceCodeFlowCallback(deviceCodeResult))
403+
deviceCodeResult => deviceCodeFlowCallback(deviceCodeResult))
358404
.WithCorrelationId(connectionId)
359405
.ExecuteAsync(cancellationToken: cts.Token)
360406
.ConfigureAwait(false);
@@ -407,6 +453,19 @@ private IPublicClientApplication GetPublicClientAppInstance(PublicClientAppKey p
407453
return clientApplicationInstance;
408454
}
409455

456+
private static string GetAccountPwCacheKey(SqlAuthenticationParameters parameters)
457+
{
458+
return parameters.Authority + "+" + parameters.UserId;
459+
}
460+
461+
private static byte[] GetHash(string input)
462+
{
463+
byte[] unhashedBytes = Encoding.Unicode.GetBytes(input);
464+
SHA256 sha256 = SHA256.Create();
465+
byte[] hashedBytes = sha256.ComputeHash(unhashedBytes);
466+
return hashedBytes;
467+
}
468+
410469
private IPublicClientApplication CreateClientAppInstance(PublicClientAppKey publicClientAppKey)
411470
{
412471
IPublicClientApplication publicClientApplication;

0 commit comments

Comments
 (0)