Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix | Linux SPN port number using named instance and Kerberos authentication does not return port# #2240

Merged
merged 17 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ private static byte[][] GetSqlServerSPNs(DataSource dataSource, string serverSPN
}
else if (!string.IsNullOrWhiteSpace(dataSource.InstanceName))
{
postfix = dataSource.InstanceName;
postfix = (dataSource._connectionProtocol == DataSource.Protocol.TCP ? dataSource.ResolvedPort.ToString() : dataSource.InstanceName);
}

SqlClientEventSource.Log.TryTraceEvent("SNIProxy.GetSqlServerSPN | Info | ServerName {0}, InstanceName {1}, Port {2}, postfix {3}", dataSource?.ServerName, dataSource?.InstanceName, dataSource?.Port, postfix);
Expand Down Expand Up @@ -317,7 +317,7 @@ private static SNITCPHandle CreateTcpHandle(
{
try
{
port = isAdminConnection ?
details.ResolvedPort = port = isAdminConnection ?
SSRP.GetDacPortByInstanceName(hostName, details.InstanceName, timeout, parallel, ipPreference) :
SSRP.GetPortByInstanceName(hostName, details.InstanceName, timeout, parallel, ipPreference);
}
Expand Down Expand Up @@ -436,6 +436,11 @@ internal enum Protocol { TCP, NP, None, Admin };
/// </summary>
internal int Port { get; private set; } = -1;

/// <summary>
/// The port resolved by SSRP when InstanceName is specified
/// </summary>
internal int ResolvedPort { get; set; } = -1;

/// <summary>
/// Provides the inferred Instance Name from Server Data Source
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1073,6 +1073,33 @@ public static string GetMachineFQDN(string hostname)
return fqdn.ToString();
}

public static bool IsManagedSNI()
{
return UseManagedSNIOnWindows;
}

public static bool IsIntegratedSecurity()
{
SqlConnectionStringBuilder builder = new(DataTestUtility.TCPConnectionString);
return builder.IntegratedSecurity;
}

public static bool IsNotLocalhost()
{
// get the tcp connection string
SqlConnectionStringBuilder builder = new(DataTestUtility.TCPConnectionString);

// remove tcp from connection string
string dataSourceStr = builder.DataSource.Replace("tcp:", "");
// create a string array
string[] serverNamePartsByBackSlash = dataSourceStr.Split('\\');
// first element of array is the hostname
string hostname = serverNamePartsByBackSlash[0];

// check if hostname = localhost
return !hostname.Equals("localhost", StringComparison.OrdinalIgnoreCase);
}

private static bool RunningAsUWPApp()
{
if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System;
using System.Net;
using System.Net.Sockets;
using System.Reflection;
using System.Text;
using System.Threading.Tasks;
using Xunit;
Expand Down Expand Up @@ -83,6 +84,126 @@ public static void ConnectManagedWithInstanceNameTest(bool useMultiSubnetFailove
}
}

// Note: This Unit test was tested in a VM within the sqldrv.ad domain. i.e. from server sqldrv-win22 and
// is connecting to a Sql Server using Kerberos at sqldrv-sql22 server in the same domain.
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsManagedSNI), nameof(DataTestUtility.IsNotLocalhost), nameof(DataTestUtility.IsKerberosTest), nameof(DataTestUtility.IsNotAzureServer), nameof(DataTestUtility.IsNotAzureSynapse))]
public static void PortNumberInSPNTest()
{
SqlConnectionStringBuilder builder = new(DataTestUtility.TCPConnectionString);

Assert.True(DataTestUtility.ParseDataSource(builder.DataSource, out string hostname, out _, out string instanceName));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With specifying instance name and port number, this method call always returns false!


if (IsBrowserAlive(hostname) && IsValidInstance(hostname, instanceName))
{
using SqlConnection connection = new(builder.ConnectionString);
connection.Open();
using SqlCommand command = new("SELECT auth_scheme, local_tcp_port from sys.dm_exec_connections where session_id = @@spid", connection);
using SqlDataReader reader = command.ExecuteReader();
Assert.True(reader.Read(), "Expected to receive one row data");
Assert.Equal("KERBEROS", reader.GetString(0));
int Port = reader.GetInt32(1);

int port = -1;
string spnInfo = GetSPNInfo(builder.DataSource, out port);

// sample output to validate = MSSQLSvc/sqldrv-sql22.sqldrv.ad:1433"
Assert.Contains($"MSSQLSvc/{hostname}", spnInfo);
// the local_tcp_port Port is the same as the inferred port from instance name
Assert.Equal(Port, port);
}
}

