@@ -72,7 +72,7 @@ internal uint DisableSsl(SNIHandle handle)
72
72
/// <param name="sendBuff">Send buffer</param>
73
73
/// <param name="serverName">Service Principal Name buffer</param>
74
74
/// <returns>SNI error code</returns>
75
- internal void GenSspiClientContext ( SspiClientContextStatus sspiClientContextStatus , byte [ ] receivedBuff , ref byte [ ] sendBuff , byte [ ] serverName )
75
+ internal void GenSspiClientContext ( SspiClientContextStatus sspiClientContextStatus , byte [ ] receivedBuff , ref byte [ ] sendBuff , byte [ ] [ ] serverName )
76
76
{
77
77
SafeDeleteContext securityContext = sspiClientContextStatus . SecurityContext ;
78
78
ContextFlagsPal contextFlags = sspiClientContextStatus . ContextFlags ;
@@ -104,12 +104,15 @@ internal void GenSspiClientContext(SspiClientContextStatus sspiClientContextStat
104
104
| ContextFlagsPal . Delegate
105
105
| ContextFlagsPal . MutualAuth ;
106
106
107
- string serverSPN = System . Text . Encoding . UTF8 . GetString ( serverName ) ;
108
-
107
+ string [ ] serverSPNs = new string [ serverName . Length ] ;
108
+ for ( int i = 0 ; i < serverName . Length ; i ++ )
109
+ {
110
+ serverSPNs [ i ] = System . Text . Encoding . UTF8 . GetString ( serverName [ i ] ) ;
111
+ }
109
112
SecurityStatusPal statusCode = NegotiateStreamPal . InitializeSecurityContext (
110
113
credentialsHandle ,
111
114
ref securityContext ,
112
- serverSPN ,
115
+ serverSPNs ,
113
116
requestedContextFlags ,
114
117
inSecurityBufferArray ,
115
118
outSecurityBuffer ,
@@ -253,7 +256,7 @@ internal uint WritePacket(SNIHandle handle, SNIPacket packet, bool sync)
253
256
/// <param name="cachedFQDN">Used for DNS Cache</param>
254
257
/// <param name="pendingDNSInfo">Used for DNS Cache</param>
255
258
/// <returns>SNI handle</returns>
256
- internal SNIHandle CreateConnectionHandle ( string fullServerName , bool ignoreSniOpenTimeout , long timerExpire , out byte [ ] instanceName , ref byte [ ] spnBuffer , bool flushCache , bool async , bool parallel , bool isIntegratedSecurity , string cachedFQDN , ref SQLDNSInfo pendingDNSInfo )
259
+ internal SNIHandle CreateConnectionHandle ( string fullServerName , bool ignoreSniOpenTimeout , long timerExpire , out byte [ ] instanceName , ref byte [ ] [ ] spnBuffer , bool flushCache , bool async , bool parallel , bool isIntegratedSecurity , string cachedFQDN , ref SQLDNSInfo pendingDNSInfo )
257
260
{
258
261
instanceName = new byte [ 1 ] ;
259
262
@@ -294,7 +297,7 @@ internal SNIHandle CreateConnectionHandle(string fullServerName, bool ignoreSniO
294
297
{
295
298
try
296
299
{
297
- spnBuffer = GetSqlServerSPN ( details ) ;
300
+ spnBuffer = GetSqlServerSPNs ( details ) ;
298
301
}
299
302
catch ( Exception e )
300
303
{
@@ -305,7 +308,7 @@ internal SNIHandle CreateConnectionHandle(string fullServerName, bool ignoreSniO
305
308
return sniHandle ;
306
309
}
307
310
308
- private static byte [ ] GetSqlServerSPN ( DataSource dataSource )
311
+ private static byte [ ] [ ] GetSqlServerSPNs ( DataSource dataSource )
309
312
{
310
313
Debug . Assert ( ! string . IsNullOrWhiteSpace ( dataSource . ServerName ) ) ;
311
314
@@ -319,16 +322,11 @@ private static byte[] GetSqlServerSPN(DataSource dataSource)
319
322
{
320
323
postfix = dataSource . InstanceName ;
321
324
}
322
- // For handling tcp:<hostname> format
323
- else if ( dataSource . _connectionProtocol == DataSource . Protocol . TCP )
324
- {
325
- postfix = DefaultSqlServerPort . ToString ( ) ;
326
- }
327
325
328
- return GetSqlServerSPN ( hostName , postfix ) ;
326
+ return GetSqlServerSPNs ( hostName , postfix , dataSource . _connectionProtocol ) ;
329
327
}
330
328
331
- private static byte [ ] GetSqlServerSPN ( string hostNameOrAddress , string portOrInstanceName )
329
+ private static byte [ ] [ ] GetSqlServerSPNs ( string hostNameOrAddress , string portOrInstanceName , DataSource . Protocol protocol )
332
330
{
333
331
Debug . Assert ( ! string . IsNullOrWhiteSpace ( hostNameOrAddress ) ) ;
334
332
IPHostEntry hostEntry = null ;
@@ -347,16 +345,22 @@ private static byte[] GetSqlServerSPN(string hostNameOrAddress, string portOrIns
347
345
// If the DNS lookup failed, then resort to using the user provided hostname to construct the SPN.
348
346
fullyQualifiedDomainName = hostEntry ? . HostName ?? hostNameOrAddress ;
349
347
}
348
+
350
349
string serverSpn = SqlServerSpnHeader + "/" + fullyQualifiedDomainName ;
350
+
351
351
if ( ! string . IsNullOrWhiteSpace ( portOrInstanceName ) )
352
352
{
353
353
serverSpn += ":" + portOrInstanceName ;
354
354
}
355
- else
355
+ else if ( protocol == DataSource . Protocol . None || protocol == DataSource . Protocol . TCP ) // Default is TCP
356
356
{
357
- serverSpn += $ ":{ DefaultSqlServerPort } ";
357
+ string serverSpnWithDefaultPort = serverSpn + $ ":{ DefaultSqlServerPort } ";
358
+ // Set both SPNs with and without Port as Port is optional for default instance
359
+ return new byte [ ] [ ] { Encoding . UTF8 . GetBytes ( serverSpn ) , Encoding . UTF8 . GetBytes ( serverSpnWithDefaultPort ) } ;
358
360
}
359
- return Encoding . UTF8 . GetBytes ( serverSpn ) ;
361
+ // else Named Pipes do not need to valid port
362
+
363
+ return new byte [ ] [ ] { Encoding . UTF8 . GetBytes ( serverSpn ) } ;
360
364
}
361
365
362
366
/// <summary>
0 commit comments