22
22
#include " arrow/scalar.h"
23
23
#include " arrow/util/bit_run_reader.h"
24
24
#include " arrow/util/logging.h"
25
+ #include " arrow/visit_data_inline.h"
25
26
26
27
namespace arrow ::compute::internal {
27
28
namespace {
@@ -30,7 +31,8 @@ using arrow::internal::VisitSetBitRunsVoid;
30
31
using arrow::util::span;
31
32
32
33
struct PivotImpl : public ScalarAggregator {
33
- Status Init (const PivotWiderOptions& options, const std::vector<TypeHolder>& in_types) {
34
+ Status Init (const PivotWiderOptions& options, const std::vector<TypeHolder>& in_types,
35
+ ExecContext* ctx) {
34
36
options_ = &options;
35
37
key_type_ = in_types[0 ].GetSharedPtr ();
36
38
auto value_type = in_types[1 ].GetSharedPtr ();
@@ -42,47 +44,56 @@ struct PivotImpl : public ScalarAggregator {
42
44
values_.push_back (MakeNullScalar (value_type));
43
45
}
44
46
out_type_ = struct_ (std::move (fields));
45
- ARROW_ASSIGN_OR_RAISE (key_mapper_, PivotWiderKeyMapper::Make (*key_type_, options_));
47
+ ARROW_ASSIGN_OR_RAISE (key_mapper_,
48
+ PivotWiderKeyMapper::Make (*key_type_, options_, ctx));
46
49
return Status::OK ();
47
50
}
48
51
49
52
Status Consume (KernelContext*, const ExecSpan& batch) override {
50
53
DCHECK_EQ (batch.num_values (), 2 );
51
54
if (batch[0 ].is_array ()) {
52
- ARROW_ASSIGN_OR_RAISE (span< const PivotWiderKeyIndex> keys,
53
- key_mapper_-> MapKeys (batch[ 0 ]. array ) );
55
+ ARROW_ASSIGN_OR_RAISE (auto keys_array, key_mapper_-> MapKeys (batch[ 0 ]. array ));
56
+ ArraySpan keys_span (*keys_array );
54
57
if (batch[1 ].is_array ()) {
55
58
// Array keys, array values
56
59
auto values = batch[1 ].array .ToArray ();
57
- for (int64_t i = 0 ; i < batch.length ; ++i) {
58
- PivotWiderKeyIndex key = keys[i];
59
- if (key != kNullPivotKey && !values->IsNull (i)) {
60
- if (ARROW_PREDICT_FALSE (values_[key]->is_valid )) {
61
- return DuplicateValue ();
62
- }
63
- ARROW_ASSIGN_OR_RAISE (values_[key], values->GetScalar (i));
64
- DCHECK (values_[key]->is_valid );
65
- }
66
- }
60
+ int64_t i = 0 ;
61
+ RETURN_NOT_OK (VisitArraySpanInline<UInt32Type>(
62
+ keys_span,
63
+ [&](uint32_t key) {
64
+ if (!values->IsNull (i)) {
65
+ if (ARROW_PREDICT_FALSE (values_[key]->is_valid )) {
66
+ return DuplicateValue ();
67
+ }
68
+ ARROW_ASSIGN_OR_RAISE (values_[key], values->GetScalar (i));
69
+ }
70
+ ++i;
71
+ return Status::OK ();
72
+ },
73
+ [&]() {
74
+ ++i;
75
+ return Status::OK ();
76
+ }));
67
77
} else {
68
78
// Array keys, scalar value
69
79
const Scalar* value = batch[1 ].scalar ;
70
80
if (value->is_valid ) {
71
- for (int64_t i = 0 ; i < batch.length ; ++i) {
72
- PivotWiderKeyIndex key = keys[i];
73
- if (key != kNullPivotKey ) {
74
- if (ARROW_PREDICT_FALSE (values_[key]->is_valid )) {
75
- return DuplicateValue ();
76
- }
77
- values_[key] = value->GetSharedPtr ();
78
- }
79
- }
81
+ RETURN_NOT_OK (VisitArraySpanInline<UInt32Type>(
82
+ keys_span,
83
+ [&](uint32_t key) {
84
+ if (ARROW_PREDICT_FALSE (values_[key]->is_valid )) {
85
+ return DuplicateValue ();
86
+ }
87
+ values_[key] = value->GetSharedPtr ();
88
+ return Status::OK ();
89
+ },
90
+ [] { return Status::OK (); }));
80
91
}
81
92
}
82
93
} else {
83
- ARROW_ASSIGN_OR_RAISE (PivotWiderKeyIndex key,
84
- key_mapper_-> MapKey (*batch[ 0 ]. scalar ));
85
- if ( key != kNullPivotKey ) {
94
+ ARROW_ASSIGN_OR_RAISE (auto maybe_key, key_mapper_-> MapKey (*batch[ 0 ]. scalar ));
95
+ if (maybe_key. has_value ()) {
96
+ PivotWiderKeyIndex key = maybe_key. value ();
86
97
if (batch[1 ].is_array ()) {
87
98
// Scalar key, array values
88
99
auto values = batch[1 ].array .ToArray ();
@@ -145,10 +156,8 @@ struct PivotImpl : public ScalarAggregator {
145
156
Result<std::unique_ptr<KernelState>> PivotInit (KernelContext* ctx,
146
157
const KernelInitArgs& args) {
147
158
const auto & options = checked_cast<const PivotWiderOptions&>(*args.options );
148
- DCHECK_EQ (args.inputs .size (), 2 );
149
- DCHECK (is_base_binary_like (args.inputs [0 ].id ()));
150
159
auto state = std::make_unique<PivotImpl>();
151
- RETURN_NOT_OK (state->Init (options, args.inputs ));
160
+ RETURN_NOT_OK (state->Init (options, args.inputs , ctx-> exec_context () ));
152
161
// GH-45718: This can be simplified once we drop the R openSUSE155 crossbow
153
162
// job
154
163
// R build with openSUSE155 requires an explicit shared_ptr construction
@@ -167,6 +176,8 @@ const FunctionDoc pivot_doc{
167
176
" is emitted. If a pivot key doesn't appear, null is emitted.\n "
168
177
" If more than one non-null value is encountered for a given pivot key,\n "
169
178
" Invalid is raised.\n "
179
+ " The pivot key column can be string, binary or integer. The `key_names`\n "
180
+ " will be cast to the pivot key column type for matching.\n "
170
181
" Behavior of unexpected pivot keys is controlled by `unexpected_key_behavior`\n "
171
182
" in PivotWiderOptions." ),
172
183
{" pivot_keys" , " pivot_values" },
@@ -179,11 +190,17 @@ void RegisterScalarAggregatePivot(FunctionRegistry* registry) {
179
190
180
191
auto func = std::make_shared<ScalarAggregateFunction>(
181
192
" pivot_wider" , Arity::Binary (), pivot_doc, &default_pivot_options);
182
-
183
- for (auto key_type : BaseBinaryTypes ()) {
184
- auto sig = KernelSignature::Make ({key_type->id (), InputType::Any ()},
193
+ auto add_kernel = [&](InputType key_type) {
194
+ auto sig = KernelSignature::Make ({key_type, InputType::Any ()},
185
195
OutputType (ResolveOutputType));
186
196
AddAggKernel (std::move (sig), PivotInit, func.get ());
197
+ };
198
+
199
+ for (const auto & key_type : BaseBinaryTypes ()) {
200
+ add_kernel (key_type->id ());
201
+ }
202
+ for (const auto & key_type : IntTypes ()) {
203
+ add_kernel (key_type->id ());
187
204
}
188
205
DCHECK_OK (registry->AddFunction (std::move (func)));
189
206
}
0 commit comments