forked from dotnet/SqlClient
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathInstanceNameTest.cs
139 lines (120 loc) · 6.2 KB
/
InstanceNameTest.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Data;
using System.Net;
using System.Net.Sockets;
using System.Text;
using System.Threading.Tasks;
using Xunit;
namespace Microsoft.Data.SqlClient.ManualTesting.Tests
{
public static class InstanceNameTest
{
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsNotAzureServer), nameof(DataTestUtility.IsNotAzureSynapse), nameof(DataTestUtility.AreConnStringsSetup))]
public static void ConnectToSQLWithInstanceNameTest()
{
SqlConnectionStringBuilder builder = new(DataTestUtility.TCPConnectionString);
bool proceed = true;
string dataSourceStr = builder.DataSource.Replace("tcp:", "");
string[] serverNamePartsByBackSlash = dataSourceStr.Split('\\');
string hostname = serverNamePartsByBackSlash[0];
if (!dataSourceStr.Contains(",") && serverNamePartsByBackSlash.Length == 2)
{
proceed = !string.IsNullOrWhiteSpace(hostname) && IsBrowserAlive(hostname);
}
if (proceed)
{
using SqlConnection connection = new(builder.ConnectionString);
connection.Open();
connection.Close();
// We can only connect via IP address if we aren't doing remote Kerberos or strict TLS
if (builder.Encrypt != SqlConnectionEncryptOption.Strict &&
(!builder.IntegratedSecurity || hostname.Equals("localhost", StringComparison.OrdinalIgnoreCase) ||
hostname.Equals(Environment.MachineName, StringComparison.OrdinalIgnoreCase)))
{
// Exercise the IP address-specific code in SSRP
IPAddress[] addresses = Dns.GetHostAddresses(hostname);
builder.DataSource = builder.DataSource.Replace(hostname, addresses[0].ToString());
builder.TrustServerCertificate = true;
using SqlConnection connection2 = new(builder.ConnectionString);
connection2.Open();
connection2.Close();
}
}
}
[ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.IsNotAzureServer), nameof(DataTestUtility.IsNotAzureSynapse), nameof(DataTestUtility.AreConnStringsSetup))]
[InlineData(true, SqlConnectionIPAddressPreference.IPv4First)]
[InlineData(true, SqlConnectionIPAddressPreference.IPv6First)]
[InlineData(true, SqlConnectionIPAddressPreference.UsePlatformDefault)]
[InlineData(false, SqlConnectionIPAddressPreference.IPv4First)]
[InlineData(false, SqlConnectionIPAddressPreference.IPv6First)]
[InlineData(false, SqlConnectionIPAddressPreference.UsePlatformDefault)]
public static void ConnectManagedWithInstanceNameTest(bool useMultiSubnetFailover, SqlConnectionIPAddressPreference ipPreference)
{
SqlConnectionStringBuilder builder = new(DataTestUtility.TCPConnectionString);
builder.MultiSubnetFailover = useMultiSubnetFailover;
builder.IPAddressPreference = ipPreference;
Assert.True(DataTestUtility.ParseDataSource(builder.DataSource, out string hostname, out _, out string instanceName));
if (IsBrowserAlive(hostname) && IsValidInstance(hostname, instanceName))
{
builder.DataSource = hostname + "\\" + instanceName;
using SqlConnection connection = new(builder.ConnectionString);
connection.Open();
}
builder.ConnectTimeout = 2;
instanceName = "invalidinstance3456";
if (!IsValidInstance(hostname, instanceName))
{
builder.DataSource = hostname + "\\" + instanceName;
using SqlConnection connection = new(builder.ConnectionString);
SqlException ex = Assert.Throws<SqlException>(() => connection.Open());
Assert.Contains("Error Locating Server/Instance Specified", ex.Message);
}
}
private static bool IsBrowserAlive(string browserHostname)
{
const byte ClntUcastEx = 0x03;
byte[] responsePacket = QueryBrowser(browserHostname, new byte[] { ClntUcastEx });
return responsePacket != null && responsePacket.Length > 0;
}
private static bool IsValidInstance(string browserHostName, string instanceName)
{
byte[] request = CreateInstanceInfoRequest(instanceName);
byte[] response = QueryBrowser(browserHostName, request);
return response != null && response.Length > 0;
}
private static byte[] QueryBrowser(string browserHostname, byte[] requestPacket)
{
const int DefaultBrowserPort = 1434;
const int sendTimeout = 1000;
const int receiveTimeout = 1000;
byte[] responsePacket = null;
using (UdpClient client = new(AddressFamily.InterNetwork))
{
try
{
Task<int> sendTask = client.SendAsync(requestPacket, requestPacket.Length, browserHostname, DefaultBrowserPort);
Task<UdpReceiveResult> receiveTask = null;
if (sendTask.Wait(sendTimeout) && (receiveTask = client.ReceiveAsync()).Wait(receiveTimeout))
{
responsePacket = receiveTask.Result.Buffer;
}
}
catch { }
}
return responsePacket;
}
private static byte[] CreateInstanceInfoRequest(string instanceName)
{
const byte ClntUcastInst = 0x04;
instanceName += char.MinValue;
int byteCount = Encoding.ASCII.GetByteCount(instanceName);
byte[] requestPacket = new byte[byteCount + 1];
requestPacket[0] = ClntUcastInst;
Encoding.ASCII.GetBytes(instanceName, 0, instanceName.Length, requestPacket, 1);
return requestPacket;
}
}
}