Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Disables custom serialization #3826

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -173,7 +173,7 @@ jobs:
- id: build-previous
uses: ./.github/actions/run-bwc-suite
with:
plugin-previous-branch: "2.10"
plugin-previous-branch: "2.11"
plugin-next-branch: "current_branch"
report-artifact-name: bwc-${{ matrix.platform }}-jdk${{ matrix.jdk }}
username: admin
Original file line number Diff line number Diff line change
@@ -102,7 +102,7 @@ public class SecurityFilter implements ActionFilter {
protected final Logger log = LogManager.getLogger(this.getClass());
private final PrivilegesEvaluator evalp;
private final AdminDNs adminDns;
private DlsFlsRequestValve dlsFlsValve;
private final DlsFlsRequestValve dlsFlsValve;
private final AuditLog auditLog;
private final ThreadContext threadContext;
private final ClusterService cs;
@@ -184,7 +184,7 @@ private <Request extends ActionRequest, Response extends ActionResponse> void ap
}

if (threadContext.getTransient(ConfigConstants.USE_JDK_SERIALIZATION) == null) {
threadContext.putTransient(ConfigConstants.USE_JDK_SERIALIZATION, false);
threadContext.putTransient(ConfigConstants.USE_JDK_SERIALIZATION, true);
}

final ComplianceConfig complianceConfig = auditLog.getComplianceConfig();
@@ -255,7 +255,7 @@ private <Request extends ActionRequest, Response extends ActionResponse> void ap
);

