Skip to content

Commit cf92a05

Browse files
authored
Sample code improvements around token caching (#2821)
1 parent 5dc7b06 commit cf92a05

4 files changed

+76
-33
lines changed

doc/samples/AzureKeyVaultProviderExample.cs

+4-2
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,13 @@ public static void Main(string[] args)
7979
}
8080
}
8181

82+
// Maintain an instance of the ClientCredential object to take advantage of underlying token caching
83+
private static ClientCredential clientCredential = new ClientCredential(s_clientId, s_clientSecret);
84+
8285
public static async Task<string> AzureActiveDirectoryAuthenticationCallback(string authority, string resource, string scope)
8386
{
8487
var authContext = new AuthenticationContext(authority);
85-
ClientCredential clientCred = new ClientCredential(s_clientId, s_clientSecret);
86-
AuthenticationResult result = await authContext.AcquireTokenAsync(resource, clientCred);
88+
AuthenticationResult result = await authContext.AcquireTokenAsync(resource, clientCredential);
8789
if (result == null)
8890
{
8991
throw new InvalidOperationException($"Failed to retrieve an access token for {resource}");

doc/samples/AzureKeyVaultProviderWithEnclaveProviderExample.cs

+4-2
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,13 @@ static void Main(string[] args)
8181
}
8282
}
8383

