Skip to content

Commit c239a6d

Browse files
committed
feat: support sort agg
Signed-off-by: kikkon <nian920@outlook.com>
1 parent 6735778 commit c239a6d

File tree

2 files changed

+248
-0
lines changed

2 files changed

+248
-0
lines changed

src/executor/mod.rs

+3
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ use self::order::*;
4141
use self::projection::*;
4242
use self::simple_agg::*;
4343
#[allow(unused_imports)]
44+
use self::sort_agg::*;
45+
#[allow(unused_imports)]
4446
use self::sort_merge_join::*;
4547
use self::table_scan::*;
4648
use self::top_n::TopNExecutor;
@@ -72,6 +74,7 @@ mod nested_loop_join;
7274
mod order;
7375
mod projection;
7476
mod simple_agg;
77+
mod sort_agg;
7578
mod sort_merge_join;
7679
mod table_scan;
7780
mod top_n;

src/executor/sort_agg.rs

+245
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
// Copyright 2022 RisingLight Project Authors. Licensed under Apache-2.0.
2+
#![allow(dead_code)]
3+
use smallvec::SmallVec;
4+
5+
use super::*;
6+
use crate::array::{ArrayBuilderImpl, ArrayImpl};
7+
use crate::binder::BoundAggCall;
8+
use crate::types::{DataTypeExt, DataTypeKind};
9+
10+
pub struct SortAggExecutor {
11+
pub agg_calls: Vec<BoundAggCall>,
12+
pub group_keys: Vec<BoundExpr>,
13+
pub child: BoxedExecutor,
14+
}
15+
16+
impl SortAggExecutor {
17+
#[try_stream(boxed, ok = DataChunk, error = ExecutorError)]
18+
pub async fn execute(self) {
19+
let mut last_key = None::<HashKey>;
20+
let mut states = create_agg_states(&self.agg_calls);
21+
22+
#[for_await]
23+
for chunk in self.child {
24+
// Eval group keys and arguments
25+
let chunk = chunk?;
26+
let exprs: SmallVec<[ArrayImpl; 16]> = self
27+
.agg_calls
28+
.iter()
29+
.map(|agg| agg.args[0].eval(&chunk))
30+
.try_collect()?;
31+
let group_cols: SmallVec<[ArrayImpl; 16]> = self
32+
.group_keys
33+
.iter()
34+
.map(|e| e.eval(&chunk))
35+
.try_collect()?;
36+
37+
let num_rows = chunk.cardinality();
38+
for row_idx in 0..num_rows {
39+
// Create group key
40+
let mut group_key = HashKey::new();
41+
for col in group_cols.iter() {
42+
group_key.push(col.get(row_idx));
43+
}
44+
// Check group key & lask key
45+
match last_key {
46+
Some(last_key) => {
47+
if last_key != group_key {
48+
yield Self::finish_agg(&states);
49+
states = create_agg_states(&self.agg_calls);
50+
}
51+
}
52+
None => (),
53+
}
54+
for (state, expr) in states.iter_mut().zip_eq(&exprs) {
55+
state.update_single(&expr.get(row_idx))?;
56+
}
57+
last_key = Some(group_key);
58+
}
59+
}
60+
yield Self::finish_agg(&states);
61+
62+
}
63+
64+
fn finish_agg(states: &SmallVec<[Box<dyn AggregationState>; 16]>) -> DataChunk {
65+
return states
66+
.iter()
67+
.map(|s| {
68+
let result = &s.output();
69+
match &result.data_type() {
70+
Some(r) => {
71+
let mut builder = ArrayBuilderImpl::with_capacity(1, r);
72+
builder.push(result);
73+
builder.finish()
74+
}
75+
None => ArrayBuilderImpl::new(&DataTypeKind::Int(None).nullable()).finish(),
76+
}
77+
})
78+
.collect::<DataChunk>();
79+
}
80+
}
81+
82+
#[cfg(test)]
83+
mod tests {
84+
85+
use super::*;
86+
use crate::array::ArrayImpl;
87+
use crate::binder::{AggKind, BoundInputRef};
88+
use crate::types::{DataType, DataTypeKind};
89+
90+
#[tokio::test]
91+
async fn test_multi_group_agg() {
92+
test_group_agg(
93+
vec![0, 1],
94+
vec![1, 2],
95+
vec![
96+
vec![1.1, 0.2, 0.3, 0.4, 0.5],
97+
vec![1.1, 1.1, 1.3, 1.4, 1.5],
98+
vec![1.2, 1.2, 2.3, 2.4, 2.5],
99+
],
100+
vec![
101+
vec![1.3, 2.2],
102+
vec![0.3, 1.3],
103+
vec![0.4, 1.4],
104+
vec![0.5, 1.5],
105+
],
106+
)
107+
.await;
108+
test_group_agg(
109+
vec![0, 1],
110+
vec![1, 2],
111+
vec![
112+
vec![0.1, 0.2, 0.3, 0.4, 0.5],
113+
vec![1.1, 1.1, 1.3, 1.4, 1.5],
114+
vec![1.1, 1.2, 2.3, 2.4, 2.5],
115+
],
116+
vec![
117+
vec![0.1, 1.1],
118+
vec![0.2, 1.1],
119+
vec![0.3, 1.3],
120+
vec![0.4, 1.4],
121+
vec![0.5, 1.5],
122+
],
123+
)
124+
.await
125+
}
126+
127+
#[tokio::test]
128+
async fn test_single_group_agg() {
129+
test_group_agg(
130+
vec![0, 1],
131+
vec![0],
132+
vec![vec![1.0, 1.0], vec![1.0, 2.0]],
133+
vec![vec![2.0, 3.0]],
134+
)
135+
.await;
136+
test_group_agg(
137+
vec![0, 1],
138+
vec![1],
139+
vec![
140+
vec![1.1, 0.2, 0.3, 0.4, 0.5],
141+
vec![1.1, 1.1, 1.3, 1.4, 1.5],
142+
vec![2.1, 2.2, 2.3, 2.4, 2.5],
143+
],
144+
vec![
145+
vec![1.3, 2.2],
146+
vec![0.3, 1.3],
147+
vec![0.4, 1.4],
148+
vec![0.5, 1.5],
149+
],
150+
)
151+
.await;
152+
test_group_agg(
153+
vec![0, 1],
154+
vec![1],
155+
vec![
156+
vec![0.1, 0.2, 0.3, 0.4, 0.5],
157+
vec![1.1, 1.2, 1.3, 1.4, 1.5],
158+
vec![2.1, 2.2, 2.3, 2.4, 2.5],
159+
],
160+
vec![
161+
vec![0.1, 1.1],
162+
vec![0.2, 1.2],
163+
vec![0.3, 1.3],
164+
vec![0.4, 1.4],
165+
vec![0.5, 1.5],
166+
],
167+
)
168+
.await
169+
}
170+
171+
async fn test_group_agg(
172+
agg_call_index: Vec<usize>,
173+
group_key_index: Vec<usize>,
174+
cols: Vec<Vec<f64>>,
175+
expected_cols: Vec<Vec<f64>>,
176+
) {
177+
let mut agg_calls = vec![];
178+
for index in agg_call_index {
179+
agg_calls.push(create_sum_agg_call(index));
180+
}
181+
182+
let mut group_keys = vec![];
183+
for index in group_key_index {
184+
group_keys.push(create_input_ref(index));
185+
}
186+
187+
let child: BoxedExecutor = async_stream::try_stream! {
188+
let mut child = vec![];
189+
for col in cols {
190+
child.push(ArrayImpl::new_float64(col.into_iter().collect()));
191+
}
192+
yield child.into_iter().collect()
193+
}
194+
.boxed();
195+
196+
let executor = SortAggExecutor {
197+
agg_calls,
198+
group_keys,
199+
child,
200+
};
201+
let mut executor = executor.execute();
202+
203+
let mut expected_cols_size = 0;
204+
while let Some(chunk) = executor.next().await {
205+
let chunk = chunk.unwrap();
206+
let expected_col = expected_cols.get(expected_cols_size).unwrap();
207+
let mut expected_array = vec![];
208+
for data in expected_col {
209+
expected_array.push(ArrayImpl::new_float64(vec![*data].into_iter().collect()));
210+
}
211+
assert_eq!(chunk.arrays(), expected_array);
212+
expected_cols_size += 1;
213+
}
214+
assert_eq!(expected_cols_size, expected_cols.len());
215+
}
216+
217+
fn create_sum_agg_call(value: usize) -> BoundAggCall {
218+
BoundAggCall {
219+
kind: AggKind::Sum,
220+
args: vec![BoundExpr::InputRef(BoundInputRef {
221+
index: value,
222+
return_type: DataType::new(
223+
DataTypeKind::Decimal(Option::Some(15), Option::Some(2)),
224+
false,
225+
),
226+
})],
227+
return_type: DataType::new(DataTypeKind::Double, false),
228+
}
229+
}
230+
231+
fn create_input_ref(value: usize) -> BoundExpr {
232+
BoundExpr::InputRef(BoundInputRef {
233+
index: value,
234+
return_type: DataType::new(DataTypeKind::Int(Option::None), false),
235+
})
236+
}
237+
238+
fn create_expected_col(cols: Vec<f64>) -> Vec<ArrayImpl> {
239+
let mut expected = Vec::new();
240+
for col in cols {
241+
expected.push(ArrayImpl::new_float64([col].into_iter().collect()));
242+
}
243+
expected
244+
}
245+
}

0 commit comments

Comments
 (0)