threadContext.putHeader(
"_opendistro_security_trace" + System.currentTimeMillis() + "#" + UUID.randomUUID().toString(),
"_opendistro_security_trace" + System.currentTimeMillis() + "#" + UUID.randomUUID(),
Thread.currentThread().getName()
+ " FILTER -> "
+ "Node "
@@ -481,11 +481,7 @@ public void onFailure(Exception e) {
}

private static boolean isUserAdmin(User user, final AdminDNs adminDns) {
if (user != null && adminDns.isAdmin(user)) {
return true;
}

return false;
return user != null && adminDns.isAdmin(user);
}

private void attachSourceFieldContext(ActionRequest request) {
Original file line number Diff line number Diff line change
@@ -46,6 +46,8 @@

import io.netty.handler.ssl.SslHandler;

import static org.opensearch.security.support.Base64Helper.shouldUseJDKSerialization;

public class SecuritySSLRequestHandler<T extends TransportRequest> implements TransportRequestHandler<T> {

private final String action;
@@ -94,10 +96,7 @@ public final void messageReceived(T request, TransportChannel channel, Task task
channel = getInnerChannel(channel);
}

threadContext.putTransient(
ConfigConstants.USE_JDK_SERIALIZATION,
channel.getVersion().before(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION)
);
threadContext.putTransient(ConfigConstants.USE_JDK_SERIALIZATION, shouldUseJDKSerialization(channel.getVersion()));

if (SSLRequestHelper.containsBadHeader(threadContext, "_opendistro_security_ssl_")) {
final Exception exception = ExceptionUtils.createBadHeaderException();
34 changes: 32 additions & 2 deletions src/main/java/org/opensearch/security/support/Base64Helper.java
Original file line number Diff line number Diff line change
@@ -28,18 +28,20 @@

import java.io.Serializable;

import org.opensearch.Version;

public class Base64Helper {

public static String serializeObject(final Serializable object, final boolean useJDKSerialization) {
return useJDKSerialization ? Base64JDKHelper.serializeObject(object) : Base64CustomHelper.serializeObject(object);
}

public static String serializeObject(final Serializable object) {
return serializeObject(object, false);
return serializeObject(object, true);
}

public static Serializable deserializeObject(final String string) {
return deserializeObject(string, false);
return deserializeObject(string, true);
}

public static Serializable deserializeObject(final String string, final boolean useJDKDeserialization) {
@@ -69,4 +71,32 @@ public static String ensureJDKSerialized(final String string) {
// If we see an exception now, we want the caller to see it -
return Base64Helper.serializeObject(serializable, true);
}

/**
* Ensures that the returned string is custom serialized.
*
* If the supplied string is a JDK serialized representation, will deserialize it and further serialize using
* custom, otherwise returns the string as is.
*
* @param string original string, can be JDK or custom serialized
* @return custom serialized string
*/
public static String ensureCustomSerialized(final String string) {
Serializable serializable;
try {
serializable = Base64Helper.deserializeObject(string, true);
} catch (Exception e) {
// We received an exception when de-serializing the given string. It is probably custom serialized.
// Try to deserialize using custom
Base64Helper.deserializeObject(string, false);
// Since we could deserialize the object using custom, the string is already custom serialized, return as is
return string;
}
// If we see an exception now, we want the caller to see it -
return Base64Helper.serializeObject(serializable, false);
}

public static boolean shouldUseJDKSerialization(Version remoteVersion) {
return !remoteVersion.equals(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@parasjain1 Does this need to include 2.11.1 as well?

return !Set.of(Version.2_11_0, Version.2_11_1).contains(remoteVersion);

}
}
Original file line number Diff line number Diff line change
@@ -72,12 +72,13 @@
import org.opensearch.transport.TransportResponseHandler;

import static org.opensearch.security.OpenSearchSecurityPlugin.isActionTraceEnabled;
import static org.opensearch.security.support.Base64Helper.shouldUseJDKSerialization;

public class SecurityInterceptor {

protected final Logger log = LogManager.getLogger(getClass());
private BackendRegistry backendRegistry;
private AuditLog auditLog;
private final BackendRegistry backendRegistry;
private final AuditLog auditLog;
private final ThreadPool threadPool;
private final PrincipalExtractor principalExtractor;
private final InterClusterRequestEvaluator requestEvalProvider;
@@ -148,7 +149,7 @@ public <T extends TransportResponse> void sendRequestDecorate(
final String origCCSTransientMf = getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_CCS);

final boolean isDebugEnabled = log.isDebugEnabled();
final boolean useJDKSerialization = connection.getVersion().before(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION);
final boolean useJDKSerialization = shouldUseJDKSerialization(connection.getVersion());
final boolean isSameNodeRequest = localNode != null && localNode.equals(connection.getNode());

try (ThreadContext.StoredContext stashedContext = getThreadContext().stashContext()) {
@@ -226,13 +227,13 @@ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROL
);
}

if (useJDKSerialization) {
Map<String, String> jdkSerializedHeaders = new HashMap<>();
if (!useJDKSerialization) {
Map<String, String> customSerializedHeaders = new HashMap<>();
HeaderHelper.getAllSerializedHeaderNames()
.stream()
.filter(k -> headerMap.get(k) != null)
.forEach(k -> jdkSerializedHeaders.put(k, Base64Helper.ensureJDKSerialized(headerMap.get(k))));
headerMap.putAll(jdkSerializedHeaders);
.forEach(k -> customSerializedHeaders.put(k, Base64Helper.ensureCustomSerialized(headerMap.get(k))));
headerMap.putAll(customSerializedHeaders);
}

getThreadContext().putHeader(headerMap);
@@ -249,7 +250,7 @@ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROL

if (isActionTraceEnabled()) {
getThreadContext().putHeader(
"_opendistro_security_trace" + System.currentTimeMillis() + "#" + UUID.randomUUID().toString(),
"_opendistro_security_trace" + System.currentTimeMillis() + "#" + UUID.randomUUID(),
Thread.currentThread().getName()
+ " IC -> "
+ action
Original file line number Diff line number Diff line change
@@ -38,6 +38,11 @@ public void testSerde() {
String test = "string";
Assert.assertEquals(test, ds(test));
Assert.assertEquals(test, dsJDK(test));

// verify that default methods use JDK serialization
Assert.assertEquals(serializeObject(test), serializeObject(test, true));
String serialized = serializeObject(test);
Assert.assertEquals(deserializeObject(serialized), deserializeObject(serialized, true));
}

@Test
@@ -48,4 +53,13 @@ public void testEnsureJDKSerialized() {
Assert.assertEquals(jdkSerialized, Base64Helper.ensureJDKSerialized(jdkSerialized));
Assert.assertEquals(jdkSerialized, Base64Helper.ensureJDKSerialized(customSerialized));
}

@Test
public void testEnsureCustomSerialized() {
String test = "string";
String jdkSerialized = Base64Helper.serializeObject(test, true);
String customSerialized = Base64Helper.serializeObject(test, false);
Assert.assertEquals(customSerialized, Base64Helper.ensureCustomSerialized(jdkSerialized));
Assert.assertEquals(customSerialized, Base64Helper.ensureCustomSerialized(customSerialized));
}
}
Original file line number Diff line number Diff line change
@@ -21,6 +21,7 @@
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.common.transport.TransportAddress;
import org.opensearch.core.transport.TransportResponse;
import org.opensearch.extensions.ExtensionsManager;
@@ -51,6 +52,7 @@

import static java.util.Collections.emptySet;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

@@ -108,8 +110,7 @@ public void setup() {
);
}

private void testSendRequestDecorate(Version remoteNodeVersion) {
boolean useJDKSerialization = remoteNodeVersion.before(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION);
private void testSendRequestDecorate(DiscoveryNode localNode, DiscoveryNode otherNode, boolean shouldUseJDKSerialization) {
ClusterName clusterName = ClusterName.DEFAULT;
when(clusterService.getClusterName()).thenReturn(clusterName);

@@ -143,17 +144,7 @@ private void testSendRequestDecorate(Version remoteNodeVersion) {
@SuppressWarnings("unchecked")
TransportResponseHandler<TransportResponse> handler = mock(TransportResponseHandler.class);

InetAddress localAddress = null;
try {
localAddress = InetAddress.getByName("0.0.0.0");
} catch (final UnknownHostException uhe) {
throw new RuntimeException(uhe);
}

DiscoveryNode localNode = new DiscoveryNode("local-node", new TransportAddress(localAddress, 1234), Version.CURRENT);
Connection connection1 = transportService.getConnection(localNode);

DiscoveryNode otherNode = new DiscoveryNode("remote-node", new TransportAddress(localAddress, 4321), remoteNodeVersion);
Connection connection2 = transportService.getConnection(otherNode);

// from thread context inside sendRequestDecorate
@@ -176,7 +167,7 @@ public <T extends TransportResponse> void sendRequest(
// from original context
User transientUser = threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER);
assertEquals(transientUser, user);
assertEquals(threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER), null);
assertNull(threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER));

// checking thread context inside sendRequestDecorate
sender = new AsyncSender() {
@@ -189,7 +180,7 @@ public <T extends TransportResponse> void sendRequest(
TransportResponseHandler<T> handler
) {
String serializedUserHeader = threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER);
assertEquals(serializedUserHeader, Base64Helper.serializeObject(user, useJDKSerialization));
assertEquals(serializedUserHeader, Base64Helper.serializeObject(user, shouldUseJDKSerialization));
}
};
// isSameNodeRequest = false
@@ -198,20 +189,52 @@ public <T extends TransportResponse> void sendRequest(
// from original context
User transientUser2 = threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER);
assertEquals(transientUser2, user);
assertEquals(threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER), null);
assertNull(threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER));
}

/**
* Tests the scenario when remote node is on same OS version
*/
@Test
public void testSendRequestDecorate() {
testSendRequestDecorate(Version.CURRENT);
DiscoveryNode localNode = new DiscoveryNode("local-node", new TransportAddress(getLocalAddress(), 1234), Version.CURRENT);
DiscoveryNode otherNode = new DiscoveryNode("other-node", new TransportAddress(getLocalAddress(), 3456), Version.CURRENT);
testSendRequestDecorate(localNode, otherNode, true);
}

/**
* Tests the scenario when remote node does not implement custom serialization protocol and uses JDK serialization
* Tests the scenarios for mixed node versions
*/
@Test
public void testSendRequestDecorateWhenRemoteNodeUsesJDKSerde() {
testSendRequestDecorate(Version.V_2_0_0);
public void testSendRequestDecorateWithMixedNodeVersions() {

// local on latest version, remote on 2.11.0 - should use custom

try (ThreadContext.StoredContext ignore = threadPool.getThreadContext().stashContext()) {
DiscoveryNode localNode = new DiscoveryNode("local-node", new TransportAddress(getLocalAddress(), 1234), Version.CURRENT);
DiscoveryNode otherNode = new DiscoveryNode(
"other-node",
new TransportAddress(getLocalAddress(), 3456),
ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION
);
testSendRequestDecorate(localNode, otherNode, false);
}

// remote node is on a version > 2.11.1 while local node is on version 2.11.1 - should use JDK
try (ThreadContext.StoredContext ignore = threadPool.getThreadContext().stashContext()) {
DiscoveryNode localNode = new DiscoveryNode("local-node", new TransportAddress(getLocalAddress(), 1234), Version.CURRENT);
DiscoveryNode otherNode = new DiscoveryNode("other-node", new TransportAddress(getLocalAddress(), 3456), Version.V_2_11_1);
testSendRequestDecorate(localNode, otherNode, true);
}

}

private static InetAddress getLocalAddress() {
try {
return InetAddress.getByName("0.0.0.0");
} catch (final UnknownHostException uhe) {
throw new RuntimeException(uhe);
}
}

}
Original file line number Diff line number Diff line change
@@ -89,9 +89,15 @@ public void testUseJDKSerializationHeaderIsSetOnMessageReceived() throws Excepti
Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));

threadPool.getThreadContext().stashContext();
when(transportChannel.getVersion()).thenReturn(Version.V_2_11_0);
when(transportChannel.getVersion()).thenReturn(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION);
Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, transportChannel, task));
Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));

threadPool.getThreadContext().stashContext();
when(transportChannel.getVersion()).thenReturn(Version.CURRENT);
Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, transportChannel, task));
Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));

}

@Test
@@ -108,9 +114,14 @@ public void testUseJDKSerializationHeaderIsSetWithWrapperChannel() throws Except
Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));

threadPool.getThreadContext().stashContext();
when(transportChannel.getVersion()).thenReturn(Version.V_2_11_0);
when(transportChannel.getVersion()).thenReturn(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION);
Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task));
Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));

threadPool.getThreadContext().stashContext();
when(transportChannel.getVersion()).thenReturn(Version.CURRENT);
Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task));
Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));
}

@Test
@@ -135,7 +146,7 @@ public void testUseJDKSerializationHeaderIsSetAfterGetInnerChannel() throws Exce

public class WrappedTransportChannel implements TransportChannel {

private TransportChannel inner;
private final TransportChannel inner;

public WrappedTransportChannel(TransportChannel inner) {
this.inner = inner;