Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
5f10176
Implement UDAF improvements for list type handling
kosiew Jan 20, 2026
fbba2a0
Document UDAF list-valued scalar returns
kosiew Jan 20, 2026
21906bb
Fix pyarrow calls and improve type handling in RustAccumulator
kosiew Jan 20, 2026
7f363a7
Refactor RustAccumulator to support pyarrow array types and improve t…
kosiew Jan 20, 2026
5271ba2
Fixed PyO3 type mismatch by cloning Array/ChunkedArray types before u…
kosiew Jan 20, 2026
9c59258
Add timezone information to datetime objects in test_udaf_list_timest…
kosiew Jan 20, 2026
6742954
clippy fix
kosiew Jan 20, 2026
f16c718
Refactor RustAccumulator and utility functions for improved type hand…
kosiew Feb 5, 2026
dcf6145
Enhance PyArrow integration by refining type handling and conversion …
kosiew Feb 5, 2026
7f2d9ae
Merge branch 'main' into typeconversion-issue-1339
kosiew Feb 5, 2026
7ff146e
Fix array data binding in py_obj_to_scalar_value function
kosiew Feb 5, 2026
87a0e30
Implement single point for scalar conversion from python objects
timsaucer Feb 6, 2026
76dee1c
Add unit tests and simplify python wrapper for literal
timsaucer Feb 6, 2026
85ee4f7
Add nanoarrow and arro3-core to dev dependencies. Sort the dependenci…
timsaucer Feb 11, 2026
9d0ac50
Refactor common code into helper function so we do not duplicate it.
timsaucer Feb 11, 2026
5ac0164
Update import path to access Scalar type
timsaucer Feb 11, 2026
1f7da06
Add test for generic python objects that support the C interface
timsaucer Feb 11, 2026
33f0b7f
Merge branch 'main' into typeconversion-issue-1339
timsaucer Feb 18, 2026
390d753
Update unit test to pass back either pyarrow array or array wrapped a…
timsaucer Feb 18, 2026
67a6bc1
Update tests to pass back raw python values or pyarrow scalar
timsaucer Feb 18, 2026
27fa92a
Expand on user documentation for how to return list arrays
timsaucer Feb 18, 2026
a10ee5a
More user documentation
timsaucer Feb 18, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions docs/source/user-guide/common-operations/udf-and-udfa.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ also see how the inputs to ``update`` and ``merge`` differ.

.. code-block:: python

import pyarrow
import pyarrow as pa
import pyarrow.compute
import datafusion
from datafusion import col, udaf, Accumulator
Expand All @@ -136,16 +136,16 @@ also see how the inputs to ``update`` and ``merge`` differ.
def __init__(self):
self._sum = 0.0

def update(self, values_a: pyarrow.Array, values_b: pyarrow.Array) -> None:
def update(self, values_a: pa.Array, values_b: pa.Array) -> None:
self._sum = self._sum + pyarrow.compute.sum(values_a).as_py() - pyarrow.compute.sum(values_b).as_py()

def merge(self, states: List[pyarrow.Array]) -> None:
def merge(self, states: list[pa.Array]) -> None:
self._sum = self._sum + pyarrow.compute.sum(states[0]).as_py()

def state(self) -> pyarrow.Array:
return pyarrow.array([self._sum])
def state(self) -> list[pa.Scalar]:
return [pyarrow.scalar(self._sum)]

def evaluate(self) -> pyarrow.Scalar:
def evaluate(self) -> pa.Scalar:
return pyarrow.scalar(self._sum)

ctx = datafusion.SessionContext()
Expand All @@ -156,10 +156,29 @@ also see how the inputs to ``update`` and ``merge`` differ.
}
)

my_udaf = udaf(MyAccumulator, [pyarrow.float64(), pyarrow.float64()], pyarrow.float64(), [pyarrow.float64()], 'stable')
my_udaf = udaf(MyAccumulator, [pa.float64(), pa.float64()], pa.float64(), [pa.float64()], 'stable')

df.aggregate([], [my_udaf(col("a"), col("b")).alias("col_diff")])

FAQ
^^^

**How do I return a list from a UDAF?**

