Skip to content

Commit bb6510a

Browse files
committed
Implement engine consumer and conversation api client for tab organization feature
- GetSuggestedTopics, DedupeTopics, GetFocusTabs are implemented to get topics and classified tabs result from the server.
1 parent fb04cff commit bb6510a

11 files changed

+780
-24
lines changed

components/ai_chat/core/browser/engine/conversation_api_client.cc

+9
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,10 @@ base::Value::List ConversationEventsToList(
200200
{ConversationEventType::RequestSuggestedActions,
201201
"requestSuggestedActions"},
202202
{ConversationEventType::SuggestedActions, "suggestedActions"},
203+
{ConversationEventType::GetSuggestedTopicsForFocusTabs,
204+
"suggestFocusTopics"},
205+
{ConversationEventType::DedupeTopics, "dedupeFocusTopics"},
206+
{ConversationEventType::GetFocusTabsForTopic, "classifyTabs"},
203207
{ConversationEventType::UploadImage, "uploadImage"}});
204208

205209
base::Value::List events;
@@ -217,6 +221,11 @@ base::Value::List ConversationEventsToList(
217221
event_dict.Set("type", type_it->second);
218222

219223
event_dict.Set("content", event.content);
224+
225+
if (event.type == ConversationEventType::GetFocusTabsForTopic) {
226+
event_dict.Set("topic", event.topic);
227+
}
228+
220229
events.Append(std::move(event_dict));
221230
}
222231
return events;

components/ai_chat/core/browser/engine/conversation_api_client.h

