From d3e39e2709383da616b4404325c24f725bddd557 Mon Sep 17 00:00:00 2001
From: Craig Perkins <cwperx@amazon.com>
Date: Thu, 23 Nov 2023 20:52:37 -0500
Subject: [PATCH 1/4] Move channel.getVersion after getInnerChannel

Signed-off-by: Craig Perkins <cwperx@amazon.com>
---
 .../ssl/transport/SecuritySSLRequestHandler.java       | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java b/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java
index 78c98dd99f..cc1aa58b31 100644
--- a/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java
+++ b/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java
@@ -86,6 +86,11 @@ public final void messageReceived(T request, TransportChannel channel, Task task
 
         ThreadContext threadContext = getThreadContext();
 
+        String channelType = channel.getChannelType();
+        if (!channelType.equals("direct") && !channelType.equals("transport")) {
+            channel = getInnerChannel(channel);
+        }
+
         threadContext.putTransient(
             ConfigConstants.USE_JDK_SERIALIZATION,
             channel.getVersion().before(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION)
@@ -97,11 +102,6 @@ public final void messageReceived(T request, TransportChannel channel, Task task
             throw exception;
         }
 
-        String channelType = channel.getChannelType();
-        if (!channelType.equals("direct") && !channelType.equals("transport")) {
-            channel = getInnerChannel(channel);
-        }
-
         if (!"transport".equals(channel.getChannelType())) { // netty4
             messageReceivedDecorate(request, actualHandler, channel, task);
             return;

From f15a4a2afbe3f735606cdc78ca813ca25e5204cb Mon Sep 17 00:00:00 2001
From: Craig Perkins <cwperx@amazon.com>
Date: Fri, 24 Nov 2023 06:52:40 -0500
Subject: [PATCH 2/4] Add test with wrapped transport channel

Signed-off-by: Craig Perkins <cwperx@amazon.com>
---
 .../SecuritySSLRequestHandlerTests.java       | 60 +++++++++++++++++++
 1 file changed, 60 insertions(+)

diff --git a/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java b/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java
index b6967b0e68..f4e0773fa0 100644
--- a/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java
+++ b/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java
@@ -9,12 +9,15 @@
  */
 package org.opensearch.security.transport;
 
+import java.io.IOException;
+
 import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 
 import org.opensearch.Version;
 import org.opensearch.common.settings.Settings;
+import org.opensearch.core.transport.TransportResponse;
 import org.opensearch.security.ssl.SslExceptionHandler;
 import org.opensearch.security.ssl.transport.PrincipalExtractor;
 import org.opensearch.security.ssl.transport.SSLConfig;
@@ -93,4 +96,61 @@ public void testUseJDKSerializationHeaderIsSetOnMessageReceived() throws Excepti
         Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, transportChannel, task));
         Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));
     }
+
+    @Test
+    public void testUseJDKSerializationHeaderIsSetAfterGetInnerChannel() throws Exception {
+        TransportRequest transportRequest = mock(TransportRequest.class);
+        TransportChannel transportChannel = mock(TransportChannel.class);
+        TransportChannel wrappedChannel = new WrappedTransportChannel(transportChannel);
+        Task task = mock(Task.class);
+        doNothing().when(transportChannel).sendResponse(ArgumentMatchers.any(Exception.class));
+        when(transportChannel.getVersion()).thenReturn(Version.V_2_10_0);
+        when(transportChannel.getChannelType()).thenReturn("other");
+
+        Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task));
+        Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));
+
+        threadPool.getThreadContext().stashContext();
+        when(transportChannel.getVersion()).thenReturn(Version.V_2_11_0);
+        Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task));
+        Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));
+
+        threadPool.getThreadContext().stashContext();
+        when(transportChannel.getVersion()).thenReturn(Version.V_3_0_0);
+        Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task));
+        Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));
+    }
+
+    public class WrappedTransportChannel implements TransportChannel {
+
+        private TransportChannel inner;
+
+        public WrappedTransportChannel(TransportChannel inner) {
+            this.inner = inner;
+        }
+
+        @Override
+        public String getProfileName() {
+            return "WrappedTransportChannelProfileName";
+        }
+
+        public TransportChannel getInnerChannel() {
+            return this.inner;
+        }
+
+        @Override
+        public void sendResponse(TransportResponse response) throws IOException {
+            inner.sendResponse(response);
+        }
+
+        @Override
+        public void sendResponse(Exception e) throws IOException {
+
+        }
+
+        @Override
+        public String getChannelType() {
+            return "WrappedTransportChannelType";
+        }
+    }
 }

