Skip to content

Commit 2dbe2d3

Browse files
committed
API for SSLParameters
Enabled hostname validation by default
1 parent 455d09e commit 2dbe2d3

File tree

4 files changed

+212
-3
lines changed

4 files changed

+212
-3
lines changed

src/main/java/org/java_websocket/client/WebSocketClient.java

+16-2
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,6 @@ public void run() {
449449
} else if( socket == null ) {
450450
socket = new Socket( proxy );
451451
isNewSocket = true;
452-
453452
} else if( socket.isClosed() ) {
454453
throw new IOException();
455454
}
@@ -464,13 +463,19 @@ public void run() {
464463

465464
// if the socket is set by others we don't apply any TLS wrapper
466465
if (isNewSocket && "wss".equals( uri.getScheme())) {
467-
468466
SSLContext sslContext = SSLContext.getInstance("TLSv1.2");
469467
sslContext.init(null, null, null);
470468
SSLSocketFactory factory = sslContext.getSocketFactory();
471469
socket = factory.createSocket(socket, uri.getHost(), getPort(), true);
472470
}
473471

472+
if (socket instanceof SSLSocket) {
473+
SSLSocket sslSocket = (SSLSocket)socket;
474+
SSLParameters sslParameters = sslSocket.getSSLParameters();
475+
onSetSSLParameters(sslParameters);
476+
sslSocket.setSSLParameters(sslParameters);
477+
}
478+
474479
istream = socket.getInputStream();
475480
ostream = socket.getOutputStream();
476481

@@ -511,6 +516,15 @@ public void run() {
511516
connectReadThread = null;
512517
}
513518

519+
/**
520+
* Apply specific
521+
* @param sslParameters the SSLParameters which will be used for the SSLSocket
522+
*/
523+
protected void onSetSSLParameters(SSLParameters sslParameters) {
524+
// Make sure we perform hostname validation
525+
sslParameters.setEndpointIdentificationAlgorithm("HTTPS");
526+
}
527+
514528
/**
515529
* Extract the specified port
516530
* @return the specified port or the default port for the specific scheme
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
package org.java_websocket.issues;
2+
3+
/*
4+
* Copyright (c) 2010-2020 Nathan Rajlich
5+
*
6+
* Permission is hereby granted, free of charge, to any person
7+
* obtaining a copy of this software and associated documentation
8+
* files (the "Software"), to deal in the Software without
9+
* restriction, including without limitation the rights to use,
10+
* copy, modify, merge, publish, distribute, sublicense, and/or sell
11+
* copies of the Software, and to permit persons to whom the
12+
* Software is furnished to do so, subject to the following
13+
* conditions:
14+
*
15+
* The above copyright notice and this permission notice shall be
16+
* included in all copies or substantial portions of the Software.
17+
*
18+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
19+
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
20+
* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
21+
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
22+
* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
23+
* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
24+
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
25+
* OTHER DEALINGS IN THE SOFTWARE.
26+
*
27+
*/
28+
29+
30+
import org.java_websocket.WebSocket;
31+
import org.java_websocket.client.WebSocketClient;
32+
import org.java_websocket.handshake.ClientHandshake;
33+
import org.java_websocket.handshake.ServerHandshake;
34+
import org.java_websocket.server.DefaultSSLWebSocketServerFactory;
35+
import org.java_websocket.server.WebSocketServer;
36+
import org.java_websocket.util.SSLContextUtil;
37+
import org.java_websocket.util.SocketUtil;
38+
import org.junit.Test;
39+
40+
import javax.net.ssl.SSLContext;
41+
import javax.net.ssl.SSLHandshakeException;
42+
import javax.net.ssl.SSLParameters;
43+
import java.io.IOException;
44+
import java.net.*;
45+
import java.security.KeyManagementException;
46+
import java.security.KeyStoreException;
47+
import java.security.NoSuchAlgorithmException;
48+
import java.security.UnrecoverableKeyException;
49+
import java.security.cert.CertificateException;
50+
import java.util.concurrent.CountDownLatch;
51+
import java.util.concurrent.TimeUnit;
52+
53+
import static org.junit.Assert.*;
54+
55+
public class Issue997Test {
56+
57+
@Test(timeout=2000)
58+
public void test_localServer_ServerLocalhost_Client127_CheckActive() throws CertificateException, UnrecoverableKeyException, NoSuchAlgorithmException, KeyManagementException, KeyStoreException, IOException, URISyntaxException, InterruptedException {
59+
SSLWebSocketClient client = testIssueWithLocalServer("127.0.0.1", SocketUtil.getAvailablePort(), SSLContextUtil.getLocalhostOnlyContext(), SSLContextUtil.getLocalhostOnlyContext(), "HTTPS");
60+
assertFalse(client.onOpen);
61+
assertTrue(client.onSSLError);
62+
}
63+
@Test(timeout=2000)
64+
public void test_localServer_ServerLocalhost_Client127_CheckInactive() throws CertificateException, UnrecoverableKeyException, NoSuchAlgorithmException, KeyManagementException, KeyStoreException, IOException, URISyntaxException, InterruptedException {
65+
SSLWebSocketClient client = testIssueWithLocalServer("127.0.0.1", SocketUtil.getAvailablePort(), SSLContextUtil.getLocalhostOnlyContext(), SSLContextUtil.getLocalhostOnlyContext(), "");
66+
assertTrue(client.onOpen);
67+
assertFalse(client.onSSLError);
68+
}
69+
70+
@Test(timeout=2000)
71+
public void test_localServer_ServerLocalhost_ClientLocalhost_CheckActive() throws CertificateException, UnrecoverableKeyException, NoSuchAlgorithmException, KeyManagementException, KeyStoreException, IOException, URISyntaxException, InterruptedException {
72+
SSLWebSocketClient client = testIssueWithLocalServer("localhost", SocketUtil.getAvailablePort(), SSLContextUtil.getLocalhostOnlyContext(), SSLContextUtil.getLocalhostOnlyContext(), "HTTPS");
73+
assertTrue(client.onOpen);
74+
assertFalse(client.onSSLError);
75+
}
76+
@Test(timeout=2000)
77+
public void test_localServer_ServerLocalhost_ClientLocalhost_CheckInactive() throws CertificateException, UnrecoverableKeyException, NoSuchAlgorithmException, KeyManagementException, KeyStoreException, IOException, URISyntaxException, InterruptedException {
78+
SSLWebSocketClient client = testIssueWithLocalServer("localhost", SocketUtil.getAvailablePort(), SSLContextUtil.getLocalhostOnlyContext(), SSLContextUtil.getLocalhostOnlyContext(), "");
79+
assertTrue(client.onOpen);
80+
assertFalse(client.onSSLError);
81+
}
82+
83+
84+
public SSLWebSocketClient testIssueWithLocalServer(String address, int port, SSLContext serverContext, SSLContext clientContext, String endpointIdentificationAlgorithm) throws IOException, URISyntaxException, InterruptedException {
85+
CountDownLatch countServerDownLatch = new CountDownLatch(1);
86+
SSLWebSocketClient client = new SSLWebSocketClient(address, port, endpointIdentificationAlgorithm);
87+
WebSocketServer server = new SSLWebSocketServer(port, countServerDownLatch);
88+
89+
server.setWebSocketFactory(new DefaultSSLWebSocketServerFactory(serverContext));
90+
if (clientContext != null) {
91+
client.setSocketFactory(clientContext.getSocketFactory());
92+
}
93+
server.start();
94+
countServerDownLatch.await();
95+
client.connectBlocking(1, TimeUnit.SECONDS);
96+
return client;
97+
}
98+
99+
100+
private static class SSLWebSocketClient extends WebSocketClient {
101+
private final String endpointIdentificationAlgorithm;
102+
public boolean onSSLError = false;
103+
public boolean onOpen = false;
104+
105+
public SSLWebSocketClient(String address, int port, String endpointIdentificationAlgorithm) throws URISyntaxException {
106+
super(new URI("wss://"+ address + ':' +port));
107+
this.endpointIdentificationAlgorithm = endpointIdentificationAlgorithm;
108+
}
109+
110+
@Override
111+
public void onOpen(ServerHandshake handshakedata) {
112+
this.onOpen = true;
113+
}
114+
115+
@Override
116+
public void onMessage(String message) {
117+
}
118+
119+
@Override
120+
public void onClose(int code, String reason, boolean remote) {
121+
}
122+
123+
@Override
124+
public void onError(Exception ex) {
125+
if (ex instanceof SSLHandshakeException) {
126+
this.onSSLError = true;
127+
}
128+
}
129+
130+
@Override
131+
protected void onSetSSLParameters(SSLParameters sslParameters) {
132+
if (endpointIdentificationAlgorithm == null) {
133+
super.onSetSSLParameters(sslParameters);
134+
} else {
135+
sslParameters.setEndpointIdentificationAlgorithm(endpointIdentificationAlgorithm);
136+
}
137+
}
138+
139+
};
140+
141+
142+
private static class SSLWebSocketServer extends WebSocketServer {
143+
private final CountDownLatch countServerDownLatch;
144+
145+
146+
public SSLWebSocketServer(int port, CountDownLatch countServerDownLatch) {
147+
super(new InetSocketAddress(port));
148+
this.countServerDownLatch = countServerDownLatch;
149+
}
150+
151+
@Override
152+
public void onOpen(WebSocket conn, ClientHandshake handshake) {
153+
}
154+
155+
@Override
156+
public void onClose(WebSocket conn, int code, String reason, boolean remote) {
157+
}
158+
159+
@Override
160+
public void onMessage(WebSocket conn, String message) {
161+
162+
}
163+
164+
@Override
165+
public void onError(WebSocket conn, Exception ex) {
166+
ex.printStackTrace();
167+
}
168+
169+
@Override
170+
public void onStart() {
171+
countServerDownLatch.countDown();
172+
}
173+
}
174+
}
Binary file not shown.

src/test/java/org/java_websocket/util/SSLContextUtil.java

+22-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
import javax.net.ssl.TrustManagerFactory;
3232
import java.io.File;
3333
import java.io.FileInputStream;
34-
import java.io.FileNotFoundException;
3534
import java.io.IOException;
3635
import java.security.*;
3736
import java.security.cert.CertificateException;
@@ -59,4 +58,26 @@ public static SSLContext getContext() throws NoSuchAlgorithmException, KeyManage
5958
sslContext.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
6059
return sslContext;
6160
}
61+
62+
public static SSLContext getLocalhostOnlyContext() throws NoSuchAlgorithmException, KeyManagementException, KeyStoreException, IOException, CertificateException, UnrecoverableKeyException {
63+
// load up the key store
64+
String STORETYPE = "JKS";
65+
String KEYSTORE = String.format("src%1$stest%1$1sjava%1$1sorg%1$1sjava_websocket%1$1skeystore_localhost_only.jks", File.separator);
66+
String STOREPASSWORD = "storepassword";
67+
String KEYPASSWORD = "keypassword";
68+
69+
KeyStore ks = KeyStore.getInstance(STORETYPE);
70+
File kf = new File(KEYSTORE);
71+
ks.load(new FileInputStream(kf), STOREPASSWORD.toCharArray());
72+
73+
KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
74+
kmf.init(ks, KEYPASSWORD.toCharArray());
75+
TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509");
76+
tmf.init(ks);
77+
78+
SSLContext sslContext = null;
79+
sslContext = SSLContext.getInstance("TLS");
80+
sslContext.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
81+
return sslContext;
82+
}
6283
}

0 commit comments

Comments
 (0)