+4
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ class ConversationAPIClient {
6464
RequestRewrite,
6565
SuggestedActions,
6666
UploadImage,
67+
GetSuggestedTopicsForFocusTabs,
68+
DedupeTopics,
69+
GetFocusTabsForTopic,
6770
// TODO(petemill):
6871
// - Search in-progress?
6972
// - Sources?
@@ -76,6 +79,7 @@ class ConversationAPIClient {
7679
mojom::CharacterType role;
7780
ConversationEventType type;
7881
std::string content;
82+
std::string topic; // Used in GetFocusTabsForTopic event.
7983
};
8084

8185
ConversationAPIClient(

components/ai_chat/core/browser/engine/conversation_api_client_unittest.cc

+46-24
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,44 @@ using Ticket = api_request_helper::APIRequestHelper::Ticket;
6464

6565
namespace ai_chat {
6666

67+
namespace {
68+
69+
const std::pair<std::vector<ConversationAPIClient::ConversationEvent>,
70+
std::string>&
71+
GetMockEventsAndExpectedEventsBody() {
72+
static base::NoDestructor<std::pair<
73+
std::vector<ConversationAPIClient::ConversationEvent>, std::string>>
74+
mock_events_and_expected_events_body{
75+
std::vector<ConversationAPIClient::ConversationEvent>{
76+
{mojom::CharacterType::HUMAN, ConversationAPIClient::PageText,
77+
"This is a page about The Mandalorian."},
78+
{mojom::CharacterType::HUMAN, ConversationAPIClient::PageExcerpt,
79+
"The Mandalorian"},
80+
{mojom::CharacterType::HUMAN, ConversationAPIClient::ChatMessage,
81+
"Est-ce lié à une série plus large?"},
82+
{mojom::CharacterType::HUMAN,
83+
ConversationAPIClient::GetSuggestedTopicsForFocusTabs,
84+
"GetSuggestedTopicsForFocusTabs"},
85+
{mojom::CharacterType::HUMAN, ConversationAPIClient::DedupeTopics,
86+
"DedupeTopics"},
87+
{mojom::CharacterType::HUMAN,
88+
ConversationAPIClient::GetFocusTabsForTopic,
89+
"GetFocusTabsForTopics", "C++"},
90+
},
91+
R"([
92+
{"role": "user", "type": "pageText", "content": "This is a page about The Mandalorian."},
93+
{"role": "user", "type": "pageExcerpt", "content": "The Mandalorian"},
94+
{"role": "user", "type": "chatMessage", "content": "Est-ce lié à une série plus large?"},
95+
{"role": "user", "type": "suggestFocusTopics", "content": "GetSuggestedTopicsForFocusTabs"},
96+
{"role": "user", "type": "dedupeFocusTopics", "content": "DedupeTopics"},
97+
{"role": "user", "type": "classifyTabs", "content": "GetFocusTabsForTopics", "topic": "C++"}
98+
])"};
99+
100+
return *mock_events_and_expected_events_body;
101+
}
102+
103+
} // namespace
104+
67105
using ConversationEvent = ConversationAPIClient::ConversationEvent;
68106

69107
class MockCallbacks {
@@ -198,18 +236,10 @@ TEST_F(ConversationAPIUnitTest, PerformRequest_PremiumHeaders) {
198236
// - ConversationEvent is correctly formatted into JSON
199237
// - completion response is parsed and passed through to the callbacks
200238
std::string expected_crediential = "unit_test_credential";
201-
std::vector<ConversationAPIClient::ConversationEvent> events = {
202-
{mojom::CharacterType::HUMAN, ConversationAPIClient::PageText,
203-
"This is a page about The Mandalorian."},
204-
{mojom::CharacterType::HUMAN, ConversationAPIClient::PageExcerpt,
205-
"The Mandalorian"},
206-
{mojom::CharacterType::HUMAN, ConversationAPIClient::ChatMessage,
207-
"Est-ce lié à une série plus large?"}};
208-
std::string expected_events_body = R"([
209-
{"role": "user", "type": "pageText", "content": "This is a page about The Mandalorian."},
210-
{"role": "user", "type": "pageExcerpt", "content": "The Mandalorian"},
211-
{"role": "user", "type": "chatMessage", "content": "Est-ce lié à une série plus large?"}
212-
])";
239+
const std::vector<ConversationAPIClient::ConversationEvent>& events =
240+
GetMockEventsAndExpectedEventsBody().first;
241+
const std::string& expected_events_body =
242+
GetMockEventsAndExpectedEventsBody().second;
213243
std::string expected_system_language = "en_KY";
214244
const brave_l10n::test::ScopedDefaultLocale scoped_default_locale(
215245
expected_system_language);
@@ -406,18 +436,10 @@ TEST_F(ConversationAPIUnitTest, PerformRequest_NonPremium) {
406436
// - ConversationEvent is correctly formatted into JSON
407437
// - completion response is parsed and passed through to the callbacks
408438
std::string expected_crediential = "unit_test_credential";
409-
std::vector<ConversationAPIClient::ConversationEvent> events = {
410-
{mojom::CharacterType::HUMAN, ConversationAPIClient::PageText,
411-
"This is a page about The Mandalorian."},
412-
{mojom::CharacterType::HUMAN, ConversationAPIClient::PageExcerpt,
413-
"The Mandalorian"},
414-
{mojom::CharacterType::HUMAN, ConversationAPIClient::ChatMessage,
415-
"Est-ce lié à une série plus large?"}};
416-
std::string expected_events_body = R"([
417-
{"role": "user", "type": "pageText", "content": "This is a page about The Mandalorian."},
418-
{"role": "user", "type": "pageExcerpt", "content": "The Mandalorian"},
419-
{"role": "user", "type": "chatMessage", "content": "Est-ce lié à une série plus large?"}
420-
])";
439+
const std::vector<ConversationAPIClient::ConversationEvent>& events =
440+
GetMockEventsAndExpectedEventsBody().first;
441+
const std::string& expected_events_body =
442+
GetMockEventsAndExpectedEventsBody().second;
421443
std::string expected_system_language = "en_KY";
422444
const brave_l10n::test::ScopedDefaultLocale scoped_default_locale(
423445
expected_system_language);

components/ai_chat/core/browser/engine/engine_consumer.cc

+13
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,17 @@ bool EngineConsumer::CanPerformCompletionRequest(
4949
return true;
5050
}
5151

52+
void EngineConsumer::GetSuggestedTopics(const std::vector<Tab>& tabs,
53+
GetSuggestedTopicsCallback callback) {
54+
NOTIMPLEMENTED();
55+
std::move(callback).Run(base::unexpected(mojom::APIError::InternalError));
56+
}
57+
58+
void EngineConsumer::GetFocusTabs(const std::vector<Tab>& tabs,
59+
const std::string& topic,
60+
GetFocusTabsCallback callback) {
61+
NOTIMPLEMENTED();
62+
std::move(callback).Run(base::unexpected(mojom::APIError::InternalError));
63+
}
64+
5265
} // namespace ai_chat

