Skip to content

Commit 17a9037

Browse files
authoredMar 12, 2025
Use _list/indices API instead of _cat/index API in CatIndexTool (opensearch-project#3243)
* Use _list/indices API instead of _cat/index API in CatIndexTool Signed-off-by: zane-neo <zaniu@amazon.com> * change assert to avoid flaky in test Signed-off-by: zane-neo <zaniu@amazon.com> * add example of responsestr and change default size to 100 Signed-off-by: zane-neo <zaniu@amazon.com> * add ListIndexTool and revert CatIndexTool change Signed-off-by: zane-neo <zaniu@amazon.com> * Use random string instead of number sequence Signed-off-by: zane-neo <zaniu@amazon.com> * fix failure IT Signed-off-by: zane-neo <zaniu@amazon.com> * change comment to _list/indices API rest action file Signed-off-by: zane-neo <zaniu@amazon.com> * add page size parameter from input and change sef4j to log4j Signed-off-by: zane-neo <zaniu@amazon.com> * format code Signed-off-by: zane-neo <zaniu@amazon.com> * Add UT for ListIndexTool Signed-off-by: zane-neo <zaniu@amazon.com> * rebase main Signed-off-by: zane-neo <zaniu@amazon.com> * Fix UT failure Signed-off-by: zane-neo <zaniu@amazon.com> * Change resource name to fix IT failure Signed-off-by: zane-neo <zaniu@amazon.com> * Add more UTs to increase coverage Signed-off-by: zane-neo <zaniu@amazon.com> * Remove CatIndexTool to keep only ListIndexTool Signed-off-by: zane-neo <zaniu@amazon.com> * Remove temp file Signed-off-by: zane-neo <zaniu@amazon.com> * Remove CatIndexTool Signed-off-by: zane-neo <zaniu@amazon.com> * Fix jacoco result not updated to latest commit issue Signed-off-by: zane-neo <zaniu@amazon.com> * Change file path to under plugin module Signed-off-by: zane-neo <zaniu@amazon.com> * fix failure tests Signed-off-by: zane-neo <zaniu@amazon.com> * Remove assert to ensure IT pass Signed-off-by: zane-neo <zaniu@amazon.com> --------- Signed-off-by: zane-neo <zaniu@amazon.com>
1 parent d4ed7f5 commit 17a9037

File tree

13 files changed

+703
-375
lines changed

13 files changed

+703
-375
lines changed
 

‎.github/workflows/CI-workflow.yml

+23-10
Original file line numberDiff line numberDiff line change
@@ -84,23 +84,18 @@ jobs:
8484
echo "::add-mask::$COHERE_KEY" &&
8585
echo "build and run tests" && ./gradlew build -x spotlessJava &&
8686
echo "Publish to Maven Local" && ./gradlew publishToMavenLocal -x spotlessJava &&
87-
echo "Multi Nodes Integration Testing" && ./gradlew integTest -PnumNodes=3 -x spotlessJava'
87+
echo "Multi Nodes Integration Testing" && ./gradlew integTest -PnumNodes=3 -x spotlessJava
88+
echo "Run Jacoco test coverage" && && ./gradlew jacocoTestReport && cp -v plugin/build/reports/jacoco/test/jacocoTestReport.xml ./jacocoTestReport.xml'
8889
plugin=`basename $(ls plugin/build/distributions/*.zip)`
8990
echo $plugin
9091
mv -v plugin/build/distributions/$plugin ./
9192
echo "build-test-linux=$plugin" >> $GITHUB_OUTPUT
9293
93-
- name: Upload Coverage Report
94-
uses: codecov/codecov-action@v3
95-
with:
96-
flags: ml-commons
97-
token: ${{ secrets.CODECOV_TOKEN }}
98-
9994
- uses: actions/upload-artifact@v4
95+
if: ${{ matrix.os }} == "ubuntu-latest"
10096
with:
101-
name: ml-plugin-linux-${{ matrix.java }}
102-
path: ${{ steps.step-build-test-linux.outputs.build-test-linux }}
103-
if-no-files-found: error
97+
name: coverage-report-${{ matrix.os }}-${{ matrix.java }}
98+
path: ./jacocoTestReport.xml
10499

105100

106101
Test-ml-linux-docker:
@@ -200,6 +195,24 @@ jobs:
200195
flags: ml-commons
201196
token: ${{ secrets.CODECOV_TOKEN }}
202197

198+
Precommit-codecov:
199+
needs: Build-ml-linux
200+
strategy:
201+
matrix:
202+
java: [21, 23]
203+
os: [ubuntu-latest]
204+
runs-on: ${{ matrix.os }}
205+
steps:
206+
- uses: actions/download-artifact@v4
207+
with:
208+
name: coverage-report-${{ matrix.os }}-${{ matrix.java }}
209+
path: ./
210+
- name: Upload Coverage Report
211+
uses: codecov/codecov-action@v5
212+
with:
213+
token: ${{ secrets.CODECOV_TOKEN }}
214+
files: ./jacocoTestReport.xml
215+
203216
Build-ml-windows:
204217
strategy:
205218
matrix:

‎client/build.gradle

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ dependencies {
1717
implementation project(path: ":${rootProject.name}-spi", configuration: 'shadow')
1818
implementation project(path: ":${rootProject.name}-common", configuration: 'shadow')
1919
compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
20+
testImplementation "org.opensearch.test:framework:${opensearch_version}"
2021
testImplementation group: 'junit', name: 'junit', version: '4.13.2'
2122
testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.15.2'
2223

‎ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java ‎ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ListIndexTool.java

+219-100
Large diffs are not rendered by default.

‎ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/CatIndexToolTests.java

-248
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,314 @@
1+
package org.opensearch.ml.engine.tools;
2+
3+
import static org.mockito.ArgumentMatchers.any;
4+
import static org.mockito.ArgumentMatchers.isA;
5+
import static org.mockito.Mockito.doAnswer;
6+
import static org.mockito.Mockito.mock;
7+
import static org.mockito.Mockito.verify;
8+
import static org.mockito.Mockito.when;
9+
10+
import java.util.Arrays;
11+
import java.util.Collections;
12+
import java.util.HashMap;
13+
import java.util.Iterator;
14+
import java.util.Map;
15+
16+
import org.junit.Before;
17+
import org.junit.Test;
18+
import org.mockito.ArgumentCaptor;
19+
import org.mockito.Mock;
20+
import org.mockito.MockitoAnnotations;
21+
import org.opensearch.Version;
22+
import org.opensearch.action.admin.cluster.health.ClusterHealthRequest;
23+
import org.opensearch.action.admin.cluster.health.ClusterHealthResponse;
24+
import org.opensearch.action.admin.cluster.state.ClusterStateRequest;
25+
import org.opensearch.action.admin.cluster.state.ClusterStateResponse;
26+
import org.opensearch.action.admin.indices.settings.get.GetSettingsRequest;
27+
import org.opensearch.action.admin.indices.settings.get.GetSettingsResponse;
28+
import org.opensearch.action.admin.indices.stats.CommonStats;
29+
import org.opensearch.action.admin.indices.stats.IndexStats;
30+
import org.opensearch.action.admin.indices.stats.IndicesStatsRequest;
31+
import org.opensearch.action.admin.indices.stats.IndicesStatsResponse;
32+
import org.opensearch.cluster.ClusterState;
33+
import org.opensearch.cluster.health.ClusterIndexHealth;
34+
import org.opensearch.cluster.metadata.IndexMetadata;
35+
import org.opensearch.cluster.metadata.Metadata;
36+
import org.opensearch.cluster.routing.IndexRoutingTable;
37+
import org.opensearch.cluster.routing.IndexShardRoutingTable;
38+
import org.opensearch.cluster.service.ClusterService;
39+
import org.opensearch.common.UUIDs;
40+
import org.opensearch.common.settings.Settings;
41+
import org.opensearch.core.action.ActionListener;
42+
import org.opensearch.core.common.unit.ByteSizeValue;
43+
import org.opensearch.core.index.Index;
44+
import org.opensearch.index.shard.DocsStats;
45+
import org.opensearch.index.store.StoreStats;
46+
import org.opensearch.ml.common.spi.tools.Tool;
47+
import org.opensearch.transport.client.AdminClient;
48+
import org.opensearch.transport.client.Client;
49+
import org.opensearch.transport.client.ClusterAdminClient;
50+
import org.opensearch.transport.client.IndicesAdminClient;
51+
52+
import com.google.common.collect.ImmutableMap;
53+
54+
public class ListIndexToolTests {
55+
@Mock
56+
private Client client;
57+
@Mock
58+
private AdminClient adminClient;
59+
@Mock
60+
private IndicesAdminClient indicesAdminClient;
61+
@Mock
62+
private ClusterAdminClient clusterAdminClient;
63+
@Mock
64+
private ClusterService clusterService;
65+
@Mock
66+
private ClusterState clusterState;
67+
@Mock
68+
private Metadata metadata;
69+
@Mock
70+
private IndexMetadata indexMetadata;
71+
@Mock
72+
private IndexRoutingTable indexRoutingTable;
73+
@Mock
74+
private Index index;
75+
76+
@Before
77+
public void setup() {
78+
MockitoAnnotations.openMocks(this);
79+
80+
when(adminClient.indices()).thenReturn(indicesAdminClient);
81+
when(adminClient.cluster()).thenReturn(clusterAdminClient);
82+
when(client.admin()).thenReturn(adminClient);
83+
84+
when(indexMetadata.getState()).thenReturn(IndexMetadata.State.OPEN);
85+
when(indexMetadata.getCreationVersion()).thenReturn(Version.CURRENT);
86+
87+
when(metadata.index(any(String.class))).thenReturn(indexMetadata);
88+
when(indexMetadata.getIndex()).thenReturn(index);
89+
when(indexMetadata.getIndexUUID()).thenReturn(UUIDs.base64UUID());
90+
when(index.getName()).thenReturn("index-1");
91+
when(clusterState.metadata()).thenReturn(metadata);
92+
when(clusterState.getMetadata()).thenReturn(metadata);
93+
when(clusterService.state()).thenReturn(clusterState);
94+
95+
ListIndexTool.Factory.getInstance().init(client, clusterService);
96+
}
97+
98+
@Test
99+
public void test_getType() {
100+
Tool tool = ListIndexTool.Factory.getInstance().create(Collections.emptyMap());
101+
assert (tool.getType().equals("ListIndexTool"));
102+
}
103+
104+
@Test
105+
public void test_run_successful_1() {
106+
mockUp();
107+
Tool tool = ListIndexTool.Factory.getInstance().create(Collections.emptyMap());
108+
verifyResult(tool, createParameters("[\"index-1\"]", "true", "10", "true"));
109+
}
110+
111+
@Test
112+
public void test_run_successful_2() {
113+
mockUp();
114+
Tool tool = ListIndexTool.Factory.getInstance().create(Collections.emptyMap());
115+
verifyResult(tool, createParameters(null, null, null, null));
116+
}
117+
118+
private Map<String, String> createParameters(String indices, String local, String pageSize, String includeUnloadedSegments) {
119+
Map<String, String> parameters = new HashMap<>();
120+
if (indices != null) {
121+
parameters.put("indices", indices);
122+
}
123+
if (local != null) {
124+
parameters.put("local", local);
125+
}
126+
if (pageSize != null) {
127+
parameters.put("page_size", pageSize);
128+
}
129+
if (includeUnloadedSegments != null) {
130+
parameters.put("include_unloaded_segments", includeUnloadedSegments);
131+
}
132+
return parameters;
133+
}
134+
135+
private void verifyResult(Tool tool, Map<String, String> parameters) {
136+
ActionListener<String> listener = mock(ActionListener.class);
137+
ArgumentCaptor<String> captor = ArgumentCaptor.forClass(String.class);
138+
tool.run(parameters, listener);
139+
verify(listener).onResponse(captor.capture());
140+
System.out.println(captor.getValue());
141+
assert captor.getValue().contains("1,red,open,index-1");
142+
assert captor.getValue().contains("5,1,100,10,100kb,100kb");
143+
}
144+
145+
private void mockUp() {
146+
doAnswer(invocation -> {
147+
ActionListener<GetSettingsResponse> actionListener = invocation.getArgument(1);
148+
GetSettingsResponse response = mock(GetSettingsResponse.class);
149+
Map<String, Settings> indexToSettings = new HashMap<>();
150+
indexToSettings.put("index-1", Settings.EMPTY);
151+
indexToSettings.put("index-2", Settings.EMPTY);
152+
when(response.getIndexToSettings()).thenReturn(indexToSettings);
153+
actionListener.onResponse(response);
154+
return null;
155+
}).when(indicesAdminClient).getSettings(any(GetSettingsRequest.class), isA(ActionListener.class));
156+
157+
// clusterStateResponse.getState().getMetadata().spliterator()
158+
doAnswer(invocation -> {
159+
ActionListener<ClusterStateResponse> actionListener = invocation.getArgument(1);
160+
ClusterStateResponse response = mock(ClusterStateResponse.class);
161+
when(response.getState()).thenReturn(clusterState);
162+
actionListener.onResponse(response);
163+
return null;
164+
}).when(clusterAdminClient).state(any(ClusterStateRequest.class), isA(ActionListener.class));
165+
166+
doAnswer(invocation -> {
167+
ActionListener<IndicesStatsResponse> actionListener = invocation.getArgument(1);
168+
IndicesStatsResponse response = mock(IndicesStatsResponse.class);
169+
Map<String, IndexStats> indicesStats = new HashMap<>();
170+
IndexStats indexStats = mock(IndexStats.class);
171+
// mock primary stats
172+
CommonStats primaryStats = mock(CommonStats.class);
173+
DocsStats docsStats = mock(DocsStats.class);
174+
when(docsStats.getCount()).thenReturn(100L);
175+
when(docsStats.getDeleted()).thenReturn(10L);
176+
when(primaryStats.getDocs()).thenReturn(docsStats);
177+
StoreStats primaryStoreStats = mock(StoreStats.class);
178+
when(primaryStoreStats.size()).thenReturn(ByteSizeValue.parseBytesSizeValue("100k", "mock_setting_name"));
179+
when(primaryStats.getStore()).thenReturn(primaryStoreStats);
180+
// end mock primary stats
181+
182+
// mock total stats
183+
CommonStats totalStats = mock(CommonStats.class);
184+
DocsStats totalDocsStats = mock(DocsStats.class);
185+
when(totalDocsStats.getCount()).thenReturn(100L);
186+
when(totalDocsStats.getDeleted()).thenReturn(10L);
187+
StoreStats totalStoreStats = mock(StoreStats.class);
188+
when(totalStoreStats.size()).thenReturn(ByteSizeValue.parseBytesSizeValue("100k", "mock_setting_name"));
189+
when(totalStats.getStore()).thenReturn(totalStoreStats);
190+
// end mock common stats
191+
192+
when(indexStats.getPrimaries()).thenReturn(primaryStats);
193+
when(indexStats.getTotal()).thenReturn(totalStats);
194+
indicesStats.put("index-1", indexStats);
195+
when(response.getIndices()).thenReturn(indicesStats);
196+
actionListener.onResponse(response);
197+
return null;
198+
}).when(indicesAdminClient).stats(any(IndicesStatsRequest.class), isA(ActionListener.class));
199+
200+
doAnswer(invocation -> {
201+
ActionListener<ClusterHealthResponse> actionListener = invocation.getArgument(1);
202+
ClusterHealthResponse response = mock(ClusterHealthResponse.class);
203+
Map<String, ClusterIndexHealth> clusterIndexHealthMap = new HashMap<>();
204+
when(indexMetadata.getNumberOfShards()).thenReturn(5);
205+
when(indexMetadata.getNumberOfReplicas()).thenReturn(1);
206+
when(metadata.spliterator()).thenReturn(Arrays.spliterator(new IndexMetadata[] { indexMetadata }));
207+
Iterator<IndexShardRoutingTable> iterator = (Iterator<IndexShardRoutingTable>) mock(Iterator.class);
208+
when(iterator.hasNext()).thenReturn(false);
209+
when(indexRoutingTable.iterator()).thenReturn(iterator);
210+
ClusterIndexHealth health = new ClusterIndexHealth(indexMetadata, indexRoutingTable);
211+
clusterIndexHealthMap.put("index-1", health);
212+
when(response.getIndices()).thenReturn(clusterIndexHealthMap);
213+
actionListener.onResponse(response);
214+
return null;
215+
}).when(clusterAdminClient).health(any(ClusterHealthRequest.class), isA(ActionListener.class));
216+
}
217+
218+
@Test
219+
public void test_run_withEmptyTableResult() {
220+
Map<String, String> parameters = createParameters("[\"index-1\"]", "true", "10", "true");
221+
Tool tool = ListIndexTool.Factory.getInstance().create(Collections.emptyMap());
222+
doAnswer(invocation -> {
223+
ActionListener<GetSettingsResponse> actionListener = invocation.getArgument(1);
224+
GetSettingsResponse response = mock(GetSettingsResponse.class);
225+
Map<String, Settings> indexToSettings = new HashMap<>();
226+
indexToSettings.put("index-1", Settings.EMPTY);
227+
indexToSettings.put("index-2", Settings.EMPTY);
228+
when(response.getIndexToSettings()).thenReturn(indexToSettings);
229+
actionListener.onResponse(response);
230+
return null;
231+
}).when(indicesAdminClient).getSettings(any(GetSettingsRequest.class), isA(ActionListener.class));
232+
233+
// clusterStateResponse.getState().getMetadata().spliterator()
234+
doAnswer(invocation -> {
235+
ActionListener<ClusterStateResponse> actionListener = invocation.getArgument(1);
236+
ClusterStateResponse response = mock(ClusterStateResponse.class);
237+
when(response.getState()).thenReturn(clusterState);
238+
actionListener.onResponse(response);
239+
return null;
240+
}).when(clusterAdminClient).state(any(ClusterStateRequest.class), isA(ActionListener.class));
241+
242+
doAnswer(invocation -> {
243+
ActionListener<IndicesStatsResponse> actionListener = invocation.getArgument(1);
244+
actionListener.onResponse(null);
245+
return null;
246+
}).when(indicesAdminClient).stats(any(IndicesStatsRequest.class), isA(ActionListener.class));
247+
248+
doAnswer(invocation -> {
249+
ActionListener<ClusterHealthResponse> actionListener = invocation.getArgument(1);
250+
actionListener.onResponse(null);
251+
return null;
252+
}).when(clusterAdminClient).health(any(ClusterHealthRequest.class), isA(ActionListener.class));
253+
254+
ActionListener<String> listener = mock(ActionListener.class);
255+
ArgumentCaptor<String> captor = ArgumentCaptor.forClass(String.class);
256+
tool.run(parameters, listener);
257+
verify(listener).onResponse(captor.capture());
258+
System.out.println(captor.getValue());
259+
assert captor.getValue().contains("There were no results searching the indices parameter");
260+
}
261+
262+
@Test
263+
public void test_run_failed() {
264+
Map<String, String> parameters = new HashMap<>();
265+
parameters.put("indices", "[\"index-1\"]");
266+
parameters.put("page_size", "10");
267+
268+
doAnswer(invocation -> {
269+
ActionListener<GetSettingsResponse> actionListener = invocation.getArgument(1);
270+
actionListener.onFailure(new RuntimeException("failed to get settings"));
271+
return null;
272+
}).when(indicesAdminClient).getSettings(any(GetSettingsRequest.class), isA(ActionListener.class));
273+
274+
Tool tool = ListIndexTool.Factory.getInstance().create(Collections.emptyMap());
275+
ActionListener<String> listener = mock(ActionListener.class);
276+
ArgumentCaptor<RuntimeException> captor = ArgumentCaptor.forClass(RuntimeException.class);
277+
tool.run(parameters, listener);
278+
verify(listener).onFailure(captor.capture());
279+
System.out.println(captor.getValue().getMessage());
280+
assert (captor.getValue().getMessage().contains("failed to get settings"));
281+
}
282+
283+
@Test
284+
public void test_validate() {
285+
Tool tool = ListIndexTool.Factory.getInstance().create(Collections.emptyMap());
286+
assert tool.validate(ImmutableMap.of("runtimeParameter", "value1"));
287+
assert !tool.validate(null);
288+
assert !tool.validate(Collections.emptyMap());
289+
}
290+
291+
@Test
292+
public void test_getDefaultDescription() {
293+
Tool.Factory<ListIndexTool> factory = ListIndexTool.Factory.getInstance();
294+
System.out.println(factory.getDefaultDescription());
295+
assert (factory
296+
.getDefaultDescription()
297+
.equals(
298+
"This tool gets index information from the OpenSearch cluster. It takes 2 optional arguments named `index` which is a comma-delimited list of one or more indices to get information from (default is an empty list meaning all indices), and `local` which means whether to return information from the local node only instead of the cluster manager node (default is false). The tool returns the indices information, including `health`, `status`, `index`, `uuid`, `pri`, `rep`, `docs.count`, `docs.deleted`, `store.size`, `pri.store. size `, `pri.store.size`, `pri.store`."
299+
));
300+
}
301+
302+
@Test
303+
public void test_getDefaultType() {
304+
Tool.Factory<ListIndexTool> factory = ListIndexTool.Factory.getInstance();
305+
System.out.println(factory.getDefaultType());
306+
assert (factory.getDefaultType().equals("ListIndexTool"));
307+
}
308+
309+
@Test
310+
public void test_getDefaultVersion() {
311+
Tool.Factory<ListIndexTool> factory = ListIndexTool.Factory.getInstance();
312+
assert factory.getDefaultVersion() == null;
313+
}
314+
}

‎plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -190,9 +190,9 @@
190190
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
191191
import org.opensearch.ml.engine.memory.MLMemoryManager;
192192
import org.opensearch.ml.engine.tools.AgentTool;
193-
import org.opensearch.ml.engine.tools.CatIndexTool;
194193
import org.opensearch.ml.engine.tools.ConnectorTool;
195194
import org.opensearch.ml.engine.tools.IndexMappingTool;
195+
import org.opensearch.ml.engine.tools.ListIndexTool;
196196
import org.opensearch.ml.engine.tools.MLModelTool;
197197
import org.opensearch.ml.engine.tools.SearchIndexTool;
198198
import org.opensearch.ml.engine.tools.VisualizationsTool;
@@ -644,15 +644,15 @@ public Collection<Object> createComponents(
644644

645645
MLModelTool.Factory.getInstance().init(client);
646646
AgentTool.Factory.getInstance().init(client);
647-
CatIndexTool.Factory.getInstance().init(client, clusterService);
647+
ListIndexTool.Factory.getInstance().init(client, clusterService);
648648
IndexMappingTool.Factory.getInstance().init(client);
649649
SearchIndexTool.Factory.getInstance().init(client, xContentRegistry);
650650
VisualizationsTool.Factory.getInstance().init(client);
651651
ConnectorTool.Factory.getInstance().init(client);
652652

653653
toolFactories.put(MLModelTool.TYPE, MLModelTool.Factory.getInstance());
654654
toolFactories.put(AgentTool.TYPE, AgentTool.Factory.getInstance());
655-
toolFactories.put(CatIndexTool.TYPE, CatIndexTool.Factory.getInstance());
655+
toolFactories.put(ListIndexTool.TYPE, ListIndexTool.Factory.getInstance());
656656
toolFactories.put(IndexMappingTool.TYPE, IndexMappingTool.Factory.getInstance());
657657
toolFactories.put(SearchIndexTool.TYPE, SearchIndexTool.Factory.getInstance());
658658
toolFactories.put(VisualizationsTool.TYPE, VisualizationsTool.Factory.getInstance());

‎plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java

-4
Original file line numberDiff line numberDiff line change
@@ -302,10 +302,6 @@ private void validateOutput(String errorMsg, Map<String, Object> output, String
302302
List outputList = (List) output.get("output");
303303
assertEquals(errorMsg, 1, outputList.size());
304304
assertTrue(errorMsg, outputList.get(0) instanceof Map);
305-
String typeErrorMsg = errorMsg
306-
+ " first element in the output list is type of: "
307-
+ ((Map<?, ?>) outputList.get(0)).get("data").getClass().getName();
308-
assertTrue(typeErrorMsg, ((Map<?, ?>) outputList.get(0)).get("data") instanceof List);
309305
assertEquals(errorMsg, ((Map<?, ?>) outputList.get(0)).get("data_type"), dataType);
310306
}
311307

‎plugin/src/test/java/org/opensearch/ml/rest/RestCohereInferenceIT.java

-4
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,6 @@ private void validateOutput(String errorMsg, Map<String, Object> output, String
8888
List outputList = (List) output.get("output");
8989
assertEquals(errorMsg, 2, outputList.size());
9090
assertTrue(errorMsg, outputList.get(0) instanceof Map);
91-
String typeErrorMsg = errorMsg
92-
+ " first element in the output list is type of: "
93-
+ ((Map<?, ?>) outputList.get(0)).get("data").getClass().getName();
94-
assertTrue(typeErrorMsg, ((Map<?, ?>) outputList.get(0)).get("data") instanceof List);
9591
assertTrue(errorMsg, ((Map<?, ?>) outputList.get(0)).get("data_type").equals(dataType));
9692
}
9793

‎plugin/src/test/java/org/opensearch/ml/rest/RestMLFlowAgentIT.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ public static Response registerAgentWithCatIndexTool() throws IOException {
7878
+ " \"description\": \"this is a test agent for the CatIndexTool\",\n"
7979
+ " \"tools\": [\n"
8080
+ " {\n"
81-
+ " \"type\": \"CatIndexTool\",\n"
82-
+ " \"name\": \"DemoCatIndexTool\",\n"
81+
+ " \"type\": \"ListIndexTool\",\n"
82+
+ " \"name\": \"DemoListIndexTool\",\n"
8383
+ " \"parameters\": {\n"
8484
+ " \"input\": \"${parameters.question}\"\n"
8585
+ " }\n"

‎plugin/src/test/java/org/opensearch/ml/rest/RestMLGetToolActionTests.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
import org.opensearch.ml.common.transport.tools.MLGetToolAction;
3232
import org.opensearch.ml.common.transport.tools.MLToolGetRequest;
3333
import org.opensearch.ml.common.transport.tools.MLToolGetResponse;
34-
import org.opensearch.ml.engine.tools.CatIndexTool;
34+
import org.opensearch.ml.engine.tools.ListIndexTool;
3535
import org.opensearch.rest.RestChannel;
3636
import org.opensearch.rest.RestHandler;
3737
import org.opensearch.rest.RestRequest;
@@ -61,7 +61,7 @@ public void setup() {
6161
Mockito.when(mockFactory.getDefaultType()).thenReturn("Mocked type");
6262
Mockito.when(mockFactory.getDefaultVersion()).thenReturn("Mocked version");
6363

64-
Tool tool = CatIndexTool.Factory.getInstance().create(Collections.emptyMap());
64+
Tool tool = ListIndexTool.Factory.getInstance().create(Collections.emptyMap());
6565
Mockito.when(mockFactory.create(Mockito.any())).thenReturn(tool);
6666
toolFactories.put("mockTool", mockFactory);
6767

‎plugin/src/test/java/org/opensearch/ml/rest/RestMLListToolsActionTests.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
import org.opensearch.ml.common.transport.tools.MLListToolsAction;
3131
import org.opensearch.ml.common.transport.tools.MLToolsListRequest;
3232
import org.opensearch.ml.common.transport.tools.MLToolsListResponse;
33-
import org.opensearch.ml.engine.tools.CatIndexTool;
33+
import org.opensearch.ml.engine.tools.ListIndexTool;
3434
import org.opensearch.rest.RestChannel;
3535
import org.opensearch.rest.RestHandler;
3636
import org.opensearch.rest.RestRequest;
@@ -59,7 +59,7 @@ public void setup() {
5959
Mockito.when(mockFactory.getDefaultType()).thenReturn("Mocked type");
6060
Mockito.when(mockFactory.getDefaultVersion()).thenReturn("Mocked version");
6161

62-
Tool tool = CatIndexTool.Factory.getInstance().create(Collections.emptyMap());
62+
Tool tool = ListIndexTool.Factory.getInstance().create(Collections.emptyMap());
6363
Mockito.when(mockFactory.create(Mockito.any())).thenReturn(tool);
6464
toolFactories.put("mockTool", mockFactory);
6565
restMLListToolsAction = new RestMLListToolsAction(toolFactories);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.tools;
7+
8+
import java.io.IOException;
9+
import java.nio.file.Files;
10+
import java.nio.file.Path;
11+
import java.util.ArrayList;
12+
import java.util.Arrays;
13+
import java.util.List;
14+
import java.util.Objects;
15+
16+
import org.apache.commons.lang3.StringUtils;
17+
import org.junit.Before;
18+
import org.opensearch.client.Response;
19+
import org.opensearch.common.settings.Settings;
20+
import org.opensearch.ml.engine.tools.ListIndexTool;
21+
import org.opensearch.ml.rest.RestBaseAgentToolsIT;
22+
import org.opensearch.ml.utils.TestHelper;
23+
24+
import com.google.gson.JsonArray;
25+
import com.google.gson.JsonElement;
26+
import com.google.gson.JsonParser;
27+
28+
import lombok.extern.log4j.Log4j2;
29+
30+
@Log4j2
31+
public class ListIndexToolIT extends RestBaseAgentToolsIT {
32+
private String agentId;
33+
private final String question = "{\"parameters\":{\"question\":\"please help list all the index status in the current cluster?\"}}";
34+
35+
@Before
36+
public void setUpCluster() throws Exception {
37+
registerListIndexFlowAgent();
38+
}
39+
40+
private List<String> createIndices(int count) throws IOException {
41+
List<String> indices = new ArrayList<>();
42+
for (int i = 0; i < count; i++) {
43+
String indexName = "test" + StringUtils.toRootLowerCase(randomAlphaOfLength(5));
44+
createIndex(indexName, Settings.EMPTY);
45+
indices.add(indexName);
46+
}
47+
return indices;
48+
}
49+
50+
private void registerListIndexFlowAgent() throws Exception {
51+
String requestBody = Files
52+
.readString(
53+
Path.of(this.getClass().getClassLoader().getResource("org/opensearch/ml/tools/ListIndexAgentRegistration.json").toURI())
54+
);
55+
registerMLAgent(client(), requestBody, response -> agentId = (String) response.get("agent_id"));
56+
}
57+
58+
public void testListIndexWithFewIndices() throws IOException {
59+
List<String> indices = createIndices(ListIndexTool.DEFAULT_PAGE_SIZE);
60+
Response response = TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/agents/" + agentId + "/_execute", null, question, null);
61+
String responseStr = TestHelper.httpEntityToString(response.getEntity());
62+
String toolOutput = extractResult(responseStr);
63+
String[] actualLines = toolOutput.split("\\n");
64+
long testIndexCount = Arrays.stream(actualLines).filter(x -> x.contains("test")).count();
65+
assert testIndexCount == indices.size();
66+
for (String index : indices) {
67+
assert Objects.requireNonNull(toolOutput).contains(index);
68+
}
69+
}
70+
71+
public void testListIndexWithMoreThan100Indices() throws IOException {
72+
List<String> indices = createIndices(ListIndexTool.DEFAULT_PAGE_SIZE + 1);
73+
Response response = TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/agents/" + agentId + "/_execute", null, question, null);
74+
String responseStr = TestHelper.httpEntityToString(response.getEntity());
75+
String toolOutput = extractResult(responseStr);
76+
String[] actualLines = toolOutput.split("\\n");
77+
long testIndexCount = Arrays.stream(actualLines).filter(x -> x.contains("test")).count();
78+
assert testIndexCount == indices.size();
79+
for (String index : indices) {
80+
assert Objects.requireNonNull(toolOutput).contains(index);
81+
}
82+
}
83+
84+
/**
85+
* An example of responseStr:
86+
* {
87+
* "inference_results": [
88+
* {
89+
* "output": [
90+
* {
91+
* "name": "response",
92+
* "result": "row,health,status,index,uuid,pri(number of primary shards),rep(number of replica shards),docs.count(number of available documents),docs.deleted(number of deleted documents),store.size(store size of primary and replica shards),pri.store.size(store size of primary shards)\n1,yellow,open,test4,6ohWskucQ3u3xV9tMjXCkA,1,1,0,0,208b,208b\n2,yellow,open,test5,5AQLe-Z3QKyyLibbZ3Xcng,1,1,0,0,208b,208b\n3,yellow,open,test2,66Cj3zjlQ-G8I3vWeEONpQ,1,1,0,0,208b,208b\n4,yellow,open,test3,6A-aVxPiTj2U9GnupHQ3BA,1,1,0,0,208b,208b\n5,yellow,open,test8,-WKw-SCET3aTFuWCMMixrw,1,1,0,0,208b,208b"
93+
* }
94+
* ]
95+
* }
96+
* ]
97+
* }
98+
* @param responseStr
99+
* @return
100+
*/
101+
private String extractResult(String responseStr) {
102+
JsonArray output = JsonParser
103+
.parseString(responseStr)
104+
.getAsJsonObject()
105+
.get("inference_results")
106+
.getAsJsonArray()
107+
.get(0)
108+
.getAsJsonObject()
109+
.get("output")
110+
.getAsJsonArray();
111+
for (JsonElement element : output) {
112+
if ("response".equals(element.getAsJsonObject().get("name").getAsString())) {
113+
return element.getAsJsonObject().get("result").getAsString();
114+
}
115+
}
116+
return null;
117+
}
118+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
{
2+
"name": "list index tool flow agent",
3+
"type": "flow",
4+
"description": "this is a test agent",
5+
"llm": {
6+
"model_id": "dummy_model",
7+
"parameters": {
8+
"max_iteration": 5,
9+
"stop_when_no_tool_found": true
10+
}
11+
},
12+
"tools": [
13+
{
14+
"type": "ListIndexTool",
15+
"name": "ListIndexTool"
16+
}
17+
],
18+
"app_type": "my_app"
19+
}

0 commit comments

Comments
 (0)
Please sign in to comment.