diff --git a/src/binder/expression/mod.rs b/src/binder/expression/mod.rs index 205535b24..df2ed18fd 100644 --- a/src/binder/expression/mod.rs +++ b/src/binder/expression/mod.rs @@ -133,6 +133,25 @@ impl BoundExpr { visitor.visit_expr(self); visitor.0 } + + pub fn format_name(&self, child_schema: &Vec) -> String { + match self { + Self::Constant(DataValue::Int64(num)) => format!("{}", num), + Self::Constant(DataValue::Int32(num)) => format!("{}", num), + Self::Constant(DataValue::Float64(num)) => format!("{:.3}", num), + Self::BinaryOp(expr) => { + let left_expr_name = expr.left_expr.format_name(child_schema); + let right_expr_name = expr.right_expr.format_name(child_schema); + format!("{}{}{}", left_expr_name, expr.op, right_expr_name) + } + Self::UnaryOp(expr) => { + let expr_name = expr.expr.format_name(child_schema); + format!("{}{}", expr.op, expr_name) + } + Self::InputRef(expr) => child_schema[expr.index].name().to_string(), + _ => "".to_string(), + } + } } impl std::fmt::Debug for BoundExpr { @@ -282,3 +301,86 @@ impl From<&Value> for DataValue { } } } + +#[cfg(test)] +mod tests { + use sqlparser::ast::{BinaryOperator, UnaryOperator}; + + use crate::binder::{BoundBinaryOp, BoundExpr, BoundInputRef, BoundUnaryOp}; + use crate::catalog::ColumnDesc; + use crate::types::{DataType, DataTypeKind, DataValue}; + + // test when BoundExpr is Constant + #[test] + fn test_format_name_constant() { + let expr = BoundExpr::Constant(DataValue::Int32(1_i32)); + assert_eq!("1", expr.format_name(&vec![])); + let expr = BoundExpr::Constant(DataValue::Int64(1_i64)); + assert_eq!("1", expr.format_name(&vec![])); + let expr = BoundExpr::Constant(DataValue::Float64(32.0_f64)); + assert_eq!("32.000", expr.format_name(&vec![])); + } + + // test when BoundExpr is UnaryOp(form like -a) + #[test] + fn test_format_name_unary_op() { + let data_type = DataType::new(DataTypeKind::Int(None), true); + let expr = BoundExpr::InputRef(BoundInputRef { + index: 0, + return_type: data_type.clone(), + }); + let child_schema = vec![ColumnDesc::new(data_type.clone(), "a".to_string(), false)]; + let expr = BoundExpr::UnaryOp(BoundUnaryOp { + op: UnaryOperator::Minus, + expr: Box::new(expr), + return_type: Some(data_type), + }); + assert_eq!("-a", expr.format_name(&child_schema)); + } + + // test when BoundExpr is BinaryOp + #[test] + fn test_format_name_binary_op() { + // forms like a + 1 + { + let left_data_type = DataType::new(DataTypeKind::Int(None), true); + let left_expr = BoundExpr::InputRef(BoundInputRef { + index: 0, + return_type: left_data_type.clone(), + }); + let right_expr = BoundExpr::Constant(DataValue::Int64(1_i64)); + let child_schema = vec![ColumnDesc::new(left_data_type, "a".to_string(), false)]; + let expr = BoundExpr::BinaryOp(BoundBinaryOp { + op: BinaryOperator::Plus, + left_expr: Box::new(left_expr), + right_expr: Box::new(right_expr), + return_type: Some(DataType::new(DataTypeKind::Int(None), true)), + }); + assert_eq!("a+1", expr.format_name(&child_schema)); + } + // forms like a + b + { + let data_type = DataType::new(DataTypeKind::Int(None), true); + let left_expr = BoundExpr::InputRef(BoundInputRef { + index: 0, + return_type: data_type.clone(), + }); + let right_expr = BoundExpr::InputRef(BoundInputRef { + index: 1, + return_type: data_type.clone(), + }); + let child_schema = vec![ + ColumnDesc::new(data_type.clone(), "a".to_string(), false), + ColumnDesc::new(data_type.clone(), "b".to_string(), false), + ]; + + let expr = BoundExpr::BinaryOp(BoundBinaryOp { + op: BinaryOperator::Plus, + left_expr: Box::new(left_expr), + right_expr: Box::new(right_expr), + return_type: Some(data_type), + }); + assert_eq!("a+b", expr.format_name(&child_schema)); + } + } +} diff --git a/src/optimizer/plan_nodes/logical_aggregate.rs b/src/optimizer/plan_nodes/logical_aggregate.rs index fe277e5c0..06c9e4db6 100644 --- a/src/optimizer/plan_nodes/logical_aggregate.rs +++ b/src/optimizer/plan_nodes/logical_aggregate.rs @@ -69,13 +69,15 @@ impl_plan_tree_node_for_unary!(LogicalAggregate); impl PlanNode for LogicalAggregate { fn schema(&self) -> Vec { let child_schema = self.child.schema(); - let mut input_refs = vec![]; self.group_keys .iter() - .for_each(|expr| expr.resolve_input_ref(&mut input_refs)); - input_refs - .iter() - .map(|expr| child_schema[expr.index].clone()) + .map(|expr| { + ColumnDesc::new( + expr.return_type().unwrap(), + expr.format_name(&child_schema), + false, + ) + }) .chain(self.agg_calls.iter().map(|agg_call| { agg_call .return_type