Skip to content

Commit ff19a0b

Browse files
[Release 2.1] Fix | Fixes Kerberos auth when SPN does not contain port (#935)
1 parent 4c45dce commit ff19a0b

File tree

7 files changed

+67
-48
lines changed

7 files changed

+67
-48
lines changed

src/Microsoft.Data.SqlClient/netcore/src/Common/src/System/Net/Security/NegotiateStreamPal.Unix.cs

+30-15
Original file line numberDiff line numberDiff line change
@@ -135,15 +135,18 @@ private static SecurityStatusPal EstablishSecurityContext(
135135
}
136136
catch (Exception ex)
137137
{
138-
if (NetEventSource.IsEnabled) NetEventSource.Error(null, ex);
138+
if (NetEventSource.IsEnabled)
139+
{
140+
NetEventSource.Error(null, ex);
141+
}
139142
return new SecurityStatusPal(SecurityStatusPalErrorCode.InternalError, ex);
140143
}
141144
}
142145

143146
internal static SecurityStatusPal InitializeSecurityContext(
144147
SafeFreeCredentials credentialsHandle,
145148
ref SafeDeleteContext securityContext,
146-
string spn,
149+
string[] spns,
147150
ContextFlagsPal requestedContextFlags,
148151
SecurityBuffer[] inSecurityBufferArray,
149152
SecurityBuffer outSecurityBuffer,
@@ -156,20 +159,33 @@ internal static SecurityStatusPal InitializeSecurityContext(
156159
}
157160

158161
SafeFreeNegoCredentials negoCredentialsHandle = (SafeFreeNegoCredentials)credentialsHandle;
162+
SecurityStatusPal status = default;
159163

160-
if (negoCredentialsHandle.IsDefault && string.IsNullOrEmpty(spn))
164+
foreach (string spn in spns)
161165
{
162-
throw new PlatformNotSupportedException(Strings.net_nego_not_supported_empty_target_with_defaultcreds);
163-
}
166+
if (negoCredentialsHandle.IsDefault && string.IsNullOrEmpty(spn))
167+
{
168+
throw new PlatformNotSupportedException(Strings.net_nego_not_supported_empty_target_with_defaultcreds);
169+
}
164170

165-
SecurityStatusPal status = EstablishSecurityContext(
166-
negoCredentialsHandle,
167-
ref securityContext,
168-
spn,
169-
requestedContextFlags,
170-
((inSecurityBufferArray != null && inSecurityBufferArray.Length != 0) ? inSecurityBufferArray[0] : null),
171-
outSecurityBuffer,
172-
ref contextFlags);
171+
status = EstablishSecurityContext(
172+
negoCredentialsHandle,
173+
ref securityContext,
174+
spn,
175+
requestedContextFlags,
176+
((inSecurityBufferArray != null && inSecurityBufferArray.Length != 0) ? inSecurityBufferArray[0] : null),
177+
outSecurityBuffer,
178+
ref contextFlags);
179+
180+
if (status.ErrorCode != SecurityStatusPalErrorCode.InternalError)
181+
{
182+
break; // Successful case, exit the loop with current SPN.
183+
}
184+
else
185+
{
186+
securityContext = null; // Reset security context to be generated again for next SPN.
187+
}
188+
}
173189

174190
// Confidentiality flag should not be set if not requested
175191
if (status.ErrorCode == SecurityStatusPalErrorCode.CompleteNeeded)
@@ -180,7 +196,6 @@ internal static SecurityStatusPal InitializeSecurityContext(
180196
throw new PlatformNotSupportedException(Strings.net_nego_protection_level_not_supported);
181197
}
182198
}
183-
184199
return status;
185200
}
186201

@@ -224,7 +239,7 @@ internal static SafeFreeCredentials AcquireCredentialsHandle(string package, boo
224239
new SafeFreeNegoCredentials(false, string.Empty, string.Empty, string.Empty) :
225240
new SafeFreeNegoCredentials(ntlmOnly, credential.UserName, credential.Password, credential.Domain);
226241
}
227-
catch(Exception ex)
242+
catch (Exception ex)
228243
{
229244
throw new Win32Exception(NTE_FAIL, ex.Message);
230245
}

