diff --git a/python/deltalake/_internal.pyi b/python/deltalake/_internal.pyi index a3a60b86f1..668250539b 100644 --- a/python/deltalake/_internal.pyi +++ b/python/deltalake/_internal.pyi @@ -126,7 +126,7 @@ class Schema: def to_json(self) -> str: ... @staticmethod def from_json(json: str) -> "Schema": ... - def to_pyarrow(self) -> pa.Schema: ... + def to_pyarrow(self, as_large_types: bool = False) -> pa.Schema: ... @staticmethod def from_pyarrow(type: pa.Schema) -> "Schema": ... diff --git a/python/src/lib.rs b/python/src/lib.rs index b030ae9074..67603b3e24 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -379,9 +379,7 @@ impl RawDeltaTable { let column_names: HashSet<&str> = self ._table .schema() - .ok_or(PyDeltaTableError::new_err( - "table does not yet have a schema", - ))? + .ok_or_else(|| PyDeltaTableError::new_err("table does not yet have a schema"))? .get_fields() .iter() .map(|field| field.get_name()) diff --git a/python/src/schema.rs b/python/src/schema.rs index db5df88746..18787752f7 100644 --- a/python/src/schema.rs +++ b/python/src/schema.rs @@ -950,7 +950,7 @@ pub fn schema_to_pyobject(schema: &Schema, py: Python) -> PyResult { /// >>> import pyarrow as pa /// >>> Schema.from_pyarrow(pa.schema({"x": pa.int32(), "y": pa.string()})) /// Schema([Field(x, PrimitiveType("integer"), nullable=True), Field(y, PrimitiveType("string"), nullable=True)]) -#[pyclass(extends=StructType, name="Schema", module="deltalake.schema", +#[pyclass(extends = StructType, name = "Schema", module = "deltalake.schema", text_signature = "(fields)")] pub struct PySchema; @@ -1007,15 +1007,78 @@ impl PySchema { /// Return equivalent PyArrow schema /// + /// :param as_large_types: get schema with all variable size types (list, + /// binary, string) as large variants (with int64 indices). This is for + /// compatibility with systems like Polars that only support the large + /// versions of Arrow types. + /// /// :rtype: pyarrow.Schema - #[pyo3(text_signature = "($self)")] - fn to_pyarrow(self_: PyRef<'_, Self>) -> PyResult> { + #[pyo3(signature = (as_large_types = false))] + fn to_pyarrow( + self_: PyRef<'_, Self>, + as_large_types: bool, + ) -> PyResult> { let super_ = self_.as_ref(); - Ok(PyArrowType( - (&super_.inner_type.clone()) - .try_into() - .map_err(|err: ArrowError| PyException::new_err(err.to_string()))?, - )) + let res: ArrowSchema = (&super_.inner_type.clone()) + .try_into() + .map_err(|err: ArrowError| PyException::new_err(err.to_string()))?; + + fn convert_to_large_type(field: ArrowField, dt: ArrowDataType) -> ArrowField { + match dt { + ArrowDataType::Utf8 => field.with_data_type(ArrowDataType::LargeUtf8), + + ArrowDataType::Binary => field.with_data_type(ArrowDataType::LargeBinary), + + ArrowDataType::List(f) => { + let sub_field = convert_to_large_type(*f.clone(), f.data_type().clone()); + field.with_data_type(ArrowDataType::LargeList(Box::from(sub_field))) + } + + ArrowDataType::FixedSizeList(f, size) => { + let sub_field = convert_to_large_type(*f.clone(), f.data_type().clone()); + field.with_data_type(ArrowDataType::FixedSizeList(Box::from(sub_field), size)) + } + + ArrowDataType::Map(f, sorted) => { + let sub_field = convert_to_large_type(*f.clone(), f.data_type().clone()); + field.with_data_type(ArrowDataType::Map(Box::from(sub_field), sorted)) + } + + ArrowDataType::Struct(fields) => { + let sub_fields = fields + .iter() + .map(|f| { + let dt: ArrowDataType = f.data_type().clone(); + let f: ArrowField = f.clone(); + + convert_to_large_type(f, dt) + }) + .collect(); + + field.with_data_type(ArrowDataType::Struct(sub_fields)) + } + + _ => field, + } + } + + if as_large_types { + let schema = ArrowSchema::new( + res.fields + .iter() + .map(|f| { + let dt: ArrowDataType = f.data_type().clone(); + let f: ArrowField = f.clone(); + + convert_to_large_type(f, dt) + }) + .collect(), + ); + + Ok(PyArrowType(schema)) + } else { + Ok(PyArrowType(res)) + } } /// Create from a PyArrow schema diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index 213349dbef..8ba35910cc 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -816,3 +816,43 @@ def test_max_partitions_exceeding_fragment_should_fail( max_partitions=1, partition_by=["p1", "p2"], ) + + +def test_large_arrow_types(tmp_path: pathlib.Path): + pylist = [ + {"name": "Joey", "gender": b"M", "arr_type": ["x", "y"], "dict": {"a": b"M"}}, + {"name": "Ivan", "gender": b"F", "arr_type": ["x", "z"]}, + ] + schema = pa.schema( + [ + pa.field("name", pa.large_string()), + pa.field("gender", pa.large_binary()), + pa.field("arr_type", pa.large_list(pa.large_string())), + pa.field("map_type", pa.map_(pa.large_string(), pa.large_binary())), + pa.field("struct", pa.struct([pa.field("sub", pa.large_string())])), + ] + ) + table = pa.Table.from_pylist(pylist, schema=schema) + + write_deltalake(tmp_path, table) + + dt = DeltaTable(tmp_path) + assert table.schema == dt.schema().to_pyarrow(as_large_types=True) + + +def test_uint_arrow_types(tmp_path: pathlib.Path): + pylist = [ + {"num1": 3, "num2": 3, "num3": 3, "num4": 5}, + {"num1": 1, "num2": 13, "num3": 35, "num4": 13}, + ] + schema = pa.schema( + [ + pa.field("num1", pa.uint8()), + pa.field("num2", pa.uint16()), + pa.field("num3", pa.uint32()), + pa.field("num4", pa.uint64()), + ] + ) + table = pa.Table.from_pylist(pylist, schema=schema) + + write_deltalake(tmp_path, table) diff --git a/rust/src/delta_arrow.rs b/rust/src/delta_arrow.rs index f3911ca7c3..d7dbd16f4b 100644 --- a/rust/src/delta_arrow.rs +++ b/rust/src/delta_arrow.rs @@ -208,14 +208,25 @@ impl TryFrom<&ArrowDataType> for schema::SchemaDataType { fn try_from(arrow_datatype: &ArrowDataType) -> Result { match arrow_datatype { ArrowDataType::Utf8 => Ok(schema::SchemaDataType::primitive("string".to_string())), + ArrowDataType::LargeUtf8 => Ok(schema::SchemaDataType::primitive("string".to_string())), ArrowDataType::Int64 => Ok(schema::SchemaDataType::primitive("long".to_string())), // undocumented type ArrowDataType::Int32 => Ok(schema::SchemaDataType::primitive("integer".to_string())), ArrowDataType::Int16 => Ok(schema::SchemaDataType::primitive("short".to_string())), ArrowDataType::Int8 => Ok(schema::SchemaDataType::primitive("byte".to_string())), + ArrowDataType::UInt64 => Ok(schema::SchemaDataType::primitive("long".to_string())), // undocumented type + ArrowDataType::UInt32 => Ok(schema::SchemaDataType::primitive("integer".to_string())), + ArrowDataType::UInt16 => Ok(schema::SchemaDataType::primitive("short".to_string())), + ArrowDataType::UInt8 => Ok(schema::SchemaDataType::primitive("byte".to_string())), ArrowDataType::Float32 => Ok(schema::SchemaDataType::primitive("float".to_string())), ArrowDataType::Float64 => Ok(schema::SchemaDataType::primitive("double".to_string())), ArrowDataType::Boolean => Ok(schema::SchemaDataType::primitive("boolean".to_string())), ArrowDataType::Binary => Ok(schema::SchemaDataType::primitive("binary".to_string())), + ArrowDataType::FixedSizeBinary(_) => { + Ok(schema::SchemaDataType::primitive("binary".to_string())) + } + ArrowDataType::LargeBinary => { + Ok(schema::SchemaDataType::primitive("binary".to_string())) + } ArrowDataType::Decimal128(p, s) => Ok(schema::SchemaDataType::primitive(format!( "decimal({p},{s})" ))), @@ -223,6 +234,7 @@ impl TryFrom<&ArrowDataType> for schema::SchemaDataType { "decimal({p},{s})" ))), ArrowDataType::Date32 => Ok(schema::SchemaDataType::primitive("date".to_string())), + ArrowDataType::Date64 => Ok(schema::SchemaDataType::primitive("date".to_string())), ArrowDataType::Timestamp(TimeUnit::Microsecond, None) => { Ok(schema::SchemaDataType::primitive("timestamp".to_string())) } @@ -244,6 +256,12 @@ impl TryFrom<&ArrowDataType> for schema::SchemaDataType { (*field).is_nullable(), ))) } + ArrowDataType::LargeList(field) => { + Ok(schema::SchemaDataType::array(schema::SchemaTypeArray::new( + Box::new((*field).data_type().try_into()?), + (*field).is_nullable(), + ))) + } ArrowDataType::FixedSizeList(field, _) => { Ok(schema::SchemaDataType::array(schema::SchemaTypeArray::new( Box::new((*field).data_type().try_into()?), diff --git a/rust/tests/datafusion_test.rs b/rust/tests/datafusion_test.rs index 0919c304d0..466ccfc733 100644 --- a/rust/tests/datafusion_test.rs +++ b/rust/tests/datafusion_test.rs @@ -343,7 +343,7 @@ async fn get_scan_metrics( visit_execution_plan(&plan, &mut metrics).unwrap(); } - return Ok(metrics); + Ok(metrics) } fn create_all_types_batch(not_null_rows: usize, null_rows: usize, offset: usize) -> RecordBatch { @@ -488,30 +488,27 @@ async fn test_files_scanned() -> Result<()> { let table = append_to_table(table, batch).await; let metrics = get_scan_metrics(&table, &state, &[]).await?; - assert!(metrics.num_scanned_files() == 3); + assert_eq!(metrics.num_scanned_files(), 3); - // (Column name, value from file 1, value from file 2, value from file 3, non existant value) + // (Column name, value from file 1, value from file 2, value from file 3, non existent value) let tests = [ TestCase::new("utf8", |value| lit(value.to_string())), - TestCase::new("int64", |value| lit(value)), + TestCase::new("int64", lit), TestCase::new("int32", |value| lit(value as i32)), TestCase::new("int16", |value| lit(value as i16)), TestCase::new("int8", |value| lit(value as i8)), TestCase::new("float64", |value| lit(value as f64)), TestCase::new("float32", |value| lit(value as f32)), TestCase::new("timestamp", |value| { - lit(ScalarValue::TimestampMicrosecond( - Some(value * 1_000_000), - None, - )) + lit(TimestampMicrosecond(Some(value * 1_000_000), None)) }), // TODO: I think decimal statistics are being written to the log incorrectly. The underlying i128 is written - // not the proper string representation as specified by the percision and scale + // not the proper string representation as specified by the precision and scale TestCase::new("decimal", |value| { lit(Decimal128(Some((value * 100).into()), 10, 2)) }), - // TODO: The writer does not write complete statistiics for date columns - TestCase::new("date", |value| lit(ScalarValue::Date32(Some(value as i32)))), + // TODO: The writer does not write complete statistics for date columns + TestCase::new("date", |value| lit(Date32(Some(value as i32)))), // TODO: The writer does not write complete statistics for binary columns TestCase::new("binary", |value| lit(value.to_string().as_bytes())), ]; @@ -544,7 +541,7 @@ async fn test_files_scanned() -> Result<()> { let metrics = get_scan_metrics(&table, &state, &[e]).await?; assert_eq!(metrics.num_scanned_files(), 0); - // Conjuction + // Conjunction let e = col(column) .gt(file1_value.clone()) .and(col(column).lt(file2_value.clone())); @@ -617,7 +614,7 @@ async fn test_files_scanned() -> Result<()> { let metrics = get_scan_metrics(&table, &state, &[e]).await?; assert_eq!(metrics.num_scanned_files(), 0); - // Conjuction + // Conjunction let e = col(column) .gt(file1_value.clone()) .and(col(column).lt(file2_value)); @@ -679,12 +676,12 @@ async fn test_files_scanned() -> Result<()> { // Check pruning for null partitions let e = col("k").is_null(); let metrics = get_scan_metrics(&table, &state, &[e]).await?; - assert!(metrics.num_scanned_files() == 1); + assert_eq!(metrics.num_scanned_files(), 1); // Check pruning for null partitions. Since there are no record count statistics pruning cannot be done let e = col("k").is_not_null(); let metrics = get_scan_metrics(&table, &state, &[e]).await?; - assert!(metrics.num_scanned_files() == 2); + assert_eq!(metrics.num_scanned_files(), 2); Ok(()) }