Both the ``evaluate`` and the ``state`` functions expect to return scalar values.
If you wish to return a list array as a scalar value, the best practice is to
wrap the values in a ``pyarrow.Scalar`` object. For example, you can return a
timestamp list with ``pa.scalar([...], type=pa.list_(pa.timestamp("ms")))`` and
register the appropriate return or state types as
``return_type=pa.list_(pa.timestamp("ms"))`` and
``state_type=[pa.list_(pa.timestamp("ms"))]``, respectively.

As of DataFusion 52.0.0 , you can pass return any Python object, including a
PyArrow array, as the return value(s) for these functions and DataFusion will
attempt to create a scalar type from the value. DataFusion has been tested to
convert PyArrow, nanoarrow, and arro3 objects as well as primitive data types
like integers, strings, and so on.

Window Functions
----------------

Expand Down
20 changes: 11 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -173,27 +173,29 @@ ignore-words-list = ["ans", "IST"]

[dependency-groups]
dev = [
"arro3-core==0.6.5",
"codespell==2.4.1",
"maturin>=1.8.1",
"nanoarrow==0.8.0",
"numpy>1.25.0;python_version<'3.14'",
"numpy>=2.3.2;python_version>='3.14'",
"pyarrow>=19.0.0",
"pre-commit>=4.3.0",
"pyyaml>=6.0.3",
"pyarrow>=19.0.0",
"pygithub==2.5.0",
"pytest>=7.4.4",
"pytest-asyncio>=0.23.3",
"pyyaml>=6.0.3",
"ruff>=0.9.1",
"toml>=0.10.2",
"pygithub==2.5.0",
"codespell==2.4.1",
]
docs = [
"sphinx>=7.1.2",
"pydata-sphinx-theme==0.8.0",
"myst-parser>=3.0.1",
"jinja2>=3.1.5",
"ipython>=8.12.3",
"jinja2>=3.1.5",
"myst-parser>=3.0.1",
"pandas>=2.0.3",
"pickleshare>=0.7.5",
"sphinx-autoapi>=3.4.0",
"pydata-sphinx-theme==0.8.0",
"setuptools>=75.3.0",
"sphinx>=7.1.2",
"sphinx-autoapi>=3.4.0",
]
3 changes: 0 additions & 3 deletions python/datafusion/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,8 +562,6 @@ def literal(value: Any) -> Expr:
"""
if isinstance(value, str):
value = pa.scalar(value, type=pa.string_view())
if not isinstance(value, pa.Scalar):
value = pa.scalar(value)
return Expr(expr_internal.RawExpr.literal(value))

@staticmethod
Expand All @@ -576,7 +574,6 @@ def literal_with_metadata(value: Any, metadata: dict[str, str]) -> Expr:
"""
if isinstance(value, str):
value = pa.scalar(value, type=pa.string_view())
value = value if isinstance(value, pa.Scalar) else pa.scalar(value)

return Expr(expr_internal.RawExpr.literal_with_metadata(value, metadata))

Expand Down
22 changes: 20 additions & 2 deletions python/datafusion/user_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,16 @@ class Accumulator(metaclass=ABCMeta):

@abstractmethod
def state(self) -> list[pa.Scalar]:
"""Return the current state."""
"""Return the current state.

While this function template expects PyArrow Scalar values return type,
you can return any value that can be converted into a Scalar. This
includes basic Python data types such as integers and strings. In
addition to primitive types, we currently support PyArrow, nanoarrow,
and arro3 objects in addition to primitive data types. Other objects
that support the Arrow FFI standard will be given a "best attempt" at
conversion to scalar objects.
"""

@abstractmethod
def update(self, *values: pa.Array) -> None:
Expand All @@ -310,7 +319,16 @@ def merge(self, states: list[pa.Array]) -> None:

@abstractmethod
def evaluate(self) -> pa.Scalar:
"""Return the resultant value."""
"""Return the resultant value.

While this function template expects a PyArrow Scalar value return type,
you can return any value that can be converted into a Scalar. This
includes basic Python data types such as integers and strings. In
addition to primitive types, we currently support PyArrow, nanoarrow,
and arro3 objects in addition to primitive data types. Other objects
that support the Arrow FFI standard will be given a "best attempt" at
conversion to scalar objects.
"""


