Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DX-61034: Arrow changes from GH-34865 PR #44

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ python/doc/
# Egg metadata
*.egg-info

# Generated C++
CMakeFiles
cpp-jni
java-jni

.vscode
.idea/
.pytest_cache/
Expand Down Expand Up @@ -91,4 +96,4 @@ java-native-cpp/
# archery files
dev/archery/build

swift/Arrow/.build
swift/Arrow/.build
3 changes: 2 additions & 1 deletion cpp/src/arrow/flight/sql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ set(ARROW_FLIGHT_SQL_SRCS
sql_info_internal.cc
column_metadata.cc
client.cc
protocol_internal.cc)
protocol_internal.cc
server_session_middleware.cc)

add_arrow_lib(arrow_flight_sql
CMAKE_PACKAGE_NAME
Expand Down
150 changes: 150 additions & 0 deletions cpp/src/arrow/flight/sql/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@
#include "arrow/result.h"
#include "arrow/util/logging.h"

// Lambda helper & CTAD
template<class... Ts>
struct overloaded : Ts... { using Ts::operator()...; };
template<class... Ts> // CTAD will not be needed for >=C++20
overloaded(Ts...) -> overloaded<Ts...>;

namespace pb = arrow::flight::protocol;
namespace flight_sql_pb = arrow::flight::protocol::sql;

namespace arrow {
Expand Down Expand Up @@ -802,6 +809,149 @@ ::arrow::Result<CancelResult> FlightSqlClient::CancelQuery(
return Status::IOError("Server returned unknown result ", result.result());
}

::arrow::Result<std::map<std::string, SetSessionOptionResult>>
FlightSqlClient::SetSessionOptions(
const FlightCallOptions& options,
const std::map<std::string, SessionOptionValue>& session_options) {
pb::ActionSetSessionOptionsRequest request;
auto* options_map = request.mutable_session_options();

for (const auto & [name, opt_value] : session_options) {
pb::SessionOptionValue pb_opt_value;

if (opt_value.index() == std::variant_npos)
return Status::Invalid("Undefined SessionOptionValue type ");

std::visit(overloaded{
// TODO move this somewhere common that can have Proto-involved code
[&](std::string v) { pb_opt_value.set_string_value(v); },
[&](bool v) { pb_opt_value.set_bool_value(v); },
[&](int32_t v) { pb_opt_value.set_int32_value(v); },
[&](int64_t v) { pb_opt_value.set_int64_value(v); },
[&](float v) { pb_opt_value.set_float_value(v); },
[&](double v) { pb_opt_value.set_double_value(v); },
[&](std::vector<std::string> v) {
auto* string_list_value = pb_opt_value.mutable_string_list_value();
for (const std::string& s : v)
string_list_value->add_values(s);
}
}, opt_value);
(*options_map)[name] = std::move(pb_opt_value);
}

std::unique_ptr<ResultStream> results;
ARROW_ASSIGN_OR_RAISE(auto action, PackAction("SetSessionOptions", request));
ARROW_RETURN_NOT_OK(DoAction(options, action, &results));

pb::ActionSetSessionOptionsResult pb_result;
ARROW_RETURN_NOT_OK(ReadResult(results.get(), &pb_result));
ARROW_RETURN_NOT_OK(DrainResultStream(results.get()));
std::map<std::string, SetSessionOptionResult> result;
for (const auto & [result_key, result_value] : pb_result.results()) {
switch (result_value) {
case pb::ActionSetSessionOptionsResult::SET_SESSION_OPTION_RESULT_UNSPECIFIED:
result[result_key] = SetSessionOptionResult::kUnspecified;
break;
case pb::ActionSetSessionOptionsResult
::SET_SESSION_OPTION_RESULT_OK:
result[result_key] = SetSessionOptionResult::kOk;
break;
case pb::ActionSetSessionOptionsResult
::SET_SESSION_OPTION_RESULT_INVALID_VALUE:
result[result_key] = SetSessionOptionResult::kInvalidResult;
break;
case pb::ActionSetSessionOptionsResult::SET_SESSION_OPTION_RESULT_ERROR:
result[result_key] = SetSessionOptionResult::kError;
break;
default:
return Status::IOError("Invalid SetSessionOptionResult value for key "
+ result_key);
}
}

return result;
}

::arrow::Result<std::map<std::string, SessionOptionValue>>
FlightSqlClient::GetSessionOptions (
const FlightCallOptions& options) {
pb::ActionGetSessionOptionsRequest request;

std::unique_ptr<ResultStream> results;
ARROW_ASSIGN_OR_RAISE(auto action, PackAction("GetSessionOptions", request));
ARROW_RETURN_NOT_OK(DoAction(options, action, &results));

pb::ActionGetSessionOptionsResult pb_result;
ARROW_RETURN_NOT_OK(ReadResult(results.get(), &pb_result));
ARROW_RETURN_NOT_OK(DrainResultStream(results.get()));

std::map<std::string, SessionOptionValue> result;
if (pb_result.session_options_size() > 0) {
for (auto& [pb_opt_name, pb_opt_val] : pb_result.session_options()) {
SessionOptionValue val;
switch (pb_opt_val.option_value_case()) {
case pb::SessionOptionValue::OPTION_VALUE_NOT_SET:
return Status::Invalid("Unset option_value for name '" + pb_opt_name + "'");
case pb::SessionOptionValue::kStringValue:
val = pb_opt_val.string_value();
break;
case pb::SessionOptionValue::kBoolValue:
val = pb_opt_val.bool_value();
break;
case pb::SessionOptionValue::kInt32Value:
val = pb_opt_val.int32_value();
break;
case pb::SessionOptionValue::kInt64Value:
val = pb_opt_val.int64_value();
break;
case pb::SessionOptionValue::kFloatValue:
val = pb_opt_val.float_value();
break;
case pb::SessionOptionValue::kDoubleValue:
val = pb_opt_val.double_value();
break;
case pb::SessionOptionValue::kStringListValue:
val.emplace<std::vector<std::string>>();
std::get<std::vector<std::string>>(val)
.reserve(pb_opt_val.string_list_value().values_size());
for (const std::string& s : pb_opt_val.string_list_value().values())
std::get<std::vector<std::string>>(val).push_back(s);
break;
}
result[pb_opt_name] = std::move(val);
}
}

return result;
}

::arrow::Result<CloseSessionResult> FlightSqlClient::CloseSession(
const FlightCallOptions& options) {
pb::ActionCloseSessionRequest request;

std::unique_ptr<ResultStream> results;
ARROW_ASSIGN_OR_RAISE(auto action, PackAction("CloseSession", request));
ARROW_RETURN_NOT_OK(DoAction(options, action, &results));

pb::ActionCloseSessionResult result;
ARROW_RETURN_NOT_OK(ReadResult(results.get(), &result));
ARROW_RETURN_NOT_OK(DrainResultStream(results.get()));
switch (result.result()) {
case pb::ActionCloseSessionResult::CLOSE_RESULT_UNSPECIFIED:
return CloseSessionResult::kUnspecified;
case pb::ActionCloseSessionResult::CLOSE_RESULT_CLOSED:
return CloseSessionResult::kClosed;
case pb::ActionCloseSessionResult::CLOSE_RESULT_CLOSING:
return CloseSessionResult::kClosing;
case pb::ActionCloseSessionResult::CLOSE_RESULT_NOT_CLOSEABLE:
return CloseSessionResult::kNotClosable;
default:
break;
}

return Status::IOError("Server returned unknown result ", result.result());
}

Status FlightSqlClient::Close() { return impl_->Close(); }

std::ostream& operator<<(std::ostream& os, CancelResult result) {
Expand Down
20 changes: 20 additions & 0 deletions cpp/src/arrow/flight/sql/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <cstdint>
#include <memory>
#include <map>
#include <string>

#include "arrow/flight/client.h"
Expand Down Expand Up @@ -329,6 +330,25 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlClient {
/// \param[in] info The FlightInfo of the query to cancel.
::arrow::Result<CancelResult> CancelQuery(const FlightCallOptions& options,
const FlightInfo& info);

/// \brief Sets session options.
///
/// \param[in] options RPC-layer hints for this call.
/// \param[in] session_options The session options to set.
::arrow::Result<std::map<std::string, SetSessionOptionResult>> SetSessionOptions(
const FlightCallOptions& options,
const std::map<std::string, SessionOptionValue>& session_options);

/// \brief Gets current session options.
///
/// \param[in] options RPC-layer hints for this call.
::arrow::Result<std::map<std::string, SessionOptionValue>> GetSessionOptions(
const FlightCallOptions& options);

/// \brief Explicitly closes the session if applicable.
///
/// \param[in] options RPC-layer hints for this call.
::arrow::Result<CloseSessionResult> CloseSession(const FlightCallOptions& options);

/// \brief Explicitly shut down and clean up the client.
Status Close();
Expand Down
1 change: 1 addition & 0 deletions cpp/src/arrow/flight/sql/protocol_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@
#include "arrow/flight/sql/visibility.h"

#include "arrow/flight/sql/FlightSql.pb.h" // IWYU pragma: export
#include "arrow/flight/Flight.pb.h"
Loading