diff --git a/docs/source/user-guide/common-operations/udf-and-udfa.rst b/docs/source/user-guide/common-operations/udf-and-udfa.rst index d554e1e25..f669721a3 100644 --- a/docs/source/user-guide/common-operations/udf-and-udfa.rst +++ b/docs/source/user-guide/common-operations/udf-and-udfa.rst @@ -123,7 +123,7 @@ also see how the inputs to ``update`` and ``merge`` differ. .. code-block:: python - import pyarrow + import pyarrow as pa import pyarrow.compute import datafusion from datafusion import col, udaf, Accumulator @@ -136,16 +136,16 @@ also see how the inputs to ``update`` and ``merge`` differ. def __init__(self): self._sum = 0.0 - def update(self, values_a: pyarrow.Array, values_b: pyarrow.Array) -> None: + def update(self, values_a: pa.Array, values_b: pa.Array) -> None: self._sum = self._sum + pyarrow.compute.sum(values_a).as_py() - pyarrow.compute.sum(values_b).as_py() - def merge(self, states: List[pyarrow.Array]) -> None: + def merge(self, states: list[pa.Array]) -> None: self._sum = self._sum + pyarrow.compute.sum(states[0]).as_py() - def state(self) -> pyarrow.Array: - return pyarrow.array([self._sum]) + def state(self) -> list[pa.Scalar]: + return [pyarrow.scalar(self._sum)] - def evaluate(self) -> pyarrow.Scalar: + def evaluate(self) -> pa.Scalar: return pyarrow.scalar(self._sum) ctx = datafusion.SessionContext() @@ -156,10 +156,29 @@ also see how the inputs to ``update`` and ``merge`` differ. } ) - my_udaf = udaf(MyAccumulator, [pyarrow.float64(), pyarrow.float64()], pyarrow.float64(), [pyarrow.float64()], 'stable') + my_udaf = udaf(MyAccumulator, [pa.float64(), pa.float64()], pa.float64(), [pa.float64()], 'stable') df.aggregate([], [my_udaf(col("a"), col("b")).alias("col_diff")]) +FAQ +^^^ + +**How do I return a list from a UDAF?** + +Both the ``evaluate`` and the ``state`` functions expect to return scalar values. +If you wish to return a list array as a scalar value, the best practice is to +wrap the values in a ``pyarrow.Scalar`` object. For example, you can return a +timestamp list with ``pa.scalar([...], type=pa.list_(pa.timestamp("ms")))`` and +register the appropriate return or state types as +``return_type=pa.list_(pa.timestamp("ms"))`` and +``state_type=[pa.list_(pa.timestamp("ms"))]``, respectively. + +As of DataFusion 52.0.0 , you can pass return any Python object, including a +PyArrow array, as the return value(s) for these functions and DataFusion will +attempt to create a scalar type from the value. DataFusion has been tested to +convert PyArrow, nanoarrow, and arro3 objects as well as primitive data types +like integers, strings, and so on. + Window Functions ---------------- diff --git a/pyproject.toml b/pyproject.toml index d315dbe19..5a5128a2f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -173,27 +173,29 @@ ignore-words-list = ["ans", "IST"] [dependency-groups] dev = [ + "arro3-core==0.6.5", + "codespell==2.4.1", "maturin>=1.8.1", + "nanoarrow==0.8.0", "numpy>1.25.0;python_version<'3.14'", "numpy>=2.3.2;python_version>='3.14'", - "pyarrow>=19.0.0", "pre-commit>=4.3.0", - "pyyaml>=6.0.3", + "pyarrow>=19.0.0", + "pygithub==2.5.0", "pytest>=7.4.4", "pytest-asyncio>=0.23.3", + "pyyaml>=6.0.3", "ruff>=0.9.1", "toml>=0.10.2", - "pygithub==2.5.0", - "codespell==2.4.1", ] docs = [ - "sphinx>=7.1.2", - "pydata-sphinx-theme==0.8.0", - "myst-parser>=3.0.1", - "jinja2>=3.1.5", "ipython>=8.12.3", + "jinja2>=3.1.5", + "myst-parser>=3.0.1", "pandas>=2.0.3", "pickleshare>=0.7.5", - "sphinx-autoapi>=3.4.0", + "pydata-sphinx-theme==0.8.0", "setuptools>=75.3.0", + "sphinx>=7.1.2", + "sphinx-autoapi>=3.4.0", ] diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 695fe7c49..9df58f52a 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -562,8 +562,6 @@ def literal(value: Any) -> Expr: """ if isinstance(value, str): value = pa.scalar(value, type=pa.string_view()) - if not isinstance(value, pa.Scalar): - value = pa.scalar(value) return Expr(expr_internal.RawExpr.literal(value)) @staticmethod @@ -576,7 +574,6 @@ def literal_with_metadata(value: Any, metadata: dict[str, str]) -> Expr: """ if isinstance(value, str): value = pa.scalar(value, type=pa.string_view()) - value = value if isinstance(value, pa.Scalar) else pa.scalar(value) return Expr(expr_internal.RawExpr.literal_with_metadata(value, metadata)) diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index 5dd626568..d4e5302b5 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -298,7 +298,16 @@ class Accumulator(metaclass=ABCMeta): @abstractmethod def state(self) -> list[pa.Scalar]: - """Return the current state.""" + """Return the current state. + + While this function template expects PyArrow Scalar values return type, + you can return any value that can be converted into a Scalar. This + includes basic Python data types such as integers and strings. In + addition to primitive types, we currently support PyArrow, nanoarrow, + and arro3 objects in addition to primitive data types. Other objects + that support the Arrow FFI standard will be given a "best attempt" at + conversion to scalar objects. + """ @abstractmethod def update(self, *values: pa.Array) -> None: @@ -310,7 +319,16 @@ def merge(self, states: list[pa.Array]) -> None: @abstractmethod def evaluate(self) -> pa.Scalar: - """Return the resultant value.""" + """Return the resultant value. + + While this function template expects a PyArrow Scalar value return type, + you can return any value that can be converted into a Scalar. This + includes basic Python data types such as integers and strings. In + addition to primitive types, we currently support PyArrow, nanoarrow, + and arro3 objects in addition to primitive data types. Other objects + that support the Arrow FFI standard will be given a "best attempt" at + conversion to scalar objects. + """ class AggregateUDFExportable(Protocol): diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index 39e48f7c3..92251827b 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -20,6 +20,8 @@ from datetime import date, datetime, time, timezone from decimal import Decimal +import arro3.core +import nanoarrow import pyarrow as pa import pytest from datafusion import ( @@ -980,6 +982,49 @@ def test_literal_metadata(ctx): assert expected_field.metadata == actual_field.metadata +def test_scalar_conversion() -> None: + class WrappedPyArrow: + """Wrapper class for testing __arrow_c_array__.""" + + def __init__(self, val: pa.Array) -> None: + self.val = val + + def __arrow_c_array__(self, requested_schema=None): + return self.val.__arrow_c_array__(requested_schema=requested_schema) + + expected_value = lit(1) + assert str(expected_value) == "Expr(Int64(1))" + + # Test pyarrow imports + assert expected_value == lit(pa.scalar(1)) + assert expected_value == lit(pa.scalar(1, type=pa.int32())) + + # Test nanoarrow + na_scalar = nanoarrow.Array([1], nanoarrow.int32())[0] + assert expected_value == lit(na_scalar) + + # Test pyo3 + arro3_scalar = arro3.core.Scalar(1, type=arro3.core.DataType.int32()) + assert expected_value == lit(arro3_scalar) + + generic_scalar = WrappedPyArrow(pa.array([1])) + assert expected_value == lit(generic_scalar) + + expected_value = lit([1, 2, 3]) + assert str(expected_value) == "Expr(List([1, 2, 3]))" + + assert expected_value == lit(pa.scalar([1, 2, 3])) + + na_array = nanoarrow.Array([1, 2, 3], nanoarrow.int32()) + assert expected_value == lit(na_array) + + arro3_array = arro3.core.Array([1, 2, 3], type=arro3.core.DataType.int32()) + assert expected_value == lit(arro3_array) + + generic_array = WrappedPyArrow(pa.array([1, 2, 3])) + assert expected_value == lit(generic_array) + + def test_ensure_expr(): e = col("a") assert ensure_expr(e) is e.expr diff --git a/python/tests/test_udaf.py b/python/tests/test_udaf.py index 453ff6f4f..8cd480e37 100644 --- a/python/tests/test_udaf.py +++ b/python/tests/test_udaf.py @@ -17,6 +17,8 @@ from __future__ import annotations +from datetime import datetime, timezone + import pyarrow as pa import pyarrow.compute as pc import pytest @@ -26,23 +28,28 @@ class Summarize(Accumulator): """Interface of a user-defined accumulation.""" - def __init__(self, initial_value: float = 0.0): - self._sum = pa.scalar(initial_value) + def __init__(self, initial_value: float = 0.0, as_scalar: bool = False): + self._sum = initial_value + self.as_scalar = as_scalar def state(self) -> list[pa.Scalar]: + if self.as_scalar: + return [pa.scalar(self._sum)] return [self._sum] def update(self, values: pa.Array) -> None: # Not nice since pyarrow scalars can't be summed yet. # This breaks on `None` - self._sum = pa.scalar(self._sum.as_py() + pc.sum(values).as_py()) + self._sum = self._sum + pc.sum(values).as_py() def merge(self, states: list[pa.Array]) -> None: # Not nice since pyarrow scalars can't be summed yet. # This breaks on `None` - self._sum = pa.scalar(self._sum.as_py() + pc.sum(states[0]).as_py()) + self._sum = self._sum + pc.sum(states[0]).as_py() def evaluate(self) -> pa.Scalar: + if self.as_scalar: + return pa.scalar(self._sum) return self._sum @@ -58,6 +65,30 @@ def state(self) -> list[pa.Scalar]: return [self._sum] +class CollectTimestamps(Accumulator): + def __init__(self, wrap_in_scalar: bool): + self._values: list[datetime] = [] + self.wrap_in_scalar = wrap_in_scalar + + def state(self) -> list[pa.Scalar]: + if self.wrap_in_scalar: + return [pa.scalar(self._values, type=pa.list_(pa.timestamp("ns")))] + return [pa.array(self._values, type=pa.timestamp("ns"))] + + def update(self, values: pa.Array) -> None: + self._values.extend(values.to_pylist()) + + def merge(self, states: list[pa.Array]) -> None: + for state in states[0].to_pylist(): + if state is not None: + self._values.extend(state) + + def evaluate(self) -> pa.Scalar: + if self.wrap_in_scalar: + return pa.scalar(self._values, type=pa.list_(pa.timestamp("ns"))) + return pa.array(self._values, type=pa.timestamp("ns")) + + @pytest.fixture def df(ctx): # create a RecordBatch and a new DataFrame from it @@ -137,11 +168,12 @@ def summarize(): assert result.column(0) == pa.array([1.0 + 2.0 + 3.0]) -def test_udaf_aggregate_with_arguments(df): +@pytest.mark.parametrize("as_scalar", [True, False]) +def test_udaf_aggregate_with_arguments(df, as_scalar): bias = 10.0 summarize = udaf( - lambda: Summarize(bias), + lambda: Summarize(initial_value=bias, as_scalar=as_scalar), pa.float64(), pa.float64(), [pa.float64()], @@ -217,3 +249,48 @@ def test_register_udaf(ctx, df) -> None: df_result = ctx.sql("select summarize(b) from test_table") assert df_result.collect()[0][0][0].as_py() == 14.0 + + +@pytest.mark.parametrize("wrap_in_scalar", [True, False]) +def test_udaf_list_timestamp_return(ctx, wrap_in_scalar) -> None: + timestamps1 = [ + datetime(2024, 1, 1, tzinfo=timezone.utc), + datetime(2024, 1, 2, tzinfo=timezone.utc), + ] + timestamps2 = [ + datetime(2024, 1, 3, tzinfo=timezone.utc), + datetime(2024, 1, 4, tzinfo=timezone.utc), + ] + batch1 = pa.RecordBatch.from_arrays( + [pa.array(timestamps1, type=pa.timestamp("ns"))], + names=["ts"], + ) + batch2 = pa.RecordBatch.from_arrays( + [pa.array(timestamps2, type=pa.timestamp("ns"))], + names=["ts"], + ) + df = ctx.create_dataframe([[batch1], [batch2]], name="timestamp_table") + + list_type = pa.list_( + pa.field("item", type=pa.timestamp("ns"), nullable=wrap_in_scalar) + ) + + collect = udaf( + lambda: CollectTimestamps(wrap_in_scalar), + pa.timestamp("ns"), + list_type, + [list_type], + volatility="immutable", + ) + + result = df.aggregate([], [collect(column("ts"))]).collect()[0] + + # There is no guarantee about the ordering of the batches, so perform a sort + # to get consistent results. Alternatively we could sort on evaluate(). + assert ( + result.column(0).values.sort() + == pa.array( + [[*timestamps1, *timestamps2]], + type=list_type, + ).values + ) diff --git a/src/common/data_type.rs b/src/common/data_type.rs index 55848da5c..1ff332ebb 100644 --- a/src/common/data_type.rs +++ b/src/common/data_type.rs @@ -22,6 +22,9 @@ use datafusion::logical_expr::expr::NullTreatment as DFNullTreatment; use pyo3::exceptions::{PyNotImplementedError, PyValueError}; use pyo3::prelude::*; +/// A [`ScalarValue`] wrapped in a Python object. This struct allows for conversion +/// from a variety of Python objects into a [`ScalarValue`]. See +/// ``FromPyArrow::from_pyarrow_bound`` conversion details. #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd)] pub struct PyScalarValue(pub ScalarValue); diff --git a/src/config.rs b/src/config.rs index 583dea7ef..38936e6c5 100644 --- a/src/config.rs +++ b/src/config.rs @@ -22,8 +22,8 @@ use parking_lot::RwLock; use pyo3::prelude::*; use pyo3::types::*; +use crate::common::data_type::PyScalarValue; use crate::errors::PyDataFusionResult; -use crate::utils::py_obj_to_scalar_value; #[pyclass(name = "Config", module = "datafusion", subclass, frozen)] #[derive(Clone)] pub(crate) struct PyConfig { @@ -65,9 +65,9 @@ impl PyConfig { /// Set a configuration option pub fn set(&self, key: &str, value: Py, py: Python) -> PyDataFusionResult<()> { - let scalar_value = py_obj_to_scalar_value(py, value)?; + let scalar_value: PyScalarValue = value.extract(py)?; let mut options = self.config.write(); - options.set(key, scalar_value.to_string().as_str())?; + options.set(key, scalar_value.0.to_string().as_str())?; Ok(()) } diff --git a/src/dataframe.rs b/src/dataframe.rs index fe039593d..53fab58c6 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -48,6 +48,7 @@ use pyo3::prelude::*; use pyo3::pybacked::PyBackedStr; use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods}; +use crate::common::data_type::PyScalarValue; use crate::errors::{PyDataFusionError, PyDataFusionResult, py_datafusion_err}; use crate::expr::PyExpr; use crate::expr::sort_expr::{PySortExpr, to_sort_expressions}; @@ -55,9 +56,7 @@ use crate::physical_plan::PyExecutionPlan; use crate::record_batch::{PyRecordBatchStream, poll_next_batch}; use crate::sql::logical::PyLogicalPlan; use crate::table::{PyTable, TempViewTable}; -use crate::utils::{ - is_ipython_env, py_obj_to_scalar_value, spawn_future, validate_pycapsule, wait_for_future, -}; +use crate::utils::{is_ipython_env, spawn_future, validate_pycapsule, wait_for_future}; /// File-level static CStr for the Arrow array stream capsule name. static ARROW_ARRAY_STREAM_NAME: &CStr = cstr!("arrow_array_stream"); @@ -1191,14 +1190,14 @@ impl PyDataFrame { columns: Option>, py: Python, ) -> PyDataFusionResult { - let scalar_value = py_obj_to_scalar_value(py, value)?; + let scalar_value: PyScalarValue = value.extract(py)?; let cols = match columns { Some(col_names) => col_names.iter().map(|c| c.to_string()).collect(), None => Vec::new(), // Empty vector means fill null for all columns }; - let df = self.df.as_ref().clone().fill_null(scalar_value, cols)?; + let df = self.df.as_ref().clone().fill_null(scalar_value.0, cols)?; Ok(Self::new(df)) } } diff --git a/src/pyarrow_util.rs b/src/pyarrow_util.rs index 264cfd342..2a119274f 100644 --- a/src/pyarrow_util.rs +++ b/src/pyarrow_util.rs @@ -17,8 +17,13 @@ //! Conversions between PyArrow and DataFusion types -use arrow::array::{Array, ArrayData}; +use std::sync::Arc; + +use arrow::array::{Array, ArrayData, ArrayRef, ListArray, make_array}; +use arrow::buffer::OffsetBuffer; +use arrow::datatypes::Field; use arrow::pyarrow::{FromPyArrow, ToPyArrow}; +use datafusion::common::exec_err; use datafusion::scalar::ScalarValue; use pyo3::types::{PyAnyMethods, PyList}; use pyo3::{Bound, FromPyObject, PyAny, PyResult, Python}; @@ -26,21 +31,114 @@ use pyo3::{Bound, FromPyObject, PyAny, PyResult, Python}; use crate::common::data_type::PyScalarValue; use crate::errors::PyDataFusionError; +/// Helper function to turn an Array into a ScalarValue. If ``as_list_array`` is true, +/// the array will be turned into a ``ListArray``. Otherwise, we extract the first value +/// from the array. +fn array_to_scalar_value(array: ArrayRef, as_list_array: bool) -> PyResult { + if as_list_array { + let field = Arc::new(Field::new_list_field( + array.data_type().clone(), + array.nulls().is_some(), + )); + let offsets = OffsetBuffer::from_lengths(vec![array.len()]); + let list_array = ListArray::new(field, offsets, array, None); + Ok(PyScalarValue(ScalarValue::List(Arc::new(list_array)))) + } else { + let scalar = ScalarValue::try_from_array(&array, 0).map_err(PyDataFusionError::from)?; + Ok(PyScalarValue(scalar)) + } +} + +/// Helper function to take any Python object that contains an Arrow PyCapsule +/// interface and attempt to extract a scalar value from it. If `as_list_array` +/// is true, the array will be turned into a ``ListArray``. Otherwise, we extract +/// the first value from the array. +fn pyobj_extract_scalar_via_capsule( + value: &Bound<'_, PyAny>, + as_list_array: bool, +) -> PyResult { + let array_data = ArrayData::from_pyarrow_bound(value)?; + let array = make_array(array_data); + + array_to_scalar_value(array, as_list_array) +} + impl FromPyArrow for PyScalarValue { fn from_pyarrow_bound(value: &Bound<'_, PyAny>) -> PyResult { let py = value.py(); - let typ = value.getattr("type")?; + let pyarrow_mod = py.import("pyarrow"); - // construct pyarrow array from the python value and pyarrow type - let factory = py.import("pyarrow")?.getattr("array")?; - let args = PyList::new(py, [value])?; - let array = factory.call1((args, typ))?; + // Is it a PyArrow object? + if let Ok(pa) = pyarrow_mod.as_ref() { + let scalar_type = pa.getattr("Scalar")?; + if value.is_instance(&scalar_type)? { + let typ = value.getattr("type")?; - // convert the pyarrow array to rust array using C data interface - let array = arrow::array::make_array(ArrayData::from_pyarrow_bound(&array)?); - let scalar = ScalarValue::try_from_array(&array, 0).map_err(PyDataFusionError::from)?; + // construct pyarrow array from the python value and pyarrow type + let factory = py.import("pyarrow")?.getattr("array")?; + let args = PyList::new(py, [value])?; + let array = factory.call1((args, typ))?; - Ok(PyScalarValue(scalar)) + return pyobj_extract_scalar_via_capsule(&array, false); + } + + let array_type = pa.getattr("Array")?; + if value.is_instance(&array_type)? { + return pyobj_extract_scalar_via_capsule(value, true); + } + } + + // Is it a NanoArrow scalar? + if let Ok(na) = py.import("nanoarrow") { + let scalar_type = py.import("nanoarrow.array")?.getattr("Scalar")?; + if value.is_instance(&scalar_type)? { + return pyobj_extract_scalar_via_capsule(value, false); + } + let array_type = na.getattr("Array")?; + if value.is_instance(&array_type)? { + return pyobj_extract_scalar_via_capsule(value, true); + } + } + + // Is it a arro3 scalar? + if let Ok(arro3) = py.import("arro3").and_then(|arro3| arro3.getattr("core")) { + let scalar_type = arro3.getattr("Scalar")?; + if value.is_instance(&scalar_type)? { + return pyobj_extract_scalar_via_capsule(value, false); + } + let array_type = arro3.getattr("Array")?; + if value.is_instance(&array_type)? { + return pyobj_extract_scalar_via_capsule(value, true); + } + } + + // Does it have a PyCapsule interface but isn't one of our known libraries? + // If so do our "best guess". Try checking type name, and if that fails + // return a single value if the length is 1 and return a List value otherwise + if value.hasattr("__arrow_c_array__")? { + let type_name = value.get_type().repr()?; + if type_name.contains("Scalar")? { + return pyobj_extract_scalar_via_capsule(value, false); + } + if type_name.contains("Array")? { + return pyobj_extract_scalar_via_capsule(value, true); + } + + let array_data = ArrayData::from_pyarrow_bound(value)?; + let array = make_array(array_data); + + let as_array_list = array.len() != 1; + return array_to_scalar_value(array, as_array_list); + } + + // Last attempt - try to create a PyArrow scalar from a plain Python object + if let Ok(pa) = pyarrow_mod.as_ref() { + let scalar = pa.call_method1("scalar", (value,))?; + + PyScalarValue::from_pyarrow_bound(&scalar) + } else { + exec_err!("Unable to import scalar value").map_err(PyDataFusionError::from)? + } } } diff --git a/src/udaf.rs b/src/udaf.rs index 298a59b05..cc166035d 100644 --- a/src/udaf.rs +++ b/src/udaf.rs @@ -17,7 +17,7 @@ use std::sync::Arc; -use datafusion::arrow::array::{Array, ArrayRef}; +use datafusion::arrow::array::ArrayRef; use datafusion::arrow::datatypes::DataType; use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow}; use datafusion::common::ScalarValue; @@ -47,24 +47,24 @@ impl RustAccumulator { impl Accumulator for RustAccumulator { fn state(&mut self) -> Result> { - Python::attach(|py| { - self.accum - .bind(py) - .call_method0("state")? - .extract::>() + Python::attach(|py| -> PyResult> { + let values = self.accum.bind(py).call_method0("state")?; + let mut scalars = Vec::new(); + for item in values.try_iter()? { + let item: Bound<'_, PyAny> = item?; + let scalar = item.extract::()?.0; + scalars.push(scalar); + } + Ok(scalars) }) - .map(|v| v.into_iter().map(|x| x.0).collect()) .map_err(|e| DataFusionError::Execution(format!("{e}"))) } fn evaluate(&mut self) -> Result { - Python::attach(|py| { - self.accum - .bind(py) - .call_method0("evaluate")? - .extract::() + Python::attach(|py| -> PyResult { + let value = self.accum.bind(py).call_method0("evaluate")?; + value.extract::().map(|v| v.0) }) - .map(|v| v.0) .map_err(|e| DataFusionError::Execution(format!("{e}"))) } @@ -73,7 +73,7 @@ impl Accumulator for RustAccumulator { // 1. cast args to Pyarrow array let py_args = values .iter() - .map(|arg| arg.into_data().to_pyarrow(py).unwrap()) + .map(|arg| arg.to_data().to_pyarrow(py).unwrap()) .collect::>(); let py_args = PyTuple::new(py, py_args).map_err(to_datafusion_err)?; @@ -94,7 +94,7 @@ impl Accumulator for RustAccumulator { .iter() .map(|state| { state - .into_data() + .to_data() .to_pyarrow(py) .map_err(|e| DataFusionError::Execution(format!("{e}"))) }) @@ -119,7 +119,7 @@ impl Accumulator for RustAccumulator { // 1. cast args to Pyarrow array let py_args = values .iter() - .map(|arg| arg.into_data().to_pyarrow(py).unwrap()) + .map(|arg| arg.to_data().to_pyarrow(py).unwrap()) .collect::>(); let py_args = PyTuple::new(py, py_args).map_err(to_datafusion_err)?; @@ -144,7 +144,7 @@ impl Accumulator for RustAccumulator { } pub fn to_rust_accumulator(accum: Py) -> AccumulatorFactoryFunction { - Arc::new(move |_| -> Result> { + Arc::new(move |_args| -> Result> { let accum = Python::attach(|py| { accum .call0(py) diff --git a/src/udwf.rs b/src/udwf.rs index b5b795d27..4bf55a850 100644 --- a/src/udwf.rs +++ b/src/udwf.rs @@ -94,7 +94,6 @@ impl PartitionEvaluator for RustPartitionEvaluator { } fn evaluate_all(&mut self, values: &[ArrayRef], num_rows: usize) -> Result { - println!("evaluate all called with number of values {}", values.len()); Python::attach(|py| { let py_values = PyList::new( py, diff --git a/src/utils.rs b/src/utils.rs index 311f8fc86..28b58ba0f 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -19,7 +19,6 @@ use std::future::Future; use std::sync::{Arc, OnceLock}; use std::time::Duration; -use datafusion::common::ScalarValue; use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionContext; use datafusion::logical_expr::Volatility; @@ -34,7 +33,6 @@ use tokio::task::JoinHandle; use tokio::time::sleep; use crate::TokioRuntime; -use crate::common::data_type::PyScalarValue; use crate::context::PySessionContext; use crate::errors::{PyDataFusionError, PyDataFusionResult, py_datafusion_err, to_datafusion_err}; @@ -199,22 +197,6 @@ pub(crate) fn table_provider_from_pycapsule<'py>( } } -pub(crate) fn py_obj_to_scalar_value(py: Python, obj: Py) -> PyResult { - // convert Python object to PyScalarValue to ScalarValue - - let pa = py.import("pyarrow")?; - - // Convert Python object to PyArrow scalar - let scalar = pa.call_method1("scalar", (obj,))?; - - // Convert PyArrow scalar to PyScalarValue - let py_scalar = PyScalarValue::extract_bound(scalar.as_ref()) - .map_err(|e| PyValueError::new_err(format!("Failed to extract PyScalarValue: {e}")))?; - - // Convert PyScalarValue to ScalarValue - Ok(py_scalar.into()) -} - pub(crate) fn extract_logical_extension_codec( py: Python, obj: Option>,