diff --git a/src/duckdb_py/arrow/pyarrow_filter_pushdown.cpp b/src/duckdb_py/arrow/pyarrow_filter_pushdown.cpp index 66a6e3fa..af05789a 100644 --- a/src/duckdb_py/arrow/pyarrow_filter_pushdown.cpp +++ b/src/duckdb_py/arrow/pyarrow_filter_pushdown.cpp @@ -160,6 +160,15 @@ py::object GetScalar(Value &constant, const string &timezone_config, const Arrow } } +static py::list TransformInList(const InFilter &in) { + py::list res; + ClientProperties default_properties; + for (auto &val : in.values) { + res.append(PythonObject::FromValue(val, val.type(), default_properties)); + } + return res; +} + py::object TransformFilterRecursive(TableFilter &filter, vector column_ref, const string &timezone_config, const ArrowType &type) { auto &import_cache = *DuckDBPyConnection::ImportCache(); @@ -282,17 +291,9 @@ py::object TransformFilterRecursive(TableFilter &filter, vector column_r } case TableFilterType::IN_FILTER: { auto &in_filter = filter.Cast(); - ConjunctionOrFilter or_filter; - value_set_t unique_values; - for (const auto &value : in_filter.values) { - if (unique_values.find(value) == unique_values.end()) { - unique_values.insert(value); - } - } - for (const auto &value : unique_values) { - or_filter.child_filters.push_back(make_uniq(ExpressionType::COMPARE_EQUAL, value)); - } - return TransformFilterRecursive(or_filter, column_ref, timezone_config, type); + auto constant_field = field(py::tuple(py::cast(column_ref))); + auto in_list = TransformInList(in_filter); + return constant_field.attr("isin")(std::move(in_list)); } case TableFilterType::DYNAMIC_FILTER: { //! Ignore dynamic filters for now, not necessary for correctness diff --git a/tests/fast/arrow/test_filter_pushdown.py b/tests/fast/arrow/test_filter_pushdown.py index 7938585a..d2eea92e 100644 --- a/tests/fast/arrow/test_filter_pushdown.py +++ b/tests/fast/arrow/test_filter_pushdown.py @@ -921,6 +921,33 @@ def test_in_filter_pushdown(self, duckdb_cursor): duckdb_conn.register("duck_probe_arrow", duck_probe_arrow) assert duckdb_conn.execute("SELECT * from duck_probe_arrow where a = any([1,999])").fetchall() == [(1,), (999,)] + @pytest.mark.timeout(10) + def test_in_filter_pushdown_large_list(self, duckdb_cursor): + """Large IN lists must not hang. Regression test for https://github.com/duckdb/duckdb-python/issues/52.""" + arrow_table = pa.table({"a": pa.array(range(5000))}) + in_list = ", ".join(str(i) for i in range(0, 5000, 2)) + result = duckdb.sql(f"SELECT count(*) FROM arrow_table WHERE a IN ({in_list})").fetchone() + assert result == (2500,) + + def test_in_filter_pushdown_with_nulls(self, duckdb_cursor): + arrow_table = pa.table({"a": pa.array([1, 2, None, 4, None, 6])}) + # IN list without NULL: null rows should not match + result = duckdb.sql("SELECT a FROM arrow_table WHERE a IN (1, 4) ORDER BY a").fetchall() + assert result == [(1,), (4,)] + # IN list with NULL: null rows still should not match (SQL semantics) + result = duckdb.sql("SELECT a FROM arrow_table WHERE a IN (1, 4, NULL) ORDER BY a").fetchall() + assert result == [(1,), (4,)] + + def test_in_filter_pushdown_varchar(self, duckdb_cursor): + arrow_table = pa.table({"s": pa.array(["alice", "bob", "charlie", "dave", None])}) + result = duckdb.sql("SELECT s FROM arrow_table WHERE s IN ('bob', 'dave') ORDER BY s").fetchall() + assert result == [("bob",), ("dave",)] + + def test_in_filter_pushdown_float(self, duckdb_cursor): + arrow_table = pa.table({"f": pa.array([1.0, 2.5, 3.75, 4.0, None], type=pa.float64())}) + result = duckdb.sql("SELECT f FROM arrow_table WHERE f IN (2.5, 4.0) ORDER BY f").fetchall() + assert result == [(2.5,), (4.0,)] + def test_pushdown_of_optional_filter(self, duckdb_cursor): cardinality_table = pa.Table.from_pydict( {