15
15
// specific language governing permissions and limitations
16
16
// under the License.
17
17
18
+ #include < algorithm>
18
19
#include < numeric>
19
20
20
21
#include < gtest/gtest.h>
30
31
#include " arrow/testing/matchers.h"
31
32
#include " arrow/testing/random.h"
32
33
#include " arrow/type_fwd.h"
34
+ #include " arrow/type_traits.h"
35
+ #include " arrow/util/bitmap_ops.h"
33
36
#include " arrow/util/checked_cast.h"
37
+ #include " arrow/util/key_value_metadata.h"
34
38
#include " arrow/util/string.h"
35
39
36
40
namespace arrow ::compute {
37
41
42
+ using ::arrow::internal::checked_cast;
38
43
using ::arrow::internal::checked_pointer_cast;
39
44
using ::arrow::internal::ToChars;
40
45
using ::testing::Eq;
@@ -605,17 +610,54 @@ struct TestGrouper {
605
610
}
606
611
}
607
612
613
+ void ExpectLookup (const std::string& key_json, const std::string& expected) {
614
+ auto expected_arr = ArrayFromJSON (uint32 (), expected);
615
+ if (shapes_.size () > 0 ) {
616
+ ExpectLookup (ExecBatchFromJSON (types_, shapes_, key_json), expected_arr);
617
+ } else {
618
+ ExpectLookup (ExecBatchFromJSON (types_, key_json), expected_arr);
619
+ }
620
+ }
621
+
622
+ void ExpectPopulate (const std::string& key_json) {
623
+ if (shapes_.size () > 0 ) {
624
+ ExpectPopulate (ExecBatchFromJSON (types_, shapes_, key_json));
625
+ } else {
626
+ ExpectPopulate (ExecBatchFromJSON (types_, key_json));
627
+ }
628
+ }
629
+
608
630
void ExpectConsume (const std::vector<Datum>& key_values, Datum expected) {
609
631
ASSERT_OK_AND_ASSIGN (auto key_batch, ExecBatch::Make (key_values));
610
632
ExpectConsume (key_batch, expected);
611
633
}
612
634
635
+ void ExpectLookup (const std::vector<Datum>& key_values, Datum expected) {
636
+ ASSERT_OK_AND_ASSIGN (auto key_batch, ExecBatch::Make (key_values));
637
+ ExpectLookup (key_batch, expected);
638
+ }
639
+
640
+ void ExpectPopulate (const std::vector<Datum>& key_values) {
641
+ ASSERT_OK_AND_ASSIGN (auto key_batch, ExecBatch::Make (key_values));
642
+ ExpectPopulate (key_batch);
643
+ }
644
+
613
645
void ExpectConsume (const ExecBatch& key_batch, Datum expected) {
614
646
Datum ids;
615
647
ConsumeAndValidate (key_batch, &ids);
616
648
AssertEquivalentIds (expected, ids);
617
649
}
618
650
651
+ void ExpectLookup (const ExecBatch& key_batch, Datum expected) {
652
+ Datum ids;
653
+ LookupAndValidate (key_batch, &ids);
654
+ AssertEquivalentIds (expected, ids);
655
+ }
656
+
657
+ void ExpectPopulate (const ExecBatch& key_batch) {
658
+ ASSERT_OK (grouper_->Populate (ExecSpan (key_batch)));
659
+ }
660
+
619
661
void ExpectUniques (const ExecBatch& uniques) {
620
662
EXPECT_THAT (grouper_->GetUniques (), ResultWith (Eq (uniques)));
621
663
}
@@ -633,27 +675,28 @@ struct TestGrouper {
633
675
auto right = actual.make_array ();
634
676
ASSERT_EQ (left->length (), right->length ()) << " #ids unequal" ;
635
677
int64_t num_ids = left->length ();
636
- auto left_data = left->data ();
637
- auto right_data = right->data ();
638
- auto left_ids = reinterpret_cast <const uint32_t *>(left_data->buffers [1 ]->data ());
639
- auto right_ids = reinterpret_cast <const uint32_t *>(right_data->buffers [1 ]->data ());
678
+ const auto & left_ids = checked_cast<const UInt32Array&>(*left);
679
+ const auto & right_ids = checked_cast<const UInt32Array&>(*right);
640
680
uint32_t max_left_id = 0 ;
641
681
uint32_t max_right_id = 0 ;
642
682
for (int64_t i = 0 ; i < num_ids; ++i) {
643
- if (left_ids[i] > max_left_id) {
644
- max_left_id = left_ids[i];
645
- }
646
- if (right_ids[i] > max_right_id) {
647
- max_right_id = right_ids[i];
683
+ ASSERT_EQ (left_ids.IsNull (i), right_ids.IsNull (i)) << " at index " << i;
684
+ if (left_ids.IsNull (i)) {
685
+ continue ;
648
686
}
687
+ max_left_id = std::max (max_left_id, left_ids.Value (i));
688
+ max_right_id = std::max (max_right_id, right_ids.Value (i));
649
689
}
650
690
std::vector<bool > right_to_left_present (max_right_id + 1 , false );
651
691
std::vector<bool > left_to_right_present (max_left_id + 1 , false );
652
692
std::vector<uint32_t > right_to_left (max_right_id + 1 );
653
693
std::vector<uint32_t > left_to_right (max_left_id + 1 );
654
694
for (int64_t i = 0 ; i < num_ids; ++i) {
655
- uint32_t left_id = left_ids[i];
656
- uint32_t right_id = right_ids[i];
695
+ if (left_ids.IsNull (i)) {
696
+ continue ;
697
+ }
698
+ uint32_t left_id = left_ids.Value (i);
699
+ uint32_t right_id = right_ids.Value (i);
657
700
if (!left_to_right_present[left_id]) {
658
701
left_to_right[left_id] = right_id;
659
702
left_to_right_present[left_id] = true ;
@@ -662,22 +705,33 @@ struct TestGrouper {
662
705
right_to_left[right_id] = left_id;
663
706
right_to_left_present[right_id] = true ;
664
707
}
665
- ASSERT_EQ (left_id, right_to_left[right_id]);
666
- ASSERT_EQ (right_id, left_to_right[left_id]);
708
+ ASSERT_EQ (left_id, right_to_left[right_id]) << " at index " << i ;
709
+ ASSERT_EQ (right_id, left_to_right[left_id]) << " at index " << i ;
667
710
}
668
711
}
669
712
670
713
void ConsumeAndValidate (const ExecBatch& key_batch, Datum* ids = nullptr ) {
671
714
ASSERT_OK_AND_ASSIGN (Datum id_batch, grouper_->Consume (ExecSpan (key_batch)));
672
715
673
- ValidateConsume (key_batch, id_batch);
716
+ ValidateConsume (key_batch, id_batch, /* can_be_null= */ false );
674
717
675
718
if (ids) {
676
719
*ids = std::move (id_batch);
677
720
}
678
721
}
679
722
680
- void ValidateConsume (const ExecBatch& key_batch, const Datum& id_batch) {
723
+ void LookupAndValidate (const ExecBatch& key_batch, Datum* ids = nullptr ) {
724
+ ASSERT_OK_AND_ASSIGN (Datum id_batch, grouper_->Lookup (ExecSpan (key_batch)));
725
+
726
+ ValidateConsume (key_batch, id_batch, /* can_be_null=*/ true );
727
+
728
+ if (ids) {
729
+ *ids = std::move (id_batch);
730
+ }
731
+ }
732
+
733
+ void ValidateConsume (const ExecBatch& key_batch, const Datum& id_batch,
734
+ bool can_be_null) {
681
735
if (uniques_.length == -1 ) {
682
736
ASSERT_OK_AND_ASSIGN (uniques_, grouper_->GetUniques ());
683
737
} else if (static_cast <int64_t >(grouper_->num_groups ()) > uniques_.length ) {
@@ -695,18 +749,49 @@ struct TestGrouper {
695
749
uniques_ = std::move (new_uniques);
696
750
}
697
751
698
- // check that the ids encode an equivalent key sequence
699
- auto ids = id_batch.make_array ();
700
- ValidateOutput (*ids);
752
+ // Check that the group ids encode an equivalent key sequence:
753
+ // calling Take(uniques, group_ids) should yield the original data.
754
+ auto group_ids = id_batch.make_array ();
755
+ ValidateOutput (*group_ids);
701
756
702
757
for (int i = 0 ; i < key_batch.num_values (); ++i) {
703
758
SCOPED_TRACE (ToChars (i) + " th key array" );
704
759
auto original =
705
760
key_batch[i].is_array ()
706
761
? key_batch[i].make_array ()
707
762
: *MakeArrayFromScalar (*key_batch[i].scalar (), key_batch.length );
708
- ASSERT_OK_AND_ASSIGN (auto encoded, Take (*uniques_[i].make_array (), *ids));
709
- AssertArraysEqual (*original, *encoded, /* verbose=*/ true ,
763
+ ASSERT_OK_AND_ASSIGN (auto encoded, Take (*uniques_[i].make_array (), *group_ids));
764
+ std::shared_ptr<Array> expected = original;
765
+ if (can_be_null && original->type_id () != Type::NA) {
766
+ // To compute the expected output, mask out the original entries that
767
+ // have a null group id.
768
+ auto expected_data = original->data ()->Copy ();
769
+ auto original_null_bitmap = original->null_bitmap ();
770
+ auto group_ids_null_bitmap = group_ids->null_bitmap ();
771
+
772
+ // This could be simplified with `OptionalBitmapAnd` (GH-45819).
773
+ std::shared_ptr<Buffer> null_bitmap;
774
+ if (original_null_bitmap && group_ids_null_bitmap) {
775
+ ASSERT_OK_AND_ASSIGN (null_bitmap,
776
+ ::arrow::internal::BitmapAnd (
777
+ default_memory_pool (), group_ids_null_bitmap->data(),
778
+ group_ids->offset(), original_null_bitmap->data(),
779
+ original->offset(), original->length(),
780
+ /* out_offset=*/ original->offset()));
781
+ } else if (group_ids_null_bitmap) {
782
+ ASSERT_OK_AND_ASSIGN (
783
+ null_bitmap, AllocateEmptyBitmap (original->offset () + original->length ()));
784
+ ::arrow::internal::CopyBitmap (group_ids_null_bitmap->data (),
785
+ group_ids->offset(), group_ids->length(),
786
+ null_bitmap->mutable_data(), original->offset());
787
+ } else {
788
+ null_bitmap = original_null_bitmap;
789
+ }
790
+ expected_data->buffers[0 ] = null_bitmap;
791
+ expected_data->null_count = kUnknownNullCount ;
792
+ expected = MakeArray(expected_data);
793
+ }
794
+ AssertArraysEqual (*expected, *encoded, /* verbose=*/ true ,
710
795
EqualOptions ().nans_equal (true ));
711
796
}
712
797
}
@@ -719,16 +804,27 @@ struct TestGrouper {
719
804
};
720
805
721
806
TEST (Grouper, BooleanKey) {
722
- TestGrouper g ({boolean ()});
723
-
724
- g.ExpectConsume (" [[true], [true]]" , " [0, 0]" );
725
-
726
- g.ExpectConsume (" [[true], [true]]" , " [0, 0]" );
727
-
728
- g.ExpectConsume (" [[false], [null]]" , " [1, 2]" );
729
-
730
- g.ExpectConsume (" [[true], [false], [true], [false], [null], [false], [null]]" ,
731
- " [0, 1, 0, 1, 2, 1, 2]" );
807
+ {
808
+ TestGrouper g ({boolean ()});
809
+ g.ExpectConsume (" [[true], [true]]" , " [0, 0]" );
810
+ g.ExpectConsume (" [[true], [true]]" , " [0, 0]" );
811
+ g.ExpectConsume (" [[false], [null]]" , " [1, 2]" );
812
+ g.ExpectConsume (" [[true], [false], [true], [false], [null], [false], [null]]" ,
813
+ " [0, 1, 0, 1, 2, 1, 2]" );
814
+ }
815
+ {
816
+ TestGrouper g ({boolean ()});
817
+ g.ExpectPopulate (" [[true], [true]]" );
818
+ g.ExpectPopulate (" [[true], [true]]" );
819
+ g.ExpectConsume (" [[false], [null]]" , " [1, 2]" );
820
+ g.ExpectConsume (" [[true], [false], [true], [false], [null], [false], [null]]" ,
821
+ " [0, 1, 0, 1, 2, 1, 2]" );
822
+ }
823
+ {
824
+ TestGrouper g ({boolean ()});
825
+ g.ExpectPopulate (" [[true], [null]]" );
826
+ g.ExpectLookup (" [[null], [false], [true], [null]]" , " [1, null, 0, 1]" );
827
+ }
732
828
}
733
829
734
830
TEST (Grouper, NumericKey) {
@@ -747,20 +843,41 @@ TEST(Grouper, NumericKey) {
747
843
}) {
748
844
SCOPED_TRACE (" key type: " + ty->ToString ());
749
845
750
- TestGrouper g ({ty});
846
+ {
847
+ TestGrouper g ({ty});
848
+ g.ExpectConsume (" [[3], [3]]" , " [0, 0]" );
849
+ g.ExpectUniques (" [[3]]" );
850
+
851
+ g.ExpectConsume (" [[3], [3]]" , " [0, 0]" );
852
+ g.ExpectUniques (" [[3]]" );
751
853
752
- g.ExpectConsume (" [[3 ], [3]] " , " [0, 0 ]" );
753
- g.ExpectUniques (" [[3]]" );
854
+ g.ExpectConsume (" [[27 ], [81], [81]] " , " [1, 2, 2 ]" );
855
+ g.ExpectUniques (" [[3], [27], [81 ]]" );
754
856
755
- g.ExpectConsume (" [[3], [3]]" , " [0, 0]" );
756
- g.ExpectUniques (" [[3]]" );
857
+ g.ExpectConsume (" [[3], [27], [3], [27], [null], [81], [27], [81]]" ,
858
+ " [0, 1, 0, 1, 3, 2, 1, 2]" );
859
+ g.ExpectUniques (" [[3], [27], [81], [null]]" );
860
+ }
861
+ {
862
+ TestGrouper g ({ty});
863
+ g.ExpectPopulate (" [[3], [3]]" );
864
+ g.ExpectPopulate (" [[3], [3]]" );
865
+ g.ExpectUniques (" [[3]]" );
757
866
758
- g. ExpectConsume (" [[27], [81], [81]] " , " [1, 2, 2 ]" );
759
- g.ExpectUniques (" [[3], [27], [81]]" );
867
+ g. ExpectPopulate (" [[27], [81], [81]]" );
868
+ g.ExpectUniques (" [[3], [27], [81]]" );
760
869
761
- g.ExpectConsume (" [[3], [27], [3], [27], [null], [81], [27], [81]]" ,
762
- " [0, 1, 0, 1, 3, 2, 1, 2]" );
763
- g.ExpectUniques (" [[3], [27], [81], [null]]" );
870
+ g.ExpectConsume (" [[3], [27], [3], [27], [null], [81], [27], [81]]" ,
871
+ " [0, 1, 0, 1, 3, 2, 1, 2]" );
872
+ g.ExpectUniques (" [[3], [27], [81], [null]]" );
873
+ }
874
+ {
875
+ TestGrouper g ({ty});
876
+ g.ExpectPopulate (" [[3], [3]]" );
877
+ g.ExpectPopulate (" [[27], [81], [81]]" );
878
+ g.ExpectLookup (" [[3], [27], [6], [27], [null], [81], [27], [6]]" ,
879
+ " [0, 1, null, 1, null, 2, 1, null]" );
880
+ }
764
881
}
765
882
}
766
883
@@ -780,21 +897,23 @@ TEST(Grouper, FloatingPointKey) {
780
897
781
898
TEST (Grouper, StringKey) {
782
899
for (auto ty : {utf8 (), large_utf8 (), fixed_size_binary (2 )}) {
783
- SCOPED_TRACE (" key type: " + ty->ToString ());
784
-
785
- TestGrouper g ({ty});
786
-
787
- g.ExpectConsume (R"( [["eh"], ["eh"]])" , " [0, 0]" );
788
-
789
- g.ExpectConsume (R"( [["eh"], ["eh"]])" , " [0, 0]" );
790
-
791
- g.ExpectConsume (R"( [["be"], [null]])" , " [1, 2]" );
900
+ ARROW_SCOPED_TRACE (" key type = " , *ty);
901
+ {
902
+ TestGrouper g ({ty});
903
+ g.ExpectConsume (R"( [["eh"], ["eh"]])" , " [0, 0]" );
904
+ g.ExpectConsume (R"( [["eh"], ["eh"]])" , " [0, 0]" );
905
+ g.ExpectConsume (R"( [["be"], [null]])" , " [1, 2]" );
906
+ }
907
+ {
908
+ TestGrouper g ({ty});
909
+ g.ExpectConsume (R"( [["eh"], ["eh"]])" , " [0, 0]" );
910
+ g.ExpectConsume (R"( [["be"], [null]])" , " [1, 2]" );
911
+ g.ExpectLookup (R"( [["be"], [null], ["da"]])" , " [1, 2, null]" );
912
+ }
792
913
}
793
914
}
794
915
795
916
TEST (Grouper, DictKey) {
796
- TestGrouper g ({dictionary (int32 (), utf8 ())});
797
-
798
917
// For dictionary keys, all batches must share a single dictionary.
799
918
// Eventually, differing dictionaries will be unified and indices transposed
800
919
// during encoding to relieve this restriction.
@@ -804,25 +923,47 @@ TEST(Grouper, DictKey) {
804
923
return Datum (*DictionaryArray::FromArrays (ArrayFromJSON (int32 (), indices), dict));
805
924
};
806
925
807
- // NB: null index is not considered equivalent to index=3 (which encodes null in dict)
808
- g.ExpectConsume ({WithIndices (" [3, 1, null, 0, 2]" )},
809
- ArrayFromJSON (uint32 (), " [0, 1, 2, 3, 4]" ));
810
-
811
- g = TestGrouper ({dictionary (int32 (), utf8 ())});
812
-
813
- g.ExpectConsume ({WithIndices (" [0, 1, 2, 3, null]" )},
814
- ArrayFromJSON (uint32 (), " [0, 1, 2, 3, 4]" ));
815
-
816
- g.ExpectConsume ({WithIndices (" [3, 1, null, 0, 2]" )},
817
- ArrayFromJSON (uint32 (), " [3, 1, 4, 0, 2]" ));
818
-
819
- auto dict_arr = *DictionaryArray::FromArrays (
820
- ArrayFromJSON (int32 (), " [0, 1]" ),
821
- ArrayFromJSON (utf8 (), R"( ["different", "dictionary"])" ));
822
- ExecSpan dict_span ({*dict_arr->data ()}, 2 );
823
- EXPECT_RAISES_WITH_MESSAGE_THAT (NotImplemented,
824
- HasSubstr (" Unifying differing dictionaries" ),
825
- g.grouper_ ->Consume (dict_span));
926
+ {
927
+ TestGrouper g ({dictionary (int32 (), utf8 ())});
928
+ // NB: null index is not considered equivalent to index=3 (which encodes null in dict)
929
+ g.ExpectConsume ({WithIndices (" [3, 1, null, 0, 2]" )},
930
+ ArrayFromJSON (uint32 (), " [0, 1, 2, 3, 4]" ));
931
+ }
932
+ {
933
+ TestGrouper g ({dictionary (int32 (), utf8 ())});
934
+ g.ExpectPopulate ({WithIndices (" [3, 1, null, 2]" )});
935
+ g.ExpectConsume ({WithIndices (" [1, null, 3, 0, 2]" )},
936
+ ArrayFromJSON (uint32 (), " [1, 2, 0, 4, 3]" ));
937
+ }
938
+ {
939
+ TestGrouper g ({dictionary (int32 (), utf8 ())});
940
+ g.ExpectPopulate ({WithIndices (" [3, 1, null, 2]" )});
941
+ g.ExpectLookup ({WithIndices (" [1, null, 3, 0, 2]" )},
942
+ ArrayFromJSON (uint32 (), " [1, 2, 0, null, 3]" ));
943
+ }
944
+ {
945
+ TestGrouper g ({dictionary (int32 (), utf8 ())});
946
+
947
+ g.ExpectConsume ({WithIndices (" [0, 1, 2, 3, null]" )},
948
+ ArrayFromJSON (uint32 (), " [0, 1, 2, 3, 4]" ));
949
+
950
+ g.ExpectConsume ({WithIndices (" [3, 1, null, 0, 2]" )},
951
+ ArrayFromJSON (uint32 (), " [3, 1, 4, 0, 2]" ));
952
+
953
+ auto dict_arr = *DictionaryArray::FromArrays (
954
+ ArrayFromJSON (int32 (), " [0, 1]" ),
955
+ ArrayFromJSON (utf8 (), R"( ["different", "dictionary"])" ));
956
+ ExecSpan dict_span ({*dict_arr->data ()}, 2 );
957
+ EXPECT_RAISES_WITH_MESSAGE_THAT (NotImplemented,
958
+ HasSubstr (" Unifying differing dictionaries" ),
959
+ g.grouper_ ->Consume (dict_span));
960
+ EXPECT_RAISES_WITH_MESSAGE_THAT (NotImplemented,
961
+ HasSubstr (" Unifying differing dictionaries" ),
962
+ g.grouper_ ->Populate (dict_span));
963
+ EXPECT_RAISES_WITH_MESSAGE_THAT (NotImplemented,
964
+ HasSubstr (" Unifying differing dictionaries" ),
965
+ g.grouper_ ->Lookup (dict_span));
966
+ }
826
967
}
827
968
828
969
// GH-45393: Test combinations of numeric type keys of different lengths.
@@ -834,55 +975,80 @@ TEST(Grouper, MultipleIntKeys) {
834
975
ARROW_SCOPED_TRACE (" t1=" , t1->ToString ());
835
976
for (auto & t2 : types) {
836
977
ARROW_SCOPED_TRACE (" t2=" , t2->ToString ());
837
- TestGrouper g ({t0, t1, t2});
838
-
839
- g.ExpectConsume (R"( [[0, 1, 2], [0, 1, 2]])" , " [0, 0]" );
840
- g.ExpectConsume (R"( [[0, 1, 2], [null, 1, 2]])" , " [0, 1]" );
841
- g.ExpectConsume (R"( [[0, 1, 2], [0, null, 2]])" , " [0, 2]" );
842
- g.ExpectConsume (R"( [[0, 1, 2], [0, 1, null]])" , " [0, 3]" );
843
-
844
- g.ExpectUniques (" [[0, 1, 2], [null, 1, 2], [0, null, 2], [0, 1, null]]" );
978
+ {
979
+ TestGrouper g ({t0, t1, t2});
980
+
981
+ g.ExpectConsume (R"( [[0, 1, 2], [0, 1, 2]])" , " [0, 0]" );
982
+ g.ExpectConsume (R"( [[0, 1, 2], [null, 1, 2]])" , " [0, 1]" );
983
+ g.ExpectConsume (R"( [[0, 1, 2], [0, null, 2]])" , " [0, 2]" );
984
+ g.ExpectConsume (R"( [[0, 1, 2], [0, 1, null]])" , " [0, 3]" );
985
+
986
+ g.ExpectUniques (" [[0, 1, 2], [null, 1, 2], [0, null, 2], [0, 1, null]]" );
987
+ }
988
+ {
989
+ TestGrouper g ({t0, t1, t2});
990
+
991
+ g.ExpectPopulate (R"( [[0, 1, 2], [0, 1, 2]])" );
992
+ g.ExpectPopulate (R"( [[0, 1, 2], [0, null, 2]])" );
993
+ g.ExpectLookup (R"( [[0, null, 2], [0, 1, 2], [null, 1, 0], [0, null, 2]])" ,
994
+ " [1, 0, null, 1]" );
995
+ g.ExpectLookup (R"( [[0, null, 2], [0, 1, 2], [null, 1, 0], [0, null, 2]])" ,
996
+ " [1, 0, null, 1]" );
997
+
998
+ g.ExpectUniques (" [[0, 1, 2], [0, null, 2]]" );
999
+ }
845
1000
}
846
1001
}
847
1002
}
848
1003
}
849
1004
850
1005
TEST (Grouper, StringInt64Key) {
851
- TestGrouper g ({utf8 (), int64 ()});
852
-
853
- g.ExpectConsume (R"( [["eh", 0], ["eh", 0]])" , " [0, 0]" );
854
-
855
- g.ExpectConsume (R"( [["eh", 0], ["eh", null]])" , " [0, 1]" );
856
-
857
- g.ExpectConsume (R"( [["eh", 1], ["bee", 1]])" , " [2, 3]" );
858
-
859
- g.ExpectConsume (R"( [["eh", null], ["bee", 1]])" , " [1, 3]" );
860
-
861
- g = TestGrouper ({utf8 (), int64 ()});
862
-
863
- g.ExpectConsume (R"( [
864
- ["ex", 0],
865
- ["ex", 0],
866
- ["why", 0],
867
- ["ex", 1],
868
- ["why", 0],
869
- ["ex", 1],
870
- ["ex", 0],
871
- ["why", 1]
872
- ])" ,
873
- " [0, 0, 1, 2, 1, 2, 0, 3]" );
1006
+ for (auto string_type : {utf8 (), large_utf8 ()}) {
1007
+ ARROW_SCOPED_TRACE (" string_type = " , *string_type);
1008
+ {
1009
+ TestGrouper g ({string_type, int64 ()});
874
1010
875
- g.ExpectConsume (R"( [
876
- ["ex", 0],
877
- [null, 0],
878
- [null, 0],
879
- ["ex", 1],
880
- [null, null],
881
- ["ex", 1],
882
- ["ex", 0],
883
- ["why", null]
884
- ])" ,
885
- " [0, 4, 4, 2, 5, 2, 0, 6]" );
1011
+ g.ExpectConsume (R"( [["eh", 0], ["eh", 0]])" , " [0, 0]" );
1012
+ g.ExpectConsume (R"( [["eh", 0], ["eh", null]])" , " [0, 1]" );
1013
+ g.ExpectConsume (R"( [["eh", 1], ["bee", 1]])" , " [2, 3]" );
1014
+ g.ExpectConsume (R"( [["eh", null], ["bee", 1]])" , " [1, 3]" );
1015
+ }
1016
+ {
1017
+ TestGrouper g ({string_type, int64 ()});
1018
+
1019
+ g.ExpectPopulate (R"( [["eh", 0], ["eh", 0]])" );
1020
+ g.ExpectPopulate (R"( [["eh", 0], ["eh", null]])" );
1021
+ g.ExpectConsume (R"( [["eh", 1], ["bee", 1]])" , " [2, 3]" );
1022
+ g.ExpectConsume (R"( [["eh", null], ["bee", 1]])" , " [1, 3]" );
1023
+ g.ExpectLookup (R"( [["da", null], ["bee", 1]])" , " [null, 3]" );
1024
+ g.ExpectLookup (R"( [["da", null], ["bee", 1]])" , " [null, 3]" );
1025
+ }
1026
+ {
1027
+ TestGrouper g ({string_type, int64 ()});
1028
+ g.ExpectConsume (R"( [
1029
+ ["ex", 0],
1030
+ ["ex", 0],
1031
+ ["why", 0],
1032
+ ["ex", 1],
1033
+ ["why", 0],
1034
+ ["ex", 1],
1035
+ ["ex", 0],
1036
+ ["why", 1]
1037
+ ])" ,
1038
+ " [0, 0, 1, 2, 1, 2, 0, 3]" );
1039
+ g.ExpectConsume (R"( [
1040
+ ["ex", 0],
1041
+ [null, 0],
1042
+ [null, 0],
1043
+ ["ex", 1],
1044
+ [null, null],
1045
+ ["ex", 1],
1046
+ ["ex", 0],
1047
+ ["why", null]
1048
+ ])" ,
1049
+ " [0, 4, 4, 2, 5, 2, 0, 6]" );
1050
+ }
1051
+ }
886
1052
}
887
1053
888
1054
TEST (Grouper, DoubleStringInt64Key) {
@@ -898,42 +1064,88 @@ TEST(Grouper, DoubleStringInt64Key) {
898
1064
g.ExpectConsume (R"( [[-0.0, "be", 7], [0.0, "be", 7]])" , " [3, 4]" );
899
1065
}
900
1066
901
- TEST (Grouper, RandomInt64Keys) {
902
- TestGrouper g ({int64 ()});
1067
+ FieldVector AnnotateForRandomGeneration (FieldVector fields) {
1068
+ for (auto & field : fields) {
1069
+ // For each field, constrain random generation to ensure that group ids
1070
+ // can appear more than once.
1071
+ if (is_integer (*field->type ())) {
1072
+ field =
1073
+ field->WithMergedMetadata (key_value_metadata ({" min" , " max" }, {" 100" , " 10000" }));
1074
+ } else if (is_binary_like (*field->type ())) {
1075
+ // (note this is unsupported for large binary types)
1076
+ field = field->WithMergedMetadata (key_value_metadata ({" unique" }, {" 100" }));
1077
+ }
1078
+ field = field->WithMergedMetadata (key_value_metadata ({" null_probability" }, {" 0.1" }));
1079
+ }
1080
+ return fields;
1081
+ }
1082
+
1083
+ void TestRandomConsume (TestGrouper g) {
1084
+ // Exercise Consume
1085
+ auto fields = AnnotateForRandomGeneration (g.key_schema_ ->fields ());
903
1086
for (int i = 0 ; i < 4 ; ++i) {
904
1087
SCOPED_TRACE (ToChars (i) + " th key batch" );
905
1088
906
- ExecBatch key_batch{
907
- *random ::GenerateBatch (g.key_schema_ ->fields (), 1 << 12 , 0xDEADBEEF )};
1089
+ ExecBatch key_batch{*random ::GenerateBatch (fields, 1 << 12 , /* seed=*/ i + 1 )};
908
1090
g.ConsumeAndValidate (key_batch);
909
1091
}
910
1092
}
911
1093
912
- TEST (Grouper, RandomStringInt64Keys) {
913
- TestGrouper g ({utf8 (), int64 ()});
1094
+ void TestRandomLookup (TestGrouper g) {
1095
+ // Exercise Populate then Lookup
1096
+ auto fields = AnnotateForRandomGeneration (g.key_schema_ ->fields ());
1097
+ ExecBatch key_batch{*random ::GenerateBatch (fields, 1 << 12 , /* seed=*/ 1 )};
1098
+ ASSERT_OK (g.grouper_ ->Populate (ExecSpan{key_batch}));
914
1099
for (int i = 0 ; i < 4 ; ++i) {
915
1100
SCOPED_TRACE (ToChars (i) + " th key batch" );
916
1101
917
- ExecBatch key_batch{
918
- *random ::GenerateBatch (g.key_schema_ ->fields (), 1 << 12 , 0xDEADBEEF )};
919
- g.ConsumeAndValidate (key_batch);
1102
+ ExecBatch key_batch{*random ::GenerateBatch (fields, 1 << 12 , /* seed=*/ i + 1 )};
1103
+ g.LookupAndValidate (key_batch);
920
1104
}
921
1105
}
922
1106
923
- TEST (Grouper, RandomStringInt64DoubleInt32Keys ) {
924
- TestGrouper g ({ utf8 (), int64 (), float64 (), int32 ()} );
925
- for ( int i = 0 ; i < 4 ; ++i) {
926
- SCOPED_TRACE ( ToChars (i) + " th key batch " );
1107
+ TEST (Grouper, RandomInt64Keys ) {
1108
+ TestRandomConsume ( TestGrouper ({ int64 ()}) );
1109
+ TestRandomLookup ( TestGrouper ({ int64 ()}));
1110
+ }
927
1111
928
- ExecBatch key_batch{
929
- *random ::GenerateBatch (g.key_schema_ ->fields (), 1 << 12 , 0xDEADBEEF )};
930
- g.ConsumeAndValidate (key_batch);
1112
+ TEST (Grouper, RandomStringKeys) {
1113
+ for (auto string_type : {utf8 (), large_utf8 ()}) {
1114
+ ARROW_SCOPED_TRACE (" string_type = " , *string_type);
1115
+ TestRandomConsume (TestGrouper ({string_type}));
1116
+ TestRandomLookup (TestGrouper ({string_type}));
1117
+ }
1118
+ }
1119
+
1120
+ TEST (Grouper, RandomStringInt64Keys) {
1121
+ for (auto string_type : {utf8 (), large_utf8 ()}) {
1122
+ ARROW_SCOPED_TRACE (" string_type = " , *string_type);
1123
+ TestRandomConsume (TestGrouper ({string_type, int64 ()}));
1124
+ TestRandomLookup (TestGrouper ({string_type, int64 ()}));
931
1125
}
932
1126
}
933
1127
1128
+ TEST (Grouper, RandomStringInt64DoubleInt32Keys) {
1129
+ TestRandomConsume (TestGrouper ({utf8 (), int64 (), float64 (), int32 ()}));
1130
+ TestRandomLookup (TestGrouper ({utf8 (), int64 (), float64 (), int32 ()}));
1131
+ }
1132
+
934
1133
TEST (Grouper, NullKeys) {
935
- TestGrouper g ({null ()});
936
- g.ExpectConsume (" [[null], [null]]" , " [0, 0]" );
1134
+ {
1135
+ TestGrouper g ({null ()});
1136
+ g.ExpectConsume (" [[null], [null]]" , " [0, 0]" );
1137
+ }
1138
+ {
1139
+ TestGrouper g ({null ()});
1140
+ g.ExpectPopulate (" [[null], [null]]" );
1141
+ g.ExpectConsume (" [[null], [null]]" , " [0, 0]" );
1142
+ }
1143
+ {
1144
+ TestGrouper g ({null ()});
1145
+ g.ExpectLookup (" [[null], [null]]" , " [null, null]" );
1146
+ g.ExpectPopulate (" [[null], [null]]" );
1147
+ g.ExpectLookup (" [[null], [null], [null]]" , " [0, 0, 0]" );
1148
+ }
937
1149
}
938
1150
939
1151
TEST (Grouper, MultipleNullKeys) {
@@ -971,8 +1183,16 @@ TEST(Grouper, DoubleNullStringKey) {
971
1183
}
972
1184
973
1185
TEST (Grouper, EmptyNullKeys) {
974
- TestGrouper g ({null ()});
975
- g.ExpectConsume (" []" , " []" );
1186
+ {
1187
+ TestGrouper g ({null ()});
1188
+ g.ExpectConsume (" []" , " []" );
1189
+ }
1190
+ {
1191
+ TestGrouper g ({null ()});
1192
+ g.ExpectPopulate (" []" );
1193
+ g.ExpectConsume (" []" , " []" );
1194
+ g.ExpectLookup (" []" , " []" );
1195
+ }
976
1196
}
977
1197
978
1198
TEST (Grouper, MakeGroupings) {
@@ -1021,22 +1241,49 @@ TEST(Grouper, ScalarValues) {
1021
1241
ArgShape::SCALAR, ArgShape::SCALAR, ArgShape::ARRAY});
1022
1242
g.ExpectConsume (
1023
1243
R"( [
1024
- [true, 1, "1.00", "2.00", "ab", "foo", 2],
1025
- [true, 1, "1.00", "2.00", "ab", "foo", 2],
1026
- [true, 1, "1.00", "2.00", "ab", "foo", 3]
1027
- ])" ,
1244
+ [true, 1, "1.00", "2.00", "ab", "foo", 2],
1245
+ [true, 1, "1.00", "2.00", "ab", "foo", 2],
1246
+ [true, 1, "1.00", "2.00", "ab", "foo", 3]
1247
+ ])" ,
1028
1248
" [0, 0, 1]" );
1029
1249
}
1250
+ {
1251
+ TestGrouper g (
1252
+ {boolean (), int32 (), decimal128 (3 , 2 ), decimal256 (3 , 2 ), fixed_size_binary (2 ),
1253
+ str_type, int32 ()},
1254
+ {ArgShape::SCALAR, ArgShape::SCALAR, ArgShape::SCALAR, ArgShape::SCALAR,
1255
+ ArgShape::SCALAR, ArgShape::SCALAR, ArgShape::ARRAY});
1256
+ g.ExpectPopulate (
1257
+ R"( [
1258
+ [true, 1, "1.00", "2.00", "ab", "foo", 2],
1259
+ [true, 1, "1.00", "2.00", "ab", "foo", 2],
1260
+ [true, 1, "1.00", "2.00", "ab", "foo", 3]
1261
+ ])" );
1262
+ g.ExpectLookup (
1263
+ R"( [
1264
+ [true, 1, "1.00", "2.00", "ab", "foo", 3],
1265
+ [true, 1, "1.00", "2.00", "ab", "foo", 4],
1266
+ [true, 1, "1.00", "2.00", "ab", "foo", 2],
1267
+ [true, 1, "1.00", "2.00", "ab", "foo", 3]
1268
+ ])" ,
1269
+ " [1, null, 0, 1]" );
1270
+ }
1030
1271
{
1031
1272
auto dict_type = dictionary (int32 (), utf8 ());
1032
1273
TestGrouper g ({dict_type, str_type}, {ArgShape::SCALAR, ArgShape::SCALAR});
1033
- const auto dict = R"( ["foo", null])" ;
1274
+ const auto dict = R"( ["foo", null, "bar" ])" ;
1034
1275
g.ExpectConsume (
1035
1276
{DictScalarFromJSON (dict_type, " 0" , dict), ScalarFromJSON (str_type, R"( "")" )},
1036
1277
ArrayFromJSON (uint32 (), " [0]" ));
1037
1278
g.ExpectConsume (
1038
1279
{DictScalarFromJSON (dict_type, " 1" , dict), ScalarFromJSON (str_type, R"( "")" )},
1039
1280
ArrayFromJSON (uint32 (), " [1]" ));
1281
+ g.ExpectLookup (
1282
+ {DictScalarFromJSON (dict_type, " 1" , dict), ScalarFromJSON (str_type, R"( "")" )},
1283
+ ArrayFromJSON (uint32 (), " [1]" ));
1284
+ g.ExpectLookup (
1285
+ {DictScalarFromJSON (dict_type, " 2" , dict), ScalarFromJSON (str_type, R"( "")" )},
1286
+ ArrayFromJSON (uint32 (), " [null]" ));
1040
1287
}
1041
1288
}
1042
1289
}
0 commit comments