components/ai_chat/core/browser/engine/engine_consumer.h

+15
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "base/functional/callback_forward.h"
1717
#include "base/types/expected.h"
1818
#include "brave/components/ai_chat/core/browser/engine/remote_completion_client.h"
19+
#include "brave/components/ai_chat/core/browser/types.h"
1920
#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h"
2021

2122
namespace ai_chat {
@@ -43,6 +44,11 @@ class EngineConsumer {
4344

4445
using ConversationHistory = std::vector<mojom::ConversationTurnPtr>;
4546

47+
using GetSuggestedTopicsCallback = base::OnceCallback<void(
48+
base::expected<std::vector<std::string>, mojom::APIError>)>;
49+
using GetFocusTabsCallback = base::OnceCallback<void(
50+
base::expected<std::vector<std::string>, mojom::APIError>)>;
51+
4652
static std::string GetPromptForEntry(const mojom::ConversationTurnPtr& entry);
4753

4854
EngineConsumer();
@@ -86,6 +92,15 @@ class EngineConsumer {
8692

8793
virtual void UpdateModelOptions(const mojom::ModelOptions& options) = 0;
8894

95+
// Given a list of tabs, return a list of suggested topics from the server.
96+
virtual void GetSuggestedTopics(const std::vector<Tab>& tabs,
97+
GetSuggestedTopicsCallback callback);
98+
// Given a list of tabs and a specific topic, return a list of tabs to be
99+
// focused on from the server.
100+
virtual void GetFocusTabs(const std::vector<Tab>& tabs,
101+
const std::string& topic,
102+
GetFocusTabsCallback callback);
103+
89104
void SetMaxAssociatedContentLengthForTesting(
90105
uint32_t max_associated_content_length) {
91106
max_associated_content_length_ = max_associated_content_length;

components/ai_chat/core/browser/engine/engine_consumer_conversation_api.cc

+170
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,31 @@
55

66
#include "brave/components/ai_chat/core/browser/engine/engine_consumer_conversation_api.h"
77

8+
#include <algorithm>
89
#include <optional>
910
#include <string>
1011
#include <string_view>
1112
#include <type_traits>
1213
#include <vector>
1314

15+
#include "base/barrier_callback.h"
1416
#include "base/check.h"
1517
#include "base/functional/bind.h"
1618
#include "base/functional/callback.h"
1719
#include "base/functional/callback_helpers.h"
20+
#include "base/json/json_reader.h"
21+
#include "base/json/json_writer.h"
1822
#include "base/memory/scoped_refptr.h"
1923
#include "base/memory/weak_ptr.h"
2024
#include "base/numerics/clamped_math.h"
2125
#include "base/strings/string_split.h"
26+
#include "base/strings/string_util.h"
2227
#include "base/time/time.h"
2328
#include "base/types/expected.h"
29+
#include "base/values.h"
2430
#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h"
2531
#include "services/network/public/cpp/shared_url_loader_factory.h"
32+
#include "third_party/re2/src/re2/re2.h"
2633

2734
namespace ai_chat {
2835

@@ -31,6 +38,9 @@ namespace {
3138
using ConversationEvent = ConversationAPIClient::ConversationEvent;
3239
using ConversationEventType = ConversationAPIClient::ConversationEventType;
3340

41+
constexpr size_t kChunkSize = 75;
42+
constexpr char kStrArrPattern[] = R"((\[.+?\]))";
43+
3444
} // namespace
3545

3646
EngineConsumerConversationAPI::EngineConsumerConversationAPI(
@@ -49,6 +59,66 @@ void EngineConsumerConversationAPI::ClearAllQueries() {
4959
api_->ClearAllQueries();
5060
}
5161

62+
// static
63+
base::expected<std::vector<std::string>, mojom::APIError>
64+
EngineConsumerConversationAPI::GetStrArrFromTabOrganizationResponses(
65+
std::vector<EngineConsumer::GenerationResult>& results) {
66+
// Rust implementation of JSON reader is required to parse the response
67+
// safely. This function currently is only called on Desktop which uses a
68+
// rust JSON reader. Chromium is in progress of making all platforms use the
69+
// rust JSON reader and updating the rule of 2 documentation to explicitly
70+
// point out that JSON parser in base is considered safe.
71+
if (!base::JSONReader::UsingRust()) {
72+
return base::unexpected(mojom::APIError::InternalError);
73+
}
74+
75+
std::vector<std::string> str_arr;
76+
mojom::APIError error = mojom::APIError::None;
77+
for (auto& result : results) {
78+
// Fail the operation if server returns an error, such as rate limiting.
79+
// On the other hand, ignore the result which cannot be parsed as expected.
80+
if (!result.has_value()) {
81+
error = result.error();
82+
break;
83+
}
84+
85+
// Skip empty results.
86+
if (result->empty()) {
87+
continue;
88+
}
89+
90+
// Remove newline characters from the result.
91+
base::ReplaceChars(*result, "\n", "", &result.value());
92+
std::string strArr = "";
93+
if (!RE2::PartialMatch(*result, kStrArrPattern, &strArr)) {
94+
continue;
95+
}
96+
auto value = base::JSONReader::Read(strArr, base::JSON_PARSE_RFC);
97+
if (!value) {
98+
continue;
99+
}
100+
101+
auto* list = value->GetIfList();
102+
if (!list) {
103+
continue;
104+
}
105+
106+
for (const auto& item : *list) {
107+
auto* str = item.GetIfString();
108+
if (!str || str->empty()) {
109+
continue;
110+
}
111+
str_arr.push_back(*str);
112+
}
113+
}
114+
115+
if (error != mojom::APIError::None) {
116+
return base::unexpected(error);
117+
}
118+
119+
return str_arr;
120+
}
121+
52122
void EngineConsumerConversationAPI::GenerateRewriteSuggestion(
53123
std::string text,
54124
const std::string& question,
@@ -183,4 +253,104 @@ EngineConsumerConversationAPI::GetAssociatedContentConversationEvent(
183253
return event;
184254
}
185255

256+
void EngineConsumerConversationAPI::DedupeTopics(
257+
base::expected<std::vector<std::string>, mojom::APIError> topics_result,
258+
GetSuggestedTopicsCallback callback) {
259+
if (!topics_result.has_value() || topics_result->empty()) {
260+
std::move(callback).Run(topics_result);
261+
return;
262+
}
263+
264+
base::Value::List topic_list;
265+
for (const auto& topic : *topics_result) {
266+
topic_list.Append(topic);
267+
}
268+
std::vector<ConversationEvent> conversation;
269+
conversation.push_back({mojom::CharacterType::HUMAN,
270+
ConversationEventType::DedupeTopics,
271+
base::WriteJson(topic_list).value_or(std::string())});
272+
api_->PerformRequest(
273+
std::move(conversation), "" /* selected_language */,
274+
base::NullCallback() /* data_received_callback */,
275+
base::BindOnce(
276+
[](GetSuggestedTopicsCallback callback,
277+
EngineConsumer::GenerationResult result) {
278+
// Return deduped topics from the response.
279+
std::vector<EngineConsumer::GenerationResult> results = {result};
280+
std::move(callback).Run(
281+
EngineConsumerConversationAPI::
282+
GetStrArrFromTabOrganizationResponses(results));
283+
},
284+
std::move(callback)));
285+
}
286+
287+
void EngineConsumerConversationAPI::ProcessTabChunks(
288+
const std::vector<Tab>& tabs,
289+
ConversationEventType event_type,
290+
base::OnceCallback<void(std::vector<GenerationResult>)> merge_callback,
291+
const std::string& topic) {
292+
CHECK(event_type == ConversationEventType::GetSuggestedTopicsForFocusTabs ||
293+
event_type == ConversationEventType::GetFocusTabsForTopic);
294+
295+
// Split tab into chunks of 75
296+
size_t num_chunks = (tabs.size() + kChunkSize - 1) / kChunkSize;
297+
const auto barrier_callback = base::BarrierCallback<GenerationResult>(
298+
num_chunks, std::move(merge_callback));
299+
300+
for (size_t chunk = 0; chunk < num_chunks; ++chunk) {
301+
base::Value::List tab_value_list;
302+
for (size_t i = chunk * kChunkSize;
303+
i < std::min((chunk + 1) * kChunkSize, tabs.size()); ++i) {
304+
tab_value_list.Append(base::Value::Dict()
305+
.Set("id", tabs[i].id)
306+
.Set("title", tabs[i].title)
307+
.Set("url", tabs[i].origin.Serialize()));
308+
}
309+
310+
std::vector<ConversationEvent> conversation;
311+
conversation.push_back(
312+
{mojom::CharacterType::HUMAN, event_type,
313+
base::WriteJson(tab_value_list).value_or(std::string()), topic});
314+
315+
api_->PerformRequest(std::move(conversation), "" /* selected_language */,
316+
base::NullCallback() /* data_received_callback */,
317+
barrier_callback /* data_completed_callback */);
318+
}
319+
}
320+
321+
void EngineConsumerConversationAPI::MergeSuggestTopicsResults(
322+
GetSuggestedTopicsCallback callback,
323+
std::vector<GenerationResult> results) {
324+
// Merge the result and send another request to dedupe topics.
325+
DedupeTopics(GetStrArrFromTabOrganizationResponses(results),
326+
std::move(callback));
327+
}
328+
329+
void EngineConsumerConversationAPI::GetSuggestedTopics(
330+
const std::vector<Tab>& tabs,
331+
GetSuggestedTopicsCallback callback) {
332+
ProcessTabChunks(
333+
tabs, ConversationEventType::GetSuggestedTopicsForFocusTabs,
334+
base::BindOnce(&EngineConsumerConversationAPI::MergeSuggestTopicsResults,
335+
weak_ptr_factory_.GetWeakPtr(), std::move(callback)),
336+
"" /* topic */);
337+
}
338+
339+
void EngineConsumerConversationAPI::GetFocusTabs(
340+
const std::vector<Tab>& tabs,
341+
const std::string& topic,
342+
EngineConsumer::GetFocusTabsCallback callback) {
343+
ProcessTabChunks(tabs, ConversationEventType::GetFocusTabsForTopic,
344+
base::BindOnce(
345+
[&](EngineConsumer::GetFocusTabsCallback callback,
346+
std::vector<GenerationResult> results) {
347+
// Merge the results and call callback with tab IDs or
348+
// error.
349+
std::move(callback).Run(
350+
GetStrArrFromTabOrganizationResponses(results));
351+
},
352+
std::move(callback)),
353+
topic);
354+
}
355+
186356
} // namespace ai_chat

0 commit comments

Comments
 (0)