class AggregateUDFExportable(Protocol):
Expand Down
45 changes: 45 additions & 0 deletions python/tests/test_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from datetime import date, datetime, time, timezone
from decimal import Decimal

import arro3.core
import nanoarrow
import pyarrow as pa
import pytest
from datafusion import (
Expand Down Expand Up @@ -980,6 +982,49 @@ def test_literal_metadata(ctx):
assert expected_field.metadata == actual_field.metadata


def test_scalar_conversion() -> None:
class WrappedPyArrow:
"""Wrapper class for testing __arrow_c_array__."""

def __init__(self, val: pa.Array) -> None:
self.val = val

def __arrow_c_array__(self, requested_schema=None):
return self.val.__arrow_c_array__(requested_schema=requested_schema)

expected_value = lit(1)
assert str(expected_value) == "Expr(Int64(1))"

# Test pyarrow imports
assert expected_value == lit(pa.scalar(1))
assert expected_value == lit(pa.scalar(1, type=pa.int32()))

# Test nanoarrow
na_scalar = nanoarrow.Array([1], nanoarrow.int32())[0]
assert expected_value == lit(na_scalar)

# Test pyo3
arro3_scalar = arro3.core.Scalar(1, type=arro3.core.DataType.int32())
assert expected_value == lit(arro3_scalar)

generic_scalar = WrappedPyArrow(pa.array([1]))
assert expected_value == lit(generic_scalar)

expected_value = lit([1, 2, 3])
assert str(expected_value) == "Expr(List([1, 2, 3]))"

assert expected_value == lit(pa.scalar([1, 2, 3]))

na_array = nanoarrow.Array([1, 2, 3], nanoarrow.int32())
assert expected_value == lit(na_array)

arro3_array = arro3.core.Array([1, 2, 3], type=arro3.core.DataType.int32())
assert expected_value == lit(arro3_array)

generic_array = WrappedPyArrow(pa.array([1, 2, 3]))
assert expected_value == lit(generic_array)


def test_ensure_expr():
e = col("a")
assert ensure_expr(e) is e.expr
Expand Down
89 changes: 83 additions & 6 deletions python/tests/test_udaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from __future__ import annotations

from datetime import datetime, timezone

import pyarrow as pa
import pyarrow.compute as pc
import pytest
Expand All @@ -26,23 +28,28 @@
class Summarize(Accumulator):
"""Interface of a user-defined accumulation."""

def __init__(self, initial_value: float = 0.0):
self._sum = pa.scalar(initial_value)
def __init__(self, initial_value: float = 0.0, as_scalar: bool = False):
self._sum = initial_value
self.as_scalar = as_scalar

def state(self) -> list[pa.Scalar]:
if self.as_scalar:
return [pa.scalar(self._sum)]
return [self._sum]

def update(self, values: pa.Array) -> None:
# Not nice since pyarrow scalars can't be summed yet.
# This breaks on `None`
self._sum = pa.scalar(self._sum.as_py() + pc.sum(values).as_py())
self._sum = self._sum + pc.sum(values).as_py()

def merge(self, states: list[pa.Array]) -> None:
# Not nice since pyarrow scalars can't be summed yet.
# This breaks on `None`
self._sum = pa.scalar(self._sum.as_py() + pc.sum(states[0]).as_py())
self._sum = self._sum + pc.sum(states[0]).as_py()

def evaluate(self) -> pa.Scalar:
if self.as_scalar:
return pa.scalar(self._sum)
return self._sum


Expand All @@ -58,6 +65,30 @@ def state(self) -> list[pa.Scalar]:
return [self._sum]


class CollectTimestamps(Accumulator):
def __init__(self, wrap_in_scalar: bool):
self._values: list[datetime] = []
self.wrap_in_scalar = wrap_in_scalar

def state(self) -> list[pa.Scalar]:
if self.wrap_in_scalar:
return [pa.scalar(self._values, type=pa.list_(pa.timestamp("ns")))]
return [pa.array(self._values, type=pa.timestamp("ns"))]

def update(self, values: pa.Array) -> None:
self._values.extend(values.to_pylist())

def merge(self, states: list[pa.Array]) -> None:
for state in states[0].to_pylist():
if state is not None:
self._values.extend(state)

def evaluate(self) -> pa.Scalar:
if self.wrap_in_scalar:
return pa.scalar(self._values, type=pa.list_(pa.timestamp("ns")))
return pa.array(self._values, type=pa.timestamp("ns"))


@pytest.fixture
def df(ctx):
# create a RecordBatch and a new DataFrame from it
Expand Down Expand Up @@ -137,11 +168,12 @@ def summarize():
assert result.column(0) == pa.array([1.0 + 2.0 + 3.0])


def test_udaf_aggregate_with_arguments(df):
@pytest.mark.parametrize("as_scalar", [True, False])
def test_udaf_aggregate_with_arguments(df, as_scalar):
bias = 10.0

summarize = udaf(
lambda: Summarize(bias),
lambda: Summarize(initial_value=bias, as_scalar=as_scalar),
pa.float64(),
pa.float64(),
[pa.float64()],
Expand Down Expand Up @@ -217,3 +249,48 @@ def test_register_udaf(ctx, df) -> None:
df_result = ctx.sql("select summarize(b) from test_table")

assert df_result.collect()[0][0][0].as_py() == 14.0


@pytest.mark.parametrize("wrap_in_scalar", [True, False])
def test_udaf_list_timestamp_return(ctx, wrap_in_scalar) -> None:
timestamps1 = [
datetime(2024, 1, 1, tzinfo=timezone.utc),
datetime(2024, 1, 2, tzinfo=timezone.utc),
]
timestamps2 = [
datetime(2024, 1, 3, tzinfo=timezone.utc),
datetime(2024, 1, 4, tzinfo=timezone.utc),
]
batch1 = pa.RecordBatch.from_arrays(
[pa.array(timestamps1, type=pa.timestamp("ns"))],
names=["ts"],
)
batch2 = pa.RecordBatch.from_arrays(
[pa.array(timestamps2, type=pa.timestamp("ns"))],
names=["ts"],
)
df = ctx.create_dataframe([[batch1], [batch2]], name="timestamp_table")

list_type = pa.list_(
pa.field("item", type=pa.timestamp("ns"), nullable=wrap_in_scalar)
)

collect = udaf(
lambda: CollectTimestamps(wrap_in_scalar),
pa.timestamp("ns"),
list_type,
[list_type],
volatility="immutable",
)

result = df.aggregate([], [collect(column("ts"))]).collect()[0]

# There is no guarantee about the ordering of the batches, so perform a sort
# to get consistent results. Alternatively we could sort on evaluate().
assert (
result.column(0).values.sort()
== pa.array(
[[*timestamps1, *timestamps2]],
type=list_type,
).values
)
3 changes: 3 additions & 0 deletions src/common/data_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ use datafusion::logical_expr::expr::NullTreatment as DFNullTreatment;
use pyo3::exceptions::{PyNotImplementedError, PyValueError};
use pyo3::prelude::*;

/// A [`ScalarValue`] wrapped in a Python object. This struct allows for conversion
/// from a variety of Python objects into a [`ScalarValue`]. See
/// ``FromPyArrow::from_pyarrow_bound`` conversion details.
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd)]
pub struct PyScalarValue(pub ScalarValue);

Expand Down
6 changes: 3 additions & 3 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ use parking_lot::RwLock;
use pyo3::prelude::*;
use pyo3::types::*;

use crate::common::data_type::PyScalarValue;
use crate::errors::PyDataFusionResult;
use crate::utils::py_obj_to_scalar_value;
#[pyclass(name = "Config", module = "datafusion", subclass, frozen)]
#[derive(Clone)]
pub(crate) struct PyConfig {
Expand Down Expand Up @@ -65,9 +65,9 @@ impl PyConfig {

/// Set a configuration option
pub fn set(&self, key: &str, value: Py<PyAny>, py: Python) -> PyDataFusionResult<()> {
let scalar_value = py_obj_to_scalar_value(py, value)?;
let scalar_value: PyScalarValue = value.extract(py)?;
let mut options = self.config.write();
options.set(key, scalar_value.to_string().as_str())?;
options.set(key, scalar_value.0.to_string().as_str())?;
Ok(())
}

Expand Down
Loading