Skip to content

Commit 4afb527

Browse files
committed
feat: support sort agg
Signed-off-by: kikkon <nian920@outlook.com>
1 parent 6735778 commit 4afb527

File tree

2 files changed

+247
-0
lines changed

2 files changed

+247
-0
lines changed

src/executor/mod.rs

+3
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ use self::order::*;
4141
use self::projection::*;
4242
use self::simple_agg::*;
4343
#[allow(unused_imports)]
44+
use self::sort_agg::*;
45+
#[allow(unused_imports)]
4446
use self::sort_merge_join::*;
4547
use self::table_scan::*;
4648
use self::top_n::TopNExecutor;
@@ -72,6 +74,7 @@ mod nested_loop_join;
7274
mod order;
7375
mod projection;
7476
mod simple_agg;
77+
mod sort_agg;
7578
mod sort_merge_join;
7679
mod table_scan;
7780
mod top_n;

src/executor/sort_agg.rs

+244
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
// Copyright 2022 RisingLight Project Authors. Licensed under Apache-2.0.
2+
#![allow(dead_code)]
3+
use smallvec::SmallVec;
4+
5+
use super::*;
6+
use crate::array::{ArrayBuilderImpl, ArrayImpl};
7+
use crate::binder::BoundAggCall;
8+
use crate::types::{DataTypeExt, DataTypeKind};
9+
10+
pub struct SortAggExecutor {
11+
pub agg_calls: Vec<BoundAggCall>,
12+
pub group_keys: Vec<BoundExpr>,
13+
pub child: BoxedExecutor,
14+
}
15+
16+
impl SortAggExecutor {
17+
#[try_stream(boxed, ok = DataChunk, error = ExecutorError)]
18+
pub async fn execute(self) {
19+
let mut last_key = None::<HashKey>;
20+
let mut states = create_agg_states(&self.agg_calls);
21+
22+
#[for_await]
23+
for chunk in self.child {
24+
// Eval group keys and arguments
25+
let chunk = chunk?;
26+
let exprs: SmallVec<[ArrayImpl; 16]> = self
27+
.agg_calls
28+
.iter()
29+
.map(|agg| agg.args[0].eval(&chunk))
30+
.try_collect()?;
31+
let group_cols: SmallVec<[ArrayImpl; 16]> = self
32+
.group_keys
33+
.iter()
34+
.map(|e| e.eval(&chunk))
35+
.try_collect()?;
36+
37+
let num_rows = chunk.cardinality();
38+
for row_idx in 0..num_rows {
39+
// Create group key
40+
let mut group_key = HashKey::new();
41+
for col in group_cols.iter() {
42+
group_key.push(col.get(row_idx));
43+
}
44+
// Check group key & lask key
45+
match last_key {
46+
Some(last_key) => {
47+
if last_key != group_key {
48+
yield finish_agg(&states);
49+
states = create_agg_states(&self.agg_calls);
50+
}
51+
}
52+
None => (),
53+
}
54+
for (state, expr) in states.iter_mut().zip_eq(&exprs) {
55+
state.update_single(&expr.get(row_idx))?;
56+
}
57+
last_key = Some(group_key);
58+
}
59+
}
60+
yield finish_agg(&states);
61+
62+
fn finish_agg(states: &SmallVec<[Box<dyn AggregationState>; 16]>) -> DataChunk {
63+
return states
64+
.iter()
65+
.map(|s| {
66+
let result = &s.output();
67+
match &result.data_type() {
68+
Some(r) => {
69+
let mut builder = ArrayBuilderImpl::with_capacity(1, r);
70+
builder.push(result);
71+
builder.finish()
72+
}
73+
None => ArrayBuilderImpl::new(&DataTypeKind::Int(None).nullable()).finish(),
74+
}
75+
})
76+
.collect::<DataChunk>();
77+
}
78+
}
79+
}
80+
81+
#[cfg(test)]
82+
mod tests {
83+
84+
use super::*;
85+
use crate::array::ArrayImpl;
86+
use crate::binder::{AggKind, BoundInputRef};
87+
use crate::types::{DataType, DataTypeKind};
88+
89+
#[tokio::test]
90+
async fn test_multi_group_agg() {
91+
test_group_agg(
92+
vec![0, 1],
93+
vec![1, 2],
94+
vec![
95+
vec![1.1, 0.2, 0.3, 0.4, 0.5],
96+
vec![1.1, 1.1, 1.3, 1.4, 1.5],
97+
vec![1.2, 1.2, 2.3, 2.4, 2.5],
98+
],
99+
vec![
100+
vec![1.3, 2.2],
101+
vec![0.3, 1.3],
102+
vec![0.4, 1.4],
103+
vec![0.5, 1.5],
104+
],
105+
)
106+
.await;
107+
test_group_agg(
108+
vec![0, 1],
109+
vec![1, 2],
110+
vec![
111+
vec![0.1, 0.2, 0.3, 0.4, 0.5],
112+
vec![1.1, 1.1, 1.3, 1.4, 1.5],
113+
vec![1.1, 1.2, 2.3, 2.4, 2.5],
114+
],
115+
vec![
116+
vec![0.1, 1.1],
117+
vec![0.2, 1.1],
118+
vec![0.3, 1.3],
119+
vec![0.4, 1.4],
120+
vec![0.5, 1.5],
121+
],
122+
)
123+
.await
124+
}
125+
126+
#[tokio::test]
127+
async fn test_single_group_agg() {
128+
test_group_agg(
129+
vec![0, 1],
130+
vec![0],
131+
vec![vec![1.0, 1.0], vec![1.0, 2.0]],
132+
vec![vec![2.0, 3.0]],
133+
)
134+
.await;
135+
test_group_agg(
136+
vec![0, 1],
137+
vec![1],
138+
vec![
139+
vec![1.1, 0.2, 0.3, 0.4, 0.5],
140+
vec![1.1, 1.1, 1.3, 1.4, 1.5],
141+
vec![2.1, 2.2, 2.3, 2.4, 2.5],
142+
],
143+
vec![
144+
vec![1.3, 2.2],
145+
vec![0.3, 1.3],
146+
vec![0.4, 1.4],
147+
vec![0.5, 1.5],
148+
],
149+
)
150+
.await;
151+
test_group_agg(
152+
vec![0, 1],
153+
vec![1],
154+
vec![
155+
vec![0.1, 0.2, 0.3, 0.4, 0.5],
156+
vec![1.1, 1.2, 1.3, 1.4, 1.5],
157+
vec![2.1, 2.2, 2.3, 2.4, 2.5],
158+
],
159+
vec![
160+
vec![0.1, 1.1],
161+
vec![0.2, 1.2],
162+
vec![0.3, 1.3],
163+
vec![0.4, 1.4],
164+
vec![0.5, 1.5],
165+
],
166+
)
167+
.await
168+
}
169+
170+
async fn test_group_agg(
171+
agg_call_index: Vec<usize>,
172+
group_key_index: Vec<usize>,
173+
cols: Vec<Vec<f64>>,
174+
expected_cols: Vec<Vec<f64>>,
175+
) {
176+
let mut agg_calls = vec![];
177+
for index in agg_call_index {
178+
agg_calls.push(create_sum_agg_call(index));
179+
}
180+
181+
let mut group_keys = vec![];
182+
for index in group_key_index {
183+
group_keys.push(create_input_ref(index));
184+
}
185+
186+
let child: BoxedExecutor = async_stream::try_stream! {
187+
let mut child = vec![];
188+
for col in cols {
189+
child.push(ArrayImpl::new_float64(col.into_iter().collect()));
190+
}
191+
yield child.into_iter().collect()
192+
}
193+
.boxed();
194+
195+
let executor = SortAggExecutor {
196+
agg_calls,
197+
group_keys,
198+
child,
199+
};
200+
let mut executor = executor.execute();
201+
202+
let mut expected_cols_size = 0;
203+
while let Some(chunk) = executor.next().await {
204+
let chunk = chunk.unwrap();
205+
let expected_col = expected_cols.get(expected_cols_size).unwrap();
206+
let mut expected_array = vec![];
207+
for data in expected_col {
208+
expected_array.push(ArrayImpl::new_float64(vec![*data].into_iter().collect()));
209+
}
210+
assert_eq!(chunk.arrays(), expected_array);
211+
expected_cols_size += 1;
212+
}
213+
assert_eq!(expected_cols_size, expected_cols.len());
214+
}
215+
216+
fn create_sum_agg_call(value: usize) -> BoundAggCall {
217+
BoundAggCall {
218+
kind: AggKind::Sum,
219+
args: vec![BoundExpr::InputRef(BoundInputRef {
220+
index: value,
221+
return_type: DataType::new(
222+
DataTypeKind::Decimal(Option::Some(15), Option::Some(2)),
223+
false,
224+
),
225+
})],
226+
return_type: DataType::new(DataTypeKind::Double, false),
227+
}
228+
}
229+
230+
fn create_input_ref(value: usize) -> BoundExpr {
231+
BoundExpr::InputRef(BoundInputRef {
232+
index: value,
233+
return_type: DataType::new(DataTypeKind::Int(Option::None), false),
234+
})
235+
}
236+
237+
fn create_expected_col(cols: Vec<f64>) -> Vec<ArrayImpl> {
238+
let mut expected = Vec::new();
239+
for col in cols {
240+
expected.push(ArrayImpl::new_float64([col].into_iter().collect()));
241+
}
242+
expected
243+
}
244+
}

0 commit comments

Comments
 (0)