Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve serialization speeds #2802

Merged
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,18 @@
import org.apache.hc.core5.http.nio.ssl.TlsStrategy;
import org.apache.hc.core5.reactor.ssl.TlsDetails;
import org.apache.hc.core5.ssl.SSLContextBuilder;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Assume;
import org.junit.Before;
import org.opensearch.client.Response;
import org.opensearch.client.ResponseException;
import org.opensearch.client.RestClient;
import org.opensearch.client.RestClientBuilder;
import org.opensearch.common.Randomness;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.util.io.IOUtils;
import org.opensearch.security.bwc.helper.RestHelper;
import org.opensearch.test.rest.OpenSearchRestTestCase;
import org.opensearch.Version;
Expand All @@ -56,13 +60,22 @@ public class SecurityBackwardsCompatibilityIT extends OpenSearchRestTestCase {
private final String TEST_USER = "user";
private final String TEST_PASSWORD = "290735c0-355d-4aaf-9b42-1aaa1f2a3cee";
private final String TEST_ROLE = "test-dls-fls-role";
private static RestClient testUserRestClient = null;

@Before
public void testSetup() {
final String bwcsuiteString = System.getProperty("tests.rest.bwcsuite");
Assume.assumeTrue("Test cannot be run outside the BWC gradle task 'bwcTestSuite' or its dependent tasks", bwcsuiteString != null);
CLUSTER_TYPE = ClusterType.parse(bwcsuiteString);
CLUSTER_NAME = System.getProperty("tests.clustername");
if (testUserRestClient == null) {
testUserRestClient = buildClient(
super.restClientSettings(),
super.getClusterHosts().toArray(new HttpHost[0]),
TEST_USER,
TEST_PASSWORD
);
}
}

@Override
Expand Down Expand Up @@ -101,18 +114,21 @@ protected final Settings restClientSettings() {
.build();
}

protected RestClient buildClient(Settings settings, HttpHost[] hosts, String username, String password) {
RestClientBuilder builder = RestClient.builder(hosts);
configureHttpsClient(builder, settings, username, password);
boolean strictDeprecationMode = settings.getAsBoolean("strictDeprecationMode", true);
builder.setStrictDeprecationMode(strictDeprecationMode);
return builder.build();
}

@Override
protected RestClient buildClient(Settings settings, HttpHost[] hosts) {
String username = Optional.ofNullable(System.getProperty("tests.opensearch.username"))
.orElseThrow(() -> new RuntimeException("user name is missing"));
String password = Optional.ofNullable(System.getProperty("tests.opensearch.password"))
.orElseThrow(() -> new RuntimeException("password is missing"));

RestClientBuilder builder = RestClient.builder(hosts);
configureHttpsClient(builder, settings, username, password);
boolean strictDeprecationMode = settings.getAsBoolean("strictDeprecationMode", true);
builder.setStrictDeprecationMode(strictDeprecationMode);
return builder.build();
return buildClient(super.restClientSettings(), super.getClusterHosts().toArray(new HttpHost[0]), username, password);
}

private static void configureHttpsClient(RestClientBuilder builder, Settings settings, String userName, String password) {
Expand Down Expand Up @@ -180,6 +196,11 @@ public void testDataIngestionAndSearchBackwardsCompatibility() throws Exception
searchMatchAll(index);
}

public void testNodeStats() throws IOException {
List<Response> responses = RestHelper.requestAgainstAllNodes(client(), "GET", "_nodes/stats", null);
responses.forEach(r -> Assert.assertEquals(200, r.getStatusLine().getStatusCode()));
}

@SuppressWarnings("unchecked")
private void assertPluginUpgrade(String uri) throws Exception {
Map<String, Map<String, Object>> responseMap = (Map<String, Map<String, Object>>) getAsMap(uri).get("nodes");
Expand All @@ -205,19 +226,26 @@ private void assertPluginUpgrade(String uri) throws Exception {
private void ingestData(String index) throws IOException {
StringBuilder bulkRequestBody = new StringBuilder();
ObjectMapper objectMapper = new ObjectMapper();
for (Song song : Song.SONGS) {
Map<String, Map<String, String>> indexRequest = new HashMap<>();
indexRequest.put("index", new HashMap<>() {
{
put("_index", index);
}
});
bulkRequestBody.append(objectMapper.writeValueAsString(indexRequest) + "\n");
bulkRequestBody.append(objectMapper.writeValueAsString(song.asJson()) + "\n");
int numberOfRequests = Randomness.get().nextInt(10);
while (numberOfRequests-- > 0) {
for (int i = 0; i < Randomness.get().nextInt(100); i++) {
Map<String, Map<String, String>> indexRequest = new HashMap<>();
indexRequest.put("index", new HashMap<>() {
{
put("_index", index);
}
});
bulkRequestBody.append(objectMapper.writeValueAsString(indexRequest) + "\n");
bulkRequestBody.append(objectMapper.writeValueAsString(Song.randomSong().asJson()) + "\n");
}
List<Response> responses = RestHelper.requestAgainstAllNodes(
testUserRestClient,
"POST",
"_bulk?refresh=wait_for",
RestHelper.toHttpEntity(bulkRequestBody.toString())
);
responses.forEach(r -> assertEquals(200, r.getStatusLine().getStatusCode()));
}

Response response = RestHelper.makeRequest(client(), "POST", "_bulk", RestHelper.toHttpEntity(bulkRequestBody.toString()));
assertThat(response.getStatusLine().getStatusCode(), equalTo(200));
}

/**
Expand All @@ -226,10 +254,16 @@ private void ingestData(String index) throws IOException {
*/
private void searchMatchAll(String index) throws IOException {
String matchAllQuery = "{\n" + " \"query\": {\n" + " \"match_all\": {}\n" + " }\n" + "}";

Response response = RestHelper.makeRequest(client(), "POST", index + "/_search", RestHelper.toHttpEntity(matchAllQuery));

assertThat(response.getStatusLine().getStatusCode(), equalTo(200));
int numberOfRequests = Randomness.get().nextInt(10);
while (numberOfRequests-- > 0) {
List<Response> responses = RestHelper.requestAgainstAllNodes(
testUserRestClient,
"POST",
index + "/_search",
RestHelper.toHttpEntity(matchAllQuery)
);
responses.forEach(r -> assertEquals(200, r.getStatusLine().getStatusCode()));
}
}

/**
Expand Down Expand Up @@ -324,4 +358,10 @@ private void createUserIfNotExists(String user, String password, String role) th
assertThat(response.getStatusLine().getStatusCode(), equalTo(201));
}
}

@AfterClass
public static void cleanUp() throws IOException {
OpenSearchRestTestCase.closeClients();
IOUtils.close(testUserRestClient);
}
}
12 changes: 12 additions & 0 deletions bwc-test/src/test/java/org/opensearch/security/bwc/Song.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.opensearch.common.Randomness;

import java.util.Map;
import java.util.Objects;
import java.util.UUID;

public class Song {

Expand Down Expand Up @@ -102,4 +104,14 @@ public Map<String, Object> asMap() {
public String asJson() throws JsonProcessingException {
return new ObjectMapper().writeValueAsString(this.asMap());
}

public static Song randomSong() {
return new Song(
UUID.randomUUID().toString(),
UUID.randomUUID().toString(),
UUID.randomUUID().toString(),
Randomness.get().nextInt(5),
UUID.randomUUID().toString()
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package org.opensearch.security.bwc.helper;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import org.apache.hc.core5.http.Header;
Expand Down Expand Up @@ -63,6 +64,26 @@ public static Response makeRequest(RestClient client, String method, String endp
return response;
}

public static List<Response> requestAgainstAllNodes(RestClient client, String method, String endpoint, HttpEntity entity)
throws IOException {
return requestAgainstAllNodes(client, method, endpoint, entity, null);
}

public static List<Response> requestAgainstAllNodes(
RestClient client,
String method,
String endpoint,
HttpEntity entity,
List<Header> headers
) throws IOException {
int nodeCount = client.getNodes().size();
List<Response> responses = new ArrayList<>();
while (nodeCount-- > 0) {
responses.add(makeRequest(client, method, endpoint, entity, headers));
}
return responses;
}

public static Header getAuthorizationHeader(String username, String password) {
return new BasicHeader("Authorization", "Basic " + username + ":" + password);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,6 @@ public <T extends TransportRequest> TransportRequestHandler<T> interceptHandler(

@Override
public void messageReceived(T request, TransportChannel channel, Task task) throws Exception {
threadContext.putTransient(ConfigConstants.USE_JDK_SERIALIZATION, channel.getVersion().before(Version.V_3_0_0));
si.getHandler(action, actualHandler).messageReceived(request, channel, task);
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,14 @@ protected ThreadContext getThreadContext() {

@Override
public final void messageReceived(T request, TransportChannel channel, Task task) throws Exception {

ThreadContext threadContext = getThreadContext();

threadContext.putTransient(
ConfigConstants.USE_JDK_SERIALIZATION,
channel.getVersion().before(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION)
);

if (SSLRequestHelper.containsBadHeader(threadContext, "_opendistro_security_ssl_")) {
final Exception exception = ExceptionUtils.createBadHeaderException();
channel.sendResponse(exception);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ protected static String serializeObject(final Serializable object) {

protected static Serializable deserializeObject(final String string) {

Preconditions.checkArgument(!Strings.isNullOrEmpty(string), "string must not be null or empty");
Preconditions.checkArgument(!Strings.isNullOrEmpty(string), "object must not be null or empty");
final byte[] bytes = BaseEncoding.base64().decode(string);
Serializable obj = null;
try (final BytesStreamInput streamInput = new SafeBytesStreamInput(bytes)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,6 @@ public static String ensureJDKSerialized(final String string) {
return string;
}
// If we see an exception now, we want the caller to see it -
return Base64Helper.serializeObject(serializable, false);
return Base64Helper.serializeObject(serializable, true);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ public static String serializeObject(final Serializable object) {

public static Serializable deserializeObject(final String string) {

Preconditions.checkArgument(!Strings.isNullOrEmpty(string), "string must not be null or empty");
Preconditions.checkArgument(!Strings.isNullOrEmpty(string), "object must not be null or empty");

final byte[] bytes = BaseEncoding.base64().decode(string);
final ByteArrayInputStream bis = new ByteArrayInputStream(bytes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;

import org.opensearch.Version;
import org.opensearch.common.settings.Settings;
import org.opensearch.security.auditlog.impl.AuditCategory;

Expand Down Expand Up @@ -325,6 +326,7 @@ public enum RolesMappingResolution {
public static final String TENANCY_GLOBAL_TENANT_DEFAULT_NAME = "";

public static final String USE_JDK_SERIALIZATION = "plugins.security.use_jdk_serialization";
public static final Version FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION = Version.V_3_0_0;

// On-behalf-of endpoints settings
// CS-SUPPRESS-SINGLE: RegexpSingleline get Extensions Settings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import org.opensearch.Version;
import org.opensearch.action.admin.cluster.shards.ClusterSearchShardsAction;
import org.opensearch.action.admin.cluster.shards.ClusterSearchShardsResponse;
import org.opensearch.action.get.GetRequest;
Expand Down Expand Up @@ -149,7 +148,7 @@ public <T extends TransportResponse> void sendRequestDecorate(
final String origCCSTransientMf = getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_CCS);

final boolean isDebugEnabled = log.isDebugEnabled();
final boolean useJDKSerialization = connection.getVersion().before(Version.V_3_0_0);
final boolean useJDKSerialization = connection.getVersion().before(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION);
final boolean isSameNodeRequest = localNode != null && localNode.equals(connection.getNode());

try (ThreadContext.StoredContext stashedContext = getThreadContext().stashContext()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,14 @@ public void testSerde() {
Assert.assertEquals(test, dsJDK(test));
}

@Test
public void testEnsureJDKSerialized() {
String test = "string";
String jdkSerialized = Base64Helper.serializeObject(test, true);
String customSerialized = Base64Helper.serializeObject(test, false);
Assert.assertEquals(jdkSerialized, Base64Helper.ensureJDKSerialized(jdkSerialized));
Assert.assertEquals(jdkSerialized, Base64Helper.ensureJDKSerialized(customSerialized));

}

}
Loading