src/Microsoft.Data.SqlClient/netcore/src/Common/src/System/Net/Security/NegotiateStreamPal.Windows.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ internal static string QueryContextAuthenticationPackage(SafeDeleteContext secur
7070
internal static SecurityStatusPal InitializeSecurityContext(
7171
SafeFreeCredentials credentialsHandle,
7272
ref SafeDeleteContext securityContext,
73-
string spn,
73+
string[] spn,
7474
ContextFlagsPal requestedContextFlags,
7575
SecurityBuffer[] inSecurityBufferArray,
7676
SecurityBuffer outSecurityBuffer,
@@ -81,7 +81,7 @@ internal static SecurityStatusPal InitializeSecurityContext(
8181
GlobalSSPI.SSPIAuth,
8282
credentialsHandle,
8383
ref securityContext,
84-
spn,
84+
spn[0],
8585
ContextFlagsAdapterPal.GetInteropFromContextFlagsPal(requestedContextFlags),
8686
Interop.SspiCli.Endianness.SECURITY_NETWORK_DREP,
8787
inSecurityBufferArray,

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs

+21-17
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ internal uint DisableSsl(SNIHandle handle)
7272
/// <param name="sendBuff">Send buffer</param>
7373
/// <param name="serverName">Service Principal Name buffer</param>
7474
/// <returns>SNI error code</returns>
75-
internal void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, ref byte[] sendBuff, byte[] serverName)
75+
internal void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, ref byte[] sendBuff, byte[][] serverName)
7676
{
7777
SafeDeleteContext securityContext = sspiClientContextStatus.SecurityContext;
7878
ContextFlagsPal contextFlags = sspiClientContextStatus.ContextFlags;
@@ -104,12 +104,15 @@ internal void GenSspiClientContext(SspiClientContextStatus sspiClientContextStat
104104
| ContextFlagsPal.Delegate
105105
| ContextFlagsPal.MutualAuth;
106106

107-
string serverSPN = System.Text.Encoding.UTF8.GetString(serverName);
108-
107+
string[] serverSPNs = new string[serverName.Length];
108+
for (int i = 0; i < serverName.Length; i++)
109+
{
110+
serverSPNs[i] = System.Text.Encoding.UTF8.GetString(serverName[i]);
111+
}
109112
SecurityStatusPal statusCode = NegotiateStreamPal.InitializeSecurityContext(
110113
credentialsHandle,
111114
ref securityContext,
112-
serverSPN,
115+
serverSPNs,
113116
requestedContextFlags,
114117
inSecurityBufferArray,
115118
outSecurityBuffer,
@@ -253,7 +256,7 @@ internal uint WritePacket(SNIHandle handle, SNIPacket packet, bool sync)
253256
/// <param name="cachedFQDN">Used for DNS Cache</param>
254257
/// <param name="pendingDNSInfo">Used for DNS Cache</param>
255258
/// <returns>SNI handle</returns>
256-
internal SNIHandle CreateConnectionHandle(string fullServerName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[] spnBuffer, bool flushCache, bool async, bool parallel, bool isIntegratedSecurity, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo)
259+
internal SNIHandle CreateConnectionHandle(string fullServerName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[][] spnBuffer, bool flushCache, bool async, bool parallel, bool isIntegratedSecurity, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo)
257260
{
258261
instanceName = new byte[1];
259262

@@ -294,7 +297,7 @@ internal SNIHandle CreateConnectionHandle(string fullServerName, bool ignoreSniO
294297
{
295298
try
296299
{
297-
spnBuffer = GetSqlServerSPN(details);
300+
spnBuffer = GetSqlServerSPNs(details);
298301
}
299302
catch (Exception e)
300303
{
@@ -305,7 +308,7 @@ internal SNIHandle CreateConnectionHandle(string fullServerName, bool ignoreSniO
305308
return sniHandle;
306309
}
307310

308-
private static byte[] GetSqlServerSPN(DataSource dataSource)
311+
private static byte[][] GetSqlServerSPNs(DataSource dataSource)
309312
{
310313
Debug.Assert(!string.IsNullOrWhiteSpace(dataSource.ServerName));
311314

@@ -319,16 +322,11 @@ private static byte[] GetSqlServerSPN(DataSource dataSource)
319322
{
320323
postfix = dataSource.InstanceName;
321324
}
322-
// For handling tcp:<hostname> format
323-
else if (dataSource._connectionProtocol == DataSource.Protocol.TCP)
324-
{
325-
postfix = DefaultSqlServerPort.ToString();
326-
}
327325

328-
return GetSqlServerSPN(hostName, postfix);
326+
return GetSqlServerSPNs(hostName, postfix, dataSource._connectionProtocol);
329327
}
330328

331-
private static byte[] GetSqlServerSPN(string hostNameOrAddress, string portOrInstanceName)
329+
private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOrInstanceName, DataSource.Protocol protocol)
332330
{
333331
Debug.Assert(!string.IsNullOrWhiteSpace(hostNameOrAddress));
334332
IPHostEntry hostEntry = null;
@@ -347,16 +345,22 @@ private static byte[] GetSqlServerSPN(string hostNameOrAddress, string portOrIns
347345
// If the DNS lookup failed, then resort to using the user provided hostname to construct the SPN.
348346
fullyQualifiedDomainName = hostEntry?.HostName ?? hostNameOrAddress;
349347
}
348+
350349
string serverSpn = SqlServerSpnHeader + "/" + fullyQualifiedDomainName;
350+
351351
if (!string.IsNullOrWhiteSpace(portOrInstanceName))
352352
{
353353
serverSpn += ":" + portOrInstanceName;
354354
}
355-
else
355+
else if (protocol == DataSource.Protocol.None || protocol == DataSource.Protocol.TCP) // Default is TCP
356356
{
357-
serverSpn += $":{DefaultSqlServerPort}";
357+
string serverSpnWithDefaultPort = serverSpn + $":{DefaultSqlServerPort}";
358+
// Set both SPNs with and without Port as Port is optional for default instance
359+
return new byte[][] { Encoding.UTF8.GetBytes(serverSpn), Encoding.UTF8.GetBytes(serverSpnWithDefaultPort) };
358360
}
359-
return Encoding.UTF8.GetBytes(serverSpn);
361+
// else Named Pipes do not need to valid port
362+
363+
return new byte[][] { Encoding.UTF8.GetBytes(serverSpn) };
360364
}
361365

362366
/// <summary>

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ internal sealed partial class TdsParser
114114

115115
private bool _isDenali = false;
116116

117-
private byte[] _sniSpnBuffer = null;
117+
private byte[][] _sniSpnBuffer = null;
118118

119119
// SqlStatistics
120120
private SqlStatistics _statistics = null;

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -789,7 +789,7 @@ private void ResetCancelAndProcessAttention()
789789
}
790790
}
791791

792-
internal abstract void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[] spnBuffer, bool flushCache, bool async, bool fParallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo, bool isIntegratedSecurity = false);
792+
internal abstract void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[][] spnBuffer, bool flushCache, bool async, bool fParallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo, bool isIntegratedSecurity = false);
793793

794794
internal abstract void AssignPendingDNSInfo(string userProtocol, string DNSCacheKey, ref SQLDNSInfo pendingDNSInfo);
795795

@@ -831,7 +831,7 @@ private void ResetCancelAndProcessAttention()
831831

832832
protected abstract void RemovePacketFromPendingList(PacketHandle pointer);
833833

834-
internal abstract uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[] _sniSpnBuffer);
834+
internal abstract uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer);
835835

836836
internal bool Deactivate()
837837
{

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ internal SNIMarsHandle CreateMarsSession(object callbackObject, bool async)
4949
protected override uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize)
5050
=> SNIProxy.GetInstance().PacketGetData(packet.ManagedPacket, _inBuff, ref dataSize);
5151

52-
internal override void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[] spnBuffer, bool flushCache, bool async, bool parallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo, bool isIntegratedSecurity)
52+
internal override void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[][] spnBuffer, bool flushCache, bool async, bool parallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo, bool isIntegratedSecurity)
5353
{
5454
_sessionHandle = SNIProxy.GetInstance().CreateConnectionHandle(serverName, ignoreSniOpenTimeout, timerExpire, out instanceName, ref spnBuffer, flushCache, async, parallel, isIntegratedSecurity, cachedFQDN, ref pendingDNSInfo);
5555
if (_sessionHandle == null)
@@ -215,7 +215,7 @@ internal override uint EnableMars(ref uint info)
215215

216216
internal override uint SetConnectionBufferSize(ref uint unsignedPacketSize) => SNIProxy.GetInstance().SetConnectionBufferSize(Handle, unsignedPacketSize);
217217

218-
internal override uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[] _sniSpnBuffer)
218+
internal override uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer)
219219
{
220220
if (_sspiClientContextStatus == null)
221221
{

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs

+9-9
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,12 @@ internal override void AssignPendingDNSInfo(string userProtocol, string DNSCache
8080

8181
if (string.IsNullOrEmpty(userProtocol))
8282
{
83-
83+
8484
result = SNINativeMethodWrapper.SniGetProviderNumber(Handle, ref providerNumber);
8585
Debug.Assert(result == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetProviderNumber");
8686
_parser.isTcpProtocol = (providerNumber == SNINativeMethodWrapper.ProviderEnum.TCP_PROV);
8787
}
88-
else if (userProtocol == TdsEnums.TCP)
88+
else if (userProtocol == TdsEnums.TCP)
8989
{
9090
_parser.isTcpProtocol = true;
9191
}
@@ -138,14 +138,14 @@ private SNINativeMethodWrapper.ConsumerInfo CreateConsumerInfo(bool async)
138138
return myInfo;
139139
}
140140

141-
internal override void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[] spnBuffer, bool flushCache, bool async, bool fParallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo, bool isIntegratedSecurity)
141+
internal override void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[][] spnBuffer, bool flushCache, bool async, bool fParallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo, bool isIntegratedSecurity)
142142
{
143143
// We assume that the loadSSPILibrary has been called already. now allocate proper length of buffer
144-
spnBuffer = null;
144+
spnBuffer = new byte[1][];
145145
if (isIntegratedSecurity)
146146
{
147147
// now allocate proper length of buffer
148-
spnBuffer = new byte[SNINativeMethodWrapper.SniMaxComposedSpnLength];
148+
spnBuffer[0] = new byte[SNINativeMethodWrapper.SniMaxComposedSpnLength];
149149
}
150150

151151
SNINativeMethodWrapper.ConsumerInfo myInfo = CreateConsumerInfo(async);
@@ -172,7 +172,7 @@ internal override void CreatePhysicalSNIHandle(string serverName, bool ignoreSni
172172
SQLDNSInfo cachedDNSInfo;
173173
bool ret = SQLFallbackDNSCache.Instance.GetDNSInfo(cachedFQDN, out cachedDNSInfo);
174174

175-
_sessionHandle = new SNIHandle(myInfo, serverName, spnBuffer, ignoreSniOpenTimeout, checked((int)timeout), out instanceName, flushCache, !async, fParallel, cachedDNSInfo);
175+
_sessionHandle = new SNIHandle(myInfo, serverName, spnBuffer[0], ignoreSniOpenTimeout, checked((int)timeout), out instanceName, flushCache, !async, fParallel, cachedDNSInfo);
176176
}
177177

178178
protected override uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize)
@@ -385,8 +385,8 @@ internal override uint EnableSsl(ref uint info)
385385
internal override uint SetConnectionBufferSize(ref uint unsignedPacketSize)
386386
=> SNINativeMethodWrapper.SNISetInfo(Handle, SNINativeMethodWrapper.QTypes.SNI_QUERY_CONN_BUFSIZE, ref unsignedPacketSize);
387387

388-
internal override uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[] _sniSpnBuffer)
389-
=> SNINativeMethodWrapper.SNISecGenClientContext(Handle, receivedBuff, receivedLength, sendBuff, ref sendLength, _sniSpnBuffer);
388+
internal override uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer)
389+
=> SNINativeMethodWrapper.SNISecGenClientContext(Handle, receivedBuff, receivedLength, sendBuff, ref sendLength, _sniSpnBuffer[0]);
390390

391391
internal override uint WaitForSSLHandShakeToComplete(out int protocolVersion)
392392
{
@@ -421,7 +421,7 @@ internal override uint WaitForSSLHandShakeToComplete(out int protocolVersion)
421421
protocolVersion = (int)SslProtocols.Ssl2;
422422
#pragma warning restore CS0618 // Type or member is obsolete : SSL is depricated
423423
}
424-
else if(nativeProtocol.HasFlag(NativeProtocols.SP_PROT_NONE))
424+
else if (nativeProtocol.HasFlag(NativeProtocols.SP_PROT_NONE))
425425
{
426426
protocolVersion = (int)SslProtocols.None;
427427
}

0 commit comments

Comments
 (0)