Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions src/duckdb_py/arrow/pyarrow_filter_pushdown.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,15 @@ py::object GetScalar(Value &constant, const string &timezone_config, const Arrow
}
}

static py::list TransformInList(const InFilter &in) {
py::list res;
ClientProperties default_properties;
for (auto &val : in.values) {
res.append(PythonObject::FromValue(val, val.type(), default_properties));
}
return res;
}

py::object TransformFilterRecursive(TableFilter &filter, vector<string> column_ref, const string &timezone_config,
const ArrowType &type) {
auto &import_cache = *DuckDBPyConnection::ImportCache();
Expand Down Expand Up @@ -282,17 +291,9 @@ py::object TransformFilterRecursive(TableFilter &filter, vector<string> column_r
}
case TableFilterType::IN_FILTER: {
auto &in_filter = filter.Cast<InFilter>();
ConjunctionOrFilter or_filter;
value_set_t unique_values;
for (const auto &value : in_filter.values) {
if (unique_values.find(value) == unique_values.end()) {
unique_values.insert(value);
}
}
for (const auto &value : unique_values) {
or_filter.child_filters.push_back(make_uniq<ConstantFilter>(ExpressionType::COMPARE_EQUAL, value));
}
return TransformFilterRecursive(or_filter, column_ref, timezone_config, type);
auto constant_field = field(py::tuple(py::cast(column_ref)));
auto in_list = TransformInList(in_filter);
return constant_field.attr("isin")(std::move(in_list));
}
case TableFilterType::DYNAMIC_FILTER: {
//! Ignore dynamic filters for now, not necessary for correctness
Expand Down
27 changes: 27 additions & 0 deletions tests/fast/arrow/test_filter_pushdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,6 +921,33 @@ def test_in_filter_pushdown(self, duckdb_cursor):
duckdb_conn.register("duck_probe_arrow", duck_probe_arrow)
assert duckdb_conn.execute("SELECT * from duck_probe_arrow where a = any([1,999])").fetchall() == [(1,), (999,)]

@pytest.mark.timeout(10)
def test_in_filter_pushdown_large_list(self, duckdb_cursor):
"""Large IN lists must not hang. Regression test for https://github.com/duckdb/duckdb-python/issues/52."""
arrow_table = pa.table({"a": pa.array(range(5000))})
in_list = ", ".join(str(i) for i in range(0, 5000, 2))
result = duckdb.sql(f"SELECT count(*) FROM arrow_table WHERE a IN ({in_list})").fetchone()
assert result == (2500,)

def test_in_filter_pushdown_with_nulls(self, duckdb_cursor):
arrow_table = pa.table({"a": pa.array([1, 2, None, 4, None, 6])})
# IN list without NULL: null rows should not match
result = duckdb.sql("SELECT a FROM arrow_table WHERE a IN (1, 4) ORDER BY a").fetchall()
assert result == [(1,), (4,)]
# IN list with NULL: null rows still should not match (SQL semantics)
result = duckdb.sql("SELECT a FROM arrow_table WHERE a IN (1, 4, NULL) ORDER BY a").fetchall()
assert result == [(1,), (4,)]

def test_in_filter_pushdown_varchar(self, duckdb_cursor):
arrow_table = pa.table({"s": pa.array(["alice", "bob", "charlie", "dave", None])})
result = duckdb.sql("SELECT s FROM arrow_table WHERE s IN ('bob', 'dave') ORDER BY s").fetchall()
assert result == [("bob",), ("dave",)]

def test_in_filter_pushdown_float(self, duckdb_cursor):
arrow_table = pa.table({"f": pa.array([1.0, 2.5, 3.75, 4.0, None], type=pa.float64())})
result = duckdb.sql("SELECT f FROM arrow_table WHERE f IN (2.5, 4.0) ORDER BY f").fetchall()
assert result == [(2.5,), (4.0,)]

def test_pushdown_of_optional_filter(self, duckdb_cursor):
cardinality_table = pa.Table.from_pydict(
{
Expand Down