5
5
6
6
#include " brave/components/ai_chat/core/browser/engine/engine_consumer_conversation_api.h"
7
7
8
+ #include < algorithm>
8
9
#include < optional>
9
10
#include < string>
10
11
#include < string_view>
11
12
#include < type_traits>
12
13
#include < vector>
13
14
15
+ #include " base/barrier_callback.h"
14
16
#include " base/check.h"
15
17
#include " base/functional/bind.h"
16
18
#include " base/functional/callback.h"
17
19
#include " base/functional/callback_helpers.h"
20
+ #include " base/json/json_reader.h"
21
+ #include " base/json/json_writer.h"
18
22
#include " base/memory/scoped_refptr.h"
19
23
#include " base/memory/weak_ptr.h"
20
24
#include " base/numerics/clamped_math.h"
21
25
#include " base/strings/string_split.h"
26
+ #include " base/strings/string_util.h"
22
27
#include " base/time/time.h"
23
28
#include " base/types/expected.h"
29
+ #include " base/values.h"
24
30
#include " brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h"
25
31
#include " services/network/public/cpp/shared_url_loader_factory.h"
32
+ #include " third_party/re2/src/re2/re2.h"
26
33
27
34
namespace ai_chat {
28
35
@@ -31,6 +38,9 @@ namespace {
31
38
using ConversationEvent = ConversationAPIClient::ConversationEvent;
32
39
using ConversationEventType = ConversationAPIClient::ConversationEventType;
33
40
41
+ constexpr size_t kChunkSize = 75 ;
42
+ constexpr char kStrArrPattern [] = R"( (\[.+?\]))" ;
43
+
34
44
} // namespace
35
45
36
46
EngineConsumerConversationAPI::EngineConsumerConversationAPI (
@@ -49,6 +59,66 @@ void EngineConsumerConversationAPI::ClearAllQueries() {
49
59
api_->ClearAllQueries ();
50
60
}
51
61
62
+ // static
63
+ base::expected<std::vector<std::string>, mojom::APIError>
64
+ EngineConsumerConversationAPI::GetStrArrFromTabOrganizationResponses (
65
+ std::vector<EngineConsumer::GenerationResult>& results) {
66
+ // Rust implementation of JSON reader is required to parse the response
67
+ // safely. This function currently is only called on Desktop which uses a
68
+ // rust JSON reader. Chromium is in progress of making all platforms use the
69
+ // rust JSON reader and updating the rule of 2 documentation to explicitly
70
+ // point out that JSON parser in base is considered safe.
71
+ if (!base::JSONReader::UsingRust ()) {
72
+ return base::unexpected (mojom::APIError::InternalError);
73
+ }
74
+
75
+ std::vector<std::string> str_arr;
76
+ mojom::APIError error = mojom::APIError::None;
77
+ for (auto & result : results) {
78
+ // Fail the operation if server returns an error, such as rate limiting.
79
+ // On the other hand, ignore the result which cannot be parsed as expected.
80
+ if (!result.has_value ()) {
81
+ error = result.error ();
82
+ break ;
83
+ }
84
+
85
+ // Skip empty results.
86
+ if (result->empty ()) {
87
+ continue ;
88
+ }
89
+
90
+ // Remove newline characters from the result.
91
+ base::ReplaceChars (*result, " \n " , " " , &result.value ());
92
+ std::string strArr = " " ;
93
+ if (!RE2::PartialMatch (*result, kStrArrPattern , &strArr)) {
94
+ continue ;
95
+ }
96
+ auto value = base::JSONReader::Read (strArr, base::JSON_PARSE_RFC);
97
+ if (!value) {
98
+ continue ;
99
+ }
100
+
101
+ auto * list = value->GetIfList ();
102
+ if (!list) {
103
+ continue ;
104
+ }
105
+
106
+ for (const auto & item : *list) {
107
+ auto * str = item.GetIfString ();
108
+ if (!str || str->empty ()) {
109
+ continue ;
110
+ }
111
+ str_arr.push_back (*str);
112
+ }
113
+ }
114
+
115
+ if (error != mojom::APIError::None) {
116
+ return base::unexpected (error);
117
+ }
118
+
119
+ return str_arr;
120
+ }
121
+
52
122
void EngineConsumerConversationAPI::GenerateRewriteSuggestion (
53
123
std::string text,
54
124
const std::string& question,
@@ -183,4 +253,104 @@ EngineConsumerConversationAPI::GetAssociatedContentConversationEvent(
183
253
return event;
184
254
}
185
255
256
+ void EngineConsumerConversationAPI::DedupeTopics (
257
+ base::expected<std::vector<std::string>, mojom::APIError> topics_result,
258
+ GetSuggestedTopicsCallback callback) {
259
+ if (!topics_result.has_value () || topics_result->empty ()) {
260
+ std::move (callback).Run (topics_result);
261
+ return ;
262
+ }
263
+
264
+ base::Value::List topic_list;
265
+ for (const auto & topic : *topics_result) {
266
+ topic_list.Append (topic);
267
+ }
268
+ std::vector<ConversationEvent> conversation;
269
+ conversation.push_back ({mojom::CharacterType::HUMAN,
270
+ ConversationEventType::DedupeTopics,
271
+ base::WriteJson (topic_list).value_or (std::string ())});
272
+ api_->PerformRequest (
273
+ std::move (conversation), " " /* selected_language */ ,
274
+ base::NullCallback () /* data_received_callback */ ,
275
+ base::BindOnce (
276
+ [](GetSuggestedTopicsCallback callback,
277
+ EngineConsumer::GenerationResult result) {
278
+ // Return deduped topics from the response.
279
+ std::vector<EngineConsumer::GenerationResult> results = {result};
280
+ std::move (callback).Run (
281
+ EngineConsumerConversationAPI::
282
+ GetStrArrFromTabOrganizationResponses (results));
283
+ },
284
+ std::move (callback)));
285
+ }
286
+
287
+ void EngineConsumerConversationAPI::ProcessTabChunks (
288
+ const std::vector<Tab>& tabs,
289
+ ConversationEventType event_type,
290
+ base::OnceCallback<void (std::vector<GenerationResult>)> merge_callback,
291
+ const std::string& topic) {
292
+ CHECK (event_type == ConversationEventType::GetSuggestedTopicsForFocusTabs ||
293
+ event_type == ConversationEventType::GetFocusTabsForTopic);
294
+
295
+ // Split tab into chunks of 75
296
+ size_t num_chunks = (tabs.size () + kChunkSize - 1 ) / kChunkSize ;
297
+ const auto barrier_callback = base::BarrierCallback<GenerationResult>(
298
+ num_chunks, std::move (merge_callback));
299
+
300
+ for (size_t chunk = 0 ; chunk < num_chunks; ++chunk) {
301
+ base::Value::List tab_value_list;
302
+ for (size_t i = chunk * kChunkSize ;
303
+ i < std::min ((chunk + 1 ) * kChunkSize , tabs.size ()); ++i) {
304
+ tab_value_list.Append (base::Value::Dict ()
305
+ .Set (" id" , tabs[i].id )
306
+ .Set (" title" , tabs[i].title )
307
+ .Set (" url" , tabs[i].origin .Serialize ()));
308
+ }
309
+
310
+ std::vector<ConversationEvent> conversation;
311
+ conversation.push_back (
312
+ {mojom::CharacterType::HUMAN, event_type,
313
+ base::WriteJson (tab_value_list).value_or (std::string ()), topic});
314
+
315
+ api_->PerformRequest (std::move (conversation), " " /* selected_language */ ,
316
+ base::NullCallback () /* data_received_callback */ ,
317
+ barrier_callback /* data_completed_callback */ );
318
+ }
319
+ }
320
+
321
+ void EngineConsumerConversationAPI::MergeSuggestTopicsResults (
322
+ GetSuggestedTopicsCallback callback,
323
+ std::vector<GenerationResult> results) {
324
+ // Merge the result and send another request to dedupe topics.
325
+ DedupeTopics (GetStrArrFromTabOrganizationResponses (results),
326
+ std::move (callback));
327
+ }
328
+
329
+ void EngineConsumerConversationAPI::GetSuggestedTopics (
330
+ const std::vector<Tab>& tabs,
331
+ GetSuggestedTopicsCallback callback) {
332
+ ProcessTabChunks (
333
+ tabs, ConversationEventType::GetSuggestedTopicsForFocusTabs,
334
+ base::BindOnce (&EngineConsumerConversationAPI::MergeSuggestTopicsResults,
335
+ weak_ptr_factory_.GetWeakPtr (), std::move (callback)),
336
+ " " /* topic */ );
337
+ }
338
+
339
+ void EngineConsumerConversationAPI::GetFocusTabs (
340
+ const std::vector<Tab>& tabs,
341
+ const std::string& topic,
342
+ EngineConsumer::GetFocusTabsCallback callback) {
343
+ ProcessTabChunks (tabs, ConversationEventType::GetFocusTabsForTopic,
344
+ base::BindOnce (
345
+ [&](EngineConsumer::GetFocusTabsCallback callback,
346
+ std::vector<GenerationResult> results) {
347
+ // Merge the results and call callback with tab IDs or
348
+ // error.
349
+ std::move (callback).Run (
350
+ GetStrArrFromTabOrganizationResponses (results));
351
+ },
352
+ std::move (callback)),
353
+ topic);
354
+ }
355
+
186
356
} // namespace ai_chat
0 commit comments