|
28 | 28 | #include "arrow/array/array_nested.h"
|
29 | 29 | #include "arrow/array/array_primitive.h"
|
30 | 30 | #include "arrow/array/builder_primitive.h"
|
| 31 | +#include "arrow/flight/client_cookie_middleware.h" |
31 | 32 | #include "arrow/flight/client_middleware.h"
|
32 | 33 | #include "arrow/flight/server_middleware.h"
|
33 | 34 | #include "arrow/flight/sql/client.h"
|
34 | 35 | #include "arrow/flight/sql/column_metadata.h"
|
35 | 36 | #include "arrow/flight/sql/server.h"
|
| 37 | +#include "arrow/flight/sql/server_session_middleware.h" |
36 | 38 | #include "arrow/flight/sql/types.h"
|
37 | 39 | #include "arrow/flight/test_util.h"
|
38 | 40 | #include "arrow/flight/types.h"
|
@@ -744,6 +746,155 @@ class ExpirationTimeRenewFlightEndpointScenario : public Scenario {
|
744 | 746 | }
|
745 | 747 | };
|
746 | 748 |
|
| 749 | +/// \brief The server used for testing Session Options. |
| 750 | +/// |
| 751 | +/// SetSessionOptions has a blacklisted option name and string option value, |
| 752 | +/// both "lol_invalid", which will result in errors attempting to set either. |
| 753 | +class SessionOptionsServer : public sql::FlightSqlServerBase { |
| 754 | + static inline const std::string invalid_option_name = "lol_invalid"; |
| 755 | + static inline const SessionOptionValue invalid_option_value = "lol_invalid"; |
| 756 | + |
| 757 | + const std::string session_middleware_key; |
| 758 | + // These will never be threaded so using a plain map and no lock |
| 759 | + std::map<std::string, SessionOptionValue> session_store_; |
| 760 | + |
| 761 | + public: |
| 762 | + explicit SessionOptionsServer(std::string session_middleware_key) |
| 763 | + : FlightSqlServerBase(), |
| 764 | + session_middleware_key(std::move(session_middleware_key)) {} |
| 765 | + |
| 766 | + arrow::Result<SetSessionOptionsResult> SetSessionOptions( |
| 767 | + const ServerCallContext& context, |
| 768 | + const SetSessionOptionsRequest& request) override { |
| 769 | + SetSessionOptionsResult res; |
| 770 | + |
| 771 | + auto* middleware = static_cast<sql::ServerSessionMiddleware*>( |
| 772 | + context.GetMiddleware(session_middleware_key)); |
| 773 | + ARROW_ASSIGN_OR_RAISE(std::shared_ptr<sql::FlightSession> session, |
| 774 | + middleware->GetSession()); |
| 775 | + |
| 776 | + for (const auto& [name, value] : request.session_options) { |
| 777 | + // Blacklisted value name |
| 778 | + if (name == invalid_option_name) { |
| 779 | + res.errors.emplace(name, SetSessionOptionsResult::Error{ |
| 780 | + SetSessionOptionErrorValue::kInvalidName}); |
| 781 | + continue; |
| 782 | + } |
| 783 | + // Blacklisted option value |
| 784 | + if (value == invalid_option_value) { |
| 785 | + res.errors.emplace(name, SetSessionOptionsResult::Error{ |
| 786 | + SetSessionOptionErrorValue::kInvalidValue}); |
| 787 | + continue; |
| 788 | + } |
| 789 | + if (std::holds_alternative<std::monostate>(value)) { |
| 790 | + session->EraseSessionOption(name); |
| 791 | + continue; |
| 792 | + } |
| 793 | + session->SetSessionOption(name, value); |
| 794 | + } |
| 795 | + |
| 796 | + return res; |
| 797 | + } |
| 798 | + |
| 799 | + arrow::Result<GetSessionOptionsResult> GetSessionOptions( |
| 800 | + const ServerCallContext& context, |
| 801 | + const GetSessionOptionsRequest& request) override { |
| 802 | + auto* middleware = static_cast<sql::ServerSessionMiddleware*>( |
| 803 | + context.GetMiddleware(session_middleware_key)); |
| 804 | + if (!middleware->HasSession()) { |
| 805 | + return Status::Invalid("No existing session to get options from."); |
| 806 | + } |
| 807 | + ARROW_ASSIGN_OR_RAISE(std::shared_ptr<sql::FlightSession> session, |
| 808 | + middleware->GetSession()); |
| 809 | + |
| 810 | + return GetSessionOptionsResult{session->GetSessionOptions()}; |
| 811 | + } |
| 812 | + |
| 813 | + arrow::Result<CloseSessionResult> CloseSession( |
| 814 | + const ServerCallContext& context, const CloseSessionRequest& request) override { |
| 815 | + // Broken (does not expire cookie) until C++ middleware handling (GH-39791) fixed: |
| 816 | + auto* middleware = static_cast<sql::ServerSessionMiddleware*>( |
| 817 | + context.GetMiddleware(session_middleware_key)); |
| 818 | + ARROW_RETURN_NOT_OK(middleware->CloseSession()); |
| 819 | + return CloseSessionResult{CloseSessionStatus::kClosed}; |
| 820 | + } |
| 821 | +}; |
| 822 | + |
| 823 | +/// \brief The Session Options scenario. |
| 824 | +/// |
| 825 | +/// This tests Session Options functionality as well as ServerSessionMiddleware. |
| 826 | +class SessionOptionsScenario : public Scenario { |
| 827 | + static inline const std::string server_middleware_key = "sessionmiddleware"; |
| 828 | + |
| 829 | + Status MakeServer(std::unique_ptr<FlightServerBase>* server, |
| 830 | + FlightServerOptions* options) override { |
| 831 | + *server = std::make_unique<SessionOptionsServer>(server_middleware_key); |
| 832 | + |
| 833 | + auto id_gen_int = std::make_shared<std::atomic_int>(1000); |
| 834 | + options->middleware.emplace_back( |
| 835 | + server_middleware_key, |
| 836 | + sql::MakeServerSessionMiddlewareFactory( |
| 837 | + [=]() -> std::string { return std::to_string((*id_gen_int)++); })); |
| 838 | + |
| 839 | + return Status::OK(); |
| 840 | + } |
| 841 | + |
| 842 | + Status MakeClient(FlightClientOptions* options) override { |
| 843 | + options->middleware.emplace_back(GetCookieFactory()); |
| 844 | + return Status::OK(); |
| 845 | + } |
| 846 | + |
| 847 | + Status RunClient(std::unique_ptr<FlightClient> flight_client) override { |
| 848 | + sql::FlightSqlClient client{std::move(flight_client)}; |
| 849 | + |
| 850 | + // Set |
| 851 | + auto req1 = SetSessionOptionsRequest{ |
| 852 | + {{"foolong", 123L}, |
| 853 | + {"bardouble", 456.0}, |
| 854 | + {"lol_invalid", "this won't get set"}, |
| 855 | + {"key_with_invalid_value", "lol_invalid"}, |
| 856 | + {"big_ol_string_list", std::vector<std::string>{"a", "b", "sea", "dee", " ", |
| 857 | + " ", "geee", "(づ。◕‿‿◕。)づ"}}}}; |
| 858 | + ARROW_ASSIGN_OR_RAISE(auto res1, client.SetSessionOptions({}, req1)); |
| 859 | + // Some errors |
| 860 | + if (res1.errors != |
| 861 | + std::map<std::string, SetSessionOptionsResult::Error>{ |
| 862 | + {"lol_invalid", |
| 863 | + SetSessionOptionsResult::Error{SetSessionOptionErrorValue::kInvalidName}}, |
| 864 | + {"key_with_invalid_value", SetSessionOptionsResult::Error{ |
| 865 | + SetSessionOptionErrorValue::kInvalidValue}}}) { |
| 866 | + return Status::Invalid("res1 incorrect: " + res1.ToString()); |
| 867 | + } |
| 868 | + // Some set, some omitted due to above errors |
| 869 | + ARROW_ASSIGN_OR_RAISE(auto res2, client.GetSessionOptions({}, {})); |
| 870 | + if (res2.session_options != |
| 871 | + std::map<std::string, SessionOptionValue>{ |
| 872 | + {"foolong", 123L}, |
| 873 | + {"bardouble", 456.0}, |
| 874 | + {"big_ol_string_list", |
| 875 | + std::vector<std::string>{"a", "b", "sea", "dee", " ", " ", "geee", |
| 876 | + "(づ。◕‿‿◕。)づ"}}}) { |
| 877 | + return Status::Invalid("res2 incorrect: " + res2.ToString()); |
| 878 | + } |
| 879 | + // Update |
| 880 | + ARROW_ASSIGN_OR_RAISE( |
| 881 | + auto res3, |
| 882 | + client.SetSessionOptions( |
| 883 | + {}, SetSessionOptionsRequest{ |
| 884 | + {{"foolong", std::monostate{}}, |
| 885 | + {"big_ol_string_list", "a,b,sea,dee, , ,geee,(づ。◕‿‿◕。)づ"}}})); |
| 886 | + ARROW_ASSIGN_OR_RAISE(auto res4, client.GetSessionOptions({}, {})); |
| 887 | + if (res4.session_options != |
| 888 | + std::map<std::string, SessionOptionValue>{ |
| 889 | + {"bardouble", 456.0}, |
| 890 | + {"big_ol_string_list", "a,b,sea,dee, , ,geee,(づ。◕‿‿◕。)づ"}}) { |
| 891 | + return Status::Invalid("res4 incorrect: " + res4.ToString()); |
| 892 | + } |
| 893 | + |
| 894 | + return Status::OK(); |
| 895 | + } |
| 896 | +}; |
| 897 | + |
747 | 898 | /// \brief The server used for testing PollFlightInfo().
|
748 | 899 | class PollFlightInfoServer : public FlightServerBase {
|
749 | 900 | public:
|
@@ -1952,6 +2103,9 @@ Status GetScenario(const std::string& scenario_name, std::shared_ptr<Scenario>*
|
1952 | 2103 | } else if (scenario_name == "expiration_time:renew_flight_endpoint") {
|
1953 | 2104 | *out = std::make_shared<ExpirationTimeRenewFlightEndpointScenario>();
|
1954 | 2105 | return Status::OK();
|
| 2106 | + } else if (scenario_name == "session_options") { |
| 2107 | + *out = std::make_shared<SessionOptionsScenario>(); |
| 2108 | + return Status::OK(); |
1955 | 2109 | } else if (scenario_name == "poll_flight_info") {
|
1956 | 2110 | *out = std::make_shared<PollFlightInfoScenario>();
|
1957 | 2111 | return Status::OK();
|
|
0 commit comments