Skip to content

Commit 3b4e11d

Browse files
kaushalmahi12kiranprakash154
andauthoredOct 8, 2024··
Add wlm resiliency orchestrator (query group service) (#15925) (#16225)
* cancellation related * Update CHANGELOG.md * add better cancellation reason * Update DefaultTaskCancellationTests.java * refactor * refactor * Update DefaultTaskCancellation.java * Update DefaultTaskCancellation.java * Update DefaultTaskCancellation.java * Update DefaultTaskSelectionStrategy.java * refactor * refactor node level threshold * use query group task * code clean up and refactorings * add unit tests and fix existing ones * uncomment the test case * update CHANGELOG * fix imports * add queryGroupService * refactor and add UTs for new constructs * fix javadocs * remove code clutter * change annotation version and task selection strategy * rename a util class * remove wrappers from resource type * apply spotless * address comments * add rename changes * address comments * initial changes * refactor changes and logical bug fix * add chanegs * address comments * temp changes * add UTs * add changelog * add task completion listener hook * add remaining pieces to make the feature functional * extend stats and fix bugs * fix bugs and add logic to make SBP work with wlm * address comments * fix bugs and SBP ITs * add missed applyCluster state change * address comments * decouple queryGroupService and cancellationService * replace StateApplier with StateListener interface * fix precommit errors --------- Signed-off-by: Kiran Prakash <awskiran@amazon.com> Signed-off-by: Kaushal Kumar <ravi.kaushal97@gmail.com> Co-authored-by: Kiran Prakash <awskiran@amazon.com>

22 files changed

+1410
-117
lines changed
 

‎CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
55

66
## [Unreleased 2.x]
77
### Added
8+
- [Workload Management] Add orchestrator for wlm resiliency (QueryGroupService) ([#15925](https://github.com/opensearch-project/OpenSearch/pull/15925))
89
- [Offline Nodes] Adds offline-tasks library containing various interfaces to be used for Offline Background Tasks. ([#13574](https://github.com/opensearch-project/OpenSearch/pull/13574))
910
- Add support for async deletion in S3BlobContainer ([#15621](https://github.com/opensearch-project/OpenSearch/pull/15621))
1011
- [Workload Management] QueryGroup resource cancellation framework changes ([#15651](https://github.com/opensearch-project/OpenSearch/pull/15651))

‎server/src/internalClusterTest/java/org/opensearch/search/backpressure/SearchBackpressureIT.java

+2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import org.opensearch.test.ParameterizedStaticSettingsOpenSearchIntegTestCase;
4040
import org.opensearch.threadpool.ThreadPool;
4141
import org.opensearch.transport.TransportService;
42+
import org.opensearch.wlm.QueryGroupTask;
4243
import org.hamcrest.MatcherAssert;
4344
import org.junit.After;
4445
import org.junit.Before;
@@ -411,6 +412,7 @@ protected void doExecute(Task task, TestRequest request, ActionListener<TestResp
411412
threadPool.executor(ThreadPool.Names.SEARCH).execute(() -> {
412413
try {
413414
CancellableTask cancellableTask = (CancellableTask) task;
415+
((QueryGroupTask) task).setQueryGroupId(threadPool.getThreadContext());
414416
long startTime = System.nanoTime();
415417

416418
// Doing a busy-wait until task cancellation or timeout.

‎server/src/main/java/org/opensearch/action/ActionModule.java

+5-1
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,7 @@
469469
import org.opensearch.tasks.Task;
470470
import org.opensearch.threadpool.ThreadPool;
471471
import org.opensearch.usage.UsageService;
472+
import org.opensearch.wlm.QueryGroupTask;
472473

473474
import java.util.ArrayList;
474475
import java.util.Collections;
@@ -552,7 +553,10 @@ public ActionModule(
552553
destructiveOperations = new DestructiveOperations(settings, clusterSettings);
553554
Set<RestHeaderDefinition> headers = Stream.concat(
554555
actionPlugins.stream().flatMap(p -> p.getRestHeaders().stream()),
555-
Stream.of(new RestHeaderDefinition(Task.X_OPAQUE_ID, false))
556+
Stream.of(
557+
new RestHeaderDefinition(Task.X_OPAQUE_ID, false),
558+
new RestHeaderDefinition(QueryGroupTask.QUERY_GROUP_ID_HEADER, false)
559+
)
556560
).collect(Collectors.toSet());
557561
UnaryOperator<RestHandler> restWrapper = null;
558562
for (ActionPlugin plugin : actionPlugins) {

‎server/src/main/java/org/opensearch/common/settings/ClusterSettings.java

+3
Original file line numberDiff line numberDiff line change
@@ -797,6 +797,9 @@ public void apply(Settings value, Settings current, Settings previous) {
797797
WorkloadManagementSettings.NODE_LEVEL_CPU_CANCELLATION_THRESHOLD,
798798
WorkloadManagementSettings.NODE_LEVEL_MEMORY_REJECTION_THRESHOLD,
799799
WorkloadManagementSettings.NODE_LEVEL_MEMORY_CANCELLATION_THRESHOLD,
800+
WorkloadManagementSettings.WLM_MODE_SETTING,
801+
WorkloadManagementSettings.QUERYGROUP_SERVICE_RUN_INTERVAL_SETTING,
802+
WorkloadManagementSettings.QUERYGROUP_SERVICE_DURESS_STREAK_SETTING,
800803

801804
SearchService.CLUSTER_ALLOW_DERIVED_FIELD_SETTING,
802805

‎server/src/main/java/org/opensearch/node/Node.java

+35-4
Original file line numberDiff line numberDiff line change
@@ -269,8 +269,13 @@
269269
import org.opensearch.usage.UsageService;
270270
import org.opensearch.watcher.ResourceWatcherService;
271271
import org.opensearch.wlm.QueryGroupService;
272+
import org.opensearch.wlm.QueryGroupsStateAccessor;
273+
import org.opensearch.wlm.WorkloadManagementSettings;
272274
import org.opensearch.wlm.WorkloadManagementTransportInterceptor;
275+
import org.opensearch.wlm.cancellation.MaximumResourceTaskSelectionStrategy;
276+
import org.opensearch.wlm.cancellation.QueryGroupTaskCancellationService;
273277
import org.opensearch.wlm.listeners.QueryGroupRequestOperationListener;
278+
import org.opensearch.wlm.tracker.QueryGroupResourceUsageTrackerService;
274279

275280
import javax.net.ssl.SNIHostName;
276281

@@ -1019,8 +1024,30 @@ protected Node(
10191024
List<IdentityAwarePlugin> identityAwarePlugins = pluginsService.filterPlugins(IdentityAwarePlugin.class);
10201025
identityService.initializeIdentityAwarePlugins(identityAwarePlugins);
10211026

1022-
final QueryGroupService queryGroupService = new QueryGroupService(); // We will need to replace this with actual instance of the
1023-
// queryGroupService
1027+
final QueryGroupResourceUsageTrackerService queryGroupResourceUsageTrackerService = new QueryGroupResourceUsageTrackerService(
1028+
taskResourceTrackingService
1029+
);
1030+
final WorkloadManagementSettings workloadManagementSettings = new WorkloadManagementSettings(
1031+
settings,
1032+
settingsModule.getClusterSettings()
1033+
);
1034+
1035+
final QueryGroupsStateAccessor queryGroupsStateAccessor = new QueryGroupsStateAccessor();
1036+
1037+
final QueryGroupService queryGroupService = new QueryGroupService(
1038+
new QueryGroupTaskCancellationService(
1039+
workloadManagementSettings,
1040+
new MaximumResourceTaskSelectionStrategy(),
1041+
queryGroupResourceUsageTrackerService,
1042+
queryGroupsStateAccessor
1043+
),
1044+
clusterService,
1045+
threadPool,
1046+
workloadManagementSettings,
1047+
queryGroupsStateAccessor
1048+
);
1049+
taskResourceTrackingService.addTaskCompletionListener(queryGroupService);
1050+
10241051
final QueryGroupRequestOperationListener queryGroupRequestOperationListener = new QueryGroupRequestOperationListener(
10251052
queryGroupService,
10261053
threadPool
@@ -1086,7 +1113,7 @@ protected Node(
10861113

10871114
WorkloadManagementTransportInterceptor workloadManagementTransportInterceptor = new WorkloadManagementTransportInterceptor(
10881115
threadPool,
1089-
new QueryGroupService() // We will need to replace this with actual implementation
1116+
queryGroupService
10901117
);
10911118

10921119
final Collection<SecureSettingsFactory> secureSettingsFactories = pluginsService.filterPlugins(Plugin.class)
@@ -1180,7 +1207,8 @@ protected Node(
11801207
searchBackpressureSettings,
11811208
taskResourceTrackingService,
11821209
threadPool,
1183-
transportService.getTaskManager()
1210+
transportService.getTaskManager(),
1211+
queryGroupService
11841212
);
11851213

11861214
final SegmentReplicationStatsTracker segmentReplicationStatsTracker = new SegmentReplicationStatsTracker(indicesService);
@@ -1392,6 +1420,7 @@ protected Node(
13921420
b.bind(IndexingPressureService.class).toInstance(indexingPressureService);
13931421
b.bind(TaskResourceTrackingService.class).toInstance(taskResourceTrackingService);
13941422
b.bind(SearchBackpressureService.class).toInstance(searchBackpressureService);
1423+
b.bind(QueryGroupService.class).toInstance(queryGroupService);
13951424
b.bind(AdmissionControlService.class).toInstance(admissionControlService);
13961425
b.bind(UsageService.class).toInstance(usageService);
13971426
b.bind(AggregationUsageService.class).toInstance(searchModule.getValuesSourceRegistry().getUsageService());
@@ -1583,6 +1612,7 @@ public Node start() throws NodeValidationException {
15831612
nodeService.getMonitorService().start();
15841613
nodeService.getSearchBackpressureService().start();
15851614
nodeService.getTaskCancellationMonitoringService().start();
1615+
injector.getInstance(QueryGroupService.class).start();
15861616

15871617
final ClusterService clusterService = injector.getInstance(ClusterService.class);
15881618

@@ -1756,6 +1786,7 @@ private Node stop() {
17561786
injector.getInstance(FsHealthService.class).stop();
17571787
injector.getInstance(NodeResourceUsageTracker.class).stop();
17581788
injector.getInstance(ResourceUsageCollectorService.class).stop();
1789+
injector.getInstance(QueryGroupService.class).stop();
17591790
nodeService.getMonitorService().stop();
17601791
nodeService.getSearchBackpressureService().stop();
17611792
injector.getInstance(GatewayService.class).stop();

‎server/src/main/java/org/opensearch/search/backpressure/SearchBackpressureService.java

+10-3
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import org.opensearch.tasks.TaskResourceTrackingService.TaskCompletionListener;
4343
import org.opensearch.threadpool.Scheduler;
4444
import org.opensearch.threadpool.ThreadPool;
45+
import org.opensearch.wlm.QueryGroupService;
4546
import org.opensearch.wlm.ResourceType;
4647

4748
import java.io.IOException;
@@ -86,12 +87,14 @@ public class SearchBackpressureService extends AbstractLifecycleComponent implem
8687

8788
private final Map<Class<? extends SearchBackpressureTask>, SearchBackpressureState> searchBackpressureStates;
8889
private final TaskManager taskManager;
90+
private final QueryGroupService queryGroupService;
8991

9092
public SearchBackpressureService(
9193
SearchBackpressureSettings settings,
9294
TaskResourceTrackingService taskResourceTrackingService,
9395
ThreadPool threadPool,
94-
TaskManager taskManager
96+
TaskManager taskManager,
97+
QueryGroupService queryGroupService
9598
) {
9699
this(settings, taskResourceTrackingService, threadPool, System::nanoTime, new NodeDuressTrackers(new EnumMap<>(ResourceType.class) {
97100
{
@@ -131,7 +134,8 @@ public SearchBackpressureService(
131134
settings.getClusterSettings(),
132135
SearchShardTaskSettings.SETTING_HEAP_MOVING_AVERAGE_WINDOW_SIZE
133136
),
134-
taskManager
137+
taskManager,
138+
queryGroupService
135139
);
136140
}
137141

@@ -143,14 +147,16 @@ public SearchBackpressureService(
143147
NodeDuressTrackers nodeDuressTrackers,
144148
TaskResourceUsageTrackers searchTaskTrackers,
145149
TaskResourceUsageTrackers searchShardTaskTrackers,
146-
TaskManager taskManager
150+
TaskManager taskManager,
151+
QueryGroupService queryGroupService
147152
) {
148153
this.settings = settings;
149154
this.taskResourceTrackingService = taskResourceTrackingService;
150155
this.taskResourceTrackingService.addTaskCompletionListener(this);
151156
this.threadPool = threadPool;
152157
this.nodeDuressTrackers = nodeDuressTrackers;
153158
this.taskManager = taskManager;
159+
this.queryGroupService = queryGroupService;
154160

155161
this.searchBackpressureStates = Map.of(
156162
SearchTask.class,
@@ -346,6 +352,7 @@ <T extends CancellableTask & SearchBackpressureTask> List<CancellableTask> getTa
346352
.stream()
347353
.filter(type::isInstance)
348354
.map(type::cast)
355+
.filter(queryGroupService::shouldSBPHandle)
349356
.collect(Collectors.toUnmodifiableList());
350357
}
351358

‎server/src/main/java/org/opensearch/wlm/QueryGroupService.java

+270-20
Original file line numberDiff line numberDiff line change
@@ -8,36 +8,192 @@
88

99
package org.opensearch.wlm;
1010

11+
import org.apache.logging.log4j.LogManager;
12+
import org.apache.logging.log4j.Logger;
13+
import org.opensearch.action.search.SearchShardTask;
14+
import org.opensearch.cluster.ClusterChangedEvent;
15+
import org.opensearch.cluster.ClusterStateListener;
16+
import org.opensearch.cluster.metadata.Metadata;
17+
import org.opensearch.cluster.metadata.QueryGroup;
18+
import org.opensearch.cluster.service.ClusterService;
19+
import org.opensearch.common.lifecycle.AbstractLifecycleComponent;
1120
import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException;
21+
import org.opensearch.monitor.jvm.JvmStats;
22+
import org.opensearch.monitor.process.ProcessProbe;
23+
import org.opensearch.search.backpressure.trackers.NodeDuressTrackers;
24+
import org.opensearch.search.backpressure.trackers.NodeDuressTrackers.NodeDuressTracker;
25+
import org.opensearch.tasks.Task;
26+
import org.opensearch.tasks.TaskResourceTrackingService;
27+
import org.opensearch.threadpool.Scheduler;
28+
import org.opensearch.threadpool.ThreadPool;
29+
import org.opensearch.wlm.cancellation.QueryGroupTaskCancellationService;
1230
import org.opensearch.wlm.stats.QueryGroupState;
1331
import org.opensearch.wlm.stats.QueryGroupStats;
1432
import org.opensearch.wlm.stats.QueryGroupStats.QueryGroupStatsHolder;
1533

34+
import java.io.IOException;
1635
import java.util.HashMap;
36+
import java.util.HashSet;
1737
import java.util.Map;
38+
import java.util.Optional;
39+
import java.util.Set;
40+
41+
import static org.opensearch.wlm.tracker.QueryGroupResourceUsageTrackerService.TRACKED_RESOURCES;
1842

1943
/**
2044
* As of now this is a stub and main implementation PR will be raised soon.Coming PR will collate these changes with core QueryGroupService changes
2145
*/
22-
public class QueryGroupService {
23-
// This map does not need to be concurrent since we will process the cluster state change serially and update
24-
// this map with new additions and deletions of entries. QueryGroupState is thread safe
25-
private final Map<String, QueryGroupState> queryGroupStateMap;
46+
public class QueryGroupService extends AbstractLifecycleComponent
47+
implements
48+
ClusterStateListener,
49+
TaskResourceTrackingService.TaskCompletionListener {
50+
51+
private static final Logger logger = LogManager.getLogger(QueryGroupService.class);
52+
53+
private final QueryGroupTaskCancellationService taskCancellationService;
54+
private volatile Scheduler.Cancellable scheduledFuture;
55+
private final ThreadPool threadPool;
56+
private final ClusterService clusterService;
57+
private final WorkloadManagementSettings workloadManagementSettings;
58+
private Set<QueryGroup> activeQueryGroups;
59+
private final Set<QueryGroup> deletedQueryGroups;
60+
private final NodeDuressTrackers nodeDuressTrackers;
61+
private final QueryGroupsStateAccessor queryGroupsStateAccessor;
62+
63+
public QueryGroupService(
64+
QueryGroupTaskCancellationService taskCancellationService,
65+
ClusterService clusterService,
66+
ThreadPool threadPool,
67+
WorkloadManagementSettings workloadManagementSettings,
68+
QueryGroupsStateAccessor queryGroupsStateAccessor
69+
) {
70+
71+
this(
72+
taskCancellationService,
73+
clusterService,
74+
threadPool,
75+
workloadManagementSettings,
76+
new NodeDuressTrackers(
77+
Map.of(
78+
ResourceType.CPU,
79+
new NodeDuressTracker(
80+
() -> workloadManagementSettings.getNodeLevelCpuCancellationThreshold() < ProcessProbe.getInstance()
81+
.getProcessCpuPercent() / 100.0,
82+
workloadManagementSettings::getDuressStreak
83+
),
84+
ResourceType.MEMORY,
85+
new NodeDuressTracker(
86+
() -> workloadManagementSettings.getNodeLevelMemoryCancellationThreshold() <= JvmStats.jvmStats()
87+
.getMem()
88+
.getHeapUsedPercent() / 100.0,
89+
workloadManagementSettings::getDuressStreak
90+
)
91+
)
92+
),
93+
queryGroupsStateAccessor,
94+
new HashSet<>(),
95+
new HashSet<>()
96+
);
97+
}
98+
99+
public QueryGroupService(
100+
QueryGroupTaskCancellationService taskCancellationService,
101+
ClusterService clusterService,
102+
ThreadPool threadPool,
103+
WorkloadManagementSettings workloadManagementSettings,
104+
NodeDuressTrackers nodeDuressTrackers,
105+
QueryGroupsStateAccessor queryGroupsStateAccessor,
106+
Set<QueryGroup> activeQueryGroups,
107+
Set<QueryGroup> deletedQueryGroups
108+
) {
109+
this.taskCancellationService = taskCancellationService;
110+
this.clusterService = clusterService;
111+
this.threadPool = threadPool;
112+
this.workloadManagementSettings = workloadManagementSettings;
113+
this.nodeDuressTrackers = nodeDuressTrackers;
114+
this.activeQueryGroups = activeQueryGroups;
115+
this.deletedQueryGroups = deletedQueryGroups;
116+
this.queryGroupsStateAccessor = queryGroupsStateAccessor;
117+
activeQueryGroups.forEach(queryGroup -> this.queryGroupsStateAccessor.addNewQueryGroup(queryGroup.get_id()));
118+
this.queryGroupsStateAccessor.addNewQueryGroup(QueryGroupTask.DEFAULT_QUERY_GROUP_ID_SUPPLIER.get());
119+
this.clusterService.addListener(this);
120+
}
121+
122+
/**
123+
* run at regular interval
124+
*/
125+
void doRun() {
126+
if (workloadManagementSettings.getWlmMode() == WlmMode.DISABLED) {
127+
return;
128+
}
129+
taskCancellationService.cancelTasks(nodeDuressTrackers::isNodeInDuress, activeQueryGroups, deletedQueryGroups);
130+
taskCancellationService.pruneDeletedQueryGroups(deletedQueryGroups);
131+
}
132+
133+
/**
134+
* {@link AbstractLifecycleComponent} lifecycle method
135+
*/
136+
@Override
137+
protected void doStart() {
138+
scheduledFuture = threadPool.scheduleWithFixedDelay(() -> {
139+
try {
140+
doRun();
141+
} catch (Exception e) {
142+
logger.debug("Exception occurred in Query Sandbox service", e);
143+
}
144+
}, this.workloadManagementSettings.getQueryGroupServiceRunInterval(), ThreadPool.Names.GENERIC);
145+
}
26146

27-
public QueryGroupService() {
28-
this(new HashMap<>());
147+
@Override
148+
protected void doStop() {
149+
if (scheduledFuture != null) {
150+
scheduledFuture.cancel();
151+
}
29152
}
30153

31-
public QueryGroupService(Map<String, QueryGroupState> queryGroupStateMap) {
32-
this.queryGroupStateMap = queryGroupStateMap;
154+
@Override
155+
protected void doClose() throws IOException {}
156+
157+
@Override
158+
public void clusterChanged(ClusterChangedEvent event) {
159+
// Retrieve the current and previous cluster states
160+
Metadata previousMetadata = event.previousState().metadata();
161+
Metadata currentMetadata = event.state().metadata();
162+
163+
// Extract the query groups from both the current and previous cluster states
164+
Map<String, QueryGroup> previousQueryGroups = previousMetadata.queryGroups();
165+
Map<String, QueryGroup> currentQueryGroups = currentMetadata.queryGroups();
166+
167+
// Detect new query groups added in the current cluster state
168+
for (String queryGroupName : currentQueryGroups.keySet()) {
169+
if (!previousQueryGroups.containsKey(queryGroupName)) {
170+
// New query group detected
171+
QueryGroup newQueryGroup = currentQueryGroups.get(queryGroupName);
172+
// Perform any necessary actions with the new query group
173+
queryGroupsStateAccessor.addNewQueryGroup(newQueryGroup.get_id());
174+
}
175+
}
176+
177+
// Detect query groups deleted in the current cluster state
178+
for (String queryGroupName : previousQueryGroups.keySet()) {
179+
if (!currentQueryGroups.containsKey(queryGroupName)) {
180+
// Query group deleted
181+
QueryGroup deletedQueryGroup = previousQueryGroups.get(queryGroupName);
182+
// Perform any necessary actions with the deleted query group
183+
this.deletedQueryGroups.add(deletedQueryGroup);
184+
queryGroupsStateAccessor.removeQueryGroup(deletedQueryGroup.get_id());
185+
}
186+
}
187+
this.activeQueryGroups = new HashSet<>(currentMetadata.queryGroups().values());
33188
}
34189

35190
/**
36191
* updates the failure stats for the query group
192+
*
37193
* @param queryGroupId query group identifier
38194
*/
39195
public void incrementFailuresFor(final String queryGroupId) {
40-
QueryGroupState queryGroupState = queryGroupStateMap.get(queryGroupId);
196+
QueryGroupState queryGroupState = queryGroupsStateAccessor.getQueryGroupState(queryGroupId);
41197
// This can happen if the request failed for a deleted query group
42198
// or new queryGroup is being created and has not been acknowledged yet
43199
if (queryGroupState == null) {
@@ -47,12 +203,11 @@ public void incrementFailuresFor(final String queryGroupId) {
47203
}
48204

49205
/**
50-
*
51206
* @return node level query group stats
52207
*/
53208
public QueryGroupStats nodeStats() {
54209
final Map<String, QueryGroupStatsHolder> statsHolderMap = new HashMap<>();
55-
for (Map.Entry<String, QueryGroupState> queryGroupsState : queryGroupStateMap.entrySet()) {
210+
for (Map.Entry<String, QueryGroupState> queryGroupsState : queryGroupsStateAccessor.getQueryGroupStateMap().entrySet()) {
56211
final String queryGroupId = queryGroupsState.getKey();
57212
final QueryGroupState currentState = queryGroupsState.getValue();
58213

@@ -63,18 +218,113 @@ public QueryGroupStats nodeStats() {
63218
}
64219

65220
/**
66-
*
67221
* @param queryGroupId query group identifier
68222
*/
69223
public void rejectIfNeeded(String queryGroupId) {
70-
if (queryGroupId == null) return;
71-
boolean reject = false;
72-
final StringBuilder reason = new StringBuilder();
73-
// TODO: At this point this is dummy and we need to decide whether to cancel the request based on last
74-
// reported resource usage for the queryGroup. We also need to increment the rejection count here for the
75-
// query group
76-
if (reject) {
77-
throw new OpenSearchRejectedExecutionException("QueryGroup " + queryGroupId + " is already contended." + reason.toString());
224+
if (workloadManagementSettings.getWlmMode() != WlmMode.ENABLED) {
225+
return;
226+
}
227+
228+
if (queryGroupId == null || queryGroupId.equals(QueryGroupTask.DEFAULT_QUERY_GROUP_ID_SUPPLIER.get())) return;
229+
QueryGroupState queryGroupState = queryGroupsStateAccessor.getQueryGroupState(queryGroupId);
230+
231+
// This can happen if the request failed for a deleted query group
232+
// or new queryGroup is being created and has not been acknowledged yet or invalid query group id
233+
if (queryGroupState == null) {
234+
return;
235+
}
236+
237+
// rejections will not happen for SOFT mode QueryGroups
238+
Optional<QueryGroup> optionalQueryGroup = activeQueryGroups.stream().filter(x -> x.get_id().equals(queryGroupId)).findFirst();
239+
240+
if (optionalQueryGroup.isPresent() && optionalQueryGroup.get().getResiliencyMode() == MutableQueryGroupFragment.ResiliencyMode.SOFT)
241+
return;
242+
243+
optionalQueryGroup.ifPresent(queryGroup -> {
244+
boolean reject = false;
245+
final StringBuilder reason = new StringBuilder();
246+
for (ResourceType resourceType : TRACKED_RESOURCES) {
247+
if (queryGroup.getResourceLimits().containsKey(resourceType)) {
248+
final double threshold = getNormalisedRejectionThreshold(
249+
queryGroup.getResourceLimits().get(resourceType),
250+
resourceType
251+
);
252+
final double lastRecordedUsage = queryGroupState.getResourceState().get(resourceType).getLastRecordedUsage();
253+
if (threshold < lastRecordedUsage) {
254+
reject = true;
255+
reason.append(resourceType)
256+
.append(" limit is breaching for ENFORCED type QueryGroup: (")
257+
.append(threshold)
258+
.append(" < ")
259+
.append(lastRecordedUsage)
260+
.append("). ");
261+
queryGroupState.getResourceState().get(resourceType).rejections.inc();
262+
// should not double count even if both the resource limits are breaching
263+
break;
264+
}
265+
}
266+
}
267+
if (reject) {
268+
queryGroupState.totalRejections.inc();
269+
throw new OpenSearchRejectedExecutionException(
270+
"QueryGroup " + queryGroupId + " is already contended. " + reason.toString()
271+
);
272+
}
273+
});
274+
}
275+
276+
private double getNormalisedRejectionThreshold(double limit, ResourceType resourceType) {
277+
if (resourceType == ResourceType.CPU) {
278+
return limit * workloadManagementSettings.getNodeLevelCpuRejectionThreshold();
279+
} else if (resourceType == ResourceType.MEMORY) {
280+
return limit * workloadManagementSettings.getNodeLevelMemoryRejectionThreshold();
281+
}
282+
throw new IllegalArgumentException(resourceType + " is not supported in WLM yet");
283+
}
284+
285+
public Set<QueryGroup> getActiveQueryGroups() {
286+
return activeQueryGroups;
287+
}
288+
289+
public Set<QueryGroup> getDeletedQueryGroups() {
290+
return deletedQueryGroups;
291+
}
292+
293+
/**
294+
* This method determines whether the task should be accounted by SBP if both features co-exist
295+
* @param t QueryGroupTask
296+
* @return whether or not SBP handle it
297+
*/
298+
public boolean shouldSBPHandle(Task t) {
299+
QueryGroupTask task = (QueryGroupTask) t;
300+
boolean isInvalidQueryGroupTask = true;
301+
if (!task.getQueryGroupId().equals(QueryGroupTask.DEFAULT_QUERY_GROUP_ID_SUPPLIER.get())) {
302+
isInvalidQueryGroupTask = activeQueryGroups.stream()
303+
.noneMatch(queryGroup -> queryGroup.get_id().equals(task.getQueryGroupId()));
304+
}
305+
return workloadManagementSettings.getWlmMode() != WlmMode.ENABLED || isInvalidQueryGroupTask;
306+
}
307+
308+
@Override
309+
public void onTaskCompleted(Task task) {
310+
if (!(task instanceof QueryGroupTask)) {
311+
return;
312+
}
313+
final QueryGroupTask queryGroupTask = (QueryGroupTask) task;
314+
String queryGroupId = queryGroupTask.getQueryGroupId();
315+
316+
// set the default queryGroupId if not existing in the active query groups
317+
String finalQueryGroupId = queryGroupId;
318+
boolean exists = activeQueryGroups.stream().anyMatch(queryGroup -> queryGroup.get_id().equals(finalQueryGroupId));
319+
320+
if (!exists) {
321+
queryGroupId = QueryGroupTask.DEFAULT_QUERY_GROUP_ID_SUPPLIER.get();
322+
}
323+
324+
if (task instanceof SearchShardTask) {
325+
queryGroupsStateAccessor.getQueryGroupState(queryGroupId).shardCompletions.inc();
326+
} else {
327+
queryGroupsStateAccessor.getQueryGroupState(queryGroupId).completions.inc();
78328
}
79329
}
80330
}

‎server/src/main/java/org/opensearch/wlm/QueryGroupTask.java

+5-4
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import org.opensearch.tasks.CancellableTask;
1818

1919
import java.util.Map;
20-
import java.util.Optional;
2120
import java.util.function.LongSupplier;
2221
import java.util.function.Supplier;
2322

@@ -82,9 +81,11 @@ public final String getQueryGroupId() {
8281
* @param threadContext current threadContext
8382
*/
8483
public final void setQueryGroupId(final ThreadContext threadContext) {
85-
this.queryGroupId = Optional.ofNullable(threadContext)
86-
.map(threadContext1 -> threadContext1.getHeader(QUERY_GROUP_ID_HEADER))
87-
.orElse(DEFAULT_QUERY_GROUP_ID_SUPPLIER.get());
84+
if (threadContext != null && threadContext.getHeader(QUERY_GROUP_ID_HEADER) != null) {
85+
this.queryGroupId = threadContext.getHeader(QUERY_GROUP_ID_HEADER);
86+
} else {
87+
this.queryGroupId = DEFAULT_QUERY_GROUP_ID_SUPPLIER.get();
88+
}
8889
}
8990

9091
public long getElapsedTime() {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* The OpenSearch Contributors require contributions made to
5+
* this file be licensed under the Apache-2.0 license or a
6+
* compatible open source license.
7+
*/
8+
9+
package org.opensearch.wlm;
10+
11+
import org.opensearch.wlm.stats.QueryGroupState;
12+
13+
import java.util.HashMap;
14+
import java.util.Map;
15+
16+
/**
17+
* This class is used to decouple {@link QueryGroupService} and {@link org.opensearch.wlm.cancellation.QueryGroupTaskCancellationService} to share the
18+
* {@link QueryGroupState}s
19+
*/
20+
public class QueryGroupsStateAccessor {
21+
// This map does not need to be concurrent since we will process the cluster state change serially and update
22+
// this map with new additions and deletions of entries. QueryGroupState is thread safe
23+
private final Map<String, QueryGroupState> queryGroupStateMap;
24+
25+
public QueryGroupsStateAccessor() {
26+
this(new HashMap<>());
27+
}
28+
29+
public QueryGroupsStateAccessor(Map<String, QueryGroupState> queryGroupStateMap) {
30+
this.queryGroupStateMap = queryGroupStateMap;
31+
}
32+
33+
/**
34+
* returns the query groups state
35+
*/
36+
public Map<String, QueryGroupState> getQueryGroupStateMap() {
37+
return queryGroupStateMap;
38+
}
39+
40+
/**
41+
* return QueryGroupState for the given queryGroupId
42+
* @param queryGroupId
43+
* @return QueryGroupState for the given queryGroupId, if id is invalid return default query group state
44+
*/
45+
public QueryGroupState getQueryGroupState(String queryGroupId) {
46+
return queryGroupStateMap.getOrDefault(queryGroupId, queryGroupStateMap.get(QueryGroupTask.DEFAULT_QUERY_GROUP_ID_SUPPLIER.get()));
47+
}
48+
49+
/**
50+
* adds new QueryGroupState against given queryGroupId
51+
* @param queryGroupId
52+
*/
53+
public void addNewQueryGroup(String queryGroupId) {
54+
this.queryGroupStateMap.putIfAbsent(queryGroupId, new QueryGroupState());
55+
}
56+
57+
/**
58+
* removes QueryGroupState against given queryGroupId
59+
* @param queryGroupId
60+
*/
61+
public void removeQueryGroup(String queryGroupId) {
62+
this.queryGroupStateMap.remove(queryGroupId);
63+
}
64+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* The OpenSearch Contributors require contributions made to
5+
* this file be licensed under the Apache-2.0 license or a
6+
* compatible open source license.
7+
*/
8+
9+
package org.opensearch.wlm;
10+
11+
import org.opensearch.common.annotation.PublicApi;
12+
13+
/**
14+
* Enum to hold the values whether wlm is enabled or not
15+
*/
16+
@PublicApi(since = "2.18.0")
17+
public enum WlmMode {
18+
ENABLED("enabled"),
19+
MONITOR_ONLY("monitor_only"),
20+
DISABLED("disabled");
21+
22+
private final String name;
23+
24+
WlmMode(String name) {
25+
this.name = name;
26+
}
27+
28+
public String getName() {
29+
return name;
30+
}
31+
32+
public static WlmMode fromName(String name) {
33+
for (WlmMode wlmMode : values()) {
34+
if (wlmMode.getName().equals(name)) {
35+
return wlmMode;
36+
}
37+
}
38+
throw new IllegalArgumentException(name + " is an invalid WlmMode");
39+
}
40+
}

‎server/src/main/java/org/opensearch/wlm/WorkloadManagementSettings.java

+105
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.opensearch.common.settings.ClusterSettings;
1313
import org.opensearch.common.settings.Setting;
1414
import org.opensearch.common.settings.Settings;
15+
import org.opensearch.common.unit.TimeValue;
1516

1617
/**
1718
* Main class to declare Workload Management related settings
@@ -22,16 +23,66 @@ public class WorkloadManagementSettings {
2223
private static final Double DEFAULT_NODE_LEVEL_MEMORY_CANCELLATION_THRESHOLD = 0.9;
2324
private static final Double DEFAULT_NODE_LEVEL_CPU_REJECTION_THRESHOLD = 0.8;
2425
private static final Double DEFAULT_NODE_LEVEL_CPU_CANCELLATION_THRESHOLD = 0.9;
26+
private static final Long DEFAULT_QUERYGROUP_SERVICE_RUN_INTERVAL_MILLIS = 1000L;
2527
public static final double NODE_LEVEL_MEMORY_CANCELLATION_THRESHOLD_MAX_VALUE = 0.95;
2628
public static final double NODE_LEVEL_MEMORY_REJECTION_THRESHOLD_MAX_VALUE = 0.9;
2729
public static final double NODE_LEVEL_CPU_CANCELLATION_THRESHOLD_MAX_VALUE = 0.95;
2830
public static final double NODE_LEVEL_CPU_REJECTION_THRESHOLD_MAX_VALUE = 0.9;
31+
public static final String DEFAULT_WLM_MODE = "monitor_only";
2932

3033
private Double nodeLevelMemoryCancellationThreshold;
3134
private Double nodeLevelMemoryRejectionThreshold;
3235
private Double nodeLevelCpuCancellationThreshold;
3336
private Double nodeLevelCpuRejectionThreshold;
3437

38+
/**
39+
* Setting name for QueryGroupService node duress streak
40+
*/
41+
public static final String QUERYGROUP_DURESS_STREAK_SETTING_NAME = "wlm.query_group.duress_streak";
42+
private int duressStreak;
43+
public static final Setting<Integer> QUERYGROUP_SERVICE_DURESS_STREAK_SETTING = Setting.intSetting(
44+
QUERYGROUP_DURESS_STREAK_SETTING_NAME,
45+
3,
46+
3,
47+
Setting.Property.Dynamic,
48+
Setting.Property.NodeScope
49+
);
50+
51+
/**
52+
* Setting name for Query Group Service run interval
53+
*/
54+
public static final String QUERYGROUP_ENFORCEMENT_INTERVAL_SETTING_NAME = "wlm.query_group.enforcement_interval";
55+
56+
private TimeValue queryGroupServiceRunInterval;
57+
/**
58+
* Setting to control the run interval of Query Group Service
59+
*/
60+
public static final Setting<Long> QUERYGROUP_SERVICE_RUN_INTERVAL_SETTING = Setting.longSetting(
61+
QUERYGROUP_ENFORCEMENT_INTERVAL_SETTING_NAME,
62+
DEFAULT_QUERYGROUP_SERVICE_RUN_INTERVAL_MILLIS,
63+
1000,
64+
Setting.Property.Dynamic,
65+
Setting.Property.NodeScope
66+
);
67+
68+
/**
69+
* WLM mode setting name
70+
*/
71+
public static final String WLM_MODE_SETTING_NAME = "wlm.query_group.mode";
72+
73+
private volatile WlmMode wlmMode;
74+
75+
/**
76+
* WLM mode setting, which determines which mode WLM is operating in
77+
*/
78+
public static final Setting<WlmMode> WLM_MODE_SETTING = new Setting<WlmMode>(
79+
WLM_MODE_SETTING_NAME,
80+
DEFAULT_WLM_MODE,
81+
WlmMode::fromName,
82+
Setting.Property.Dynamic,
83+
Setting.Property.NodeScope
84+
);
85+
3586
/**
3687
* Setting name for node level memory based rejection threshold for QueryGroup service
3788
*/
@@ -91,10 +142,13 @@ public class WorkloadManagementSettings {
91142
* @param clusterSettings - QueryGroup cluster settings
92143
*/
93144
public WorkloadManagementSettings(Settings settings, ClusterSettings clusterSettings) {
145+
this.wlmMode = WLM_MODE_SETTING.get(settings);
94146
nodeLevelMemoryCancellationThreshold = NODE_LEVEL_MEMORY_CANCELLATION_THRESHOLD.get(settings);
95147
nodeLevelMemoryRejectionThreshold = NODE_LEVEL_MEMORY_REJECTION_THRESHOLD.get(settings);
96148
nodeLevelCpuCancellationThreshold = NODE_LEVEL_CPU_CANCELLATION_THRESHOLD.get(settings);
97149
nodeLevelCpuRejectionThreshold = NODE_LEVEL_CPU_REJECTION_THRESHOLD.get(settings);
150+
this.queryGroupServiceRunInterval = TimeValue.timeValueMillis(QUERYGROUP_SERVICE_RUN_INTERVAL_SETTING.get(settings));
151+
duressStreak = QUERYGROUP_SERVICE_DURESS_STREAK_SETTING.get(settings);
98152

99153
ensureRejectionThresholdIsLessThanCancellation(
100154
nodeLevelMemoryRejectionThreshold,
@@ -113,6 +167,57 @@ public WorkloadManagementSettings(Settings settings, ClusterSettings clusterSett
113167
clusterSettings.addSettingsUpdateConsumer(NODE_LEVEL_MEMORY_REJECTION_THRESHOLD, this::setNodeLevelMemoryRejectionThreshold);
114168
clusterSettings.addSettingsUpdateConsumer(NODE_LEVEL_CPU_CANCELLATION_THRESHOLD, this::setNodeLevelCpuCancellationThreshold);
115169
clusterSettings.addSettingsUpdateConsumer(NODE_LEVEL_CPU_REJECTION_THRESHOLD, this::setNodeLevelCpuRejectionThreshold);
170+
clusterSettings.addSettingsUpdateConsumer(WLM_MODE_SETTING, this::setWlmMode);
171+
clusterSettings.addSettingsUpdateConsumer(QUERYGROUP_SERVICE_RUN_INTERVAL_SETTING, this::setQueryGroupServiceRunInterval);
172+
clusterSettings.addSettingsUpdateConsumer(QUERYGROUP_SERVICE_DURESS_STREAK_SETTING, this::setDuressStreak);
173+
}
174+
175+
/**
176+
* node duress streak getter
177+
* @return current duressStreak value
178+
*/
179+
public int getDuressStreak() {
180+
return duressStreak;
181+
}
182+
183+
/**
184+
* node duress streak setter
185+
* @param duressStreak new value
186+
*/
187+
private void setDuressStreak(int duressStreak) {
188+
this.duressStreak = duressStreak;
189+
}
190+
191+
/**
192+
* queryGroupServiceRunInterval setter
193+
* @param newIntervalInMillis new value
194+
*/
195+
private void setQueryGroupServiceRunInterval(long newIntervalInMillis) {
196+
this.queryGroupServiceRunInterval = TimeValue.timeValueMillis(newIntervalInMillis);
197+
}
198+
199+
/**
200+
* queryGroupServiceRunInterval getter
201+
* @return current queryGroupServiceRunInterval value
202+
*/
203+
public TimeValue getQueryGroupServiceRunInterval() {
204+
return this.queryGroupServiceRunInterval;
205+
}
206+
207+
/**
208+
* WlmMode setter
209+
* @param mode new mode value
210+
*/
211+
private void setWlmMode(final WlmMode mode) {
212+
this.wlmMode = mode;
213+
}
214+
215+
/**
216+
* WlmMode getter
217+
* @return the current wlmMode
218+
*/
219+
public WlmMode getWlmMode() {
220+
return this.wlmMode;
116221
}
117222

118223
/**

‎server/src/main/java/org/opensearch/wlm/cancellation/QueryGroupTaskCancellationService.java

+96-27
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,26 @@
88

99
package org.opensearch.wlm.cancellation;
1010

11+
import org.apache.logging.log4j.LogManager;
12+
import org.apache.logging.log4j.Logger;
1113
import org.opensearch.cluster.metadata.QueryGroup;
12-
import org.opensearch.tasks.CancellableTask;
1314
import org.opensearch.tasks.TaskCancellation;
1415
import org.opensearch.wlm.MutableQueryGroupFragment.ResiliencyMode;
1516
import org.opensearch.wlm.QueryGroupLevelResourceUsageView;
1617
import org.opensearch.wlm.QueryGroupTask;
18+
import org.opensearch.wlm.QueryGroupsStateAccessor;
1719
import org.opensearch.wlm.ResourceType;
20+
import org.opensearch.wlm.WlmMode;
1821
import org.opensearch.wlm.WorkloadManagementSettings;
22+
import org.opensearch.wlm.stats.QueryGroupState;
1923
import org.opensearch.wlm.tracker.QueryGroupResourceUsageTrackerService;
2024

2125
import java.util.ArrayList;
2226
import java.util.Collection;
27+
import java.util.HashSet;
2328
import java.util.List;
2429
import java.util.Map;
30+
import java.util.Set;
2531
import java.util.function.BooleanSupplier;
2632
import java.util.function.Consumer;
2733
import java.util.stream.Collectors;
@@ -47,46 +53,78 @@
4753
*/
4854
public class QueryGroupTaskCancellationService {
4955
public static final double MIN_VALUE = 1e-9;
56+
private static final Logger log = LogManager.getLogger(QueryGroupTaskCancellationService.class);
5057

5158
private final WorkloadManagementSettings workloadManagementSettings;
5259
private final TaskSelectionStrategy taskSelectionStrategy;
5360
private final QueryGroupResourceUsageTrackerService resourceUsageTrackerService;
5461
// a map of QueryGroupId to its corresponding QueryGroupLevelResourceUsageView object
5562
Map<String, QueryGroupLevelResourceUsageView> queryGroupLevelResourceUsageViews;
56-
private final Collection<QueryGroup> activeQueryGroups;
57-
private final Collection<QueryGroup> deletedQueryGroups;
63+
private final QueryGroupsStateAccessor queryGroupStateAccessor;
5864

5965
public QueryGroupTaskCancellationService(
6066
WorkloadManagementSettings workloadManagementSettings,
6167
TaskSelectionStrategy taskSelectionStrategy,
6268
QueryGroupResourceUsageTrackerService resourceUsageTrackerService,
63-
Collection<QueryGroup> activeQueryGroups,
64-
Collection<QueryGroup> deletedQueryGroups
69+
QueryGroupsStateAccessor queryGroupStateAccessor
6570
) {
6671
this.workloadManagementSettings = workloadManagementSettings;
6772
this.taskSelectionStrategy = taskSelectionStrategy;
6873
this.resourceUsageTrackerService = resourceUsageTrackerService;
69-
this.activeQueryGroups = activeQueryGroups;
70-
this.deletedQueryGroups = deletedQueryGroups;
74+
this.queryGroupStateAccessor = queryGroupStateAccessor;
7175
}
7276

7377
/**
7478
* Cancel tasks based on the implemented strategy.
7579
*/
76-
public final void cancelTasks(BooleanSupplier isNodeInDuress) {
80+
public void cancelTasks(
81+
BooleanSupplier isNodeInDuress,
82+
Collection<QueryGroup> activeQueryGroups,
83+
Collection<QueryGroup> deletedQueryGroups
84+
) {
7785
queryGroupLevelResourceUsageViews = resourceUsageTrackerService.constructQueryGroupLevelUsageViews();
7886
// cancel tasks from QueryGroups that are in Enforced mode that are breaching their resource limits
79-
cancelTasks(ResiliencyMode.ENFORCED);
87+
cancelTasks(ResiliencyMode.ENFORCED, activeQueryGroups);
8088
// if the node is in duress, cancel tasks accordingly.
81-
handleNodeDuress(isNodeInDuress);
89+
handleNodeDuress(isNodeInDuress, activeQueryGroups, deletedQueryGroups);
90+
91+
updateResourceUsageInQueryGroupState(activeQueryGroups);
92+
}
93+
94+
private void updateResourceUsageInQueryGroupState(Collection<QueryGroup> activeQueryGroups) {
95+
Set<String> isSearchWorkloadRunning = new HashSet<>();
96+
for (Map.Entry<String, QueryGroupLevelResourceUsageView> queryGroupLevelResourceUsageViewEntry : queryGroupLevelResourceUsageViews
97+
.entrySet()) {
98+
isSearchWorkloadRunning.add(queryGroupLevelResourceUsageViewEntry.getKey());
99+
QueryGroupState queryGroupState = getQueryGroupState(queryGroupLevelResourceUsageViewEntry.getKey());
100+
TRACKED_RESOURCES.forEach(resourceType -> {
101+
final double currentUsage = queryGroupLevelResourceUsageViewEntry.getValue().getResourceUsageData().get(resourceType);
102+
queryGroupState.getResourceState().get(resourceType).setLastRecordedUsage(currentUsage);
103+
});
104+
}
105+
106+
activeQueryGroups.forEach(queryGroup -> {
107+
if (!isSearchWorkloadRunning.contains(queryGroup.get_id())) {
108+
TRACKED_RESOURCES.forEach(
109+
resourceType -> getQueryGroupState(queryGroup.get_id()).getResourceState().get(resourceType).setLastRecordedUsage(0.0)
110+
);
111+
}
112+
});
82113
}
83114

84-
private void handleNodeDuress(BooleanSupplier isNodeInDuress) {
115+
private void handleNodeDuress(
116+
BooleanSupplier isNodeInDuress,
117+
Collection<QueryGroup> activeQueryGroups,
118+
Collection<QueryGroup> deletedQueryGroups
119+
) {
85120
if (!isNodeInDuress.getAsBoolean()) {
86121
return;
87122
}
88123
// List of tasks to be executed in order if the node is in duress
89-
List<Consumer<Void>> duressActions = List.of(v -> cancelTasksFromDeletedQueryGroups(), v -> cancelTasks(ResiliencyMode.SOFT));
124+
List<Consumer<Void>> duressActions = List.of(
125+
v -> cancelTasksFromDeletedQueryGroups(deletedQueryGroups),
126+
v -> cancelTasks(ResiliencyMode.SOFT, activeQueryGroups)
127+
);
90128

91129
for (Consumer<Void> duressAction : duressActions) {
92130
if (!isNodeInDuress.getAsBoolean()) {
@@ -96,18 +134,18 @@ private void handleNodeDuress(BooleanSupplier isNodeInDuress) {
96134
}
97135
}
98136

99-
private void cancelTasksFromDeletedQueryGroups() {
100-
cancelTasks(getAllCancellableTasks(this.deletedQueryGroups));
137+
private void cancelTasksFromDeletedQueryGroups(Collection<QueryGroup> deletedQueryGroups) {
138+
cancelTasks(getAllCancellableTasks(deletedQueryGroups));
101139
}
102140

103141
/**
104142
* Get all cancellable tasks from the QueryGroups.
105143
*
106144
* @return List of tasks that can be cancelled
107145
*/
108-
List<TaskCancellation> getAllCancellableTasks(ResiliencyMode resiliencyMode) {
146+
List<TaskCancellation> getAllCancellableTasks(ResiliencyMode resiliencyMode, Collection<QueryGroup> queryGroups) {
109147
return getAllCancellableTasks(
110-
activeQueryGroups.stream().filter(queryGroup -> queryGroup.getResiliencyMode() == resiliencyMode).collect(Collectors.toList())
148+
queryGroups.stream().filter(queryGroup -> queryGroup.getResiliencyMode() == resiliencyMode).collect(Collectors.toList())
111149
);
112150
}
113151

@@ -118,6 +156,7 @@ List<TaskCancellation> getAllCancellableTasks(ResiliencyMode resiliencyMode) {
118156
*/
119157
List<TaskCancellation> getAllCancellableTasks(Collection<QueryGroup> queryGroups) {
120158
List<TaskCancellation> taskCancellations = new ArrayList<>();
159+
final List<Runnable> onCancelCallbacks = new ArrayList<>();
121160
for (QueryGroup queryGroup : queryGroups) {
122161
final List<TaskCancellation.Reason> reasons = new ArrayList<>();
123162
List<QueryGroupTask> selectedTasks = new ArrayList<>();
@@ -127,8 +166,7 @@ List<TaskCancellation> getAllCancellableTasks(Collection<QueryGroup> queryGroups
127166
.calculateResourceUsage(selectedTasks);
128167
if (excessUsage > MIN_VALUE) {
129168
reasons.add(new TaskCancellation.Reason(generateReasonString(queryGroup, resourceType), 1));
130-
// TODO: We will need to add the cancellation callback for these resources for the queryGroup to reflect stats
131-
169+
onCancelCallbacks.add(this.getResourceTypeOnCancelCallback(queryGroup.get_id(), resourceType));
132170
// Only add tasks not already added to avoid double cancellations
133171
selectedTasks.addAll(
134172
taskSelectionStrategy.selectTasksForCancellation(getTasksFor(queryGroup), excessUsage, resourceType)
@@ -140,8 +178,9 @@ List<TaskCancellation> getAllCancellableTasks(Collection<QueryGroup> queryGroups
140178
}
141179

142180
if (!reasons.isEmpty()) {
181+
onCancelCallbacks.add(getQueryGroupState(queryGroup.get_id()).totalCancellations::inc);
143182
taskCancellations.addAll(
144-
selectedTasks.stream().map(task -> createTaskCancellation(task, reasons)).collect(Collectors.toList())
183+
selectedTasks.stream().map(task -> new TaskCancellation(task, reasons, onCancelCallbacks)).collect(Collectors.toList())
145184
);
146185
}
147186
}
@@ -164,16 +203,27 @@ private List<QueryGroupTask> getTasksFor(QueryGroup queryGroup) {
164203
return queryGroupLevelResourceUsageViews.get(queryGroup.get_id()).getActiveTasks();
165204
}
166205

167-
private void cancelTasks(ResiliencyMode resiliencyMode) {
168-
cancelTasks(getAllCancellableTasks(resiliencyMode));
206+
private void cancelTasks(ResiliencyMode resiliencyMode, Collection<QueryGroup> queryGroups) {
207+
cancelTasks(getAllCancellableTasks(resiliencyMode, queryGroups));
169208
}
170209

171210
private void cancelTasks(List<TaskCancellation> cancellableTasks) {
172-
cancellableTasks.forEach(TaskCancellation::cancel);
173-
}
174211

175-
private TaskCancellation createTaskCancellation(CancellableTask task, List<TaskCancellation.Reason> reasons) {
176-
return new TaskCancellation(task, reasons, List.of(this::callbackOnCancel));
212+
Consumer<TaskCancellation> cancellationLoggingConsumer = (taskCancellation -> {
213+
log.warn(
214+
"Task {} is eligible for cancellation for reason {}",
215+
taskCancellation.getTask().getId(),
216+
taskCancellation.getReasonString()
217+
);
218+
});
219+
Consumer<TaskCancellation> cancellationConsumer = cancellationLoggingConsumer;
220+
if (workloadManagementSettings.getWlmMode() == WlmMode.ENABLED) {
221+
cancellationConsumer = (taskCancellation -> {
222+
cancellationLoggingConsumer.accept(taskCancellation);
223+
taskCancellation.cancel();
224+
});
225+
}
226+
cancellableTasks.forEach(cancellationConsumer);
177227
}
178228

179229
private double getExcessUsage(QueryGroup queryGroup, ResourceType resourceType) {
@@ -199,7 +249,26 @@ private double getNormalisedThreshold(QueryGroup queryGroup, ResourceType resour
199249
return queryGroup.getResourceLimits().get(resourceType) * nodeLevelCancellationThreshold;
200250
}
201251

202-
private void callbackOnCancel() {
203-
// TODO Implement callback logic here mostly used for Stats
252+
private Runnable getResourceTypeOnCancelCallback(String queryGroupId, ResourceType resourceType) {
253+
QueryGroupState queryGroupState = getQueryGroupState(queryGroupId);
254+
return queryGroupState.getResourceState().get(resourceType).cancellations::inc;
255+
}
256+
257+
private QueryGroupState getQueryGroupState(String queryGroupId) {
258+
assert queryGroupId != null : "queryGroupId should never be null at this point.";
259+
260+
return queryGroupStateAccessor.getQueryGroupState(queryGroupId);
261+
}
262+
263+
/**
264+
* Removes the queryGroups from deleted list if it doesn't have any tasks running
265+
*/
266+
public void pruneDeletedQueryGroups(Collection<QueryGroup> deletedQueryGroups) {
267+
List<QueryGroup> currentDeletedQueryGroups = new ArrayList<>(deletedQueryGroups);
268+
for (QueryGroup queryGroup : currentDeletedQueryGroups) {
269+
if (queryGroupLevelResourceUsageViews.get(queryGroup.get_id()).getActiveTasks().isEmpty()) {
270+
deletedQueryGroups.remove(queryGroup);
271+
}
272+
}
204273
}
205274
}

‎server/src/main/java/org/opensearch/wlm/stats/QueryGroupState.java

+21-8
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,19 @@
1919
*/
2020
public class QueryGroupState {
2121
/**
22-
* completions at the query group level, this is a cumulative counter since the Opensearch start time
22+
* co-ordinator level completions at the query group level, this is a cumulative counter since the Opensearch start time
2323
*/
24-
final CounterMetric completions = new CounterMetric();
24+
public final CounterMetric completions = new CounterMetric();
25+
26+
/**
27+
* shard level completions at the query group level, this is a cumulative counter since the Opensearch start time
28+
*/
29+
public final CounterMetric shardCompletions = new CounterMetric();
2530

2631
/**
2732
* rejections at the query group level, this is a cumulative counter since the OpenSearch start time
2833
*/
29-
final CounterMetric totalRejections = new CounterMetric();
34+
public final CounterMetric totalRejections = new CounterMetric();
3035

3136
/**
3237
* this will track the cumulative failures in a query group
@@ -36,7 +41,7 @@ public class QueryGroupState {
3641
/**
3742
* This will track total number of cancellations in the query group due to all resource type breaches
3843
*/
39-
final CounterMetric totalCancellations = new CounterMetric();
44+
public final CounterMetric totalCancellations = new CounterMetric();
4045

4146
/**
4247
* This is used to store the resource type state both for CPU and MEMORY
@@ -54,12 +59,20 @@ public QueryGroupState() {
5459

5560
/**
5661
*
57-
* @return completions in the query group
62+
* @return co-ordinator completions in the query group
5863
*/
5964
public long getCompletions() {
6065
return completions.count();
6166
}
6267

68+
/**
69+
*
70+
* @return shard completions in the query group
71+
*/
72+
public long getShardCompletions() {
73+
return shardCompletions.count();
74+
}
75+
6376
/**
6477
*
6578
* @return rejections in the query group
@@ -92,9 +105,9 @@ public Map<ResourceType, ResourceTypeState> getResourceState() {
92105
* This class holds the resource level stats for the query group
93106
*/
94107
public static class ResourceTypeState {
95-
final ResourceType resourceType;
96-
final CounterMetric cancellations = new CounterMetric();
97-
final CounterMetric rejections = new CounterMetric();
108+
public final ResourceType resourceType;
109+
public final CounterMetric cancellations = new CounterMetric();
110+
public final CounterMetric rejections = new CounterMetric();
98111
private double lastRecordedUsage = 0;
99112

100113
public ResourceTypeState(ResourceType resourceType) {

‎server/src/main/java/org/opensearch/wlm/stats/QueryGroupStats.java

+10-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,9 @@ public static class QueryGroupStatsHolder implements ToXContentObject, Writeable
9191
public static final String REJECTIONS = "rejections";
9292
public static final String TOTAL_CANCELLATIONS = "total_cancellations";
9393
public static final String FAILURES = "failures";
94+
public static final String SHARD_COMPLETIONS = "shard_completions";
9495
private long completions;
96+
private long shardCompletions;
9597
private long rejections;
9698
private long failures;
9799
private long totalCancellations;
@@ -105,11 +107,13 @@ public QueryGroupStatsHolder(
105107
long rejections,
106108
long failures,
107109
long totalCancellations,
110+
long shardCompletions,
108111
Map<ResourceType, ResourceStats> resourceStats
109112
) {
110113
this.completions = completions;
111114
this.rejections = rejections;
112115
this.failures = failures;
116+
this.shardCompletions = shardCompletions;
113117
this.totalCancellations = totalCancellations;
114118
this.resourceStats = resourceStats;
115119
}
@@ -119,6 +123,7 @@ public QueryGroupStatsHolder(StreamInput in) throws IOException {
119123
this.rejections = in.readVLong();
120124
this.failures = in.readVLong();
121125
this.totalCancellations = in.readVLong();
126+
this.shardCompletions = in.readVLong();
122127
this.resourceStats = in.readMap((i) -> ResourceType.fromName(i.readString()), ResourceStats::new);
123128
}
124129

@@ -140,6 +145,7 @@ public static QueryGroupStatsHolder from(QueryGroupState queryGroupState) {
140145
statsHolder.rejections = queryGroupState.getTotalRejections();
141146
statsHolder.failures = queryGroupState.getFailures();
142147
statsHolder.totalCancellations = queryGroupState.getTotalCancellations();
148+
statsHolder.shardCompletions = queryGroupState.getShardCompletions();
143149
statsHolder.resourceStats = resourceStatsMap;
144150
return statsHolder;
145151
}
@@ -155,6 +161,7 @@ public static void writeTo(StreamOutput out, QueryGroupStatsHolder statsHolder)
155161
out.writeVLong(statsHolder.rejections);
156162
out.writeVLong(statsHolder.failures);
157163
out.writeVLong(statsHolder.totalCancellations);
164+
out.writeVLong(statsHolder.shardCompletions);
158165
out.writeMap(statsHolder.resourceStats, (o, val) -> o.writeString(val.getName()), ResourceStats::writeTo);
159166
}
160167

@@ -166,6 +173,7 @@ public void writeTo(StreamOutput out) throws IOException {
166173
@Override
167174
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
168175
builder.field(COMPLETIONS, completions);
176+
builder.field(SHARD_COMPLETIONS, shardCompletions);
169177
builder.field(REJECTIONS, rejections);
170178
builder.field(FAILURES, failures);
171179
builder.field(TOTAL_CANCELLATIONS, totalCancellations);
@@ -187,14 +195,15 @@ public boolean equals(Object o) {
187195
QueryGroupStatsHolder that = (QueryGroupStatsHolder) o;
188196
return completions == that.completions
189197
&& rejections == that.rejections
198+
&& shardCompletions == that.shardCompletions
190199
&& Objects.equals(resourceStats, that.resourceStats)
191200
&& failures == that.failures
192201
&& totalCancellations == that.totalCancellations;
193202
}
194203

195204
@Override
196205
public int hashCode() {
197-
return Objects.hash(completions, rejections, totalCancellations, failures, resourceStats);
206+
return Objects.hash(completions, shardCompletions, rejections, totalCancellations, failures, resourceStats);
198207
}
199208
}
200209

‎server/src/main/java/org/opensearch/wlm/tracker/QueryGroupResourceUsageTrackerService.java

+2
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ public Map<String, QueryGroupLevelResourceUsageView> constructQueryGroupLevelUsa
4747

4848
// Iterate over each QueryGroup entry
4949
for (Map.Entry<String, List<QueryGroupTask>> queryGroupEntry : tasksByQueryGroup.entrySet()) {
50+
// refresh the resource stats
51+
taskResourceTrackingService.refreshResourceStats(queryGroupEntry.getValue().toArray(new QueryGroupTask[0]));
5052
// Compute the QueryGroup resource usage
5153
final Map<ResourceType, Double> queryGroupUsage = new EnumMap<>(ResourceType.class);
5254
for (ResourceType resourceType : TRACKED_RESOURCES) {

‎server/src/test/java/org/opensearch/search/backpressure/SearchBackpressureServiceTests.java

+40-12
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
import org.opensearch.test.transport.MockTransportService;
4040
import org.opensearch.threadpool.TestThreadPool;
4141
import org.opensearch.threadpool.ThreadPool;
42+
import org.opensearch.wlm.QueryGroupService;
43+
import org.opensearch.wlm.QueryGroupTask;
4244
import org.opensearch.wlm.ResourceType;
4345
import org.junit.After;
4446
import org.junit.Before;
@@ -75,10 +77,12 @@ public class SearchBackpressureServiceTests extends OpenSearchTestCase {
7577
MockTransportService transportService;
7678
TaskManager taskManager;
7779
ThreadPool threadPool;
80+
QueryGroupService queryGroupService;
7881

7982
@Before
8083
public void setup() {
8184
threadPool = new TestThreadPool(getClass().getName());
85+
queryGroupService = mock(QueryGroupService.class);
8286
transportService = MockTransportService.createNewService(Settings.EMPTY, Version.CURRENT, threadPool, NoopTracer.INSTANCE);
8387
transportService.start();
8488
transportService.acceptIncomingRequests();
@@ -120,9 +124,12 @@ public void testIsNodeInDuress() {
120124
new NodeDuressTrackers(duressTrackers),
121125
new TaskResourceUsageTrackers(),
122126
new TaskResourceUsageTrackers(),
123-
taskManager
127+
taskManager,
128+
queryGroupService
124129
);
125130

131+
when(queryGroupService.shouldSBPHandle(any())).thenReturn(true);
132+
126133
// Node not in duress.
127134
cpuUsage.set(0.0);
128135
heapUsage.set(0.0);
@@ -163,9 +170,12 @@ public void testTrackerStateUpdateOnSearchTaskCompletion() {
163170
new NodeDuressTrackers(new EnumMap<>(ResourceType.class)),
164171
taskResourceUsageTrackers,
165172
new TaskResourceUsageTrackers(),
166-
taskManager
173+
taskManager,
174+
queryGroupService
167175
);
168176

177+
when(queryGroupService.shouldSBPHandle(any())).thenReturn(true);
178+
169179
for (int i = 0; i < 100; i++) {
170180
// service.onTaskCompleted(new SearchTask(1, "test", "test", () -> "Test", TaskId.EMPTY_TASK_ID, new HashMap<>()));
171181
service.onTaskCompleted(createMockTaskWithResourceStats(SearchTask.class, 100, 200, i));
@@ -194,9 +204,12 @@ public void testTrackerStateUpdateOnSearchShardTaskCompletion() {
194204
new NodeDuressTrackers(new EnumMap<>(ResourceType.class)),
195205
new TaskResourceUsageTrackers(),
196206
taskResourceUsageTrackers,
197-
taskManager
207+
taskManager,
208+
queryGroupService
198209
);
199210

211+
when(queryGroupService.shouldSBPHandle(any())).thenReturn(true);
212+
200213
// Record task completions to update the tracker state. Tasks other than SearchTask & SearchShardTask are ignored.
201214
service.onTaskCompleted(createMockTaskWithResourceStats(CancellableTask.class, 100, 200, 101));
202215
for (int i = 0; i < 100; i++) {
@@ -246,9 +259,12 @@ public void testSearchTaskInFlightCancellation() {
246259
new NodeDuressTrackers(duressTrackers),
247260
taskResourceUsageTrackers,
248261
new TaskResourceUsageTrackers(),
249-
mockTaskManager
262+
mockTaskManager,
263+
queryGroupService
250264
);
251265

266+
when(queryGroupService.shouldSBPHandle(any())).thenReturn(true);
267+
252268
// Run two iterations so that node is marked 'in duress' from the third iteration onwards.
253269
service.doRun();
254270
service.doRun();
@@ -261,14 +277,15 @@ public void testSearchTaskInFlightCancellation() {
261277
when(settings.getSearchTaskSettings()).thenReturn(searchTaskSettings);
262278

263279
// Create a mix of low and high resource usage SearchTasks (50 low + 25 high resource usage tasks).
264-
Map<Long, Task> activeSearchTasks = new HashMap<>();
280+
Map<Long, QueryGroupTask> activeSearchTasks = new HashMap<>();
265281
for (long i = 0; i < 75; i++) {
266282
if (i % 3 == 0) {
267283
activeSearchTasks.put(i, createMockTaskWithResourceStats(SearchTask.class, 500, taskHeapUsageBytes, i));
268284
} else {
269285
activeSearchTasks.put(i, createMockTaskWithResourceStats(SearchTask.class, 100, taskHeapUsageBytes, i));
270286
}
271287
}
288+
activeSearchTasks.values().forEach(task -> task.setQueryGroupId(threadPool.getThreadContext()));
272289
doReturn(activeSearchTasks).when(mockTaskResourceTrackingService).getResourceAwareTasks();
273290

274291
// There are 25 SearchTasks eligible for cancellation but only 5 will be cancelled (burst limit).
@@ -337,9 +354,12 @@ public void testSearchShardTaskInFlightCancellation() {
337354
nodeDuressTrackers,
338355
new TaskResourceUsageTrackers(),
339356
taskResourceUsageTrackers,
340-
mockTaskManager
357+
mockTaskManager,
358+
queryGroupService
341359
);
342360

361+
when(queryGroupService.shouldSBPHandle(any())).thenReturn(true);
362+
343363
// Run two iterations so that node is marked 'in duress' from the third iteration onwards.
344364
service.doRun();
345365
service.doRun();
@@ -352,14 +372,15 @@ public void testSearchShardTaskInFlightCancellation() {
352372
when(settings.getSearchShardTaskSettings()).thenReturn(searchShardTaskSettings);
353373

354374
// Create a mix of low and high resource usage tasks (60 low + 15 high resource usage tasks).
355-
Map<Long, Task> activeSearchShardTasks = new HashMap<>();
375+
Map<Long, QueryGroupTask> activeSearchShardTasks = new HashMap<>();
356376
for (long i = 0; i < 75; i++) {
357377
if (i % 5 == 0) {
358378
activeSearchShardTasks.put(i, createMockTaskWithResourceStats(SearchShardTask.class, 500, taskHeapUsageBytes, i));
359379
} else {
360380
activeSearchShardTasks.put(i, createMockTaskWithResourceStats(SearchShardTask.class, 100, taskHeapUsageBytes, i));
361381
}
362382
}
383+
activeSearchShardTasks.values().forEach(task -> task.setQueryGroupId(threadPool.getThreadContext()));
363384
doReturn(activeSearchShardTasks).when(mockTaskResourceTrackingService).getResourceAwareTasks();
364385

365386
// There are 15 SearchShardTasks eligible for cancellation but only 10 will be cancelled (burst limit).
@@ -437,9 +458,12 @@ public void testNonCancellationOfHeapBasedTasksWhenHeapNotInDuress() {
437458
nodeDuressTrackers,
438459
taskResourceUsageTrackers,
439460
new TaskResourceUsageTrackers(),
440-
mockTaskManager
461+
mockTaskManager,
462+
queryGroupService
441463
);
442464

465+
when(queryGroupService.shouldSBPHandle(any())).thenReturn(true);
466+
443467
service.doRun();
444468
service.doRun();
445469

@@ -449,14 +473,15 @@ public void testNonCancellationOfHeapBasedTasksWhenHeapNotInDuress() {
449473
when(settings.getSearchTaskSettings()).thenReturn(searchTaskSettings);
450474

451475
// Create a mix of low and high resource usage tasks (60 low + 15 high resource usage tasks).
452-
Map<Long, Task> activeSearchTasks = new HashMap<>();
476+
Map<Long, QueryGroupTask> activeSearchTasks = new HashMap<>();
453477
for (long i = 0; i < 75; i++) {
454478
if (i % 5 == 0) {
455479
activeSearchTasks.put(i, createMockTaskWithResourceStats(SearchTask.class, 500, 800, i));
456480
} else {
457481
activeSearchTasks.put(i, createMockTaskWithResourceStats(SearchTask.class, 100, 800, i));
458482
}
459483
}
484+
activeSearchTasks.values().forEach(task -> task.setQueryGroupId(threadPool.getThreadContext()));
460485
doReturn(activeSearchTasks).when(mockTaskResourceTrackingService).getResourceAwareTasks();
461486

462487
// this will trigger cancellation but these cancellation should only be cpu based
@@ -531,10 +556,12 @@ public void testNonCancellationWhenSearchTrafficIsNotQualifyingForCancellation()
531556
nodeDuressTrackers,
532557
taskResourceUsageTrackers,
533558
new TaskResourceUsageTrackers(),
534-
mockTaskManager
559+
mockTaskManager,
560+
queryGroupService
535561
)
536562
);
537563

564+
when(queryGroupService.shouldSBPHandle(any())).thenReturn(true);
538565
when(service.isHeapUsageDominatedBySearch(anyList(), anyDouble())).thenReturn(false);
539566

540567
service.doRun();
@@ -546,15 +573,16 @@ public void testNonCancellationWhenSearchTrafficIsNotQualifyingForCancellation()
546573
when(settings.getSearchTaskSettings()).thenReturn(searchTaskSettings);
547574

548575
// Create a mix of low and high resource usage tasks (60 low + 15 high resource usage tasks).
549-
Map<Long, Task> activeSearchTasks = new HashMap<>();
576+
Map<Long, QueryGroupTask> activeSearchTasks = new HashMap<>();
550577
for (long i = 0; i < 75; i++) {
551-
Class<? extends CancellableTask> taskType = randomBoolean() ? SearchTask.class : SearchShardTask.class;
578+
Class<? extends QueryGroupTask> taskType = randomBoolean() ? SearchTask.class : SearchShardTask.class;
552579
if (i % 5 == 0) {
553580
activeSearchTasks.put(i, createMockTaskWithResourceStats(taskType, 500, 800, i));
554581
} else {
555582
activeSearchTasks.put(i, createMockTaskWithResourceStats(taskType, 100, 800, i));
556583
}
557584
}
585+
activeSearchTasks.values().forEach(task -> task.setQueryGroupId(threadPool.getThreadContext()));
558586
doReturn(activeSearchTasks).when(mockTaskResourceTrackingService).getResourceAwareTasks();
559587

560588
// this will trigger cancellation but the cancellation should not happen as the node is not is duress because of search traffic

‎server/src/test/java/org/opensearch/wlm/QueryGroupServiceTests.java

+489
Large diffs are not rendered by default.

‎server/src/test/java/org/opensearch/wlm/WorkloadManagementTransportInterceptorTests.java

+34-2
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,56 @@
88

99
package org.opensearch.wlm;
1010

11+
import org.opensearch.cluster.ClusterState;
12+
import org.opensearch.cluster.metadata.Metadata;
13+
import org.opensearch.cluster.service.ClusterService;
1114
import org.opensearch.test.OpenSearchTestCase;
1215
import org.opensearch.threadpool.TestThreadPool;
1316
import org.opensearch.threadpool.ThreadPool;
1417
import org.opensearch.transport.TransportRequest;
1518
import org.opensearch.transport.TransportRequestHandler;
1619
import org.opensearch.wlm.WorkloadManagementTransportInterceptor.RequestHandler;
20+
import org.opensearch.wlm.cancellation.QueryGroupTaskCancellationService;
21+
22+
import java.util.Collections;
1723

1824
import static org.opensearch.threadpool.ThreadPool.Names.SAME;
25+
import static org.mockito.Mockito.mock;
26+
import static org.mockito.Mockito.when;
1927

2028
public class WorkloadManagementTransportInterceptorTests extends OpenSearchTestCase {
21-
29+
private QueryGroupTaskCancellationService mockTaskCancellationService;
30+
private ClusterService mockClusterService;
31+
private ThreadPool mockThreadPool;
32+
private WorkloadManagementSettings mockWorkloadManagementSettings;
2233
private ThreadPool threadPool;
2334
private WorkloadManagementTransportInterceptor sut;
35+
private QueryGroupsStateAccessor stateAccessor;
2436

2537
public void setUp() throws Exception {
2638
super.setUp();
39+
mockTaskCancellationService = mock(QueryGroupTaskCancellationService.class);
40+
mockClusterService = mock(ClusterService.class);
41+
mockThreadPool = mock(ThreadPool.class);
42+
mockWorkloadManagementSettings = mock(WorkloadManagementSettings.class);
2743
threadPool = new TestThreadPool(getTestName());
28-
sut = new WorkloadManagementTransportInterceptor(threadPool, new QueryGroupService());
44+
stateAccessor = new QueryGroupsStateAccessor();
45+
46+
ClusterState state = mock(ClusterState.class);
47+
Metadata metadata = mock(Metadata.class);
48+
when(mockClusterService.state()).thenReturn(state);
49+
when(state.metadata()).thenReturn(metadata);
50+
when(metadata.queryGroups()).thenReturn(Collections.emptyMap());
51+
sut = new WorkloadManagementTransportInterceptor(
52+
threadPool,
53+
new QueryGroupService(
54+
mockTaskCancellationService,
55+
mockClusterService,
56+
mockThreadPool,
57+
mockWorkloadManagementSettings,
58+
stateAccessor
59+
)
60+
);
2961
}
3062

3163
public void tearDown() throws Exception {

‎server/src/test/java/org/opensearch/wlm/cancellation/QueryGroupTaskCancellationServiceTests.java

+74-26
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,16 @@
1717
import org.opensearch.wlm.MutableQueryGroupFragment.ResiliencyMode;
1818
import org.opensearch.wlm.QueryGroupLevelResourceUsageView;
1919
import org.opensearch.wlm.QueryGroupTask;
20+
import org.opensearch.wlm.QueryGroupsStateAccessor;
2021
import org.opensearch.wlm.ResourceType;
22+
import org.opensearch.wlm.WlmMode;
2123
import org.opensearch.wlm.WorkloadManagementSettings;
24+
import org.opensearch.wlm.stats.QueryGroupState;
2225
import org.opensearch.wlm.tracker.QueryGroupResourceUsageTrackerService;
2326
import org.opensearch.wlm.tracker.ResourceUsageCalculatorTrackerServiceTests.TestClock;
2427
import org.junit.Before;
2528

29+
import java.util.ArrayList;
2630
import java.util.Collection;
2731
import java.util.Collections;
2832
import java.util.HashMap;
@@ -31,7 +35,9 @@
3135
import java.util.Map;
3236
import java.util.Set;
3337
import java.util.stream.Collectors;
38+
import java.util.stream.IntStream;
3439

40+
import static org.mockito.ArgumentMatchers.any;
3541
import static org.mockito.Mockito.mock;
3642
import static org.mockito.Mockito.when;
3743

@@ -47,6 +53,7 @@ public class QueryGroupTaskCancellationServiceTests extends OpenSearchTestCase {
4753
private QueryGroupTaskCancellationService taskCancellation;
4854
private WorkloadManagementSettings workloadManagementSettings;
4955
private QueryGroupResourceUsageTrackerService resourceUsageTrackerService;
56+
private QueryGroupsStateAccessor stateAccessor;
5057

5158
@Before
5259
public void setup() {
@@ -59,12 +66,13 @@ public void setup() {
5966
when(workloadManagementSettings.getNodeLevelCpuCancellationThreshold()).thenReturn(0.9);
6067
when(workloadManagementSettings.getNodeLevelMemoryCancellationThreshold()).thenReturn(0.9);
6168
resourceUsageTrackerService = mock(QueryGroupResourceUsageTrackerService.class);
69+
stateAccessor = mock(QueryGroupsStateAccessor.class);
70+
when(stateAccessor.getQueryGroupState(any())).thenReturn(new QueryGroupState());
6271
taskCancellation = new QueryGroupTaskCancellationService(
6372
workloadManagementSettings,
6473
new MaximumResourceTaskSelectionStrategy(),
6574
resourceUsageTrackerService,
66-
activeQueryGroups,
67-
deletedQueryGroups
75+
stateAccessor
6876
);
6977
}
7078

@@ -138,7 +146,7 @@ public void testGetCancellableTasksFrom_returnsTasksWhenBreachingThresholdForMem
138146
activeQueryGroups.add(queryGroup1);
139147
taskCancellation.queryGroupLevelResourceUsageViews = queryGroupLevelViews;
140148

141-
List<TaskCancellation> cancellableTasksFrom = taskCancellation.getAllCancellableTasks(ResiliencyMode.ENFORCED);
149+
List<TaskCancellation> cancellableTasksFrom = taskCancellation.getAllCancellableTasks(ResiliencyMode.ENFORCED, activeQueryGroups);
142150
assertEquals(2, cancellableTasksFrom.size());
143151
assertEquals(1234, cancellableTasksFrom.get(0).getTask().getId());
144152
assertEquals(4321, cancellableTasksFrom.get(1).getTask().getId());
@@ -187,11 +195,10 @@ public void testGetCancellableTasksFrom_filtersQueryGroupCorrectly() {
187195
workloadManagementSettings,
188196
new MaximumResourceTaskSelectionStrategy(),
189197
resourceUsageTrackerService,
190-
activeQueryGroups,
191-
deletedQueryGroups
198+
stateAccessor
192199
);
193200

194-
List<TaskCancellation> cancellableTasksFrom = taskCancellation.getAllCancellableTasks(ResiliencyMode.SOFT);
201+
List<TaskCancellation> cancellableTasksFrom = taskCancellation.getAllCancellableTasks(ResiliencyMode.SOFT, activeQueryGroups);
195202
assertEquals(0, cancellableTasksFrom.size());
196203
}
197204

@@ -219,19 +226,19 @@ public void testCancelTasks_cancelsGivenTasks() {
219226
workloadManagementSettings,
220227
new MaximumResourceTaskSelectionStrategy(),
221228
resourceUsageTrackerService,
222-
activeQueryGroups,
223-
deletedQueryGroups
229+
stateAccessor
224230
);
225231

226232
taskCancellation.queryGroupLevelResourceUsageViews = queryGroupLevelViews;
227233

228-
List<TaskCancellation> cancellableTasksFrom = taskCancellation.getAllCancellableTasks(ResiliencyMode.ENFORCED);
234+
List<TaskCancellation> cancellableTasksFrom = taskCancellation.getAllCancellableTasks(ResiliencyMode.ENFORCED, activeQueryGroups);
229235
assertEquals(2, cancellableTasksFrom.size());
230236
assertEquals(1234, cancellableTasksFrom.get(0).getTask().getId());
231237
assertEquals(4321, cancellableTasksFrom.get(1).getTask().getId());
232238

233239
when(resourceUsageTrackerService.constructQueryGroupLevelUsageViews()).thenReturn(queryGroupLevelViews);
234-
taskCancellation.cancelTasks(() -> false);
240+
when(workloadManagementSettings.getWlmMode()).thenReturn(WlmMode.ENABLED);
241+
taskCancellation.cancelTasks(() -> false, activeQueryGroups, deletedQueryGroups);
235242
assertTrue(cancellableTasksFrom.get(0).getTask().isCancelled());
236243
assertTrue(cancellableTasksFrom.get(1).getTask().isCancelled());
237244
}
@@ -281,13 +288,11 @@ public void testCancelTasks_cancelsTasksFromDeletedQueryGroups() {
281288
workloadManagementSettings,
282289
new MaximumResourceTaskSelectionStrategy(),
283290
resourceUsageTrackerService,
284-
activeQueryGroups,
285-
deletedQueryGroups
291+
stateAccessor
286292
);
287-
288293
taskCancellation.queryGroupLevelResourceUsageViews = queryGroupLevelViews;
289294

290-
List<TaskCancellation> cancellableTasksFrom = taskCancellation.getAllCancellableTasks(ResiliencyMode.ENFORCED);
295+
List<TaskCancellation> cancellableTasksFrom = taskCancellation.getAllCancellableTasks(ResiliencyMode.ENFORCED, activeQueryGroups);
291296
assertEquals(2, cancellableTasksFrom.size());
292297
assertEquals(1234, cancellableTasksFrom.get(0).getTask().getId());
293298
assertEquals(4321, cancellableTasksFrom.get(1).getTask().getId());
@@ -298,7 +303,8 @@ public void testCancelTasks_cancelsTasksFromDeletedQueryGroups() {
298303
assertEquals(1001, cancellableTasksFromDeletedQueryGroups.get(1).getTask().getId());
299304

300305
when(resourceUsageTrackerService.constructQueryGroupLevelUsageViews()).thenReturn(queryGroupLevelViews);
301-
taskCancellation.cancelTasks(() -> true);
306+
when(workloadManagementSettings.getWlmMode()).thenReturn(WlmMode.ENABLED);
307+
taskCancellation.cancelTasks(() -> true, activeQueryGroups, deletedQueryGroups);
302308

303309
assertTrue(cancellableTasksFrom.get(0).getTask().isCancelled());
304310
assertTrue(cancellableTasksFrom.get(1).getTask().isCancelled());
@@ -352,12 +358,11 @@ public void testCancelTasks_does_not_cancelTasksFromDeletedQueryGroups_whenNodeN
352358
workloadManagementSettings,
353359
new MaximumResourceTaskSelectionStrategy(),
354360
resourceUsageTrackerService,
355-
activeQueryGroups,
356-
deletedQueryGroups
361+
stateAccessor
357362
);
358363
taskCancellation.queryGroupLevelResourceUsageViews = queryGroupLevelViews;
359364

360-
List<TaskCancellation> cancellableTasksFrom = taskCancellation.getAllCancellableTasks(ResiliencyMode.ENFORCED);
365+
List<TaskCancellation> cancellableTasksFrom = taskCancellation.getAllCancellableTasks(ResiliencyMode.ENFORCED, activeQueryGroups);
361366
assertEquals(2, cancellableTasksFrom.size());
362367
assertEquals(1234, cancellableTasksFrom.get(0).getTask().getId());
363368
assertEquals(4321, cancellableTasksFrom.get(1).getTask().getId());
@@ -368,7 +373,8 @@ public void testCancelTasks_does_not_cancelTasksFromDeletedQueryGroups_whenNodeN
368373
assertEquals(1001, cancellableTasksFromDeletedQueryGroups.get(1).getTask().getId());
369374

370375
when(resourceUsageTrackerService.constructQueryGroupLevelUsageViews()).thenReturn(queryGroupLevelViews);
371-
taskCancellation.cancelTasks(() -> false);
376+
when(workloadManagementSettings.getWlmMode()).thenReturn(WlmMode.ENABLED);
377+
taskCancellation.cancelTasks(() -> false, activeQueryGroups, deletedQueryGroups);
372378

373379
assertTrue(cancellableTasksFrom.get(0).getTask().isCancelled());
374380
assertTrue(cancellableTasksFrom.get(1).getTask().isCancelled());
@@ -411,24 +417,24 @@ public void testCancelTasks_cancelsGivenTasks_WhenNodeInDuress() {
411417
workloadManagementSettings,
412418
new MaximumResourceTaskSelectionStrategy(),
413419
resourceUsageTrackerService,
414-
activeQueryGroups,
415-
deletedQueryGroups
420+
stateAccessor
416421
);
417422

418423
taskCancellation.queryGroupLevelResourceUsageViews = queryGroupLevelViews;
419424

420-
List<TaskCancellation> cancellableTasksFrom = taskCancellation.getAllCancellableTasks(ResiliencyMode.ENFORCED);
425+
List<TaskCancellation> cancellableTasksFrom = taskCancellation.getAllCancellableTasks(ResiliencyMode.ENFORCED, activeQueryGroups);
421426
assertEquals(2, cancellableTasksFrom.size());
422427
assertEquals(1234, cancellableTasksFrom.get(0).getTask().getId());
423428
assertEquals(4321, cancellableTasksFrom.get(1).getTask().getId());
424429

425-
List<TaskCancellation> cancellableTasksFrom1 = taskCancellation.getAllCancellableTasks(ResiliencyMode.SOFT);
430+
List<TaskCancellation> cancellableTasksFrom1 = taskCancellation.getAllCancellableTasks(ResiliencyMode.SOFT, activeQueryGroups);
426431
assertEquals(2, cancellableTasksFrom1.size());
427432
assertEquals(5678, cancellableTasksFrom1.get(0).getTask().getId());
428433
assertEquals(8765, cancellableTasksFrom1.get(1).getTask().getId());
429434

430435
when(resourceUsageTrackerService.constructQueryGroupLevelUsageViews()).thenReturn(queryGroupLevelViews);
431-
taskCancellation.cancelTasks(() -> true);
436+
when(workloadManagementSettings.getWlmMode()).thenReturn(WlmMode.ENABLED);
437+
taskCancellation.cancelTasks(() -> true, activeQueryGroups, deletedQueryGroups);
432438
assertTrue(cancellableTasksFrom.get(0).getTask().isCancelled());
433439
assertTrue(cancellableTasksFrom.get(1).getTask().isCancelled());
434440
assertTrue(cancellableTasksFrom1.get(0).getTask().isCancelled());
@@ -456,7 +462,7 @@ public void testGetAllCancellableTasks_ReturnsNoTasksWhenNotBreachingThresholds(
456462
activeQueryGroups.add(queryGroup1);
457463
taskCancellation.queryGroupLevelResourceUsageViews = queryGroupLevelViews;
458464

459-
List<TaskCancellation> allCancellableTasks = taskCancellation.getAllCancellableTasks(ResiliencyMode.ENFORCED);
465+
List<TaskCancellation> allCancellableTasks = taskCancellation.getAllCancellableTasks(ResiliencyMode.ENFORCED, activeQueryGroups);
460466
assertTrue(allCancellableTasks.isEmpty());
461467
}
462468

@@ -479,7 +485,7 @@ public void testGetAllCancellableTasks_ReturnsTasksWhenBreachingThresholds() {
479485
activeQueryGroups.add(queryGroup1);
480486
taskCancellation.queryGroupLevelResourceUsageViews = queryGroupLevelViews;
481487

482-
List<TaskCancellation> allCancellableTasks = taskCancellation.getAllCancellableTasks(ResiliencyMode.ENFORCED);
488+
List<TaskCancellation> allCancellableTasks = taskCancellation.getAllCancellableTasks(ResiliencyMode.ENFORCED, activeQueryGroups);
483489
assertEquals(2, allCancellableTasks.size());
484490
assertEquals(1234, allCancellableTasks.get(0).getTask().getId());
485491
assertEquals(4321, allCancellableTasks.get(1).getTask().getId());
@@ -513,6 +519,48 @@ public void testGetCancellableTasksFrom_doesNotReturnTasksWhenQueryGroupIdNotFou
513519
assertEquals(0, cancellableTasksFrom.size());
514520
}
515521

522+
public void testPruneDeletedQueryGroups() {
523+
QueryGroup queryGroup1 = new QueryGroup(
524+
"testQueryGroup1",
525+
queryGroupId1,
526+
new MutableQueryGroupFragment(ResiliencyMode.ENFORCED, Map.of(ResourceType.CPU, 0.2)),
527+
1L
528+
);
529+
QueryGroup queryGroup2 = new QueryGroup(
530+
"testQueryGroup2",
531+
queryGroupId2,
532+
new MutableQueryGroupFragment(ResiliencyMode.ENFORCED, Map.of(ResourceType.CPU, 0.1)),
533+
1L
534+
);
535+
List<QueryGroup> deletedQueryGroups = new ArrayList<>();
536+
deletedQueryGroups.add(queryGroup1);
537+
deletedQueryGroups.add(queryGroup2);
538+
QueryGroupLevelResourceUsageView resourceUsageView1 = createResourceUsageViewMock();
539+
540+
List<QueryGroupTask> activeTasks = IntStream.range(0, 5).mapToObj(this::getRandomSearchTask).collect(Collectors.toList());
541+
when(resourceUsageView1.getActiveTasks()).thenReturn(activeTasks);
542+
543+
QueryGroupLevelResourceUsageView resourceUsageView2 = createResourceUsageViewMock();
544+
when(resourceUsageView2.getActiveTasks()).thenReturn(new ArrayList<>());
545+
546+
queryGroupLevelViews.put(queryGroupId1, resourceUsageView1);
547+
queryGroupLevelViews.put(queryGroupId2, resourceUsageView2);
548+
549+
QueryGroupTaskCancellationService taskCancellation = new QueryGroupTaskCancellationService(
550+
workloadManagementSettings,
551+
new MaximumResourceTaskSelectionStrategy(),
552+
resourceUsageTrackerService,
553+
stateAccessor
554+
);
555+
taskCancellation.queryGroupLevelResourceUsageViews = queryGroupLevelViews;
556+
557+
taskCancellation.pruneDeletedQueryGroups(deletedQueryGroups);
558+
559+
assertEquals(1, deletedQueryGroups.size());
560+
assertEquals(queryGroupId1, deletedQueryGroups.get(0).get_id());
561+
562+
}
563+
516564
private QueryGroupLevelResourceUsageView createResourceUsageViewMock() {
517565
QueryGroupLevelResourceUsageView mockView = mock(QueryGroupLevelResourceUsageView.class);
518566
when(mockView.getActiveTasks()).thenReturn(List.of(getRandomSearchTask(1234), getRandomSearchTask(4321)));

‎server/src/test/java/org/opensearch/wlm/listeners/QueryGroupRequestOperationListenerTests.java

+92-5
Original file line numberDiff line numberDiff line change
@@ -8,38 +8,51 @@
88

99
package org.opensearch.wlm.listeners;
1010

11+
import org.opensearch.cluster.ClusterState;
12+
import org.opensearch.cluster.metadata.Metadata;
13+
import org.opensearch.cluster.service.ClusterService;
1114
import org.opensearch.common.util.concurrent.ThreadContext;
1215
import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException;
1316
import org.opensearch.test.OpenSearchTestCase;
1417
import org.opensearch.threadpool.TestThreadPool;
1518
import org.opensearch.threadpool.ThreadPool;
1619
import org.opensearch.wlm.QueryGroupService;
1720
import org.opensearch.wlm.QueryGroupTask;
21+
import org.opensearch.wlm.QueryGroupsStateAccessor;
1822
import org.opensearch.wlm.ResourceType;
23+
import org.opensearch.wlm.WorkloadManagementSettings;
24+
import org.opensearch.wlm.cancellation.QueryGroupTaskCancellationService;
1925
import org.opensearch.wlm.stats.QueryGroupState;
2026
import org.opensearch.wlm.stats.QueryGroupStats;
2127

2228
import java.io.IOException;
2329
import java.util.ArrayList;
30+
import java.util.Collections;
2431
import java.util.HashMap;
2532
import java.util.List;
2633
import java.util.Map;
2734

2835
import static org.mockito.Mockito.doNothing;
2936
import static org.mockito.Mockito.doThrow;
3037
import static org.mockito.Mockito.mock;
38+
import static org.mockito.Mockito.when;
3139

3240
public class QueryGroupRequestOperationListenerTests extends OpenSearchTestCase {
3341
public static final int ITERATIONS = 20;
3442
ThreadPool testThreadPool;
3543
QueryGroupService queryGroupService;
36-
44+
private QueryGroupTaskCancellationService taskCancellationService;
45+
private ClusterService mockClusterService;
46+
private WorkloadManagementSettings mockWorkloadManagementSettings;
3747
Map<String, QueryGroupState> queryGroupStateMap;
3848
String testQueryGroupId;
3949
QueryGroupRequestOperationListener sut;
4050

4151
public void setUp() throws Exception {
4252
super.setUp();
53+
taskCancellationService = mock(QueryGroupTaskCancellationService.class);
54+
mockClusterService = mock(ClusterService.class);
55+
mockWorkloadManagementSettings = mock(WorkloadManagementSettings.class);
4356
queryGroupStateMap = new HashMap<>();
4457
testQueryGroupId = "safjgagnakg-3r3fads";
4558
testThreadPool = new TestThreadPool("RejectionTestThreadPool");
@@ -77,6 +90,21 @@ public void testValidQueryGroupRequestFailure() throws IOException {
7790
0,
7891
1,
7992
0,
93+
0,
94+
Map.of(
95+
ResourceType.CPU,
96+
new QueryGroupStats.ResourceStats(0, 0, 0),
97+
ResourceType.MEMORY,
98+
new QueryGroupStats.ResourceStats(0, 0, 0)
99+
)
100+
),
101+
QueryGroupTask.DEFAULT_QUERY_GROUP_ID_SUPPLIER.get(),
102+
new QueryGroupStats.QueryGroupStatsHolder(
103+
0,
104+
0,
105+
0,
106+
0,
107+
0,
80108
Map.of(
81109
ResourceType.CPU,
82110
new QueryGroupStats.ResourceStats(0, 0, 0),
@@ -93,8 +121,18 @@ public void testValidQueryGroupRequestFailure() throws IOException {
93121
public void testMultiThreadedValidQueryGroupRequestFailures() {
94122

95123
queryGroupStateMap.put(testQueryGroupId, new QueryGroupState());
96-
97-
queryGroupService = new QueryGroupService(queryGroupStateMap);
124+
QueryGroupsStateAccessor accessor = new QueryGroupsStateAccessor(queryGroupStateMap);
125+
setupMockedQueryGroupsFromClusterState();
126+
queryGroupService = new QueryGroupService(
127+
taskCancellationService,
128+
mockClusterService,
129+
testThreadPool,
130+
mockWorkloadManagementSettings,
131+
null,
132+
accessor,
133+
Collections.emptySet(),
134+
Collections.emptySet()
135+
);
98136

99137
sut = new QueryGroupRequestOperationListener(queryGroupService, testThreadPool);
100138

@@ -127,6 +165,21 @@ public void testMultiThreadedValidQueryGroupRequestFailures() {
127165
0,
128166
ITERATIONS,
129167
0,
168+
0,
169+
Map.of(
170+
ResourceType.CPU,
171+
new QueryGroupStats.ResourceStats(0, 0, 0),
172+
ResourceType.MEMORY,
173+
new QueryGroupStats.ResourceStats(0, 0, 0)
174+
)
175+
),
176+
QueryGroupTask.DEFAULT_QUERY_GROUP_ID_SUPPLIER.get(),
177+
new QueryGroupStats.QueryGroupStatsHolder(
178+
0,
179+
0,
180+
0,
181+
0,
182+
0,
130183
Map.of(
131184
ResourceType.CPU,
132185
new QueryGroupStats.ResourceStats(0, 0, 0),
@@ -149,6 +202,21 @@ public void testInvalidQueryGroupFailure() throws IOException {
149202
0,
150203
0,
151204
0,
205+
0,
206+
Map.of(
207+
ResourceType.CPU,
208+
new QueryGroupStats.ResourceStats(0, 0, 0),
209+
ResourceType.MEMORY,
210+
new QueryGroupStats.ResourceStats(0, 0, 0)
211+
)
212+
),
213+
QueryGroupTask.DEFAULT_QUERY_GROUP_ID_SUPPLIER.get(),
214+
new QueryGroupStats.QueryGroupStatsHolder(
215+
0,
216+
0,
217+
1,
218+
0,
219+
0,
152220
Map.of(
153221
ResourceType.CPU,
154222
new QueryGroupStats.ResourceStats(0, 0, 0),
@@ -169,12 +237,23 @@ private void assertSuccess(
169237
QueryGroupStats expectedStats,
170238
String threadContextQG_Id
171239
) {
172-
240+
QueryGroupsStateAccessor stateAccessor = new QueryGroupsStateAccessor(queryGroupStateMap);
173241
try (ThreadContext.StoredContext currentContext = testThreadPool.getThreadContext().stashContext()) {
174242
testThreadPool.getThreadContext().putHeader(QueryGroupTask.QUERY_GROUP_ID_HEADER, threadContextQG_Id);
175243
queryGroupStateMap.put(testQueryGroupId, new QueryGroupState());
176244

177-
queryGroupService = new QueryGroupService(queryGroupStateMap);
245+
setupMockedQueryGroupsFromClusterState();
246+
247+
queryGroupService = new QueryGroupService(
248+
taskCancellationService,
249+
mockClusterService,
250+
testThreadPool,
251+
mockWorkloadManagementSettings,
252+
null,
253+
stateAccessor,
254+
Collections.emptySet(),
255+
Collections.emptySet()
256+
);
178257

179258
sut = new QueryGroupRequestOperationListener(queryGroupService, testThreadPool);
180259
sut.onRequestFailure(null, null);
@@ -184,4 +263,12 @@ private void assertSuccess(
184263
}
185264

186265
}
266+
267+
private void setupMockedQueryGroupsFromClusterState() {
268+
ClusterState state = mock(ClusterState.class);
269+
Metadata metadata = mock(Metadata.class);
270+
when(mockClusterService.state()).thenReturn(state);
271+
when(state.metadata()).thenReturn(metadata);
272+
when(metadata.queryGroups()).thenReturn(Collections.emptyMap());
273+
}
187274
}

‎server/src/test/java/org/opensearch/wlm/stats/QueryGroupStateTests.java

+8-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,13 @@ public void testRandomQueryGroupsStateUpdates() {
2323

2424
for (int i = 0; i < 25; i++) {
2525
if (i % 5 == 0) {
26-
updaterThreads.add(new Thread(() -> queryGroupState.completions.inc()));
26+
updaterThreads.add(new Thread(() -> {
27+
if (randomBoolean()) {
28+
queryGroupState.completions.inc();
29+
} else {
30+
queryGroupState.shardCompletions.inc();
31+
}
32+
}));
2733
} else if (i % 5 == 1) {
2834
updaterThreads.add(new Thread(() -> {
2935
queryGroupState.totalRejections.inc();
@@ -57,7 +63,7 @@ public void testRandomQueryGroupsStateUpdates() {
5763
}
5864
});
5965

60-
assertEquals(5, queryGroupState.getCompletions());
66+
assertEquals(5, queryGroupState.getCompletions() + queryGroupState.getShardCompletions());
6167
assertEquals(5, queryGroupState.getTotalRejections());
6268

6369
final long sumOfRejectionsDueToResourceTypes = queryGroupState.getResourceState().get(ResourceType.CPU).rejections.count()

‎server/src/test/java/org/opensearch/wlm/stats/QueryGroupStatsTests.java

+4-2
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,10 @@ public void testToXContent() throws IOException {
2828
queryGroupId,
2929
new QueryGroupStats.QueryGroupStatsHolder(
3030
123456789,
31+
13,
3132
2,
3233
0,
33-
13,
34+
1213718,
3435
Map.of(ResourceType.CPU, new QueryGroupStats.ResourceStats(0.3, 13, 2))
3536
)
3637
);
@@ -40,7 +41,7 @@ public void testToXContent() throws IOException {
4041
queryGroupStats.toXContent(builder, ToXContent.EMPTY_PARAMS);
4142
builder.endObject();
4243
assertEquals(
43-
"{\"query_groups\":{\"afakjklaj304041-afaka\":{\"completions\":123456789,\"rejections\":2,\"failures\":0,\"total_cancellations\":13,\"cpu\":{\"current_usage\":0.3,\"cancellations\":13,\"rejections\":2}}}}",
44+
"{\"query_groups\":{\"afakjklaj304041-afaka\":{\"completions\":123456789,\"shard_completions\":1213718,\"rejections\":13,\"failures\":2,\"total_cancellations\":0,\"cpu\":{\"current_usage\":0.3,\"cancellations\":13,\"rejections\":2}}}}",
4445
builder.toString()
4546
);
4647
}
@@ -60,6 +61,7 @@ protected QueryGroupStats createTestInstance() {
6061
randomNonNegativeLong(),
6162
randomNonNegativeLong(),
6263
randomNonNegativeLong(),
64+
randomNonNegativeLong(),
6365
Map.of(
6466
ResourceType.CPU,
6567
new QueryGroupStats.ResourceStats(

0 commit comments

Comments
 (0)
Please sign in to comment.