private static string GetSPNInfo(string datasource, out int out_port)
{
Assembly systemData = Assembly.GetAssembly(typeof(SqlConnection));

// Get all required types using reflection
Type SniProxy = systemData.GetType("Microsoft.Data.SqlClient.SNI.SNIProxy");
Type SSRP = systemData.GetType("Microsoft.Data.SqlClient.SNI.SSRP");
Type DataSource = systemData.GetType("Microsoft.Data.SqlClient.SNI.DataSource");
Type TimeoutTimer = systemData.GetType("Microsoft.Data.ProviderBase.TimeoutTimer");

// Used in Datasource constructor param type array
Type[] types = new Type[1];
types[0] = typeof(string);

// Used in GetSqlServerSPNs function param types array
Type[] types2 = new Type[2];
types2[0] = DataSource;
types2[1] = typeof(string);

// GetPortByInstanceName parameters array
Type[] types3 = new Type[5];
types3[0] = typeof(string);
types3[1] = typeof(string);
types3[2] = TimeoutTimer;
types3[3] = typeof(bool);
types3[4] = typeof(Microsoft.Data.SqlClient.SqlConnectionIPAddressPreference);

// TimeoutTimer.StartSecondsTimeout params
Type[] types4 = new Type[1];
types4[0] = typeof(int);

// Get all types constructors
ConstructorInfo sniProxyCtor = SniProxy.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);
ConstructorInfo SSRPCtor = SSRP.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);
ConstructorInfo datasSourceCtor = DataSource.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, types , null);
ConstructorInfo timeoutTimerCtor = TimeoutTimer.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);

// Instantiate SNIProxy
var sniProxy = sniProxyCtor.Invoke(new object[] { });

// Instatntiate datasource
var details = datasSourceCtor.Invoke(new object[] { datasource });

// Instantiate SSRP
var ssrp = SSRPCtor.Invoke(new object[] { });

// Instantiate TimeoutTimer
var timeoutTimer = timeoutTimerCtor.Invoke(new object[] { });

// Get TimeoutTimer.StartSecondsTimeout Method
MethodInfo startSecondsTimeout = timeoutTimer.GetType().GetMethod("StartSecondsTimeout", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, types4, null);
// Create a timeoutTimer that expires in 30 seconds
timeoutTimer = startSecondsTimeout.Invoke(details, new object[] { 30 });

// Parse the datasource to separate the server name and instance name
MethodInfo ParseServerName = details.GetType().GetMethod("ParseServerName", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, types, null);
var dataSrcInfo = ParseServerName.Invoke(details, new object[] { datasource });

// Get the GetPortByInstanceName method of SSRP
MethodInfo getPortByInstanceName = ssrp.GetType().GetMethod("GetPortByInstanceName", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, types3, null);

// Get the server name
PropertyInfo serverInfo = dataSrcInfo.GetType().GetProperty("ServerName", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic);
var serverName = serverInfo.GetValue(dataSrcInfo, null).ToString();

// Get the instance name
PropertyInfo instanceNameInfo = dataSrcInfo.GetType().GetProperty("InstanceName", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic);
var instanceName = instanceNameInfo.GetValue(dataSrcInfo, null).ToString();

// Get the port number using the GetPortByInstanceName method of SSRP
var port = getPortByInstanceName.Invoke(ssrp, parameters: new object[] { serverName, instanceName, timeoutTimer, false, 0 } );

// Set the resolved port property of datasource
PropertyInfo resolvedPortInfo = dataSrcInfo.GetType().GetProperty("ResolvedPort", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic);
resolvedPortInfo.SetValue(dataSrcInfo, (int)port, null);

// Prepare the GetSqlServerSPNs method
string serverSPN = "";
MethodInfo getSqlServerSPNs = sniProxy.GetType().GetMethod("GetSqlServerSPNs", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, types2, null);

// Finally call GetSqlServerSPNs
dynamic result = getSqlServerSPNs.Invoke(sniProxy, new object[] { dataSrcInfo, serverSPN });

// MSSQLSvc/sqldrv-sql22.sqldrv.ad:1433"
var spnInfo = System.Text.Encoding.Unicode.GetString(result[0]);

out_port = (int)port;

return spnInfo;
}

private static bool IsBrowserAlive(string browserHostname)
{
const byte ClntUcastEx = 0x03;
Expand Down