From 87168a3b4c484385ec2ff4c30651a383202f3054 Mon Sep 17 00:00:00 2001
From: Craig Perkins <cwperx@amazon.com>
Date: Fri, 24 Nov 2023 06:59:06 -0500
Subject: [PATCH 3/4] Verify using InOrder

Signed-off-by: Craig Perkins <cwperx@amazon.com>
---
 .../SecuritySSLRequestHandlerTests.java       | 24 ++++++++++++++++++-
 1 file changed, 23 insertions(+), 1 deletion(-)

diff --git a/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java b/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java
index f4e0773fa0..2d10b6f84f 100644
--- a/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java
+++ b/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java
@@ -30,11 +30,13 @@
 import org.opensearch.transport.TransportRequestHandler;
 
 import org.mockito.ArgumentMatchers;
+import org.mockito.InOrder;
 import org.mockito.Mock;
 
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.Mockito.doNothing;
+import static org.mockito.Mockito.inOrder;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
@@ -98,7 +100,7 @@ public void testUseJDKSerializationHeaderIsSetOnMessageReceived() throws Excepti
     }
 
     @Test
-    public void testUseJDKSerializationHeaderIsSetAfterGetInnerChannel() throws Exception {
+    public void testUseJDKSerializationHeaderIsSetWithWrapperChannel() throws Exception {
         TransportRequest transportRequest = mock(TransportRequest.class);
         TransportChannel transportChannel = mock(TransportChannel.class);
         TransportChannel wrappedChannel = new WrappedTransportChannel(transportChannel);
@@ -121,6 +123,26 @@ public void testUseJDKSerializationHeaderIsSetAfterGetInnerChannel() throws Exce
         Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));
     }
 
+    @Test
+    public void testUseJDKSerializationHeaderIsSetAfterGetInnerChannel() throws Exception {
+        TransportRequest transportRequest = mock(TransportRequest.class);
+        TransportChannel transportChannel = mock(TransportChannel.class);
+        WrappedTransportChannel wrappedChannel = mock(WrappedTransportChannel.class);
+        Task task = mock(Task.class);
+        when(wrappedChannel.getInnerChannel()).thenReturn(transportChannel);
+        when(wrappedChannel.getChannelType()).thenReturn("other");
+        doNothing().when(transportChannel).sendResponse(ArgumentMatchers.any(Exception.class));
+        when(transportChannel.getVersion()).thenReturn(Version.V_2_10_0);
+
+        Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task));
+        Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));
+
+        InOrder inOrder = inOrder(wrappedChannel, transportChannel);
+
+        inOrder.verify(wrappedChannel).getInnerChannel();
+        inOrder.verify(transportChannel).getVersion();
+    }
+
     public class WrappedTransportChannel implements TransportChannel {
 
         private TransportChannel inner;

From fb3b0915f4bd31726ee0fff2682d2bd53bba09dc Mon Sep 17 00:00:00 2001
From: Craig Perkins <cwperx@amazon.com>
Date: Mon, 27 Nov 2023 10:17:59 -0500
Subject: [PATCH 4/4] Add new constant with default transport channel types

Signed-off-by: Craig Perkins <cwperx@amazon.com>
---
 .../security/ssl/transport/SecuritySSLRequestHandler.java    | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java b/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java
index cc1aa58b31..39312e29ad 100644
--- a/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java
+++ b/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java
@@ -21,6 +21,7 @@
 import java.security.cert.Certificate;
 import java.security.cert.X509Certificate;
 import java.util.Arrays;
+import java.util.Set;
 import javax.net.ssl.SSLPeerUnverifiedException;
 
 import org.apache.logging.log4j.LogManager;
@@ -55,6 +56,8 @@ public class SecuritySSLRequestHandler<T extends TransportRequest> implements Tr
     private final SslExceptionHandler errorHandler;
     private final SSLConfig SSLConfig;
 
+    private static final Set<String> DEFAULT_CHANNEL_TYPES = Set.of("direct", "transport");
+
     public SecuritySSLRequestHandler(
         String action,
         TransportRequestHandler<T> actualHandler,
@@ -87,7 +90,7 @@ public final void messageReceived(T request, TransportChannel channel, Task task
         ThreadContext threadContext = getThreadContext();
 
         String channelType = channel.getChannelType();
-        if (!channelType.equals("direct") && !channelType.equals("transport")) {
+        if (!DEFAULT_CHANNEL_TYPES.contains(channelType)) {
             channel = getInnerChannel(channel);
         }