Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,8 @@ struct DuckDBPyConnection : public enable_shared_from_this<DuckDBPyConnection> {
static unique_ptr<QueryResult> CompletePendingQuery(PendingQueryResult &pending_query);

private:
unique_ptr<DuckDBPyRelation> CreateRelation(shared_ptr<Relation> rel);
unique_ptr<DuckDBPyRelation> CreateRelation(shared_ptr<DuckDBPyResult> result);
PathLike GetPathLike(const py::object &object);
ScalarFunction CreateScalarUDF(const string &name, const py::function &udf, const py::object &parameters,
const shared_ptr<DuckDBPyType> &return_type, bool vectorized,
Expand Down
7 changes: 7 additions & 0 deletions src/duckdb_py/include/duckdb_python/pyrelation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,10 @@ struct DuckDBPyRelation {

bool ContainsColumnByName(const string &name) const;

void SetConnectionOwner(py::object owner);
unique_ptr<DuckDBPyRelation> DeriveRelation(shared_ptr<Relation> new_rel);
unique_ptr<DuckDBPyRelation> DeriveRelation(shared_ptr<DuckDBPyResult> result);

private:
string ToStringInternal(const BoxRendererConfig &config, bool invalidate_cache = false);
string GenerateExpressionList(const string &function_name, const string &aggregated_columns,
Expand All @@ -284,6 +288,9 @@ struct DuckDBPyRelation {
unique_ptr<QueryResult> ExecuteInternal(bool stream_result = false);

private:
//! Prevents GC of the parent DuckDBPyConnection.
//! Declared first so it is destroyed last (reverse declaration order).
py::object connection_owner;
//! Whether the relation has been executed at least once
bool executed;
shared_ptr<Relation> rel;
Expand Down
48 changes: 31 additions & 17 deletions src/duckdb_py/pyconnection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,20 @@ DuckDBPyConnection::~DuckDBPyConnection() {
}
}

unique_ptr<DuckDBPyRelation> DuckDBPyConnection::CreateRelation(shared_ptr<Relation> rel) {
auto py_rel = make_uniq<DuckDBPyRelation>(std::move(rel));
py::gil_scoped_acquire gil;
py_rel->SetConnectionOwner(py::cast(shared_from_this()));
return py_rel;
}

unique_ptr<DuckDBPyRelation> DuckDBPyConnection::CreateRelation(shared_ptr<DuckDBPyResult> result) {
auto py_rel = make_uniq<DuckDBPyRelation>(std::move(result));
py::gil_scoped_acquire gil;
py_rel->SetConnectionOwner(py::cast(shared_from_this()));
return py_rel;
}

void DuckDBPyConnection::DetectEnvironment() {
// Get the formatted Python version
py::module_ sys = py::module_::import("sys");
Expand Down Expand Up @@ -513,8 +527,9 @@ shared_ptr<DuckDBPyConnection> DuckDBPyConnection::ExecuteMany(const py::object
}
// Set the internal 'result' object
if (query_result) {
auto py_result = make_uniq<DuckDBPyResult>(std::move(query_result));
con.SetResult(make_uniq<DuckDBPyRelation>(std::move(py_result)));
// Don't use CreateRelation here — the result is stored inside the connection,
// so setting connection_owner would create a ref cycle (connection → result → connection).
con.SetResult(make_uniq<DuckDBPyRelation>(make_shared_ptr<DuckDBPyResult>(std::move(query_result))));
}

return shared_from_this();
Expand Down Expand Up @@ -713,8 +728,9 @@ shared_ptr<DuckDBPyConnection> DuckDBPyConnection::Execute(const py::object &que

// Set the internal 'result' object
if (res) {
auto py_result = make_uniq<DuckDBPyResult>(std::move(res));
con.SetResult(make_uniq<DuckDBPyRelation>(std::move(py_result)));
// Don't use CreateRelation here — the result is stored inside the connection,
// so setting connection_owner would create a ref cycle (connection → result → connection).
con.SetResult(make_uniq<DuckDBPyRelation>(make_shared_ptr<DuckDBPyResult>(std::move(res))));
}
return shared_from_this();
}
Expand Down Expand Up @@ -982,7 +998,7 @@ unique_ptr<DuckDBPyRelation> DuckDBPyConnection::ReadJSON(
if (file_like_object_wrapper) {
read_json_relation->AddExternalDependency(std::move(file_like_object_wrapper));
}
return make_uniq<DuckDBPyRelation>(std::move(read_json_relation));
return CreateRelation(std::move(read_json_relation));
}

PathLike DuckDBPyConnection::GetPathLike(const py::object &object) {
Expand Down Expand Up @@ -1553,7 +1569,7 @@ unique_ptr<DuckDBPyRelation> DuckDBPyConnection::ReadCSV(const py::object &name_
read_csv.AddExternalDependency(std::move(file_like_object_wrapper));
}

return make_uniq<DuckDBPyRelation>(read_csv_p->Alias(read_csv.alias));
return CreateRelation(read_csv_p->Alias(read_csv.alias));
}

void DuckDBPyConnection::ExecuteImmediately(vector<unique_ptr<SQLStatement>> statements) {
Expand Down Expand Up @@ -1639,7 +1655,7 @@ unique_ptr<DuckDBPyRelation> DuckDBPyConnection::RunQuery(const py::object &quer
relation = make_shared_ptr<MaterializedRelation>(connection.context, materialized_result.TakeCollection(),
res->names, alias);
}
return make_uniq<DuckDBPyRelation>(std::move(relation));
return CreateRelation(std::move(relation));
}

unique_ptr<DuckDBPyRelation> DuckDBPyConnection::Table(const string &tname) {
Expand All @@ -1649,8 +1665,7 @@ unique_ptr<DuckDBPyRelation> DuckDBPyConnection::Table(const string &tname) {
qualified_name.schema = DEFAULT_SCHEMA;
}
try {
return make_uniq<DuckDBPyRelation>(
connection.Table(qualified_name.catalog, qualified_name.schema, qualified_name.name));
return CreateRelation(connection.Table(qualified_name.catalog, qualified_name.schema, qualified_name.name));
} catch (const CatalogException &) {
// CatalogException will be of the type '... is not a table'
// Not a table in the database, make a query relation that can perform replacement scans
Expand Down Expand Up @@ -1716,7 +1731,7 @@ unique_ptr<DuckDBPyRelation> DuckDBPyConnection::Values(const py::args &args) {
py::handle first_arg = args[0];
if (arg_count == 1 && py::isinstance<py::list>(first_arg)) {
vector<vector<Value>> values {DuckDBPyConnection::TransformPythonParamList(first_arg)};
return make_uniq<DuckDBPyRelation>(connection.Values(values));
return CreateRelation(connection.Values(values));
} else {
vector<vector<unique_ptr<ParsedExpression>>> expressions;
if (py::isinstance<py::tuple>(first_arg)) {
Expand All @@ -1725,13 +1740,13 @@ unique_ptr<DuckDBPyRelation> DuckDBPyConnection::Values(const py::args &args) {
auto values = ValueListFromExpressions(args);
expressions.push_back(std::move(values));
}
return make_uniq<DuckDBPyRelation>(connection.Values(std::move(expressions)));
return CreateRelation(connection.Values(std::move(expressions)));
}
}

unique_ptr<DuckDBPyRelation> DuckDBPyConnection::View(const string &vname) {
auto &connection = con.GetConnection();
return make_uniq<DuckDBPyRelation>(connection.View(vname));
return CreateRelation(connection.View(vname));
}

unique_ptr<DuckDBPyRelation> DuckDBPyConnection::TableFunction(const string &fname, py::object params) {
Expand All @@ -1743,8 +1758,7 @@ unique_ptr<DuckDBPyRelation> DuckDBPyConnection::TableFunction(const string &fna
throw InvalidInputException("'params' has to be a list of parameters");
}

return make_uniq<DuckDBPyRelation>(
connection.TableFunction(fname, DuckDBPyConnection::TransformPythonParamList(params)));
return CreateRelation(connection.TableFunction(fname, DuckDBPyConnection::TransformPythonParamList(params)));
}

unique_ptr<DuckDBPyRelation> DuckDBPyConnection::FromDF(const PandasDataFrame &value) {
Expand All @@ -1757,7 +1771,7 @@ unique_ptr<DuckDBPyRelation> DuckDBPyConnection::FromDF(const PandasDataFrame &v
auto tableref = PythonReplacementScan::ReplacementObject(value, name, *connection.context);
D_ASSERT(tableref);
auto rel = make_shared_ptr<ViewRelation>(connection.context, std::move(tableref), name);
return make_uniq<DuckDBPyRelation>(std::move(rel));
return CreateRelation(std::move(rel));
}

unique_ptr<DuckDBPyRelation> DuckDBPyConnection::FromParquetInternal(Value &&file_param, bool binary_as_string,
Expand All @@ -1782,7 +1796,7 @@ unique_ptr<DuckDBPyRelation> DuckDBPyConnection::FromParquetInternal(Value &&fil
}
D_ASSERT(py::gil_check());
py::gil_scoped_release gil;
return make_uniq<DuckDBPyRelation>(connection.TableFunction("parquet_scan", params, named_parameters)->Alias(name));
return CreateRelation(connection.TableFunction("parquet_scan", params, named_parameters)->Alias(name));
}

unique_ptr<DuckDBPyRelation> DuckDBPyConnection::FromParquet(const string &file_glob, bool binary_as_string,
Expand Down Expand Up @@ -1818,7 +1832,7 @@ unique_ptr<DuckDBPyRelation> DuckDBPyConnection::FromArrow(py::object &arrow_obj
auto tableref = PythonReplacementScan::ReplacementObject(arrow_object, name, *connection.context, true);
D_ASSERT(tableref);
auto rel = make_shared_ptr<ViewRelation>(connection.context, std::move(tableref), name);
return make_uniq<DuckDBPyRelation>(std::move(rel));
return CreateRelation(std::move(rel));
}

unordered_set<string> DuckDBPyConnection::GetTableNames(const string &query, bool qualified) {
Expand Down
68 changes: 42 additions & 26 deletions src/duckdb_py/pyrelation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ DuckDBPyRelation::DuckDBPyRelation(shared_ptr<DuckDBPyResult> result_p) : rel(nu
}

unique_ptr<DuckDBPyRelation> DuckDBPyRelation::ProjectFromExpression(const string &expression) {
auto projected_relation = make_uniq<DuckDBPyRelation>(rel->Project(expression));
auto projected_relation = DeriveRelation(rel->Project(expression));
for (auto &dep : this->rel->external_dependencies) {
projected_relation->rel->AddExternalDependency(dep);
}
Expand Down Expand Up @@ -108,9 +108,9 @@ unique_ptr<DuckDBPyRelation> DuckDBPyRelation::Project(const py::args &args, con
vector<string> empty_aliases;
if (groups.empty()) {
// No groups provided
return make_uniq<DuckDBPyRelation>(rel->Project(std::move(expressions), empty_aliases));
return DeriveRelation(rel->Project(std::move(expressions), empty_aliases));
}
return make_uniq<DuckDBPyRelation>(rel->Aggregate(std::move(expressions), groups));
return DeriveRelation(rel->Aggregate(std::move(expressions), groups));
}
}

Expand Down Expand Up @@ -180,7 +180,7 @@ unique_ptr<DuckDBPyRelation> DuckDBPyRelation::EmptyResult(const shared_ptr<Clie
}

unique_ptr<DuckDBPyRelation> DuckDBPyRelation::SetAlias(const string &expr) {
return make_uniq<DuckDBPyRelation>(rel->Alias(expr));
return DeriveRelation(rel->Alias(expr));
}

py::str DuckDBPyRelation::GetAlias() {
Expand All @@ -197,19 +197,19 @@ unique_ptr<DuckDBPyRelation> DuckDBPyRelation::Filter(const py::object &expr) {
throw InvalidInputException("Please provide either a string or a DuckDBPyExpression object to 'filter'");
}
auto expr_p = expression->GetExpression().Copy();
return make_uniq<DuckDBPyRelation>(rel->Filter(std::move(expr_p)));
return DeriveRelation(rel->Filter(std::move(expr_p)));
}

unique_ptr<DuckDBPyRelation> DuckDBPyRelation::FilterFromExpression(const string &expr) {
return make_uniq<DuckDBPyRelation>(rel->Filter(expr));
return DeriveRelation(rel->Filter(expr));
}

unique_ptr<DuckDBPyRelation> DuckDBPyRelation::Limit(int64_t n, int64_t offset) {
return make_uniq<DuckDBPyRelation>(rel->Limit(n, offset));
return DeriveRelation(rel->Limit(n, offset));
}

unique_ptr<DuckDBPyRelation> DuckDBPyRelation::Order(const string &expr) {
return make_uniq<DuckDBPyRelation>(rel->Order(expr));
return DeriveRelation(rel->Order(expr));
}

unique_ptr<DuckDBPyRelation> DuckDBPyRelation::Sort(const py::args &args) {
Expand All @@ -228,7 +228,7 @@ unique_ptr<DuckDBPyRelation> DuckDBPyRelation::Sort(const py::args &args) {
if (order_nodes.empty()) {
throw InvalidInputException("Please provide at least one expression to sort on");
}
return make_uniq<DuckDBPyRelation>(rel->Order(std::move(order_nodes)));
return DeriveRelation(rel->Order(std::move(order_nodes)));
}

vector<unique_ptr<ParsedExpression>> GetExpressions(ClientContext &context, const py::object &expr) {
Expand Down Expand Up @@ -259,9 +259,9 @@ unique_ptr<DuckDBPyRelation> DuckDBPyRelation::Aggregate(const py::object &expr,
AssertRelation();
auto expressions = GetExpressions(*rel->context->GetContext(), expr);
if (!groups.empty()) {
return make_uniq<DuckDBPyRelation>(rel->Aggregate(std::move(expressions), groups));
return DeriveRelation(rel->Aggregate(std::move(expressions), groups));
}
return make_uniq<DuckDBPyRelation>(rel->Aggregate(std::move(expressions)));
return DeriveRelation(rel->Aggregate(std::move(expressions)));
}

void DuckDBPyRelation::AssertResult() const {
Expand Down Expand Up @@ -354,7 +354,7 @@ unique_ptr<DuckDBPyRelation> DuckDBPyRelation::Describe() {
DescribeAggregateInfo("stddev", true), DescribeAggregateInfo("min"),
DescribeAggregateInfo("max"), DescribeAggregateInfo("median", true)};
auto expressions = CreateExpressionList(columns, aggregates);
return make_uniq<DuckDBPyRelation>(rel->Aggregate(expressions));
return DeriveRelation(rel->Aggregate(expressions));
}

string DuckDBPyRelation::ToSQL() {
Expand Down Expand Up @@ -456,7 +456,7 @@ DuckDBPyRelation::GenericWindowFunction(const string &function_name, const strin
const string &projected_columns) {
auto expr = GenerateExpressionList(function_name, aggr_columns, "", function_parameters, ignore_nulls,
projected_columns, window_spec);
return make_uniq<DuckDBPyRelation>(rel->Project(expr));
return DeriveRelation(rel->Project(expr));
}

unique_ptr<DuckDBPyRelation> DuckDBPyRelation::ApplyAggOrWin(const string &function_name, const string &agg_columns,
Expand Down Expand Up @@ -722,7 +722,7 @@ py::tuple DuckDBPyRelation::Shape() {
}

unique_ptr<DuckDBPyRelation> DuckDBPyRelation::Unique(const string &std_columns) {
return make_uniq<DuckDBPyRelation>(rel->Project(std_columns)->Distinct());
return DeriveRelation(rel->Project(std_columns)->Distinct());
}

/* General-purpose window functions */
Expand Down Expand Up @@ -796,7 +796,7 @@ unique_ptr<DuckDBPyRelation> DuckDBPyRelation::NthValue(const string &column, co
}

unique_ptr<DuckDBPyRelation> DuckDBPyRelation::Distinct() {
return make_uniq<DuckDBPyRelation>(rel->Distinct());
return DeriveRelation(rel->Distinct());
}

duckdb::pyarrow::RecordBatchReader DuckDBPyRelation::FetchRecordBatchReader(idx_t rows_per_batch) {
Expand Down Expand Up @@ -1064,6 +1064,22 @@ bool DuckDBPyRelation::ContainsColumnByName(const string &name) const {
[&](const string &item) { return StringUtil::CIEquals(name, item); }) != names.end();
}

void DuckDBPyRelation::SetConnectionOwner(py::object owner) {
connection_owner = std::move(owner);
}

unique_ptr<DuckDBPyRelation> DuckDBPyRelation::DeriveRelation(shared_ptr<Relation> new_rel) {
auto result = make_uniq<DuckDBPyRelation>(std::move(new_rel));
result->connection_owner = connection_owner;
return result;
}

unique_ptr<DuckDBPyRelation> DuckDBPyRelation::DeriveRelation(shared_ptr<DuckDBPyResult> result_p) {
auto result = make_uniq<DuckDBPyRelation>(std::move(result_p));
result->connection_owner = connection_owner;
return result;
}

static bool ContainsStructFieldByName(LogicalType &type, const string &name) {
if (type.id() != LogicalTypeId::STRUCT) {
return false;
Expand Down Expand Up @@ -1104,19 +1120,19 @@ unique_ptr<DuckDBPyRelation> DuckDBPyRelation::GetAttribute(const string &name)
expressions.push_back(std::move(make_uniq<ColumnRefExpression>(column_names)));
vector<string> aliases;
aliases.push_back(name);
return make_uniq<DuckDBPyRelation>(rel->Project(std::move(expressions), aliases));
return DeriveRelation(rel->Project(std::move(expressions), aliases));
}

unique_ptr<DuckDBPyRelation> DuckDBPyRelation::Union(DuckDBPyRelation *other) {
return make_uniq<DuckDBPyRelation>(rel->Union(other->rel));
return DeriveRelation(rel->Union(other->rel));
}

unique_ptr<DuckDBPyRelation> DuckDBPyRelation::Except(DuckDBPyRelation *other) {
return make_uniq<DuckDBPyRelation>(rel->Except(other->rel));
return DeriveRelation(rel->Except(other->rel));
}

unique_ptr<DuckDBPyRelation> DuckDBPyRelation::Intersect(DuckDBPyRelation *other) {
return make_uniq<DuckDBPyRelation>(rel->Intersect(other->rel));
return DeriveRelation(rel->Intersect(other->rel));
}

namespace {
Expand Down Expand Up @@ -1177,7 +1193,7 @@ unique_ptr<DuckDBPyRelation> DuckDBPyRelation::Join(DuckDBPyRelation *other, con
}
if (py::isinstance<py::str>(condition)) {
auto condition_string = std::string(py::cast<py::str>(condition));
return make_uniq<DuckDBPyRelation>(rel->Join(other->rel, condition_string, join_type));
return DeriveRelation(rel->Join(other->rel, condition_string, join_type));
}
vector<string> using_list;
if (py::is_list_like(condition)) {
Expand All @@ -1193,7 +1209,7 @@ unique_ptr<DuckDBPyRelation> DuckDBPyRelation::Join(DuckDBPyRelation *other, con
throw InvalidInputException("Please provide at least one string in the condition to create a USING clause");
}
auto join_relation = make_shared_ptr<JoinRelation>(rel, other->rel, std::move(using_list), join_type);
return make_uniq<DuckDBPyRelation>(std::move(join_relation));
return DeriveRelation(std::move(join_relation));
}
shared_ptr<DuckDBPyExpression> condition_expr;
if (!py::try_cast(condition, condition_expr)) {
Expand All @@ -1202,11 +1218,11 @@ unique_ptr<DuckDBPyRelation> DuckDBPyRelation::Join(DuckDBPyRelation *other, con
}
vector<unique_ptr<ParsedExpression>> conditions;
conditions.push_back(condition_expr->GetExpression().Copy());
return make_uniq<DuckDBPyRelation>(rel->Join(other->rel, std::move(conditions), join_type));
return DeriveRelation(rel->Join(other->rel, std::move(conditions), join_type));
}

unique_ptr<DuckDBPyRelation> DuckDBPyRelation::Cross(DuckDBPyRelation *other) {
return make_uniq<DuckDBPyRelation>(rel->CrossProduct(other->rel));
return DeriveRelation(rel->CrossProduct(other->rel));
}

static Value NestedDictToStruct(const py::object &dictionary) {
Expand Down Expand Up @@ -1502,7 +1518,7 @@ void DuckDBPyRelation::ToCSV(const string &filename, const py::object &sep, cons
// should this return a rel with the new view?
unique_ptr<DuckDBPyRelation> DuckDBPyRelation::CreateView(const string &view_name, bool replace) {
rel->CreateView(view_name, replace);
return make_uniq<DuckDBPyRelation>(rel);
return DeriveRelation(rel);
}

static bool IsDescribeStatement(SQLStatement &statement) {
Expand Down Expand Up @@ -1530,7 +1546,7 @@ unique_ptr<DuckDBPyRelation> DuckDBPyRelation::Query(const string &view_name, co
auto select_statement = unique_ptr_cast<SQLStatement, SelectStatement>(std::move(parser.statements[0]));
auto query_relation = make_shared_ptr<QueryRelation>(rel->context->GetContext(), std::move(select_statement),
sql_query, "query_relation");
return make_uniq<DuckDBPyRelation>(std::move(query_relation));
return DeriveRelation(std::move(query_relation));
} else if (IsDescribeStatement(statement)) {
auto query = PragmaShow(view_name);
return Query(view_name, query);
Expand Down Expand Up @@ -1630,7 +1646,7 @@ unique_ptr<DuckDBPyRelation> DuckDBPyRelation::Map(py::function fun, Optional<py
vector<Value> params;
params.emplace_back(Value::POINTER(CastPointerToValue(fun.ptr())));
params.emplace_back(Value::POINTER(CastPointerToValue(schema.ptr())));
auto relation = make_uniq<DuckDBPyRelation>(rel->TableFunction("python_map_function", params));
auto relation = DeriveRelation(rel->TableFunction("python_map_function", params));
auto rel_dependency = make_uniq<ExternalDependency>();
rel_dependency->AddDependency("map", PythonDependencyItem::Create(std::move(fun)));
rel_dependency->AddDependency("schema", PythonDependencyItem::Create(std::move(schema)));
Expand Down
Loading
Loading