Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 1b73796

Browse files
committedMar 26, 2025·
apacheGH-45732: [C++][Compute] Accept more pivot key types
1 parent 9b7875c commit 1b73796

9 files changed

+324
-194
lines changed
 

‎cpp/src/arrow/acero/hash_aggregate_test.cc

+40-1
Original file line numberDiff line numberDiff line change
@@ -4440,7 +4440,7 @@ TEST_P(GroupBy, PivotBasics) {
44404440
}
44414441
}
44424442

4443-
TEST_P(GroupBy, PivotAllKeyTypes) {
4443+
TEST_P(GroupBy, PivotBinaryKeyTypes) {
44444444
auto value_type = float32();
44454445
std::vector<std::string> table_json = {R"([
44464446
[1, "width", 10.5],
@@ -4464,6 +4464,30 @@ TEST_P(GroupBy, PivotAllKeyTypes) {
44644464
}
44654465
}
44664466

4467+
TEST_P(GroupBy, PivotIntegerKeyTypes) {
4468+
auto value_type = float32();
4469+
std::vector<std::string> table_json = {R"([
4470+
[1, 78, 10.5],
4471+
[2, 78, 11.5]
4472+
])",
4473+
R"([
4474+
[2, 56, 12.5],
4475+
[3, 78, 13.5],
4476+
[1, 56, 14.5]
4477+
])"};
4478+
std::string expected_json = R"([
4479+
[1, {"56": 14.5, "78": 10.5} ],
4480+
[2, {"56": 12.5, "78": 11.5} ],
4481+
[3, {"56": null, "78": 13.5} ]
4482+
])";
4483+
PivotWiderOptions options(/*key_names=*/{"56", "78"});
4484+
4485+
for (const auto& key_type : IntTypes()) {
4486+
ARROW_SCOPED_TRACE("key_type = ", *key_type);
4487+
TestPivot(key_type, value_type, options, table_json, expected_json);
4488+
}
4489+
}
4490+
44674491
TEST_P(GroupBy, PivotNumericValues) {
44684492
auto key_type = utf8();
44694493
std::vector<std::string> table_json = {R"([
@@ -4749,6 +4773,21 @@ TEST_P(GroupBy, PivotDuplicateKeys) {
47494773
RunPivot(key_type, value_type, options, table_json));
47504774
}
47514775

4776+
TEST_P(GroupBy, PivotInvalidKeys) {
4777+
// Integer key type, but key names cannot be converted to int
4778+
auto key_type = int32();
4779+
auto value_type = float32();
4780+
std::vector<std::string> table_json = {R"([])"};
4781+
PivotWiderOptions options(/*key_names=*/{"height", "width"});
4782+
EXPECT_RAISES_WITH_MESSAGE_THAT(
4783+
Invalid, HasSubstr("Failed to parse string: 'width' as a scalar of type int32"),
4784+
RunPivot(key_type, value_type, options, table_json));
4785+
options.key_names = {"12.3", "45"};
4786+
EXPECT_RAISES_WITH_MESSAGE_THAT(
4787+
Invalid, HasSubstr("Failed to parse string: '12.3' as a scalar of type int32"),
4788+
RunPivot(key_type, value_type, options, table_json));
4789+
}
4790+
47524791
TEST_P(GroupBy, PivotDuplicateValues) {
47534792
auto key_type = utf8();
47544793
auto value_type = float32();

‎cpp/src/arrow/compute/api_aggregate.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -202,9 +202,10 @@ class ARROW_EXPORT TDigestOptions : public FunctionOptions {
202202
/// - The corresponding `Aggregate::target` must have two FieldRef elements;
203203
/// the first one points to the pivot key column, the second points to the
204204
/// pivoted data column.
205-
/// - The pivot key column must be string-like; its values will be matched
206-
/// against `key_names` in order to dispatch the pivoted data into the
207-
/// output.
205+
/// - The pivot key column can be string-like or integer; its values will be
206+
/// matched against `key_names` in order to dispatch the pivoted data into
207+
/// the output. If the pivot key column is not string-like, the `key_names`
208+
/// will be cast to the pivot key type.
208209
///
209210
/// "pivot_wider" example
210211
/// ---------------------

‎cpp/src/arrow/compute/exec.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ struct ExecValue {
276276
ArraySpan array = {};
277277
const Scalar* scalar = NULLPTR;
278278

279-
ExecValue(Scalar* scalar) // NOLINT implicit conversion
279+
ExecValue(const Scalar* scalar) // NOLINT implicit conversion
280280
: scalar(scalar) {}
281281

282282
ExecValue(ArraySpan array) // NOLINT implicit conversion

‎cpp/src/arrow/compute/kernels/aggregate_pivot.cc

+49-32
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "arrow/scalar.h"
2323
#include "arrow/util/bit_run_reader.h"
2424
#include "arrow/util/logging.h"
25+
#include "arrow/visit_data_inline.h"
2526

2627
namespace arrow::compute::internal {
2728
namespace {
@@ -30,7 +31,8 @@ using arrow::internal::VisitSetBitRunsVoid;
3031
using arrow::util::span;
3132

3233
struct PivotImpl : public ScalarAggregator {
33-
Status Init(const PivotWiderOptions& options, const std::vector<TypeHolder>& in_types) {
34+
Status Init(const PivotWiderOptions& options, const std::vector<TypeHolder>& in_types,
35+
ExecContext* ctx) {
3436
options_ = &options;
3537
key_type_ = in_types[0].GetSharedPtr();
3638
auto value_type = in_types[1].GetSharedPtr();
@@ -42,47 +44,56 @@ struct PivotImpl : public ScalarAggregator {
4244
values_.push_back(MakeNullScalar(value_type));
4345
}
4446
out_type_ = struct_(std::move(fields));
45-
ARROW_ASSIGN_OR_RAISE(key_mapper_, PivotWiderKeyMapper::Make(*key_type_, options_));
47+
ARROW_ASSIGN_OR_RAISE(key_mapper_,
48+
PivotWiderKeyMapper::Make(*key_type_, options_, ctx));
4649
return Status::OK();
4750
}
4851

4952
Status Consume(KernelContext*, const ExecSpan& batch) override {
5053
DCHECK_EQ(batch.num_values(), 2);
5154
if (batch[0].is_array()) {
52-
ARROW_ASSIGN_OR_RAISE(span<const PivotWiderKeyIndex> keys,
53-
key_mapper_->MapKeys(batch[0].array));
55+
ARROW_ASSIGN_OR_RAISE(auto keys_array, key_mapper_->MapKeys(batch[0].array));
56+
ArraySpan keys_span(*keys_array);
5457
if (batch[1].is_array()) {
5558
// Array keys, array values
5659
auto values = batch[1].array.ToArray();
57-
for (int64_t i = 0; i < batch.length; ++i) {
58-
PivotWiderKeyIndex key = keys[i];
59-
if (key != kNullPivotKey && !values->IsNull(i)) {
60-
if (ARROW_PREDICT_FALSE(values_[key]->is_valid)) {
61-
return DuplicateValue();
62-
}
63-
ARROW_ASSIGN_OR_RAISE(values_[key], values->GetScalar(i));
64-
DCHECK(values_[key]->is_valid);
65-
}
66-
}
60+
int64_t i = 0;
61+
RETURN_NOT_OK(VisitArraySpanInline<UInt32Type>(
62+
keys_span,
63+
[&](uint32_t key) {
64+
if (!values->IsNull(i)) {
65+
if (ARROW_PREDICT_FALSE(values_[key]->is_valid)) {
66+
return DuplicateValue();
67+
}
68+
ARROW_ASSIGN_OR_RAISE(values_[key], values->GetScalar(i));
69+
}
70+
++i;
71+
return Status::OK();
72+
},
73+
[&]() {
74+
++i;
75+
return Status::OK();
76+
}));
6777
} else {
6878
// Array keys, scalar value
6979
const Scalar* value = batch[1].scalar;
7080
if (value->is_valid) {
71-
for (int64_t i = 0; i < batch.length; ++i) {
72-
PivotWiderKeyIndex key = keys[i];
73-
if (key != kNullPivotKey) {
74-
if (ARROW_PREDICT_FALSE(values_[key]->is_valid)) {
75-
return DuplicateValue();
76-
}
77-
values_[key] = value->GetSharedPtr();
78-
}
79-
}
81+
RETURN_NOT_OK(VisitArraySpanInline<UInt32Type>(
82+
keys_span,
83+
[&](uint32_t key) {
84+
if (ARROW_PREDICT_FALSE(values_[key]->is_valid)) {
85+
return DuplicateValue();
86+
}
87+
values_[key] = value->GetSharedPtr();
88+
return Status::OK();
89+
},
90+
[] { return Status::OK(); }));
8091
}
8192
}
8293
} else {
83-
ARROW_ASSIGN_OR_RAISE(PivotWiderKeyIndex key,
84-
key_mapper_->MapKey(*batch[0].scalar));
85-
if (key != kNullPivotKey) {
94+
ARROW_ASSIGN_OR_RAISE(auto maybe_key, key_mapper_->MapKey(*batch[0].scalar));
95+
if (maybe_key.has_value()) {
96+
PivotWiderKeyIndex key = maybe_key.value();
8697
if (batch[1].is_array()) {
8798
// Scalar key, array values
8899
auto values = batch[1].array.ToArray();
@@ -145,10 +156,8 @@ struct PivotImpl : public ScalarAggregator {
145156
Result<std::unique_ptr<KernelState>> PivotInit(KernelContext* ctx,
146157
const KernelInitArgs& args) {
147158
const auto& options = checked_cast<const PivotWiderOptions&>(*args.options);
148-
DCHECK_EQ(args.inputs.size(), 2);
149-
DCHECK(is_base_binary_like(args.inputs[0].id()));
150159
auto state = std::make_unique<PivotImpl>();
151-
RETURN_NOT_OK(state->Init(options, args.inputs));
160+
RETURN_NOT_OK(state->Init(options, args.inputs, ctx->exec_context()));
152161
// GH-45718: This can be simplified once we drop the R openSUSE155 crossbow
153162
// job
154163
// R build with openSUSE155 requires an explicit shared_ptr construction
@@ -167,6 +176,8 @@ const FunctionDoc pivot_doc{
167176
"is emitted. If a pivot key doesn't appear, null is emitted.\n"
168177
"If more than one non-null value is encountered for a given pivot key,\n"
169178
"Invalid is raised.\n"
179+
"The pivot key column can be string, binary or integer. The `key_names`\n"
180+
"will be cast to the pivot key column type for matching.\n"
170181
"Behavior of unexpected pivot keys is controlled by `unexpected_key_behavior`\n"
171182
"in PivotWiderOptions."),
172183
{"pivot_keys", "pivot_values"},
@@ -179,11 +190,17 @@ void RegisterScalarAggregatePivot(FunctionRegistry* registry) {
179190

180191
auto func = std::make_shared<ScalarAggregateFunction>(
181192
"pivot_wider", Arity::Binary(), pivot_doc, &default_pivot_options);
182-
183-
for (auto key_type : BaseBinaryTypes()) {
184-
auto sig = KernelSignature::Make({key_type->id(), InputType::Any()},
193+
auto add_kernel = [&](InputType key_type) {
194+
auto sig = KernelSignature::Make({key_type, InputType::Any()},
185195
OutputType(ResolveOutputType));
186196
AddAggKernel(std::move(sig), PivotInit, func.get());
197+
};
198+
199+
for (const auto& key_type : BaseBinaryTypes()) {
200+
add_kernel(key_type->id());
201+
}
202+
for (const auto& key_type : IntTypes()) {
203+
add_kernel(key_type->id());
187204
}
188205
DCHECK_OK(registry->AddFunction(std::move(func)));
189206
}

‎cpp/src/arrow/compute/kernels/aggregate_test.cc

+33-3
Original file line numberDiff line numberDiff line change
@@ -4504,10 +4504,9 @@ TEST_F(TestPivotKernel, Basics) {
45044504
PivotWiderOptions(/*key_names=*/{"height", "width"}));
45054505
}
45064506

4507-
TEST_F(TestPivotKernel, AllKeyTypes) {
4507+
TEST_F(TestPivotKernel, BinaryKeyTypes) {
4508+
auto value_type = float32();
45084509
for (auto key_type : BaseBinaryTypes()) {
4509-
auto value_type = float32();
4510-
45114510
auto keys = ArrayFromJSON(key_type, R"(["width", "height"])");
45124511
auto values = ArrayFromJSON(value_type, "[10.5, 11.5]");
45134512
auto expected =
@@ -4518,6 +4517,19 @@ TEST_F(TestPivotKernel, AllKeyTypes) {
45184517
}
45194518
}
45204519

4520+
TEST_F(TestPivotKernel, IntegerKeyTypes) {
4521+
// It is possible to use an integer key column, while passing its string equivalent
4522+
// in PivotWiderOptions::key_names.
4523+
auto value_type = float32();
4524+
for (auto key_type : IntTypes()) {
4525+
auto keys = ArrayFromJSON(key_type, "[34, 12]");
4526+
auto values = ArrayFromJSON(value_type, "[10.5, 11.5]");
4527+
auto expected = ScalarFromJSON(
4528+
struct_({field("12", value_type), field("34", value_type)}), "[11.5, 10.5]");
4529+
AssertPivot(keys, values, *expected, PivotWiderOptions(/*key_names=*/{"12", "34"}));
4530+
}
4531+
}
4532+
45214533
TEST_F(TestPivotKernel, Numbers) {
45224534
auto key_type = utf8();
45234535
for (auto value_type : NumericTypes()) {
@@ -4724,6 +4736,24 @@ TEST_F(TestPivotKernel, DuplicateKeyNames) {
47244736
CallFunction("pivot_wider", {keys, values}, &options));
47254737
}
47264738

4739+
TEST_F(TestPivotKernel, InvalidKeyName) {
4740+
auto key_type = int32();
4741+
auto value_type = float32();
4742+
4743+
auto keys = ArrayFromJSON(key_type, "[]");
4744+
auto values = ArrayFromJSON(value_type, "[]");
4745+
auto options = PivotWiderOptions(/*key_names=*/{"height", "width"});
4746+
EXPECT_RAISES_WITH_MESSAGE_THAT(
4747+
Invalid,
4748+
::testing::HasSubstr("Failed to parse string: 'width' as a scalar of type int32"),
4749+
CallFunction("pivot_wider", {keys, values}, &options));
4750+
options.key_names = {"12.3", "45"};
4751+
EXPECT_RAISES_WITH_MESSAGE_THAT(
4752+
Invalid,
4753+
::testing::HasSubstr("Failed to parse string: '12.3' as a scalar of type int32"),
4754+
CallFunction("pivot_wider", {keys, values}, &options));
4755+
}
4756+
47274757
TEST_F(TestPivotKernel, DuplicateValues) {
47284758
auto key_type = utf8();
47294759
auto value_type = float32();

0 commit comments

Comments
 (0)
Please sign in to comment.