84+
// Maintain an instance of the ClientCredential object to take advantage of underlying token caching
85+
private static ClientCredential clientCredential = new ClientCredential(s_clientId, s_clientSecret);
86+
8487
public static async Task<string> AzureActiveDirectoryAuthenticationCallback(string authority, string resource, string scope)
8588
{
8689
var authContext = new AuthenticationContext(authority);
87-
ClientCredential clientCred = new ClientCredential(s_clientId, s_clientSecret);
88-
AuthenticationResult result = await authContext.AcquireTokenAsync(resource, clientCred);
90+
AuthenticationResult result = await authContext.AcquireTokenAsync(resource, clientCredential);
8991
if (result == null)
9092
{
9193
throw new InvalidOperationException($"Failed to retrieve an access token for {resource}");

doc/samples/CustomDeviceCodeFlowAzureAuthenticationProvider.cs

+32-12
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
//<Snippet1>
22
using System;
3+
using System.Collections.Generic;
4+
using System.Linq;
35
using System.Threading.Tasks;
4-
using Microsoft.Identity.Client;
56
using Microsoft.Data.SqlClient;
7+
using Microsoft.Identity.Client;
68

79
namespace CustomAuthenticationProviderExamples
810
{
@@ -12,28 +14,46 @@ namespace CustomAuthenticationProviderExamples
1214
/// </summary>
1315
public class CustomDeviceCodeFlowAzureAuthenticationProvider : SqlAuthenticationProvider
1416
{
17+
private const string clientId = "my-client-id";
18+
private const string clientName = "My Application Name";
19+
private const string s_defaultScopeSuffix = "/.default";
20+
21+
// Maintain a copy of the PublicClientApplication object to cache the underlying access tokens it provides
22+
private static IPublicClientApplication pcApplication;
23+
1524
public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenticationParameters parameters)
1625
{
17-
string clientId = "my-client-id";
18-
string clientName = "My Application Name";
19-
string s_defaultScopeSuffix = "/.default";
20-
2126
string[] scopes = new string[] { parameters.Resource.EndsWith(s_defaultScopeSuffix) ? parameters.Resource : parameters.Resource + s_defaultScopeSuffix };
2227

23-
IPublicClientApplication app = PublicClientApplicationBuilder.Create(clientId)
24-
.WithAuthority(parameters.Authority)
25-
.WithClientName(clientName)
26-
.WithRedirectUri("https://login.microsoftonline.com/common/oauth2/nativeclient")
28+
IPublicClientApplication app = pcApplication;
29+
if (app == null)
30+
{
31+
pcApplication = app = PublicClientApplicationBuilder.Create(clientId)
32+
.WithAuthority(parameters.Authority)
33+
.WithClientName(clientName)
34+
.WithRedirectUri("https://login.microsoftonline.com/common/oauth2/nativeclient")
2735
.Build();
36+
}
37+
38+
AuthenticationResult result;
39+
40+
try
41+
{
42+
IEnumerable<IAccount> accounts = await app.GetAccountsAsync();
43+
result = await app.AcquireTokenSilent(scopes, accounts.FirstOrDefault()).ExecuteAsync();
44+
}
45+
catch (MsalUiRequiredException)
46+
{
47+
result = await app.AcquireTokenWithDeviceCode(scopes,
48+
deviceCodeResult => CustomDeviceFlowCallback(deviceCodeResult)).ExecuteAsync();
49+
}
2850

29-
AuthenticationResult result = await app.AcquireTokenWithDeviceCode(scopes,
30-
deviceCodeResult => CustomDeviceFlowCallback(deviceCodeResult)).ExecuteAsync();
3151
return new SqlAuthenticationToken(result.AccessToken, result.ExpiresOn);
3252
}
3353

3454
public override bool IsSupported(SqlAuthenticationMethod authenticationMethod) => authenticationMethod.Equals(SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow);
3555

36-
private Task CustomDeviceFlowCallback(DeviceCodeResult result)
56+
private static Task<int> CustomDeviceFlowCallback(DeviceCodeResult result)
3757
{
3858
Console.WriteLine(result.Message);
3959
return Task.FromResult(0);

doc/samples/SqlConnection_AccessTokenCallback.cs

+36-17
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
using System;
2-
using System.Data;
32
// <Snippet1>
4-
using Microsoft.Data.SqlClient;
3+
using System.Collections.Concurrent;
4+
using System.Threading;
5+
using System.Threading.Tasks;
6+
using Azure.Core;
57
using Azure.Identity;
8+
using Microsoft.Data.SqlClient;
69

710
class Program
811
{
@@ -12,25 +15,41 @@ static void Main()
1215
Console.ReadLine();
1316
}
1417

18+
const string defaultScopeSuffix = "/.default";
19+
20+
// Reuse credential objects to take advantage of underlying token caches
21+
private static ConcurrentDictionary<string, DefaultAzureCredential> credentials = new ConcurrentDictionary<string, DefaultAzureCredential>();
22+
23+
// Use a shared callback function for connections that should be in the same connection pool
24+
private static Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticationToken>> myAccessTokenCallback =
25+
async (authParams, cancellationToken) =>
26+
{
27+
string scope = authParams.Resource.EndsWith(defaultScopeSuffix)
28+
? authParams.Resource
29+
: $"{authParams.Resource}{defaultScopeSuffix}";
30+
31+
DefaultAzureCredentialOptions options = new DefaultAzureCredentialOptions();
32+
options.ManagedIdentityClientId = authParams.UserId;
33+
34+
// Reuse the same credential object if we are using the same MI Client Id
35+
AccessToken token = await credentials.GetOrAdd(authParams.UserId, new DefaultAzureCredential(options)).GetTokenAsync(
36+
new TokenRequestContext(new string[] { scope }),
37+
cancellationToken);
38+
39+
return new SqlAuthenticationToken(token.Token, token.ExpiresOn);
40+
};
41+
1542
private static void OpenSqlConnection()
1643
{
17-
const string defaultScopeSuffix = "/.default";
18-
string connectionString = GetConnectionString();
19-
DefaultAzureCredential credential = new();
44+
// (Optional) Pass a User-Assigned Managed Identity Client ID.
45+
// This will ensure different MI Client IDs are in different connection pools.
46+
string connectionString = "Server=myServer.database.windows.net;Encrypt=Mandatory;UserId=<ManagedIdentitityClientId>;";
2047

21-
using (SqlConnection connection = new(connectionString)
48+
using (SqlConnection connection = new SqlConnection(connectionString)
2249
{
23-
AccessTokenCallback = async (authParams, cancellationToken) =>
24-
{
25-
string scope = authParams.Resource.EndsWith(defaultScopeSuffix)
26-
? authParams.Resource
27-
: $"{authParams.Resource}{defaultScopeSuffix}";
28-
AccessToken token = await credential.GetTokenAsync(
29-
new TokenRequestContext([scope]),
30-
cancellationToken);
31-
32-
return new SqlAuthenticationToken(token.Token, token.ExpiresOn);
33-
}
50+
// The callback function is part of the connection pool key. Using a static callback function
51+
// ensures connections will not create a new pool per connection just for the callback.
52+
AccessTokenCallback = myAccessTokenCallback
3453
})
3554
{
3655
connection.Open();

0 commit comments

Comments
 (0)