Skip to content

Commit 8ae88a7

Browse files
authored
Add ensureCustomSerialization to ensure that headers are serialized correctly with multiple transport hops (opensearch-project#4741)
Signed-off-by: Craig Perkins <cwperx@amazon.com>
1 parent 7ddbf6a commit 8ae88a7

File tree

6 files changed

+150
-20
lines changed

6 files changed

+150
-20
lines changed

src/main/java/org/opensearch/security/configuration/ClusterInfoHolder.java

+12
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import org.apache.logging.log4j.LogManager;
3030
import org.apache.logging.log4j.Logger;
3131

32+
import org.opensearch.Version;
3233
import org.opensearch.cluster.ClusterChangedEvent;
3334
import org.opensearch.cluster.ClusterStateListener;
3435
import org.opensearch.cluster.node.DiscoveryNode;
@@ -67,6 +68,17 @@ public boolean isInitialized() {
6768
return initialized;
6869
}
6970

71+
public Version getMinNodeVersion() {
72+
if (nodes == null) {
73+
if (log.isDebugEnabled()) {
74+
log.debug("Cluster Info Holder not initialized yet for 'nodes'");
75+
}
76+
return null;
77+
}
78+
79+
return nodes.getMinNodeVersion();
80+
}
81+
7082
public Boolean hasNode(DiscoveryNode node) {
7183
if (nodes == null) {
7284
if (log.isDebugEnabled()) {

src/main/java/org/opensearch/security/filter/SecurityFilter.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ private <Request extends ActionRequest, Response extends ActionResponse> void ap
185185
}
186186

187187
if (threadContext.getTransient(ConfigConstants.USE_JDK_SERIALIZATION) == null) {
188-
threadContext.putTransient(ConfigConstants.USE_JDK_SERIALIZATION, false);
188+
threadContext.putTransient(ConfigConstants.USE_JDK_SERIALIZATION, true);
189189
}
190190

191191
final ComplianceConfig complianceConfig = auditLog.getComplianceConfig();

src/main/java/org/opensearch/security/support/Base64Helper.java

+26-2
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@ public static String serializeObject(final Serializable object, final boolean us
3535
}
3636

3737
public static String serializeObject(final Serializable object) {
38-
return serializeObject(object, false);
38+
return serializeObject(object, true);
3939
}
4040

4141
public static Serializable deserializeObject(final String string) {
42-
return deserializeObject(string, false);
42+
return deserializeObject(string, true);
4343
}
4444

4545
public static Serializable deserializeObject(final String string, final boolean useJDKDeserialization) {
@@ -69,4 +69,28 @@ public static String ensureJDKSerialized(final String string) {
6969
// If we see an exception now, we want the caller to see it -
7070
return Base64Helper.serializeObject(serializable, true);
7171
}
72+
73+
/**
74+
* Ensures that the returned string is custom serialized.
75+
*
76+
* If the supplied string is a JDK serialized representation, will deserialize it and further serialize using
77+
* custom, otherwise returns the string as is.
78+
*
79+
* @param string original string, can be JDK or custom serialized
80+
* @return custom serialized string
81+
*/
82+
public static String ensureCustomSerialized(final String string) {
83+
Serializable serializable;
84+
try {
85+
serializable = Base64Helper.deserializeObject(string, true);
86+
} catch (Exception e) {
87+
// We received an exception when de-serializing the given string. It is probably custom serialized.
88+
// Try to deserialize using custom
89+
Base64Helper.deserializeObject(string, false);
90+
// Since we could deserialize the object using custom, the string is already custom serialized, return as is
91+
return string;
92+
}
93+
// If we see an exception now, we want the caller to see it -
94+
return Base64Helper.serializeObject(serializable, false);
95+
}
7296
}

src/main/java/org/opensearch/security/transport/SecurityInterceptor.java

+17-7
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import org.apache.logging.log4j.LogManager;
4040
import org.apache.logging.log4j.Logger;
4141

42+
import org.opensearch.Version;
4243
import org.opensearch.action.admin.cluster.shards.ClusterSearchShardsAction;
4344
import org.opensearch.action.admin.cluster.shards.ClusterSearchShardsResponse;
4445
import org.opensearch.action.get.GetRequest;
@@ -231,13 +232,22 @@ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROL
231232
}
232233

233234
try {
234-
if (serializationFormat == SerializationFormat.JDK) {
235-
Map<String, String> jdkSerializedHeaders = new HashMap<>();
236-
HeaderHelper.getAllSerializedHeaderNames()
237-
.stream()
238-
.filter(k -> headerMap.get(k) != null)
239-
.forEach(k -> jdkSerializedHeaders.put(k, Base64Helper.ensureJDKSerialized(headerMap.get(k))));
240-
headerMap.putAll(jdkSerializedHeaders);
235+
if (clusterInfoHolder.getMinNodeVersion() == null || clusterInfoHolder.getMinNodeVersion().before(Version.V_2_14_0)) {
236+
if (serializationFormat == SerializationFormat.JDK) {
237+
Map<String, String> jdkSerializedHeaders = new HashMap<>();
238+
HeaderHelper.getAllSerializedHeaderNames()
239+
.stream()
240+
.filter(k -> headerMap.get(k) != null)
241+
.forEach(k -> jdkSerializedHeaders.put(k, Base64Helper.ensureJDKSerialized(headerMap.get(k))));
242+
headerMap.putAll(jdkSerializedHeaders);
243+
} else if (serializationFormat == SerializationFormat.CustomSerializer_2_11) {
244+
Map<String, String> customSerializedHeaders = new HashMap<>();
245+
HeaderHelper.getAllSerializedHeaderNames()
246+
.stream()
247+
.filter(k -> headerMap.get(k) != null)
248+
.forEach(k -> customSerializedHeaders.put(k, Base64Helper.ensureCustomSerialized(headerMap.get(k))));
249+
headerMap.putAll(customSerializedHeaders);
250+
}
241251
}
242252
getThreadContext().putHeader(headerMap);
243253
} catch (IllegalArgumentException iae) {

src/test/java/org/opensearch/security/support/Base64HelperTest.java

+9
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,15 @@ public void testEnsureJDKSerialized() {
5353
assertThat(Base64Helper.ensureJDKSerialized(customSerialized), is(jdkSerialized));
5454
}
5555

56+
@Test
57+
public void testEnsureCustomSerialized() {
58+
String test = "string";
59+
String jdkSerialized = Base64Helper.serializeObject(test, true);
60+
String customSerialized = Base64Helper.serializeObject(test, false);
61+
assertThat(Base64Helper.ensureCustomSerialized(jdkSerialized), is(customSerialized));
62+
assertThat(Base64Helper.ensureCustomSerialized(customSerialized), is(customSerialized));
63+
}
64+
5665
@Test
5766
public void testDuplicatedItemSizes() {
5867
var largeObject = new HashMap<String, Object>();

src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java

+85-10
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,12 @@ public class SecurityInterceptorTests {
119119
private Connection connection3;
120120
private DiscoveryNode otherRemoteNode;
121121
private Connection connection4;
122+
private DiscoveryNode remoteNodeWithCustomSerialization;
123+
private Connection connection5;
122124

123125
private AsyncSender sender;
124-
private AsyncSender serializedSender;
126+
private AsyncSender jdkSerializedSender;
127+
private AsyncSender customSerializedSender;
125128
private AtomicReference<CountDownLatch> senderLatch = new AtomicReference<>(new CountDownLatch(1));
126129

127130
@Before
@@ -199,7 +202,30 @@ public void setup() {
199202
otherRemoteNode = new DiscoveryNode("remote-node2", new TransportAddress(remoteAddress, 9876), remoteNodeVersion);
200203
connection4 = transportService.getConnection(otherRemoteNode);
201204

202-
serializedSender = new AsyncSender() {
205+
remoteNodeWithCustomSerialization = new DiscoveryNode(
206+
"remote-node-with-custom-serialization",
207+
new TransportAddress(localAddress, 7456),
208+
Version.V_2_12_0
209+
);
210+
connection5 = transportService.getConnection(remoteNodeWithCustomSerialization);
211+
212+
jdkSerializedSender = new AsyncSender() {
213+
@Override
214+
public <T extends TransportResponse> void sendRequest(
215+
Connection connection,
216+
String action,
217+
TransportRequest request,
218+
TransportRequestOptions options,
219+
TransportResponseHandler<T> handler
220+
) {
221+
String serializedUserHeader = threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER);
222+
User deserializedUser = (User) Base64Helper.deserializeObject(serializedUserHeader, true);
223+
assertThat(deserializedUser, is(user));
224+
senderLatch.get().countDown();
225+
}
226+
};
227+
228+
customSerializedSender = new AsyncSender() {
203229
@Override
204230
public <T extends TransportResponse> void sendRequest(
205231
Connection connection,
@@ -209,7 +235,7 @@ public <T extends TransportResponse> void sendRequest(
209235
TransportResponseHandler<T> handler
210236
) {
211237
String serializedUserHeader = threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER);
212-
assertThat(serializedUserHeader, is(Base64Helper.serializeObject(user, true)));
238+
assertThat(serializedUserHeader, is(Base64Helper.serializeObject(user, false)));
213239
senderLatch.get().countDown();
214240
}
215241
};
@@ -265,6 +291,27 @@ final void completableRequestDecorate(
265291
senderLatch.set(new CountDownLatch(1));
266292
}
267293

294+
@SuppressWarnings({ "rawtypes", "unchecked" })
295+
final void completableRequestDecorateWithPreviouslyPopulatedHeaders(
296+
AsyncSender sender,
297+
Connection connection,
298+
String action,
299+
TransportRequest request,
300+
TransportRequestOptions options,
301+
TransportResponseHandler handler,
302+
DiscoveryNode localNode
303+
) {
304+
securityInterceptor.sendRequestDecorate(sender, connection, action, request, options, handler, localNode);
305+
try {
306+
senderLatch.get().await(1, TimeUnit.SECONDS);
307+
} catch (final InterruptedException e) {
308+
throw new RuntimeException(e);
309+
}
310+
311+
// Reset the latch so another request can be processed
312+
senderLatch.set(new CountDownLatch(1));
313+
}
314+
268315
@Test
269316
public void testSendRequestDecorateLocalConnection() {
270317

@@ -278,16 +325,44 @@ public void testSendRequestDecorateLocalConnection() {
278325
public void testSendRequestDecorateRemoteConnection() {
279326

280327
// this is a remote request
281-
completableRequestDecorate(serializedSender, connection3, action, request, options, handler, localNode);
328+
completableRequestDecorate(jdkSerializedSender, connection3, action, request, options, handler, localNode);
282329
// this is a remote request where the transport address is different
283-
completableRequestDecorate(serializedSender, connection4, action, request, options, handler, localNode);
330+
completableRequestDecorate(jdkSerializedSender, connection4, action, request, options, handler, localNode);
331+
}
332+
333+
@Test
334+
public void testSendRequestDecorateRemoteConnectionUsesJDKSerialization() {
335+
threadPool.getThreadContext().putHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER, Base64Helper.serializeObject(user, false));
336+
completableRequestDecorateWithPreviouslyPopulatedHeaders(
337+
jdkSerializedSender,
338+
connection3,
339+
action,
340+
request,
341+
options,
342+
handler,
343+
localNode
344+
);
345+
}
346+
347+
@Test
348+
public void testSendRequestDecorateRemoteConnectionUsesCustomSerialization() {
349+
threadPool.getThreadContext().putHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER, Base64Helper.serializeObject(user, true));
350+
completableRequestDecorateWithPreviouslyPopulatedHeaders(
351+
customSerializedSender,
352+
connection5,
353+
action,
354+
request,
355+
options,
356+
handler,
357+
localNode
358+
);
284359
}
285360

286361
@Test
287362
public void testSendNoOriginNodeCausesSerialization() {
288363

289364
// this is a request where the local node is null; have to use the remote connection since the serialization will fail
290-
completableRequestDecorate(serializedSender, connection3, action, request, options, handler, null);
365+
completableRequestDecorate(jdkSerializedSender, connection3, action, request, options, handler, null);
291366
}
292367

293368
@Test
@@ -296,7 +371,7 @@ public void testSendNoConnectionShouldThrowNPE() {
296371
// The completable version swallows the NPE so have to call actual method
297372
assertThrows(
298373
java.lang.NullPointerException.class,
299-
() -> securityInterceptor.sendRequestDecorate(serializedSender, null, action, request, options, handler, localNode)
374+
() -> securityInterceptor.sendRequestDecorate(jdkSerializedSender, null, action, request, options, handler, localNode)
300375
);
301376
}
302377

@@ -328,7 +403,7 @@ public void testCustomRemoteAddressCausesSerialization() {
328403
ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS,
329404
String.valueOf(new TransportAddress(new InetSocketAddress("8.8.8.8", 80)))
330405
);
331-
completableRequestDecorate(serializedSender, connection3, action, request, options, handler, localNode);
406+
completableRequestDecorate(jdkSerializedSender, connection3, action, request, options, handler, localNode);
332407
}
333408

334409
@Test
@@ -351,7 +426,7 @@ public void testFakeHeaderIsIgnored() {
351426
// this is a local request
352427
completableRequestDecorate(sender, connection1, action, request, options, handler, localNode);
353428
// this is a remote request
354-
completableRequestDecorate(serializedSender, connection3, action, request, options, handler, localNode);
429+
completableRequestDecorate(jdkSerializedSender, connection3, action, request, options, handler, localNode);
355430
}
356431

357432
@Test
@@ -363,7 +438,7 @@ public void testNullHeaderIsIgnored() {
363438
// this is a local request
364439
completableRequestDecorate(sender, connection1, action, request, options, handler, localNode);
365440
// this is a remote request
366-
completableRequestDecorate(serializedSender, connection3, action, request, options, handler, localNode);
441+
completableRequestDecorate(jdkSerializedSender, connection3, action, request, options, handler, localNode);
367442
}
368443

369444
@Test

0 commit comments

Comments
 (0)