From d3e39e2709383da616b4404325c24f725bddd557 Mon Sep 17 00:00:00 2001 From: Craig Perkins <cwperx@amazon.com> Date: Thu, 23 Nov 2023 20:52:37 -0500 Subject: [PATCH 1/4] Move channel.getVersion after getInnerChannel Signed-off-by: Craig Perkins <cwperx@amazon.com> --- .../ssl/transport/SecuritySSLRequestHandler.java | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java b/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java index 78c98dd99f..cc1aa58b31 100644 --- a/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java +++ b/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java @@ -86,6 +86,11 @@ public final void messageReceived(T request, TransportChannel channel, Task task ThreadContext threadContext = getThreadContext(); + String channelType = channel.getChannelType(); + if (!channelType.equals("direct") && !channelType.equals("transport")) { + channel = getInnerChannel(channel); + } + threadContext.putTransient( ConfigConstants.USE_JDK_SERIALIZATION, channel.getVersion().before(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION) @@ -97,11 +102,6 @@ public final void messageReceived(T request, TransportChannel channel, Task task throw exception; } - String channelType = channel.getChannelType(); - if (!channelType.equals("direct") && !channelType.equals("transport")) { - channel = getInnerChannel(channel); - } - if (!"transport".equals(channel.getChannelType())) { // netty4 messageReceivedDecorate(request, actualHandler, channel, task); return; From f15a4a2afbe3f735606cdc78ca813ca25e5204cb Mon Sep 17 00:00:00 2001 From: Craig Perkins <cwperx@amazon.com> Date: Fri, 24 Nov 2023 06:52:40 -0500 Subject: [PATCH 2/4] Add test with wrapped transport channel Signed-off-by: Craig Perkins <cwperx@amazon.com> --- .../SecuritySSLRequestHandlerTests.java | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java b/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java index b6967b0e68..f4e0773fa0 100644 --- a/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java +++ b/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java @@ -9,12 +9,15 @@ */ package org.opensearch.security.transport; +import java.io.IOException; + import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.opensearch.Version; import org.opensearch.common.settings.Settings; +import org.opensearch.core.transport.TransportResponse; import org.opensearch.security.ssl.SslExceptionHandler; import org.opensearch.security.ssl.transport.PrincipalExtractor; import org.opensearch.security.ssl.transport.SSLConfig; @@ -93,4 +96,61 @@ public void testUseJDKSerializationHeaderIsSetOnMessageReceived() throws Excepti Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, transportChannel, task)); Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); } + + @Test + public void testUseJDKSerializationHeaderIsSetAfterGetInnerChannel() throws Exception { + TransportRequest transportRequest = mock(TransportRequest.class); + TransportChannel transportChannel = mock(TransportChannel.class); + TransportChannel wrappedChannel = new WrappedTransportChannel(transportChannel); + Task task = mock(Task.class); + doNothing().when(transportChannel).sendResponse(ArgumentMatchers.any(Exception.class)); + when(transportChannel.getVersion()).thenReturn(Version.V_2_10_0); + when(transportChannel.getChannelType()).thenReturn("other"); + + Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task)); + Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); + + threadPool.getThreadContext().stashContext(); + when(transportChannel.getVersion()).thenReturn(Version.V_2_11_0); + Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task)); + Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); + + threadPool.getThreadContext().stashContext(); + when(transportChannel.getVersion()).thenReturn(Version.V_3_0_0); + Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task)); + Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); + } + + public class WrappedTransportChannel implements TransportChannel { + + private TransportChannel inner; + + public WrappedTransportChannel(TransportChannel inner) { + this.inner = inner; + } + + @Override + public String getProfileName() { + return "WrappedTransportChannelProfileName"; + } + + public TransportChannel getInnerChannel() { + return this.inner; + } + + @Override + public void sendResponse(TransportResponse response) throws IOException { + inner.sendResponse(response); + } + + @Override + public void sendResponse(Exception e) throws IOException { + + } + + @Override + public String getChannelType() { + return "WrappedTransportChannelType"; + } + } } From 87168a3b4c484385ec2ff4c30651a383202f3054 Mon Sep 17 00:00:00 2001 From: Craig Perkins <cwperx@amazon.com> Date: Fri, 24 Nov 2023 06:59:06 -0500 Subject: [PATCH 3/4] Verify using InOrder Signed-off-by: Craig Perkins <cwperx@amazon.com> --- .../SecuritySSLRequestHandlerTests.java | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java b/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java index f4e0773fa0..2d10b6f84f 100644 --- a/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java +++ b/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java @@ -30,11 +30,13 @@ import org.opensearch.transport.TransportRequestHandler; import org.mockito.ArgumentMatchers; +import org.mockito.InOrder; import org.mockito.Mock; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -98,7 +100,7 @@ public void testUseJDKSerializationHeaderIsSetOnMessageReceived() throws Excepti } @Test - public void testUseJDKSerializationHeaderIsSetAfterGetInnerChannel() throws Exception { + public void testUseJDKSerializationHeaderIsSetWithWrapperChannel() throws Exception { TransportRequest transportRequest = mock(TransportRequest.class); TransportChannel transportChannel = mock(TransportChannel.class); TransportChannel wrappedChannel = new WrappedTransportChannel(transportChannel); @@ -121,6 +123,26 @@ public void testUseJDKSerializationHeaderIsSetAfterGetInnerChannel() throws Exce Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); } + @Test + public void testUseJDKSerializationHeaderIsSetAfterGetInnerChannel() throws Exception { + TransportRequest transportRequest = mock(TransportRequest.class); + TransportChannel transportChannel = mock(TransportChannel.class); + WrappedTransportChannel wrappedChannel = mock(WrappedTransportChannel.class); + Task task = mock(Task.class); + when(wrappedChannel.getInnerChannel()).thenReturn(transportChannel); + when(wrappedChannel.getChannelType()).thenReturn("other"); + doNothing().when(transportChannel).sendResponse(ArgumentMatchers.any(Exception.class)); + when(transportChannel.getVersion()).thenReturn(Version.V_2_10_0); + + Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task)); + Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); + + InOrder inOrder = inOrder(wrappedChannel, transportChannel); + + inOrder.verify(wrappedChannel).getInnerChannel(); + inOrder.verify(transportChannel).getVersion(); + } + public class WrappedTransportChannel implements TransportChannel { private TransportChannel inner; From fb3b0915f4bd31726ee0fff2682d2bd53bba09dc Mon Sep 17 00:00:00 2001 From: Craig Perkins <cwperx@amazon.com> Date: Mon, 27 Nov 2023 10:17:59 -0500 Subject: [PATCH 4/4] Add new constant with default transport channel types Signed-off-by: Craig Perkins <cwperx@amazon.com> --- .../security/ssl/transport/SecuritySSLRequestHandler.java | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java b/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java index cc1aa58b31..39312e29ad 100644 --- a/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java +++ b/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java @@ -21,6 +21,7 @@ import java.security.cert.Certificate; import java.security.cert.X509Certificate; import java.util.Arrays; +import java.util.Set; import javax.net.ssl.SSLPeerUnverifiedException; import org.apache.logging.log4j.LogManager; @@ -55,6 +56,8 @@ public class SecuritySSLRequestHandler<T extends TransportRequest> implements Tr private final SslExceptionHandler errorHandler; private final SSLConfig SSLConfig; + private static final Set<String> DEFAULT_CHANNEL_TYPES = Set.of("direct", "transport"); + public SecuritySSLRequestHandler( String action, TransportRequestHandler<T> actualHandler, @@ -87,7 +90,7 @@ public final void messageReceived(T request, TransportChannel channel, Task task ThreadContext threadContext = getThreadContext(); String channelType = channel.getChannelType(); - if (!channelType.equals("direct") && !channelType.equals("transport")) { + if (!DEFAULT_CHANNEL_TYPES.contains(channelType)) { channel = getInnerChannel(channel); }