Skip to content

Commit ddb51f5

Browse files
committed
Fix regression
1 parent b1e5ec6 commit ddb51f5

File tree

10 files changed

+80
-42
lines changed

10 files changed

+80
-42
lines changed

rust/ballista/rust/core/proto/ballista.proto

+6
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ message LogicalExprNode {
5959
InListNode in_list = 14;
6060
bool wildcard = 15;
6161
ScalarFunctionNode scalar_function = 16;
62+
TryCastNode try_cast = 17;
6263
}
6364
}
6465

@@ -172,6 +173,11 @@ message CastNode {
172173
ArrowType arrow_type = 2;
173174
}
174175

176+
message TryCastNode {
177+
LogicalExprNode expr = 1;
178+
ArrowType arrow_type = 2;
179+
}
180+
175181
message SortExprNode {
176182
LogicalExprNode expr = 1;
177183
bool asc = 2;

rust/ballista/rust/core/src/serde/logical_plan/from_proto.rs

+19-10
Original file line numberDiff line numberDiff line change
@@ -510,10 +510,10 @@ fn typechecked_scalar_value_conversion(
510510
ScalarValue::Date32(Some(*v))
511511
}
512512
(Value::TimeMicrosecondValue(v), PrimitiveScalarType::TimeMicrosecond) => {
513-
ScalarValue::TimeMicrosecond(Some(*v))
513+
ScalarValue::TimestampMicrosecond(Some(*v))
514514
}
515515
(Value::TimeNanosecondValue(v), PrimitiveScalarType::TimeMicrosecond) => {
516-
ScalarValue::TimeNanosecond(Some(*v))
516+
ScalarValue::TimestampNanosecond(Some(*v))
517517
}
518518
(Value::Utf8Value(v), PrimitiveScalarType::Utf8) => {
519519
ScalarValue::Utf8(Some(v.to_owned()))
@@ -546,10 +546,10 @@ fn typechecked_scalar_value_conversion(
546546
PrimitiveScalarType::LargeUtf8 => ScalarValue::LargeUtf8(None),
547547
PrimitiveScalarType::Date32 => ScalarValue::Date32(None),
548548
PrimitiveScalarType::TimeMicrosecond => {
549-
ScalarValue::TimeMicrosecond(None)
549+
ScalarValue::TimestampMicrosecond(None)
550550
}
551551
PrimitiveScalarType::TimeNanosecond => {
552-
ScalarValue::TimeNanosecond(None)
552+
ScalarValue::TimestampNanosecond(None)
553553
}
554554
PrimitiveScalarType::Null => {
555555
return Err(proto_error(
@@ -609,10 +609,10 @@ impl TryInto<datafusion::scalar::ScalarValue> for &protobuf::scalar_value::Value
609609
ScalarValue::Date32(Some(*v))
610610
}
611611
protobuf::scalar_value::Value::TimeMicrosecondValue(v) => {
612-
ScalarValue::TimeMicrosecond(Some(*v))
612+
ScalarValue::TimestampMicrosecond(Some(*v))
613613
}
614614
protobuf::scalar_value::Value::TimeNanosecondValue(v) => {
615-
ScalarValue::TimeNanosecond(Some(*v))
615+
ScalarValue::TimestampNanosecond(Some(*v))
616616
}
617617
protobuf::scalar_value::Value::ListValue(v) => v.try_into()?,
618618
protobuf::scalar_value::Value::NullListValue(v) => {
@@ -775,10 +775,10 @@ impl TryInto<datafusion::scalar::ScalarValue> for protobuf::PrimitiveScalarType
775775
protobuf::PrimitiveScalarType::LargeUtf8 => ScalarValue::LargeUtf8(None),
776776
protobuf::PrimitiveScalarType::Date32 => ScalarValue::Date32(None),
777777
protobuf::PrimitiveScalarType::TimeMicrosecond => {
778-
ScalarValue::TimeMicrosecond(None)
778+
ScalarValue::TimestampMicrosecond(None)
779779
}
780780
protobuf::PrimitiveScalarType::TimeNanosecond => {
781-
ScalarValue::TimeNanosecond(None)
781+
ScalarValue::TimestampNanosecond(None)
782782
}
783783
})
784784
}
@@ -828,10 +828,10 @@ impl TryInto<datafusion::scalar::ScalarValue> for &protobuf::ScalarValue {
828828
ScalarValue::Date32(Some(*v))
829829
}
830830
protobuf::scalar_value::Value::TimeMicrosecondValue(v) => {
831-
ScalarValue::TimeMicrosecond(Some(*v))
831+
ScalarValue::TimestampMicrosecond(Some(*v))
832832
}
833833
protobuf::scalar_value::Value::TimeNanosecondValue(v) => {
834-
ScalarValue::TimeNanosecond(Some(*v))
834+
ScalarValue::TimestampNanosecond(Some(*v))
835835
}
836836
protobuf::scalar_value::Value::ListValue(scalar_list) => {
837837
let protobuf::ScalarListValue {
@@ -961,6 +961,15 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
961961
let data_type = arrow_type.try_into()?;
962962
Ok(Expr::Cast { expr, data_type })
963963
}
964+
ExprType::TryCast(cast) => {
965+
let expr = Box::new(parse_required_expr(&cast.expr)?);
966+
let arrow_type: &protobuf::ArrowType = cast
967+
.arrow_type
968+
.as_ref()
969+
.ok_or_else(|| proto_error("Protobuf deserialization error: CastNode message missing required field 'arrow_type'"))?;
970+
let data_type = arrow_type.try_into()?;
971+
Ok(Expr::TryCast { expr, data_type })
972+
}
964973
ExprType::Sort(sort) => Ok(Expr::Sort {
965974
expr: Box::new(parse_required_expr(&sort.expr)?),
966975
asc: sort.asc,

rust/ballista/rust/core/src/serde/logical_plan/mod.rs

+9-9
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,8 @@ mod roundtrip_tests {
212212
ScalarValue::LargeUtf8(None),
213213
ScalarValue::List(None, DataType::Boolean),
214214
ScalarValue::Date32(None),
215-
ScalarValue::TimeMicrosecond(None),
216-
ScalarValue::TimeNanosecond(None),
215+
ScalarValue::TimestampMicrosecond(None),
216+
ScalarValue::TimestampNanosecond(None),
217217
ScalarValue::Boolean(Some(true)),
218218
ScalarValue::Boolean(Some(false)),
219219
ScalarValue::Float32(Some(1.0)),
@@ -252,11 +252,11 @@ mod roundtrip_tests {
252252
ScalarValue::LargeUtf8(Some(String::from("Test Large utf8"))),
253253
ScalarValue::Date32(Some(0)),
254254
ScalarValue::Date32(Some(i32::MAX)),
255-
ScalarValue::TimeNanosecond(Some(0)),
256-
ScalarValue::TimeNanosecond(Some(i64::MAX)),
257-
ScalarValue::TimeMicrosecond(Some(0)),
258-
ScalarValue::TimeMicrosecond(Some(i64::MAX)),
259-
ScalarValue::TimeMicrosecond(None),
255+
ScalarValue::TimestampNanosecond(Some(0)),
256+
ScalarValue::TimestampNanosecond(Some(i64::MAX)),
257+
ScalarValue::TimestampMicrosecond(Some(0)),
258+
ScalarValue::TimestampMicrosecond(Some(i64::MAX)),
259+
ScalarValue::TimestampMicrosecond(None),
260260
ScalarValue::List(
261261
Some(vec![
262262
ScalarValue::Float32(Some(-213.1)),
@@ -610,8 +610,8 @@ mod roundtrip_tests {
610610
ScalarValue::Utf8(None),
611611
ScalarValue::LargeUtf8(None),
612612
ScalarValue::Date32(None),
613-
ScalarValue::TimeMicrosecond(None),
614-
ScalarValue::TimeNanosecond(None),
613+
ScalarValue::TimestampMicrosecond(None),
614+
ScalarValue::TimestampNanosecond(None),
615615
//ScalarValue::List(None, DataType::Boolean)
616616
];
617617

rust/ballista/rust/core/src/serde/logical_plan/to_proto.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -641,12 +641,12 @@ impl TryFrom<&datafusion::scalar::ScalarValue> for protobuf::ScalarValue {
641641
datafusion::scalar::ScalarValue::Date32(val) => {
642642
create_proto_scalar(val, PrimitiveScalarType::Date32, |s| Value::Date32Value(*s))
643643
}
644-
datafusion::scalar::ScalarValue::TimeMicrosecond(val) => {
644+
datafusion::scalar::ScalarValue::TimestampMicrosecond(val) => {
645645
create_proto_scalar(val, PrimitiveScalarType::TimeMicrosecond, |s| {
646646
Value::TimeMicrosecondValue(*s)
647647
})
648648
}
649-
datafusion::scalar::ScalarValue::TimeNanosecond(val) => {
649+
datafusion::scalar::ScalarValue::TimestampNanosecond(val) => {
650650
create_proto_scalar(val, PrimitiveScalarType::TimeNanosecond, |s| {
651651
Value::TimeNanosecondValue(*s)
652652
})

rust/ballista/rust/core/src/serde/physical_plan/from_proto.rs

+1-4
Original file line numberDiff line numberDiff line change
@@ -106,15 +106,12 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
106106
.file_extension(&scan.file_extension)
107107
.delimiter(scan.delimiter.as_bytes()[0])
108108
.schema(&schema);
109-
// TODO we don't care what the DataFusion batch size was because Ballista will
110-
// have its own configs. Hard-code for now.
111-
let batch_size = 32768;
112109
let projection = scan.projection.iter().map(|i| *i as usize).collect();
113110
Ok(Arc::new(CsvExec::try_new(
114111
&scan.path,
115112
options,
116113
Some(projection),
117-
batch_size,
114+
scan.batch_size as usize,
118115
None,
119116
)?))
120117
}

rust/ballista/rust/core/src/serde/physical_plan/to_proto.rs

+11-2
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ use std::{
2828

2929
use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
3030
use datafusion::physical_plan::csv::CsvExec;
31-
use datafusion::physical_plan::expressions::CastExpr;
3231
use datafusion::physical_plan::expressions::{
3332
CaseExpr, InListExpr, IsNotNullExpr, IsNullExpr, NegativeExpr, NotExpr,
3433
};
34+
use datafusion::physical_plan::expressions::{CastExpr, TryCastExpr};
3535
use datafusion::physical_plan::filter::FilterExec;
3636
use datafusion::physical_plan::hash_aggregate::AggregateMode;
3737
use datafusion::physical_plan::hash_join::HashJoinExec;
@@ -236,7 +236,7 @@ impl TryInto<protobuf::PhysicalPlanNode> for Arc<dyn ExecutionPlan> {
236236
schema: Some(exec.file_schema().as_ref().into()),
237237
has_header: exec.has_header(),
238238
delimiter: delimiter.to_string(),
239-
batch_size: 32768,
239+
batch_size: exec.batch_size() as u32,
240240
},
241241
)),
242242
})
@@ -510,6 +510,15 @@ impl TryFrom<Arc<dyn PhysicalExpr>> for protobuf::LogicalExprNode {
510510
},
511511
))),
512512
})
513+
} else if let Some(cast) = expr.downcast_ref::<TryCastExpr>() {
514+
Ok(protobuf::LogicalExprNode {
515+
expr_type: Some(protobuf::logical_expr_node::ExprType::TryCast(
516+
Box::new(protobuf::TryCastNode {
517+
expr: Some(Box::new(cast.expr().clone().try_into()?)),
518+
arrow_type: Some(cast.cast_type().into()),
519+
}),
520+
)),
521+
})
513522
} else if let Some(expr) = expr.downcast_ref::<ScalarFunctionExpr>() {
514523
let fun: BuiltinScalarFunction =
515524
BuiltinScalarFunction::from_str(expr.name())?;

rust/ballista/rust/scheduler/src/api/mod.rs

+8-6
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ pub type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
3030
pub type HttpBody = dyn http_body::Body<Data = dyn Buf, Error = Error> + 'static;
3131

3232
impl<A, B> http_body::Body for EitherBody<A, B>
33-
where
34-
A: http_body::Body + Send + Unpin,
35-
B: http_body::Body<Data = A::Data> + Send + Unpin,
36-
A::Error: Into<Error>,
37-
B::Error: Into<Error>,
33+
where
34+
A: http_body::Body + Send + Unpin,
35+
B: http_body::Body<Data = A::Data> + Send + Unpin,
36+
A::Error: Into<Error>,
37+
B::Error: Into<Error>,
3838
{
3939
type Data = A::Data;
4040
type Error = Error;
@@ -67,7 +67,9 @@ impl<A, B> http_body::Body for EitherBody<A, B>
6767
}
6868
}
6969

70-
fn map_option_err<T, U: Into<Error>>(err: Option<Result<T, U>>) -> Option<Result<T, Error>> {
70+
fn map_option_err<T, U: Into<Error>>(
71+
err: Option<Result<T, U>>,
72+
) -> Option<Result<T, Error>> {
7173
err.map(|e| e.map_err(Into::into))
7274
}
7375

rust/ballista/rust/scheduler/src/main.rs

+5-3
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@ use ballista_core::BALLISTA_VERSION;
2929
use ballista_core::{
3030
print_version, serde::protobuf::scheduler_grpc_server::SchedulerGrpcServer,
3131
};
32+
use ballista_scheduler::api::{get_routes, EitherBody, Error};
3233
#[cfg(feature = "etcd")]
3334
use ballista_scheduler::state::EtcdClient;
3435
#[cfg(feature = "sled")]
3536
use ballista_scheduler::state::StandaloneClient;
3637
use ballista_scheduler::{state::ConfigBackendClient, ConfigBackend, SchedulerServer};
37-
use ballista_scheduler::api::{get_routes, EitherBody, Error};
3838

3939
use log::info;
4040

@@ -63,8 +63,10 @@ async fn start_server(
6363
);
6464
Ok(Server::bind(&addr)
6565
.serve(make_service_fn(move |_| {
66-
let scheduler_server = SchedulerServer::new(config_backend.clone(), namespace.clone());
67-
let scheduler_grpc_server = SchedulerGrpcServer::new(scheduler_server.clone());
66+
let scheduler_server =
67+
SchedulerServer::new(config_backend.clone(), namespace.clone());
68+
let scheduler_grpc_server =
69+
SchedulerGrpcServer::new(scheduler_server.clone());
6870

6971
let mut tonic = TonicServer::builder()
7072
.add_service(scheduler_grpc_server)

rust/ballista/rust/scheduler/src/planner.rs

+5-4
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@ use ballista_core::{
3434
execution_plans::{QueryStageExec, ShuffleReaderExec, UnresolvedShuffleExec},
3535
serde::scheduler::PartitionLocation,
3636
};
37-
use datafusion::execution::context::ExecutionContext;
37+
use datafusion::execution::context::{ExecutionConfig, ExecutionContext};
3838
use datafusion::physical_optimizer::coalesce_batches::CoalesceBatches;
39+
use datafusion::physical_optimizer::merge_exec::AddMergeExec;
3940
use datafusion::physical_optimizer::optimizer::PhysicalOptimizerRule;
4041
use datafusion::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec};
4142
use datafusion::physical_plan::hash_join::HashJoinExec;
4243
use datafusion::physical_plan::merge::MergeExec;
43-
use datafusion::physical_optimizer::merge_exec::AddMergeExec;
4444
use datafusion::physical_plan::ExecutionPlan;
4545
use log::{debug, info};
4646
use tokio::task::JoinHandle;
@@ -139,12 +139,13 @@ impl DistributedPlanner {
139139
}
140140

141141
if let Some(adapter) = execution_plan.as_any().downcast_ref::<DFTableAdapter>() {
142+
// remove Repartition rule because that isn't supported yet
142143
let rules: Vec<Arc<dyn PhysicalOptimizerRule + Send + Sync>> = vec![
143144
Arc::new(CoalesceBatches::new()),
144145
Arc::new(AddMergeExec::new()),
145146
];
146-
let mut ctx = ExecutionContext::new();
147-
let ctx = ctx.with_physical_optimizers(rules);
147+
let config = ExecutionConfig::new().with_physical_optimizer_rules(rules);
148+
let ctx = ExecutionContext::with_config(config);
148149
Ok((ctx.create_physical_plan(&adapter.logical_plan)?, stages))
149150
} else if let Some(merge) = execution_plan.as_any().downcast_ref::<MergeExec>() {
150151
let query_stage = create_query_stage(

rust/ballista/rust/scheduler/src/test_utils.rs

+14-2
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,30 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use std::sync::Arc;
19+
1820
use ballista_core::error::Result;
1921

2022
use arrow::datatypes::{DataType, Field, Schema};
21-
use datafusion::execution::context::ExecutionContext;
23+
use datafusion::execution::context::{ExecutionConfig, ExecutionContext};
24+
use datafusion::physical_optimizer::coalesce_batches::CoalesceBatches;
25+
use datafusion::physical_optimizer::merge_exec::AddMergeExec;
26+
use datafusion::physical_optimizer::optimizer::PhysicalOptimizerRule;
2227
use datafusion::physical_plan::csv::CsvReadOptions;
2328

2429
pub const TPCH_TABLES: &[&str] = &[
2530
"part", "supplier", "partsupp", "customer", "orders", "lineitem", "nation", "region",
2631
];
2732

2833
pub fn datafusion_test_context(path: &str) -> Result<ExecutionContext> {
29-
let mut ctx = ExecutionContext::new();
34+
// remove Repartition rule because that isn't supported yet
35+
let rules: Vec<Arc<dyn PhysicalOptimizerRule + Send + Sync>> = vec![
36+
Arc::new(CoalesceBatches::new()),
37+
Arc::new(AddMergeExec::new()),
38+
];
39+
let config = ExecutionConfig::new().with_physical_optimizer_rules(rules);
40+
let mut ctx = ExecutionContext::with_config(config);
41+
3042
for table in TPCH_TABLES {
3143
let schema = get_tpch_schema(table);
3244
let options = CsvReadOptions::new()

0 commit comments

Comments
 (0)