From 48bb76357077bee958945164db9cfc8fdc0ca860 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 2 Mar 2026 15:33:21 +0000 Subject: [PATCH 01/34] [TIR][Refactor] Introduce BindNode and Bind (PR 1+2/11: core IR + functor + analysis infrastructure) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduces `BindNode`/`Bind`, a new TIR statement node that binds a variable to a value with flat (no-body) scope semantics, as the first step of the LetStmt-to-Bind refactor. Unlike `LetStmtNode`, `BindNode` has no body field; the bound variable is visible in subsequent siblings of the enclosing SeqStmt. PR 1 — Core IR node + functor infrastructure: - `include/tvm/tir/stmt.h`: Define BindNode (var, value, no body) and Bind ref class - `src/tir/ir/stmt.cc`: Implement Bind constructor, RegisterReflection, GlobalDef - `include/tvm/tir/stmt_functor.h`: Add VisitStmt_(BindNode*) to StmtFunctor vtable, StmtVisitor, StmtMutator - `src/tir/ir/stmt_functor.cc`: Implement StmtVisitor and StmtMutator for BindNode - `src/tir/ir/py_functor.cc`: Add BindNode dispatch entries for Python functors - `src/tir/ir/tir_visitor_with_path.{h,cc}`: Add BindNode visitor (visits value only) PR 2 — Base visitors/mutators + arithmetic + analysis: - `src/arith/ir_mutator_with_analyzer.{h,cc}`: BindNode handler binds in analyzer - `src/arith/ir_visitor_with_analyzer.{h,cc}`: BindNode handler visits value + binds - `src/tir/ir/data_type_rewriter.{h,cc}`: BindNode support in DataTypeLegalizer and IndexDataTypeRewriter - `src/tir/analysis/var_use_def_analysis.{h,cc}`: BindNode registers HandleDef - `src/tir/analysis/verify_ssa.cc`: BindNode calls MarkDef - `src/tir/analysis/verify_memory.cc`: BindNode books defs_ map - `src/tir/analysis/control_flow_graph.cc`: BindNode checks UsesLoopVar LetStmtNode/LetStmt are kept intact (deprecated aliases come in PR 11). Tests: TIR base (269 passed, 2 skipped), all-platform-minimal (75 passed, 77 skipped) --- include/tvm/tir/stmt.h | 37 ++++++++++++++++++++++++ include/tvm/tir/stmt_functor.h | 4 +++ src/arith/ir_mutator_with_analyzer.cc | 14 +++++++++ src/arith/ir_mutator_with_analyzer.h | 1 + src/arith/ir_visitor_with_analyzer.cc | 5 ++++ src/arith/ir_visitor_with_analyzer.h | 1 + src/tir/analysis/control_flow_graph.cc | 8 +++++ src/tir/analysis/var_use_def_analysis.cc | 5 ++++ src/tir/analysis/var_use_def_analysis.h | 2 ++ src/tir/analysis/verify_memory.cc | 5 ++++ src/tir/analysis/verify_ssa.cc | 4 +++ src/tir/ir/data_type_rewriter.cc | 30 +++++++++++++++++++ src/tir/ir/data_type_rewriter.h | 5 ++++ src/tir/ir/py_functor.cc | 8 +++++ src/tir/ir/stmt.cc | 25 ++++++++++++++++ src/tir/ir/stmt_functor.cc | 17 +++++++++++ src/tir/ir/tir_visitor_with_path.cc | 8 +++++ src/tir/ir/tir_visitor_with_path.h | 1 + 18 files changed, 180 insertions(+) diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 0ded1e977fa2..7cd1eb20a9a6 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -67,8 +67,44 @@ class Stmt : public ObjectRef { TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Stmt, ObjectRef, StmtNode); }; +/*! + * \brief Bind a variable to a value in the enclosing scope. + * + * Unlike LetStmt, BindNode has no body field. The bound variable is visible + * in all subsequent statements within the same enclosing scope (SeqStmt, + * ForNode.body, etc.). This enables flat (non-nested) IR sequences. + */ +class BindNode : public StmtNode { + public: + /*! \brief The variable being bound. */ + Var var; + /*! \brief The value to bind to the variable. */ + PrimExpr value; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("var", &BindNode::var, refl::AttachFieldFlag::SEqHashDef()) + .def_ro("value", &BindNode::value); + } + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Bind", BindNode, StmtNode); +}; + +/*! + * \brief Managed reference to BindNode. + * \sa BindNode + */ +class Bind : public Stmt { + public: + TVM_DLL Bind(Var var, PrimExpr value, Span span = Span()); + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Bind, Stmt, BindNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(BindNode); +}; + /*! * \brief Let binding, bind var to value, then run body. + * \deprecated Use Bind instead, which has flat scope semantics. */ class LetStmtNode : public StmtNode { public: @@ -92,6 +128,7 @@ class LetStmtNode : public StmtNode { /*! * \brief Managed reference to LetStmtNode. * \sa LetStmtNode + * \deprecated Use Bind instead. */ class LetStmt : public Stmt { public: diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index e86c6bb125dd..b7a23bd3e963 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -84,6 +84,7 @@ class StmtFunctor { return vtable(n, this, std::forward(args)...); } // Functions that can be overriden by subclass + virtual R VisitStmt_(const BindNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const LetStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const AttrStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const IfThenElseNode* op, Args... args) STMT_FUNCTOR_DEFAULT; @@ -106,6 +107,7 @@ class StmtFunctor { // initialize the vtable. static FType InitVTable() { FType vtable; + IR_STMT_FUNCTOR_DISPATCH(BindNode); IR_STMT_FUNCTOR_DISPATCH(LetStmtNode); IR_STMT_FUNCTOR_DISPATCH(AttrStmtNode); IR_STMT_FUNCTOR_DISPATCH(IfThenElseNode); @@ -159,6 +161,7 @@ class TVM_DLL StmtVisitor : protected StmtFunctor { */ virtual void VisitBufferUse(const Buffer& buffer); // statement visitor + void VisitStmt_(const BindNode* op) override; void VisitStmt_(const AttrStmtNode* op) override; void VisitStmt_(const IfThenElseNode* op) override; void VisitStmt_(const LetStmtNode* op) override; @@ -273,6 +276,7 @@ class TVM_DLL StmtMutator : protected StmtFunctor { */ virtual Buffer VisitBufferUse(const Buffer& buffer); // statement visitor + Stmt VisitStmt_(const BindNode* op) override; Stmt VisitStmt_(const AttrStmtNode* op) override; Stmt VisitStmt_(const IfThenElseNode* op) override; Stmt VisitStmt_(const LetStmtNode* op) override; diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc index f6c8db016132..7639288e3de0 100644 --- a/src/arith/ir_mutator_with_analyzer.cc +++ b/src/arith/ir_mutator_with_analyzer.cc @@ -80,6 +80,20 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const SBlockNode* op) { }); } +Stmt IRMutatorWithAnalyzer::VisitStmt_(const BindNode* op) { + PrimExpr value = this->VisitExpr(op->value); + if (SideEffect(value) <= CallEffectKind::kPure) { + analyzer_->Bind(op->var, value); + } + if (value.same_as(op->value)) { + return ffi::GetRef(op); + } else { + auto n = this->CopyOnWrite(op); + n->value = std::move(value); + return Stmt(n); + } +} + Stmt IRMutatorWithAnalyzer::VisitStmt_(const LetStmtNode* op) { PrimExpr value = this->VisitExpr(op->value); if (SideEffect(value) <= CallEffectKind::kPure) { diff --git a/src/arith/ir_mutator_with_analyzer.h b/src/arith/ir_mutator_with_analyzer.h index 8810a8f78f62..e0ac48188b2c 100644 --- a/src/arith/ir_mutator_with_analyzer.h +++ b/src/arith/ir_mutator_with_analyzer.h @@ -54,6 +54,7 @@ class IRMutatorWithAnalyzer : public tir::StmtExprMutator { // override functions that need to populate the context information. tir::Stmt VisitStmt_(const tir::ForNode* op) override; tir::Stmt VisitStmt_(const tir::SBlockNode* op) override; + tir::Stmt VisitStmt_(const tir::BindNode* op) override; tir::Stmt VisitStmt_(const tir::LetStmtNode* op) override; tir::Stmt VisitStmt_(const tir::IfThenElseNode* op) override; tir::Stmt VisitStmt_(const tir::AttrStmtNode* op) override; diff --git a/src/arith/ir_visitor_with_analyzer.cc b/src/arith/ir_visitor_with_analyzer.cc index 736e148d7a31..a701c62cee21 100644 --- a/src/arith/ir_visitor_with_analyzer.cc +++ b/src/arith/ir_visitor_with_analyzer.cc @@ -48,6 +48,11 @@ void IRVisitorWithAnalyzer::VisitStmt_(const SBlockNode* op) { }); } +void IRVisitorWithAnalyzer::VisitStmt_(const BindNode* op) { + this->VisitExpr(op->value); + analyzer_.Bind(op->var, op->value); +} + void IRVisitorWithAnalyzer::VisitStmt_(const LetStmtNode* op) { this->VisitExpr(op->value); analyzer_.Bind(op->var, op->value); diff --git a/src/arith/ir_visitor_with_analyzer.h b/src/arith/ir_visitor_with_analyzer.h index f0553a1c428c..8de143589971 100644 --- a/src/arith/ir_visitor_with_analyzer.h +++ b/src/arith/ir_visitor_with_analyzer.h @@ -43,6 +43,7 @@ class IRVisitorWithAnalyzer : public tir::StmtExprVisitor { void VisitStmt_(const tir::ForNode* op); void VisitStmt_(const tir::SBlockNode* op); + void VisitStmt_(const tir::BindNode* op); void VisitStmt_(const tir::LetStmtNode* op); void VisitStmt_(const tir::IfThenElseNode* op); void VisitStmt_(const tir::AttrStmtNode* op); diff --git a/src/tir/analysis/control_flow_graph.cc b/src/tir/analysis/control_flow_graph.cc index 8d7a13b0b5c7..5870346e0f5c 100644 --- a/src/tir/analysis/control_flow_graph.cc +++ b/src/tir/analysis/control_flow_graph.cc @@ -332,6 +332,14 @@ class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer { Parent::VisitExpr_(op); } + void VisitStmt_(const BindNode* op) override { + std::optional binding; + if (UsesLoopVar(op->value)) { + binding.emplace(this, op->var, op->value); + } + Parent::VisitStmt_(op); + } + void VisitStmt_(const LetStmtNode* op) override { std::optional binding; if (UsesLoopVar(op->value)) { diff --git a/src/tir/analysis/var_use_def_analysis.cc b/src/tir/analysis/var_use_def_analysis.cc index b2236e28cee5..d7b462140533 100644 --- a/src/tir/analysis/var_use_def_analysis.cc +++ b/src/tir/analysis/var_use_def_analysis.cc @@ -54,6 +54,11 @@ void VarUseDefAnalyzer::VisitStmt_(const AttrStmtNode* op) { } } +void VarUseDefAnalyzer::VisitStmt_(const BindNode* op) { + this->HandleDef(op->var); + StmtExprVisitor::VisitStmt_(op); +} + void VarUseDefAnalyzer::VisitStmt_(const LetStmtNode* op) { this->HandleDef(op->var); StmtExprVisitor::VisitStmt_(op); diff --git a/src/tir/analysis/var_use_def_analysis.h b/src/tir/analysis/var_use_def_analysis.h index 7196fb2e8fde..52a44da5731b 100644 --- a/src/tir/analysis/var_use_def_analysis.h +++ b/src/tir/analysis/var_use_def_analysis.h @@ -57,6 +57,8 @@ class VarUseDefAnalyzer : public StmtExprVisitor { std::unordered_map let_binding_; void VisitStmt_(const AttrStmtNode* op) final; + void VisitStmt_(const BindNode* op) final; + void VisitStmt_(const LetStmtNode* op) final; void VisitStmt_(const ForNode* op) final; diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index 19bc55bf64ec..ca5b63931251 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -72,6 +72,11 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { void VisitStmt(const Stmt& n) final { StmtExprVisitor::VisitStmt(n); } + void VisitStmt_(const BindNode* op) final { + // Book keep definitions + defs_[op->var.get()] = op->value; + return StmtExprVisitor::VisitStmt_(op); + } void VisitStmt_(const LetStmtNode* op) final { // Book keep definitions defs_[op->var.get()] = op->value; diff --git a/src/tir/analysis/verify_ssa.cc b/src/tir/analysis/verify_ssa.cc index e78e1cb58b69..4f2445c2c9bb 100644 --- a/src/tir/analysis/verify_ssa.cc +++ b/src/tir/analysis/verify_ssa.cc @@ -67,6 +67,10 @@ class SSAVerifier final : public StmtExprVisitor { StmtExprVisitor::VisitExpr_(op); } + void VisitStmt_(const BindNode* op) final { + MarkDef(op->var, op->value); + StmtExprVisitor::VisitStmt_(op); + } void VisitStmt_(const LetStmtNode* op) final { MarkDef(op->var, op->value); StmtExprVisitor::VisitStmt_(op); diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc index f7c508dac9f1..0cd050c96306 100644 --- a/src/tir/ir/data_type_rewriter.cc +++ b/src/tir/ir/data_type_rewriter.cc @@ -140,6 +140,22 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const LetNode* op) { } } +Stmt DataTypeLegalizer::VisitStmt_(const BindNode* op) { + PrimExpr value = this->VisitExpr(op->value); + Var var = op->var; + + if (value.dtype() != op->var->dtype) { + var = op->var.copy_with_dtype(value.dtype()); + var_remap_[op->var.get()] = var; + } + + if (value.same_as(op->value) && var.same_as(op->var)) { + return ffi::GetRef(op); + } else { + return Bind(var, value, op->span); + } +} + Stmt DataTypeLegalizer::VisitStmt_(const LetStmtNode* op) { PrimExpr value = this->VisitExpr(op->value); Var var = op->var; @@ -528,6 +544,20 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const ForNode* op) { } } +Stmt IndexDataTypeRewriter::VisitStmt_(const BindNode* op) { + Bind bind_stmt = Downcast(DataTypeLegalizer::VisitStmt_(op)); + if (var_remap_.find(bind_stmt->var.get()) == var_remap_.end()) { + return bind_stmt; + } + bool is_enabled = is_enabled_; + is_enabled_ = true; + PrimExpr value = VisitExpr(op->value); + Var var = var_remap_[bind_stmt->var.get()]; + is_enabled_ = is_enabled; + TVM_FFI_ICHECK(value.dtype() == var.dtype()); + return Bind(var, value, bind_stmt->span); +} + Stmt IndexDataTypeRewriter::VisitStmt_(const LetStmtNode* op) { LetStmt let_stmt = Downcast(DataTypeLegalizer::VisitStmt_(op)); if (var_remap_.find(let_stmt->var.get()) == var_remap_.end()) { diff --git a/src/tir/ir/data_type_rewriter.h b/src/tir/ir/data_type_rewriter.h index e886777096bd..7884d477a5d2 100644 --- a/src/tir/ir/data_type_rewriter.h +++ b/src/tir/ir/data_type_rewriter.h @@ -53,6 +53,7 @@ class DataTypeLegalizer : public StmtExprMutator { Stmt VisitStmt_(const AttrStmtNode* op) override; Stmt VisitStmt_(const SBlockRealizeNode* op) override; Stmt VisitStmt_(const SBlockNode* op) override; + Stmt VisitStmt_(const BindNode* op) override; Stmt VisitStmt_(const LetStmtNode* op) override; PrimExpr VisitExpr_(const VarNode* op) override; PrimExpr VisitExpr_(const SelectNode* op) override; @@ -110,6 +111,10 @@ class IndexDataTypeRewriter : public DataTypeLegalizer { PrimExpr VisitExpr_(const BufferLoadNode* op) override; ffi::Array VisitIndices(ffi::Array indices); Stmt VisitStmt_(const IfThenElseNode* op) override; + Stmt VisitStmt_(const DeclBufferNode* op) override; + Stmt VisitStmt_(const AllocBufferNode* op) override; + Stmt VisitStmt_(const AllocateNode* op) override; + Stmt VisitStmt_(const BindNode* op) override; Stmt VisitStmt_(const LetStmtNode* op) override; PrimExpr VisitExpr_(const EQNode* op) override; PrimExpr VisitExpr_(const NENode* op) override; diff --git a/src/tir/ir/py_functor.cc b/src/tir/ir/py_functor.cc index 867a740bcaa1..6851fd5d2698 100644 --- a/src/tir/ir/py_functor.cc +++ b/src/tir/ir/py_functor.cc @@ -170,6 +170,8 @@ class PyStmtExprVisitorNode : public Object, public StmtExprVisitor { // Statement functions /*! \brief The packed function to the `VisitStmt(const Stmt& stmt)` function. */ ffi::Function f_visit_stmt{nullptr}; + /*! \brief The packed function to the `VisitStmt_(const BindNode* op)` function. */ + ffi::Function f_visit_bind{nullptr}; /*! \brief The packed function to the `VisitStmt_(const LetStmtNode* op)` function. */ ffi::Function f_visit_attr_stmt{nullptr}; /*! \brief The packed function to the `VisitStmt_(const IfThenElseNode* op)` function. */ @@ -220,6 +222,7 @@ class PyStmtExprVisitorNode : public Object, public StmtExprVisitor { private: // Statement functions + PY_STMT_VISITOR_DISPATCH(BindNode, f_visit_bind); PY_STMT_VISITOR_DISPATCH(LetStmtNode, f_visit_let_stmt); PY_STMT_VISITOR_DISPATCH(AttrStmtNode, f_visit_attr_stmt); PY_STMT_VISITOR_DISPATCH(IfThenElseNode, f_visit_if_then_else); @@ -311,6 +314,7 @@ class PyStmtExprVisitorNode : public Object, public StmtExprVisitor { static FStmtType InitStmtVTable() { FStmtType vtable; + PY_STMT_VISITOR_DEFAULT_DISPATCH(BindNode); PY_STMT_VISITOR_DEFAULT_DISPATCH(LetStmtNode); PY_STMT_VISITOR_DEFAULT_DISPATCH(AttrStmtNode); PY_STMT_VISITOR_DEFAULT_DISPATCH(IfThenElseNode); @@ -525,6 +529,8 @@ class PyStmtExprMutatorNode : public Object, public StmtExprMutator { // Statement functions /*! \brief The packed function to the `VisitStmt(const Stmt& stmt)` function. */ ffi::Function f_visit_stmt{nullptr}; + /*! \brief The packed function to the `VisitStmt_(const BindNode* op)` function. */ + ffi::Function f_visit_bind{nullptr}; /*! \brief The packed function to the `VisitStmt_(const LetStmtNode* op)` function. */ ffi::Function f_visit_let_stmt{nullptr}; /*! \brief The packed function to the `VisitStmt_(const AttrStmtNode* op)` function. */ @@ -575,6 +581,7 @@ class PyStmtExprMutatorNode : public Object, public StmtExprMutator { private: // Statement functions + PY_STMT_MUTATOR_DISPATCH(BindNode, f_visit_bind); PY_STMT_MUTATOR_DISPATCH(LetStmtNode, f_visit_let_stmt); PY_STMT_MUTATOR_DISPATCH(AttrStmtNode, f_visit_attr_stmt); PY_STMT_MUTATOR_DISPATCH(IfThenElseNode, f_visit_if_then_else); @@ -666,6 +673,7 @@ class PyStmtExprMutatorNode : public Object, public StmtExprMutator { static FStmtType InitStmtVTable() { FStmtType vtable; + PY_STMT_MUTATOR_DEFAULT_DISPATCH(BindNode); PY_STMT_MUTATOR_DEFAULT_DISPATCH(LetStmtNode); PY_STMT_MUTATOR_DEFAULT_DISPATCH(AttrStmtNode); PY_STMT_MUTATOR_DEFAULT_DISPATCH(IfThenElseNode); diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 6e0fec885fe6..6ee9fbbf6541 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -34,6 +34,7 @@ namespace tir { TVM_FFI_STATIC_INIT_BLOCK() { StmtNode::RegisterReflection(); + BindNode::RegisterReflection(); LetStmtNode::RegisterReflection(); AttrStmtNode::RegisterReflection(); AssertStmtNode::RegisterReflection(); @@ -51,6 +52,30 @@ TVM_FFI_STATIC_INIT_BLOCK() { SBlockRealizeNode::RegisterReflection(); } +// Bind +Bind::Bind(Var var, PrimExpr value, Span span) { + TVM_FFI_ICHECK(value.defined()); + auto vdtype = value.dtype(); + // It is still valid to bind a pointer type var to a value that is of type handle. + if (var->type_annotation.as()) { + TVM_FFI_ICHECK(vdtype.is_handle()); + } else { + TVM_FFI_ICHECK_EQ(value.dtype(), var.dtype()); + } + + ObjectPtr node = ffi::make_object(); + node->var = std::move(var); + node->value = std::move(value); + node->span = std::move(span); + data_ = std::move(node); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.Bind", + [](Var var, PrimExpr value, Span span) { return Bind(var, value, span); }); +} + // LetStmt LetStmt::LetStmt(Var var, PrimExpr value, Stmt body, Span span) { TVM_FFI_ICHECK(value.defined()); diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index ff79c374db44..a933c3c985e9 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -33,6 +33,11 @@ namespace tvm { namespace tir { +void StmtVisitor::VisitStmt_(const BindNode* op) { + // Bind has no body -- only visit the value expression. + this->VisitExpr(op->value); +} + void StmtVisitor::VisitStmt_(const LetStmtNode* op) { this->VisitExpr(op->value); this->VisitStmt(op->body); @@ -249,6 +254,18 @@ class StmtMutator::Internal { } }; +Stmt StmtMutator::VisitStmt_(const BindNode* op) { + // Bind has no body -- only mutate the value expression. + PrimExpr value = this->VisitExpr(op->value); + if (value.same_as(op->value)) { + return ffi::GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->value = std::move(value); + return Stmt(n); + } +} + Stmt StmtMutator::VisitStmt_(const AttrStmtNode* op) { PrimExpr value = this->VisitExpr(op->value); Stmt body = this->VisitStmt(op->body); diff --git a/src/tir/ir/tir_visitor_with_path.cc b/src/tir/ir/tir_visitor_with_path.cc index 5436e73d57c0..dd9dfad3db40 100644 --- a/src/tir/ir/tir_visitor_with_path.cc +++ b/src/tir/ir/tir_visitor_with_path.cc @@ -172,6 +172,14 @@ void TIRVisitorWithPath::Visit(const Range& range, AccessPath path) { Visit(range->extent, path->Attr("extent")); } +void TIRVisitorWithPath::VisitStmt_(const BindNode* op, AccessPath path) { + // Bind has no body -- var scope is defined by the enclosing scope. + Visit(op->value, path->Attr("value")); + // Note: we do NOT call WithDef here because Bind's var scope extends + // to subsequent siblings in the enclosing SeqStmt, not just a subtree. + // Scope tracking for BindNode is handled at the SeqStmt level by callers. +} + void TIRVisitorWithPath::VisitStmt_(const LetStmtNode* op, AccessPath path) { Visit(op->value, path->Attr("value")); auto context = WithDef(op->var, path->Attr("var")); diff --git a/src/tir/ir/tir_visitor_with_path.h b/src/tir/ir/tir_visitor_with_path.h index f5189ae61cee..7f476c32543f 100644 --- a/src/tir/ir/tir_visitor_with_path.h +++ b/src/tir/ir/tir_visitor_with_path.h @@ -106,6 +106,7 @@ class TIRVisitorWithPath } using StmtFunctor::VisitStmt; + void VisitStmt_(const BindNode* op, ffi::reflection::AccessPath path) override; void VisitStmt_(const AttrStmtNode* op, ffi::reflection::AccessPath path) override; void VisitStmt_(const IfThenElseNode* op, ffi::reflection::AccessPath path) override; void VisitStmt_(const LetStmtNode* op, ffi::reflection::AccessPath path) override; From b5a545fdf6323a21df29428ce5c8fa6f8832cb76 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 2 Mar 2026 17:01:38 +0000 Subject: [PATCH 02/34] [TIR] Phase out LetStmtNode body field: migrate to flat BindNode This commit completes the migration from tree-nested LetStmtNode (with body field) to flat BindNode (no body) across the entire TVM codebase. BindNode binds a variable visible to subsequent siblings in the enclosing SeqStmt scope, replacing the old nested scoping. Key changes: - LetStmtNode is now a `using` alias for BindNode - All C++ LetStmt(var,val,body) constructions -> SeqStmt({Bind(var,val), body}) - All VisitStmt_(LetStmtNode*) handlers -> VisitStmt_(BindNode*) with op->body access removed (parent SeqStmt handles traversal) - TIRVisitorWithPath::SeqStmt handler tracks Bind-defined vars for well-formed verification - CSE pass: new SeqStmtNode handler + VisitSeqStmtSlice to process flat Bind sequences (mirrors old nested LetStmt CSE behavior) - Python: added Bind class, LetStmt = Bind alias - Updated ~65 C++ files and test files --- .../tvm/script/printer/ir_docsifier_functor.h | 6 + include/tvm/tir/stmt.h | 40 +- include/tvm/tir/stmt_functor.h | 4 - python/tvm/tir/__init__.py | 2 +- python/tvm/tir/functor.py | 1 + python/tvm/tir/stmt.py | 27 +- src/arith/ir_mutator_with_analyzer.cc | 18 - src/arith/ir_mutator_with_analyzer.h | 1 - src/arith/ir_visitor_with_analyzer.cc | 6 - src/arith/ir_visitor_with_analyzer.h | 1 - src/relax/op/tensor/inspect.cc | 58 ++- src/s_tir/analysis/estimate_flops.cc | 3 +- .../analysis/sblock_access_region_detector.cc | 4 +- .../backend/adreno/inject_texture_alloc.cc | 6 +- src/s_tir/schedule/analysis/reducer.cc | 43 ++- .../primitive/layout_transformation.cc | 2 +- src/s_tir/schedule/primitive/reduction.cc | 11 +- src/s_tir/transform/compact_buffer_region.cc | 6 +- src/s_tir/transform/hoist_expression.cc | 20 +- src/s_tir/transform/inject_virtual_thread.cc | 13 +- src/s_tir/transform/lower_thread_allreduce.cc | 8 +- src/s_tir/transform/lower_vtcm_alloc.cc | 6 +- .../transform/profile_instrumentation.cc | 14 +- src/s_tir/transform/remove_store_undef.cc | 4 +- src/s_tir/transform/renew_defs.cc | 2 +- src/s_tir/transform/storage_access.cc | 4 +- src/s_tir/transform/storage_access.h | 2 +- src/script/ir_builder/tir/frame.cc | 2 +- src/script/printer/relax/distributed.cc | 1 + src/script/printer/tir/stmt.cc | 28 +- src/target/llvm/codegen_llvm.cc | 3 +- src/target/llvm/codegen_llvm.h | 2 +- src/target/source/codegen_c.cc | 4 +- src/target/source/codegen_c.h | 2 +- src/target/source/codegen_webgpu.cc | 3 +- src/target/source/codegen_webgpu.h | 2 +- src/target/spirv/codegen_spirv.cc | 3 +- src/target/spirv/codegen_spirv.h | 2 +- src/te/operation/create_primfunc.cc | 9 +- src/tir/analysis/control_flow_graph.cc | 8 - src/tir/analysis/var_use_def_analysis.cc | 5 - src/tir/analysis/var_use_def_analysis.h | 1 - src/tir/analysis/verify_memory.cc | 5 - src/tir/analysis/verify_ssa.cc | 4 - src/tir/ir/data_type_rewriter.cc | 32 -- src/tir/ir/data_type_rewriter.h | 2 - src/tir/ir/py_functor.cc | 16 +- src/tir/ir/specialize.cc | 2 +- src/tir/ir/stmt.cc | 28 +- src/tir/ir/stmt_functor.cc | 18 - src/tir/ir/tir_visitor_with_path.cc | 18 +- src/tir/ir/tir_visitor_with_path.h | 1 - src/tir/transform/common_subexpr_elim.cc | 155 ++++++-- src/tir/transform/common_subexpr_elim.h | 6 +- src/tir/transform/ir_utils.cc | 13 +- src/tir/transform/ir_utils.h | 4 +- src/tir/transform/lower_tvm_builtin.cc | 90 +++-- src/tir/transform/remove_no_op.cc | 21 +- src/tir/transform/simplify.cc | 14 +- src/tir/transform/split_host_device.cc | 4 +- src/tir/transform/storage_rewrite.cc | 11 +- src/tir/transform/tvm_ffi_binder.cc | 18 +- src/tir/transform/tvm_ffi_binder.h | 10 +- .../transform/unsupported_dtype_legalize.cc | 14 +- src/tir/transform/vectorize_loop.cc | 12 +- .../test_tir_analysis_verify_ssa.py | 4 +- tests/python/tir-base/test_tir_constructor.py | 5 +- tests/python/tir-base/test_tir_nodes.py | 4 +- .../test_tir_structural_equal_hash.py | 2 +- .../test_tir_transform_common_subexpr_elim.py | 352 +++++++++--------- .../test_tir_transform_convert_ssa.py | 32 +- .../test_tir_transform_lower_tvm_builtin.py | 4 +- .../test_tir_transform_prim_func_pass.py | 2 +- .../test_tvmscript_ir_builder_tir.py | 4 +- 74 files changed, 636 insertions(+), 658 deletions(-) diff --git a/include/tvm/script/printer/ir_docsifier_functor.h b/include/tvm/script/printer/ir_docsifier_functor.h index 2cc1782d9240..211e65510e81 100644 --- a/include/tvm/script/printer/ir_docsifier_functor.h +++ b/include/tvm/script/printer/ir_docsifier_functor.h @@ -74,6 +74,9 @@ class IRDocsifierFunctor { return (*pf)(obj, args...).template cast(); } + LOG(WARNING) << "ObjectFunctor calls un-registered function on type: " + << runtime::Object::TypeIndex2Key(type_index) << " (token: " << token << ")" + << ". ObjectType: " << obj->GetTypeKey() << ". Object: " << obj; LOG(WARNING) << "ObjectFunctor calls un-registered function on type: " << runtime::Object::TypeIndex2Key(type_index) << " (token: " << token << ")" << ". ObjectType: " << obj->GetTypeKey() << ". Object: " << obj; @@ -81,6 +84,9 @@ class IRDocsifierFunctor { << runtime::Object::TypeIndex2Key(type_index) << " (token: " << token << ")" << ". ObjectType: " << obj->GetTypeKey() << ". Object: " << obj; +#if defined(__GNUC__) || defined(__clang__) + __builtin_unreachable(); +#endif } /*! diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 7cd1eb20a9a6..a089d8b56287 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -102,41 +102,10 @@ class Bind : public Stmt { TVM_DEFINE_OBJECT_REF_COW_METHOD(BindNode); }; -/*! - * \brief Let binding, bind var to value, then run body. - * \deprecated Use Bind instead, which has flat scope semantics. - */ -class LetStmtNode : public StmtNode { - public: - /*! \brief The variable. */ - Var var; - /*! \brief The value to be bound. */ - PrimExpr value; - /*! \brief The body block. */ - Stmt body; - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("var", &LetStmtNode::var, refl::AttachFieldFlag::SEqHashDef()) - .def_ro("value", &LetStmtNode::value) - .def_ro("body", &LetStmtNode::body); - } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.LetStmt", LetStmtNode, StmtNode); -}; - -/*! - * \brief Managed reference to LetStmtNode. - * \sa LetStmtNode - * \deprecated Use Bind instead. - */ -class LetStmt : public Stmt { - public: - TVM_DLL LetStmt(Var var, PrimExpr value, Stmt body, Span span = Span()); - - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(LetStmt, Stmt, LetStmtNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(LetStmtNode); -}; +/*! \brief Deprecated: use BindNode instead. */ +using LetStmtNode = BindNode; +/*! \brief Deprecated: use Bind instead. */ +using LetStmt = Bind; /*! * \brief Define certain auxiliary attribute for the body to be a symbolic value. @@ -1015,6 +984,7 @@ inline const char* ForKind2String(ForKind t) { return "thread_binding"; } TVM_FFI_THROW(InternalError) << "Unknown ForKind" << t; + __builtin_unreachable(); } } // namespace tir diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index b7a23bd3e963..d99cdfb84e59 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -85,7 +85,6 @@ class StmtFunctor { } // Functions that can be overriden by subclass virtual R VisitStmt_(const BindNode* op, Args... args) STMT_FUNCTOR_DEFAULT; - virtual R VisitStmt_(const LetStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const AttrStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const IfThenElseNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const ForNode* op, Args... args) STMT_FUNCTOR_DEFAULT; @@ -108,7 +107,6 @@ class StmtFunctor { static FType InitVTable() { FType vtable; IR_STMT_FUNCTOR_DISPATCH(BindNode); - IR_STMT_FUNCTOR_DISPATCH(LetStmtNode); IR_STMT_FUNCTOR_DISPATCH(AttrStmtNode); IR_STMT_FUNCTOR_DISPATCH(IfThenElseNode); IR_STMT_FUNCTOR_DISPATCH(ForNode); @@ -164,7 +162,6 @@ class TVM_DLL StmtVisitor : protected StmtFunctor { void VisitStmt_(const BindNode* op) override; void VisitStmt_(const AttrStmtNode* op) override; void VisitStmt_(const IfThenElseNode* op) override; - void VisitStmt_(const LetStmtNode* op) override; void VisitStmt_(const ForNode* op) override; void VisitStmt_(const WhileNode* op) override; void VisitStmt_(const AllocBufferNode* op) override; @@ -279,7 +276,6 @@ class TVM_DLL StmtMutator : protected StmtFunctor { Stmt VisitStmt_(const BindNode* op) override; Stmt VisitStmt_(const AttrStmtNode* op) override; Stmt VisitStmt_(const IfThenElseNode* op) override; - Stmt VisitStmt_(const LetStmtNode* op) override; Stmt VisitStmt_(const ForNode* op) override; Stmt VisitStmt_(const WhileNode* op) override; Stmt VisitStmt_(const AllocBufferNode* op) override; diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index c645dccda3b8..daa23817d1b6 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -29,7 +29,7 @@ from .expr import Select, BufferLoad, ProducerLoad, Ramp, Broadcast, Shuffle from .expr import Call, CallEffectKind, Let, IterVar, CommReducer -from .stmt import Stmt, LetStmt, AssertStmt, ForKind, For, While +from .stmt import Stmt, Bind, LetStmt, AssertStmt, ForKind, For, While from .stmt import ( BufferStore, AllocBuffer, diff --git a/python/tvm/tir/functor.py b/python/tvm/tir/functor.py index 5b19fd19b2c9..82b1f29aec63 100644 --- a/python/tvm/tir/functor.py +++ b/python/tvm/tir/functor.py @@ -70,6 +70,7 @@ Evaluate, For, IfThenElse, + Bind, LetStmt, SBlock, SBlockRealize, diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index cc945ba4fcec..0985f30b259b 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -44,9 +44,14 @@ class Stmt(Object, Scriptable): """Base class of all the statements.""" -@tvm_ffi.register_object("tir.LetStmt") -class LetStmt(Stmt): - """LetStmt node. +@tvm_ffi.register_object("tir.Bind") +class Bind(Stmt): + """Bind node. + + Bind a variable to a value in the enclosing scope. + Unlike the deprecated LetStmt, Bind has no body field. + The bound variable is visible in all subsequent statements + within the same enclosing scope (SeqStmt, ForNode.body, etc.). Parameters ---------- @@ -54,10 +59,7 @@ class LetStmt(Stmt): The variable in the binding. value : PrimExpr - The value in to be bound. - - body : Stmt - The body statement. + The value to be bound. span : Optional[Span] The location of the stmt in the source code. @@ -65,19 +67,22 @@ class LetStmt(Stmt): var: Var value: PrimExpr - body: Stmt span: Span | None - def __init__(self, var: Var, value: PrimExpr, body: Stmt, span: Span | None = None) -> None: + def __init__(self, var: Var, value: PrimExpr, span: Span | None = None) -> None: self.__init_handle_by_constructor__( - _ffi_api.LetStmt, + _ffi_api.Bind, var, value, - body, span, # type: ignore ) +# Deprecated: use Bind instead. +# LetStmt(var, value, body) now returns SeqStmt(Bind(var, value), body). +LetStmt = Bind + + @tvm_ffi.register_object("tir.AssertStmt") class AssertStmt(Stmt): """AssertStmt node. diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc index 7639288e3de0..11caef56850b 100644 --- a/src/arith/ir_mutator_with_analyzer.cc +++ b/src/arith/ir_mutator_with_analyzer.cc @@ -94,24 +94,6 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const BindNode* op) { } } -Stmt IRMutatorWithAnalyzer::VisitStmt_(const LetStmtNode* op) { - PrimExpr value = this->VisitExpr(op->value); - if (SideEffect(value) <= CallEffectKind::kPure) { - analyzer_->Bind(op->var, value); - } - // We keep the let-binding here - // as sub-class may or maynot choose to replace it. - Stmt body = this->VisitStmt(op->body); - if (value.same_as(op->value) && body.same_as(op->body)) { - return ffi::GetRef(op); - } else { - auto n = this->CopyOnWrite(op); - n->value = std::move(value); - n->body = std::move(body); - return Stmt(n); - } -} - Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) { return constraint_scope_.WithNewScope([&]() -> Stmt { PrimExpr condition = this->VisitExpr(op->condition); diff --git a/src/arith/ir_mutator_with_analyzer.h b/src/arith/ir_mutator_with_analyzer.h index e0ac48188b2c..0f03fef7d25e 100644 --- a/src/arith/ir_mutator_with_analyzer.h +++ b/src/arith/ir_mutator_with_analyzer.h @@ -55,7 +55,6 @@ class IRMutatorWithAnalyzer : public tir::StmtExprMutator { tir::Stmt VisitStmt_(const tir::ForNode* op) override; tir::Stmt VisitStmt_(const tir::SBlockNode* op) override; tir::Stmt VisitStmt_(const tir::BindNode* op) override; - tir::Stmt VisitStmt_(const tir::LetStmtNode* op) override; tir::Stmt VisitStmt_(const tir::IfThenElseNode* op) override; tir::Stmt VisitStmt_(const tir::AttrStmtNode* op) override; tir::Stmt VisitStmt_(const tir::AssertStmtNode* op) override; diff --git a/src/arith/ir_visitor_with_analyzer.cc b/src/arith/ir_visitor_with_analyzer.cc index a701c62cee21..e5041b159f8d 100644 --- a/src/arith/ir_visitor_with_analyzer.cc +++ b/src/arith/ir_visitor_with_analyzer.cc @@ -53,12 +53,6 @@ void IRVisitorWithAnalyzer::VisitStmt_(const BindNode* op) { analyzer_.Bind(op->var, op->value); } -void IRVisitorWithAnalyzer::VisitStmt_(const LetStmtNode* op) { - this->VisitExpr(op->value); - analyzer_.Bind(op->var, op->value); - this->VisitStmt(op->body); -} - void IRVisitorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) { constraint_scope_.WithNewScope([&]() { this->VisitExpr(op->condition); diff --git a/src/arith/ir_visitor_with_analyzer.h b/src/arith/ir_visitor_with_analyzer.h index 8de143589971..a5455659d0fe 100644 --- a/src/arith/ir_visitor_with_analyzer.h +++ b/src/arith/ir_visitor_with_analyzer.h @@ -44,7 +44,6 @@ class IRVisitorWithAnalyzer : public tir::StmtExprVisitor { void VisitStmt_(const tir::ForNode* op); void VisitStmt_(const tir::SBlockNode* op); void VisitStmt_(const tir::BindNode* op); - void VisitStmt_(const tir::LetStmtNode* op); void VisitStmt_(const tir::IfThenElseNode* op); void VisitStmt_(const tir::AttrStmtNode* op); void VisitStmt_(const tir::AssertStmtNode* op); diff --git a/src/relax/op/tensor/inspect.cc b/src/relax/op/tensor/inspect.cc index 0130ee1063d9..bcf7e2e354f7 100644 --- a/src/relax/op/tensor/inspect.cc +++ b/src/relax/op/tensor/inspect.cc @@ -92,11 +92,12 @@ tir::PrimFunc GetDLTensorField(tir::builtin::TVMStructFieldKind field, DataType tir::Var value("value", field_dtype); - tir::LetStmt body( - value, - tir::Call(field_dtype, tir::builtin::tvm_struct_get(), - {dlpack_handle, IntImm(DataType::Int(32), 0), IntImm(DataType::Int(32), field)}), - tir::Evaluate(tvm::ret(value))); + tir::Stmt body = tir::SeqStmt({ + tir::Bind(value, + tir::Call(field_dtype, tir::builtin::tvm_struct_get(), + {dlpack_handle, IntImm(DataType::Int(32), 0), + IntImm(DataType::Int(32), field)})), + tir::Evaluate(tvm::ret(value))}); DictAttrs attrs({{"tir.is_scheduled", true}, {"tir.is_host", true}}); @@ -305,33 +306,26 @@ Expr LegalizeTensorShape(const BlockBuilder& bb, const Call& call) { tir::Var extent("extent", field_dtype); - tir::Stmt body = tir::Evaluate(tvm::ret(extent)); - - body = tir::LetStmt(extent, tir::BufferLoad(shape_buffer, {axis}), body); - body = tir::DeclBuffer(shape_buffer, body); - body = tir::LetStmt( - shape_buffer->data, - tir::Call(DataType::Handle(), tir::builtin::tvm_struct_get(), - {dlpack_handle, IntImm(DataType::Int(32), 0), - IntImm(DataType::Int(32), tir::builtin::TVMStructFieldKind::kDLTensorShape)}), - body); - - body = tir::SeqStmt( - {tir::AssertStmt( - axis < tvm::cast(axis->dtype, ndim), tir::StringImm("RuntimeError"), - {tir::StringImm("Specified axis may not be larger than the tensor's dimensionality")}), - body}); - - body = tir::LetStmt( - ndim, - tir::Call(ndim->dtype, tir::builtin::tvm_struct_get(), - {dlpack_handle, IntImm(DataType::Int(32), 0), - IntImm(DataType::Int(32), tir::builtin::TVMStructFieldKind::kDLTensorNDim)}), - body); - - body = tir::SeqStmt({tir::AssertStmt(0 <= axis, tir::StringImm("RuntimeError"), - {tir::StringImm("Specified axis may not be negative")}), - body}); + tir::Stmt body = tir::SeqStmt({ + tir::AssertStmt(0 <= axis, tir::StringImm("RuntimeError"), + {tir::StringImm("Specified axis may not be negative")}), + tir::Bind(ndim, + tir::Call(ndim->dtype, tir::builtin::tvm_struct_get(), + {dlpack_handle, IntImm(DataType::Int(32), 0), + IntImm(DataType::Int(32), + tir::builtin::TVMStructFieldKind::kDLTensorNDim)})), + tir::AssertStmt( + axis < tvm::cast(axis->dtype, ndim), tir::StringImm("RuntimeError"), + {tir::StringImm( + "Specified axis may not be larger than the tensor's dimensionality")}), + tir::Bind(shape_buffer->data, + tir::Call(DataType::Handle(), tir::builtin::tvm_struct_get(), + {dlpack_handle, IntImm(DataType::Int(32), 0), + IntImm(DataType::Int(32), + tir::builtin::TVMStructFieldKind::kDLTensorShape)})), + tir::DeclBuffer(shape_buffer, + tir::SeqStmt({tir::Bind(extent, tir::BufferLoad(shape_buffer, {axis})), + tir::Evaluate(tvm::ret(extent))}))}); DictAttrs attrs({{"tir.is_scheduled", true}, {"tir.is_host", true}}); diff --git a/src/s_tir/analysis/estimate_flops.cc b/src/s_tir/analysis/estimate_flops.cc index 08bebea0d3ba..6cd405ebbb27 100644 --- a/src/s_tir/analysis/estimate_flops.cc +++ b/src/s_tir/analysis/estimate_flops.cc @@ -182,9 +182,8 @@ class FlopEstimator : private ExprFunctor, return result; } - TResult VisitStmt_(const LetStmtNode* let) override { + TResult VisitStmt_(const BindNode* let) override { TResult value = VisitExpr(let->value); - value += VisitStmt(let->body); return value; } diff --git a/src/s_tir/analysis/sblock_access_region_detector.cc b/src/s_tir/analysis/sblock_access_region_detector.cc index 22c0ed5ad920..1e0025d551c4 100644 --- a/src/s_tir/analysis/sblock_access_region_detector.cc +++ b/src/s_tir/analysis/sblock_access_region_detector.cc @@ -117,7 +117,7 @@ class BlockReadWriteDetector : public StmtExprVisitor { void VisitStmt_(const IfThenElseNode* op) override; void VisitStmt_(const SBlockRealizeNode* op) override; void VisitStmt_(const BufferStoreNode* op) override; - void VisitStmt_(const LetStmtNode* op) override; + void VisitStmt_(const BindNode* op) override; void VisitExpr_(const BufferLoadNode* op) override; void VisitExpr_(const VarNode* op) override; void VisitExpr_(const CallNode* op) override; @@ -189,7 +189,7 @@ void BlockReadWriteDetector::VisitStmt_(const IfThenElseNode* op) { } } -void BlockReadWriteDetector::VisitStmt_(const LetStmtNode* op) { +void BlockReadWriteDetector::VisitStmt_(const BindNode* op) { let_bindings_[op->var.get()] = op->value; StmtVisitor::VisitStmt_(op); let_bindings_.erase(op->var.get()); diff --git a/src/s_tir/backend/adreno/inject_texture_alloc.cc b/src/s_tir/backend/adreno/inject_texture_alloc.cc index 400f3edc97d3..e7df647d0f85 100644 --- a/src/s_tir/backend/adreno/inject_texture_alloc.cc +++ b/src/s_tir/backend/adreno/inject_texture_alloc.cc @@ -82,9 +82,9 @@ class TextureAllocInjector : public arith::IRMutatorWithAnalyzer { args.push_back(Call(DataType::Handle(), builtin::tvm_stack_make_shape(), {texture.width, texture.height, texture.depth})); args.push_back(IntImm(DataType::Int(64), channel_size)); - stmt = LetStmt(op->buffer->data, - Call(op->buffer->data.dtype(), builtin::nd_mem_alloc_with_scope(), args), - op->body); + stmt = SeqStmt({Bind(op->buffer->data, + Call(op->buffer->data.dtype(), builtin::nd_mem_alloc_with_scope(), args)), + op->body}); } return stmt; } diff --git a/src/s_tir/schedule/analysis/reducer.cc b/src/s_tir/schedule/analysis/reducer.cc index 7559a5bfb9d7..b173103454ad 100644 --- a/src/s_tir/schedule/analysis/reducer.cc +++ b/src/s_tir/schedule/analysis/reducer.cc @@ -354,7 +354,7 @@ void ErrorRFactorCrossThreadReductionNotApplicable(const ffi::Optional& self, SBlock block, - const LetStmtNode* let, int n_buffers, + const ffi::Array& stmts, int n_buffers, ffi::Array* updates, std::unordered_map* buf2index) { std::unordered_map var2index; @@ -363,37 +363,35 @@ void ExtractReductionUpdates(const ffi::Optional& self, SBlock bl updates->resize(n_buffers); // Step 1. - // - Extract the BufferStore values from the LetStmts. - // - Construct the mapping from let variables to the index. + // - Extract the Bind values from the sequence. + // - Construct the mapping from bind variables to the index. + // The first n_buffers stmts should be Bind nodes. for (int i = 0; i < n_buffers; ++i) { - if (let == nullptr) { + if (i >= static_cast(stmts.size())) { ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/3); } - - let_values.push_back(let->value); - auto insert_result = var2index.insert(std::make_pair(let->var.get(), i)); + const auto* bind = stmts[i].as(); + if (bind == nullptr) { + ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/3); + } + let_values.push_back(bind->value); + auto insert_result = var2index.insert(std::make_pair(bind->var.get(), i)); if (!insert_result.second) { ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/4); } - if (i != n_buffers - 1) { - let = let->body.as(); - } } - // There should be no more LetStmt. - if (let->body->IsInstance()) { + // There should be no more Bind after the first n_buffers. + if (n_buffers < static_cast(stmts.size()) && stmts[n_buffers]->IsInstance()) { ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/3); } - // Now `let` is expected to be the innermost LetStmt, whose body should either be a SeqStmt or - // a BufferStore - const auto* p_seq = let->body.as(); - const auto* p_buf_store = let->body.as(); - if (p_seq == nullptr && p_buf_store == nullptr) { - ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/5); + // The remaining stmts after the Bind nodes should be BufferStores. + // Collect them into a sequence. + ffi::Array seq; + for (int i = n_buffers; i < static_cast(stmts.size()); ++i) { + seq.push_back(stmts[i]); } - ffi::Array seq = - p_seq != nullptr ? p_seq->seq : ffi::Array{ffi::GetRef(p_buf_store)}; if (static_cast(seq.size()) != n_buffers) { ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/6); } @@ -460,9 +458,10 @@ std::pair, ffi::Array> GetInitValuesAndUpdates if (const auto* update = block->body.as()) { updates.push_back(ffi::GetRef(update)); buf2index[update->buffer.get()] = 0; + } else if (const auto* seq = block->body.as()) { + ExtractReductionUpdates(self, block, seq->seq, n_buffers, &updates, &buf2index); } else { - const auto* let = block->body.as(); - ExtractReductionUpdates(self, block, let, n_buffers, &updates, &buf2index); + ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/3); } TVM_FFI_ICHECK_EQ(updates.size(), n_buffers); diff --git a/src/s_tir/schedule/primitive/layout_transformation.cc b/src/s_tir/schedule/primitive/layout_transformation.cc index d20e002603ed..33ce0f1a23aa 100644 --- a/src/s_tir/schedule/primitive/layout_transformation.cc +++ b/src/s_tir/schedule/primitive/layout_transformation.cc @@ -131,7 +131,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { StmtExprVisitor::VisitStmt_(op); } - void VisitStmt_(const LetStmtNode* op) override { + void VisitStmt_(const BindNode* op) override { BindVariableDefinition context(this, op->var, op->value); StmtExprVisitor::VisitStmt_(op); } diff --git a/src/s_tir/schedule/primitive/reduction.cc b/src/s_tir/schedule/primitive/reduction.cc index 087eabbce812..3fcd121d0723 100644 --- a/src/s_tir/schedule/primitive/reduction.cc +++ b/src/s_tir/schedule/primitive/reduction.cc @@ -744,11 +744,14 @@ class BaseBlockCreator { let_vars.push_back(var); buf_stores.push_back(BufferStore(update_buffers_[i], var, update_indices_[i])); } - Stmt body = SeqStmt(buf_stores); - for (int i = n_buffers_ - 1; i >= 0; --i) { - body = LetStmt(let_vars[i], stored_values[i], std::move(body)); + ffi::Array stmts; + for (int i = 0; i < n_buffers_; ++i) { + stmts.push_back(tir::Bind(let_vars[i], stored_values[i])); + } + for (const auto& store : buf_stores) { + stmts.push_back(store); } - return body; + return SeqStmt(stmts); } ffi::Optional CreateBlockInit(bool has_reduce_iter) { diff --git a/src/s_tir/transform/compact_buffer_region.cc b/src/s_tir/transform/compact_buffer_region.cc index 9b2edde22189..ff811e3aa842 100644 --- a/src/s_tir/transform/compact_buffer_region.cc +++ b/src/s_tir/transform/compact_buffer_region.cc @@ -169,16 +169,12 @@ class BufferAccessRegionCollector : public StmtExprVisitor { ancestor_iters_.pop_back(); } - void VisitStmt_(const LetStmtNode* op) final { + void VisitStmt_(const BindNode* op) final { StmtExprVisitor::VisitExpr(op->value); if (arith::IsIndexType(op->value->dtype)) { dom_analyzer_.Bind(op->var, op->value); dom_map_.emplace(op->var.get(), arith::IntSet::SinglePoint(op->value)); } - StmtExprVisitor::VisitStmt(op->body); - if (arith::IsIndexType(op->value->dtype)) { - dom_map_.erase(op->var.get()); - } } void VisitExpr_(const LetNode* op) final { diff --git a/src/s_tir/transform/hoist_expression.cc b/src/s_tir/transform/hoist_expression.cc index 7858ff0e14dd..076bf2b5a442 100644 --- a/src/s_tir/transform/hoist_expression.cc +++ b/src/s_tir/transform/hoist_expression.cc @@ -322,7 +322,7 @@ class HoistInfoCollector : public StmtExprVisitor { let_var_to_let_vars[var.get()] = std::move(let_bindings_used); } - void VisitStmt_(const LetStmtNode* op) final { + void VisitStmt_(const BindNode* op) final { VisitBinding(op->var, op->value, HoistedLetBindings::kLetStmt); Parent::VisitStmt_(op); @@ -482,9 +482,16 @@ class ExpressionHoister : public arith::IRMutatorWithAnalyzer { } } } - for (auto let_it = info.let_bindings.rbegin(); let_it != info.let_bindings.rend(); let_it++) { - if (hoisted_let_bindings.count(let_it->var.get())) { - stmt = LetStmt(let_it->var, let_it->value, stmt); + { + ffi::Array binds; + for (auto let_it = info.let_bindings.begin(); let_it != info.let_bindings.end(); let_it++) { + if (hoisted_let_bindings.count(let_it->var.get())) { + binds.push_back(Bind(let_it->var, let_it->value)); + } + } + if (!binds.empty()) { + binds.push_back(stmt); + stmt = SeqStmt(binds); } } @@ -511,9 +518,10 @@ class ExpressionHoister : public arith::IRMutatorWithAnalyzer { } } - Stmt VisitStmt_(const LetStmtNode* op) final { + Stmt VisitStmt_(const BindNode* op) final { if (hoisted_let_bindings.count(op->var.get())) { - return this->VisitStmt(op->body); + // The binding was hoisted; remove it from this location. + return Evaluate(0); } else { return Parent::VisitStmt_(op); } diff --git a/src/s_tir/transform/inject_virtual_thread.cc b/src/s_tir/transform/inject_virtual_thread.cc index d914fa81cd20..9db3d8b91c17 100644 --- a/src/s_tir/transform/inject_virtual_thread.cc +++ b/src/s_tir/transform/inject_virtual_thread.cc @@ -99,11 +99,10 @@ class ExprTouched final : public StmtExprVisitor { // Analyze if the buffers are invariant to value of var class VarTouchedAnalysis : public StmtVisitor { public: - void VisitStmt_(const LetStmtNode* op) final { + void VisitStmt_(const BindNode* op) final { ExprTouched tc(touched_var_, false); tc(op->value); Record(op->var.get(), tc); - this->VisitStmt(op->body); } void VisitStmt_(const BufferStoreNode* op) final { @@ -299,18 +298,17 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { } } } - // LetStmt - Stmt VisitStmt_(const LetStmtNode* op) final { + // Bind + Stmt VisitStmt_(const BindNode* op) final { PrimExpr value = this->VisitExpr(op->value); if (visit_touched_var_ && !vt_loop_injected_) { return InjectVTLoop(ffi::GetRef(op), true); } visit_touched_var_ = false; - Stmt body = this->VisitStmt(op->body); - if (value.same_as(op->value) && body.same_as(op->body)) { + if (value.same_as(op->value)) { return ffi::GetRef(op); } else { - return LetStmt(op->var, value, body); + return Bind(op->var, value); } } // For @@ -362,6 +360,7 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { Stmt VisitStmt_(const WhileNode* op) final { // TODO(masahi): What should we do for While nodes? TVM_FFI_THROW(InternalError) << "WhileNode in InjectVirtualThread not supported yet"; + __builtin_unreachable(); } // Seq diff --git a/src/s_tir/transform/lower_thread_allreduce.cc b/src/s_tir/transform/lower_thread_allreduce.cc index f5764cbef8a3..13a8c3359f3a 100644 --- a/src/s_tir/transform/lower_thread_allreduce.cc +++ b/src/s_tir/transform/lower_thread_allreduce.cc @@ -645,10 +645,14 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { fstore({in_warp_local_vars.begin(), in_warp_local_vars.end()})); in_let_statement.emplace_back(SyncThread("warp")); - Stmt body = SeqStmt::Flatten(in_let_statement); + ffi::Array bind_stmts; for (size_t i = 0; i < size; i++) { - body = LetStmt(in_warp_local_vars[i], loads[i], body); + bind_stmts.push_back(Bind(in_warp_local_vars[i], loads[i])); } + for (const auto& s : in_let_statement) { + bind_stmts.push_back(s); + } + Stmt body = SeqStmt::Flatten(bind_stmts); in_warp_seq.push_back(body); } diff --git a/src/s_tir/transform/lower_vtcm_alloc.cc b/src/s_tir/transform/lower_vtcm_alloc.cc index d086f0d45e86..e8683f669e95 100644 --- a/src/s_tir/transform/lower_vtcm_alloc.cc +++ b/src/s_tir/transform/lower_vtcm_alloc.cc @@ -45,9 +45,9 @@ class VtcmAllocator : public StmtExprMutator { args.push_back(StringImm(storage_scope)); args.push_back(IntImm(DataType::Int(64), op->buffer->shape.size())); args.push_back(Call(DataType::Handle(), builtin::tvm_stack_make_shape(), op->buffer->shape)); - return LetStmt(op->buffer->data, - Call(op->buffer->data.dtype(), builtin::nd_mem_alloc_with_scope(), args), - body); + return SeqStmt({Bind(op->buffer->data, + Call(op->buffer->data.dtype(), builtin::nd_mem_alloc_with_scope(), args)), + body}); } return StmtExprMutator::VisitStmt_(op); } diff --git a/src/s_tir/transform/profile_instrumentation.cc b/src/s_tir/transform/profile_instrumentation.cc index c0e1c483cbe2..07cb8c72bf0d 100644 --- a/src/s_tir/transform/profile_instrumentation.cc +++ b/src/s_tir/transform/profile_instrumentation.cc @@ -135,9 +135,17 @@ class LoopAnalyzer : public StmtExprVisitor { loop_info.height = height; loops[f] = loop_info; return height + 1; - } else if (stmt->IsInstance()) { - const LetStmtNode* n = stmt.as(); - return TraverseLoop(n->body, parent_depth, has_parallel); + } else if (stmt->IsInstance()) { + // Bind has no body; skip it and return 0 (not a loop). + return 0; + } else if (stmt->IsInstance()) { + // For flat sequences, traverse children looking for loops. + const SeqStmtNode* seq = stmt.as(); + unsigned max_height = 0; + for (const auto& s : seq->seq) { + max_height = std::max(max_height, TraverseLoop(s, parent_depth, has_parallel)); + } + return max_height; } else if (stmt->IsInstance()) { const AttrStmtNode* n = stmt.as(); return TraverseLoop(n->body, parent_depth, has_parallel); diff --git a/src/s_tir/transform/remove_store_undef.cc b/src/s_tir/transform/remove_store_undef.cc index 54c91fb7f016..a5d2a9f9e267 100644 --- a/src/s_tir/transform/remove_store_undef.cc +++ b/src/s_tir/transform/remove_store_undef.cc @@ -65,7 +65,7 @@ class StoreUndefLocator : public StmtExprVisitor { // ValidateAllUndefRemoved. } - void VisitStmt_(const LetStmtNode* op) final { + void VisitStmt_(const BindNode* op) final { bool stash_undef = false; std::swap(has_undef_, stash_undef); StmtExprVisitor::VisitExpr(op->value); @@ -76,8 +76,6 @@ class StoreUndefLocator : public StmtExprVisitor { << "must not have other side effects"; var_bindings_with_undef_.insert(op->var.get()); } - - StmtExprVisitor::VisitStmt(op->body); } void VisitExpr_(const VarNode* op) final { diff --git a/src/s_tir/transform/renew_defs.cc b/src/s_tir/transform/renew_defs.cc index 224fccbadbf8..f34fb3cd856d 100644 --- a/src/s_tir/transform/renew_defs.cc +++ b/src/s_tir/transform/renew_defs.cc @@ -99,7 +99,7 @@ class RenewDefMutator : public StmtExprMutator { } private: - STMT_REGENERATE_VAR_DEF(LetStmtNode, var); + STMT_REGENERATE_VAR_DEF(BindNode, var); STMT_REGENERATE_VAR_DEF(ForNode, loop_var); // Override VisitBufferDef to create fresh buffer copies at definition sites diff --git a/src/s_tir/transform/storage_access.cc b/src/s_tir/transform/storage_access.cc index dae797486da0..ad0c9a5bd84d 100644 --- a/src/s_tir/transform/storage_access.cc +++ b/src/s_tir/transform/storage_access.cc @@ -95,7 +95,7 @@ void StorageAccessVisitor::VisitStmt_(const EvaluateNode* op) { allow_append_ = false; } -void StorageAccessVisitor::VisitStmt_(const LetStmtNode* op) { +void StorageAccessVisitor::VisitStmt_(const BindNode* op) { allow_append_ = true; TVM_FFI_ICHECK_EQ(curr_stmt_.access.size(), 0U); curr_stmt_.stmt = op; @@ -105,8 +105,6 @@ void StorageAccessVisitor::VisitStmt_(const LetStmtNode* op) { // clear access entry. curr_stmt_.access.clear(); allow_append_ = false; - // traverse body block - this->VisitStmt(op->body); } void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) { diff --git a/src/s_tir/transform/storage_access.h b/src/s_tir/transform/storage_access.h index 848d8edcf546..7635996acb7a 100644 --- a/src/s_tir/transform/storage_access.h +++ b/src/s_tir/transform/storage_access.h @@ -85,7 +85,7 @@ class StorageAccessVisitor : public StmtExprVisitor { void VisitExpr_(const BufferLoadNode* op) final; void VisitStmt_(const BufferStoreNode* op) final; void VisitStmt_(const EvaluateNode* op) final; - void VisitStmt_(const LetStmtNode* op) final; + void VisitStmt_(const BindNode* op) final; void VisitStmt_(const AttrStmtNode* op) final; void VisitStmt_(const ForNode* op) final; void VisitStmt_(const IfThenElseNode* op) final; diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc index 349374952109..742802695456 100644 --- a/src/script/ir_builder/tir/frame.cc +++ b/src/script/ir_builder/tir/frame.cc @@ -143,7 +143,7 @@ void AssertFrameNode::ExitWithScope() { void LetFrameNode::ExitWithScope() { TIRFrameNode::ExitWithScope(); - AddToParent(tvm::tir::LetStmt(var, value, AsStmt(stmts))); + AddToParent(tvm::tir::SeqStmt({tvm::tir::Bind(var, value), AsStmt(stmts)})); } void LaunchThreadFrameNode::ExitWithScope() { diff --git a/src/script/printer/relax/distributed.cc b/src/script/printer/relax/distributed.cc index 51fae05bf626..3f64c1002302 100644 --- a/src/script/printer/relax/distributed.cc +++ b/src/script/printer/relax/distributed.cc @@ -121,6 +121,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } } TVM_FFI_THROW(InternalError) << "Cannot find device mesh in global infos"; + __builtin_unreachable(); } }); diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc index 647f12c8ff7b..cf41f26317b5 100644 --- a/src/script/printer/tir/stmt.cc +++ b/src/script/printer/tir/stmt.cc @@ -49,6 +49,7 @@ bool AllowConciseScoping(const IRDocsifier& d, const ObjectRef& obj) { return f->allow_concise_scoping; } TVM_FFI_THROW(NotImplementedError) << "fragment printing"; + __builtin_unreachable(); } bool IsAncestorOfAllVarUse(const tir::Stmt& node, const ObjectRef& var, const IRDocsifier& d) { @@ -96,8 +97,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::LetStmt stmt, AccessPath p, IRDocsifier d) -> Doc { - bool concise = AllowConciseScoping(d, stmt); + .set_dispatch("", [](tir::Bind stmt, AccessPath p, IRDocsifier d) -> Doc { // Step 1. Type annotation ffi::Optional type_doc = d->AsDoc(stmt->var->type_annotation, // p->Attr("var")->Attr("type_annotation")); @@ -108,25 +108,15 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } // Step 2. RHS ExprDoc rhs = d->AsDoc(stmt->value, p->Attr("value")); - // Step 3. LHS and body - With f(d, stmt); - ffi::Array* stmts = &(*f)->stmts; + // Step 3. LHS - Bind has no body, it is a flat assignment bool var_defined = d->IsVarDefined(stmt->var); if (!var_defined) { - DefineVar(stmt->var, *f, d); - } - ExprDoc lhs = d->AsDoc(stmt->var, p->Attr("var")); - AsDocBody(stmt->body, p->Attr("body"), f->get(), d); - // Step 4. Dispatch - if (var_defined) { - return ScopeDoc(std::nullopt, TIR(d, "LetStmt")->Call({rhs}, {"var"}, {lhs}), *stmts); - } else if (concise) { - stmts->insert(stmts->begin(), AssignDoc(lhs, rhs, type_doc)); - return StmtBlockDoc(*stmts); - } else if (type_doc.defined() && !stmt->var->type_annotation->IsInstance()) { - return ScopeDoc(lhs, TIR(d, "LetStmt")->Call({rhs, type_doc.value()}), *stmts); + TVM_FFI_ICHECK(!d->frames.empty()); + ExprDoc lhs = DefineVar(stmt->var, d->frames.back(), d); + return AssignDoc(lhs, rhs, type_doc); } else { - return ScopeDoc(lhs, TIR(d, "LetStmt")->Call({rhs}), *stmts); + ExprDoc lhs = d->AsDoc(stmt->var, p->Attr("var")); + return AssignDoc(lhs, rhs, std::nullopt); } }); @@ -259,7 +249,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return DoConciseScoping(lhs, rhs.value(), &(*f)->stmts, concise); }); -TVM_SCRIPT_REPR(tir::LetStmtNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::BindNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::AttrStmtNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::AssertStmtNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::WhileNode, ReprPrintTIR); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 74317f8f2a33..44b5af7e82f3 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -2047,7 +2047,7 @@ void CodeGenLLVM::VisitStmt_(const AssertStmtNode* op) { // Constraint scoping is handled by ScopeStack in analysis passes. } -void CodeGenLLVM::VisitStmt_(const LetStmtNode* op) { +void CodeGenLLVM::VisitStmt_(const BindNode* op) { EmitDebugLocation(op); const VarNode* v = op->var.get(); TVM_FFI_ICHECK(!var_map_.count(v)); @@ -2081,7 +2081,6 @@ void CodeGenLLVM::VisitStmt_(const LetStmtNode* op) { alloc_storage_info_[v].alignment); } AddDebugInformation(value, op->var); - this->VisitStmt(op->body); } void CodeGenLLVM::VisitStmt_(const SeqStmtNode* op) { diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 08a2b07ec707..3fdbbec86fa9 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -227,7 +227,7 @@ class CodeGenLLVM : public ExprFunctor, void VisitStmt_(const AllocBufferNode* op) override; void VisitStmt_(const AttrStmtNode* op) override; void VisitStmt_(const AssertStmtNode* op) override; - void VisitStmt_(const LetStmtNode* op) override; + void VisitStmt_(const BindNode* op) override; void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const EvaluateNode* op) override; void VisitStmt_(const DeclBufferNode* op) override; diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 7fb7e0e382cb..fbf28416ef46 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -369,6 +369,7 @@ std::string CodeGenC::GetStructRef(DataType t, const PrimExpr& buffer, const Pri return os.str(); } else { TVM_FFI_THROW(RuntimeError) << "Unsupported type index: " << kind; + __builtin_unreachable(); } } @@ -1028,7 +1029,7 @@ void CodeGenC::VisitExpr_(const SelectNode* op, std::ostream& os) { // NOLINT(* os << ")"; } -void CodeGenC::VisitStmt_(const LetStmtNode* op) { +void CodeGenC::VisitStmt_(const BindNode* op) { std::string value = PrintExpr(op->value); if (print_ssa_form_) { TVM_FFI_ICHECK(!var_idmap_.count(op->var.get())); @@ -1045,7 +1046,6 @@ void CodeGenC::VisitStmt_(const LetStmtNode* op) { this->stream << ' ' << AllocVarID(op->var.get()) << " = " << value << ";\n"; } } - PrintStmt(op->body); } void CodeGenC::VisitStmt_(const AllocBufferNode* op) { diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index b683bf9105e0..2142c0aa2fc5 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -187,7 +187,7 @@ class CodeGenC : public ExprFunctor, void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const StringImmNode* op, std::ostream& os) override; // NOLINT(*) // statment - void VisitStmt_(const LetStmtNode* op) override; + void VisitStmt_(const BindNode* op) override; void VisitStmt_(const BufferStoreNode* op) override; void VisitStmt_(const ForNode* op) override; void VisitStmt_(const WhileNode* op) override; diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index 86006e89ae84..423c7e5850b7 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -569,7 +569,7 @@ void CodeGenWebGPU::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // } } -void CodeGenWebGPU::VisitStmt_(const LetStmtNode* op) { +void CodeGenWebGPU::VisitStmt_(const BindNode* op) { // use ssa form. if (print_ssa_form_) { std::string value = PrintExpr(op->value); @@ -582,7 +582,6 @@ void CodeGenWebGPU::VisitStmt_(const LetStmtNode* op) { PrintType(op->var.dtype(), this->stream); this->stream << " = " << value << ";\n"; } - PrintStmt(op->body); } void CodeGenWebGPU::VisitStmt_(const BufferStoreNode* op) { diff --git a/src/target/source/codegen_webgpu.h b/src/target/source/codegen_webgpu.h index d90d719e8d38..f53d090e586a 100644 --- a/src/target/source/codegen_webgpu.h +++ b/src/target/source/codegen_webgpu.h @@ -73,7 +73,7 @@ class CodeGenWebGPU final : public CodeGenC { void VisitExpr_(const IntImmNode* op, std::ostream& os) final; // NOLINT(*) // stmt printing - void VisitStmt_(const LetStmtNode* op) final; + void VisitStmt_(const BindNode* op) final; void VisitStmt_(const BufferStoreNode* op) final; void VisitStmt_(const ForNode* op) final; void VisitStmt_(const AllocBufferNode* op) final; diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index fe424066cbce..97fc60a00741 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -888,12 +888,11 @@ void CodeGenSPIRV::VisitStmt_(const AssertStmtNode* op) { // AssertStmt is a leaf — no body to visit. } -void CodeGenSPIRV::VisitStmt_(const LetStmtNode* op) { +void CodeGenSPIRV::VisitStmt_(const BindNode* op) { TVM_FFI_ICHECK(!var_map_.count(op->var.get())); TVM_FFI_ICHECK(!op->var.dtype().is_handle()); var_map_[op->var.get()] = MakeValue(op->value); analyzer_->Bind(op->var, op->value); - this->VisitStmt(op->body); } void CodeGenSPIRV::VisitStmt_(const SeqStmtNode* op) { diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index 70f88128cf8d..8daac154ecd9 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -112,7 +112,7 @@ class CodeGenSPIRV : public ExprFunctor, void VisitStmt_(const AllocBufferNode* op) override; void VisitStmt_(const AttrStmtNode* op) override; void VisitStmt_(const AssertStmtNode* op) override; - void VisitStmt_(const LetStmtNode* op) override; + void VisitStmt_(const BindNode* op) override; void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const EvaluateNode* op) override; diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 3d2536e423eb..c4e05151e5b1 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -414,11 +414,14 @@ Stmt GenerateBodyStmt(const ffi::Array& indices, const ffi::Array 1) { - // When there are multiple buffers, we wrap the body with LetStmts. - for (int i = n_buffers - 1; i >= 0; --i) { + // When there are multiple buffers, we wrap the body with Bind stmts. + ffi::Array bind_stmts; + for (int i = 0; i < n_buffers; ++i) { PrimExpr value = f_transform_and_remap(reduce->combiner.get()->operator()(lhs, rhs)[i]); - body = LetStmt(temp_vars[i], std::move(value), std::move(body)); + bind_stmts.push_back(Bind(temp_vars[i], std::move(value))); } + bind_stmts.push_back(body); + body = SeqStmt(bind_stmts); } } else { // Case 2. Data parallel compute diff --git a/src/tir/analysis/control_flow_graph.cc b/src/tir/analysis/control_flow_graph.cc index 5870346e0f5c..ca594963e528 100644 --- a/src/tir/analysis/control_flow_graph.cc +++ b/src/tir/analysis/control_flow_graph.cc @@ -340,14 +340,6 @@ class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer { Parent::VisitStmt_(op); } - void VisitStmt_(const LetStmtNode* op) override { - std::optional binding; - if (UsesLoopVar(op->value)) { - binding.emplace(this, op->var, op->value); - } - Parent::VisitStmt_(op); - } - void VisitExpr_(const BufferLoadNode* op) override { Parent::VisitExpr_(op); BufferLoad load = ffi::GetRef(op); diff --git a/src/tir/analysis/var_use_def_analysis.cc b/src/tir/analysis/var_use_def_analysis.cc index d7b462140533..6951e25f8c99 100644 --- a/src/tir/analysis/var_use_def_analysis.cc +++ b/src/tir/analysis/var_use_def_analysis.cc @@ -59,11 +59,6 @@ void VarUseDefAnalyzer::VisitStmt_(const BindNode* op) { StmtExprVisitor::VisitStmt_(op); } -void VarUseDefAnalyzer::VisitStmt_(const LetStmtNode* op) { - this->HandleDef(op->var); - StmtExprVisitor::VisitStmt_(op); -} - void VarUseDefAnalyzer::VisitStmt_(const ForNode* op) { this->HandleDef(op->loop_var); StmtExprVisitor::VisitStmt_(op); diff --git a/src/tir/analysis/var_use_def_analysis.h b/src/tir/analysis/var_use_def_analysis.h index 52a44da5731b..2255ed5a63df 100644 --- a/src/tir/analysis/var_use_def_analysis.h +++ b/src/tir/analysis/var_use_def_analysis.h @@ -59,7 +59,6 @@ class VarUseDefAnalyzer : public StmtExprVisitor { void VisitStmt_(const BindNode* op) final; - void VisitStmt_(const LetStmtNode* op) final; void VisitStmt_(const ForNode* op) final; diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index ca5b63931251..35f682519c2a 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -77,11 +77,6 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { defs_[op->var.get()] = op->value; return StmtExprVisitor::VisitStmt_(op); } - void VisitStmt_(const LetStmtNode* op) final { - // Book keep definitions - defs_[op->var.get()] = op->value; - return StmtExprVisitor::VisitStmt_(op); - } void VisitStmt_(const AttrStmtNode* op) final { if (!InThreadEnv() && op->attr_key == attr::thread_extent) { diff --git a/src/tir/analysis/verify_ssa.cc b/src/tir/analysis/verify_ssa.cc index 4f2445c2c9bb..b8fb99d701e8 100644 --- a/src/tir/analysis/verify_ssa.cc +++ b/src/tir/analysis/verify_ssa.cc @@ -71,10 +71,6 @@ class SSAVerifier final : public StmtExprVisitor { MarkDef(op->var, op->value); StmtExprVisitor::VisitStmt_(op); } - void VisitStmt_(const LetStmtNode* op) final { - MarkDef(op->var, op->value); - StmtExprVisitor::VisitStmt_(op); - } void VisitStmt_(const ForNode* op) final { MarkDef(op->loop_var, op->loop_var); StmtExprVisitor::VisitStmt_(op); diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc index 0cd050c96306..6c781e109546 100644 --- a/src/tir/ir/data_type_rewriter.cc +++ b/src/tir/ir/data_type_rewriter.cc @@ -156,24 +156,6 @@ Stmt DataTypeLegalizer::VisitStmt_(const BindNode* op) { } } -Stmt DataTypeLegalizer::VisitStmt_(const LetStmtNode* op) { - PrimExpr value = this->VisitExpr(op->value); - Var var = op->var; - - if (value.dtype() != op->var->dtype) { - var = op->var.copy_with_dtype(value.dtype()); - var_remap_[op->var.get()] = var; - } - - Stmt new_body = this->VisitStmt(op->body); - - if (value.same_as(op->value) && new_body.same_as(op->body)) { - return ffi::GetRef(op); - } else { - return LetStmt(var, value, new_body, op->span); - } -} - PrimExpr DataTypeLegalizer::VisitExpr_(const VarNode* op) { if (auto it = var_remap_.find(op); it != var_remap_.end()) { return it->second; @@ -558,20 +540,6 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const BindNode* op) { return Bind(var, value, bind_stmt->span); } -Stmt IndexDataTypeRewriter::VisitStmt_(const LetStmtNode* op) { - LetStmt let_stmt = Downcast(DataTypeLegalizer::VisitStmt_(op)); - if (var_remap_.find(let_stmt->var.get()) == var_remap_.end()) { - return let_stmt; - } - bool is_enabled = is_enabled_; - is_enabled_ = true; - PrimExpr value = VisitExpr(op->value); - Var var = var_remap_[let_stmt->var.get()]; - is_enabled_ = is_enabled; - TVM_FFI_ICHECK(value.dtype() == var.dtype()); - // No need to re-visit body - return LetStmt(var, value, let_stmt->body, let_stmt->span); -} #define TVM_DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ PrimExpr IndexDataTypeRewriter::VisitExpr_(const OP* op) { \ diff --git a/src/tir/ir/data_type_rewriter.h b/src/tir/ir/data_type_rewriter.h index 7884d477a5d2..7363e97e1bcf 100644 --- a/src/tir/ir/data_type_rewriter.h +++ b/src/tir/ir/data_type_rewriter.h @@ -54,7 +54,6 @@ class DataTypeLegalizer : public StmtExprMutator { Stmt VisitStmt_(const SBlockRealizeNode* op) override; Stmt VisitStmt_(const SBlockNode* op) override; Stmt VisitStmt_(const BindNode* op) override; - Stmt VisitStmt_(const LetStmtNode* op) override; PrimExpr VisitExpr_(const VarNode* op) override; PrimExpr VisitExpr_(const SelectNode* op) override; PrimExpr VisitExpr_(const RampNode* op) override; @@ -115,7 +114,6 @@ class IndexDataTypeRewriter : public DataTypeLegalizer { Stmt VisitStmt_(const AllocBufferNode* op) override; Stmt VisitStmt_(const AllocateNode* op) override; Stmt VisitStmt_(const BindNode* op) override; - Stmt VisitStmt_(const LetStmtNode* op) override; PrimExpr VisitExpr_(const EQNode* op) override; PrimExpr VisitExpr_(const NENode* op) override; PrimExpr VisitExpr_(const LTNode* op) override; diff --git a/src/tir/ir/py_functor.cc b/src/tir/ir/py_functor.cc index 6851fd5d2698..c4b0a81e3533 100644 --- a/src/tir/ir/py_functor.cc +++ b/src/tir/ir/py_functor.cc @@ -172,13 +172,11 @@ class PyStmtExprVisitorNode : public Object, public StmtExprVisitor { ffi::Function f_visit_stmt{nullptr}; /*! \brief The packed function to the `VisitStmt_(const BindNode* op)` function. */ ffi::Function f_visit_bind{nullptr}; - /*! \brief The packed function to the `VisitStmt_(const LetStmtNode* op)` function. */ + /*! \brief The packed function to the `VisitStmt_(const AttrStmtNode* op)` function. */ ffi::Function f_visit_attr_stmt{nullptr}; /*! \brief The packed function to the `VisitStmt_(const IfThenElseNode* op)` function. */ ffi::Function f_visit_if_then_else{nullptr}; // NOLINT(readability/braces) /*! \brief The packed function to the `VisitStmt_(const ForNode* op)` function. */ - ffi::Function f_visit_let_stmt{nullptr}; - /*! \brief The packed function to the `VisitStmt_(const AttrStmtNode* op)` function. */ ffi::Function f_visit_for{nullptr}; /*! \brief The packed function to the `VisitStmt_(const WhileNode* op)` function. */ ffi::Function f_visit_while{nullptr}; @@ -223,7 +221,6 @@ class PyStmtExprVisitorNode : public Object, public StmtExprVisitor { private: // Statement functions PY_STMT_VISITOR_DISPATCH(BindNode, f_visit_bind); - PY_STMT_VISITOR_DISPATCH(LetStmtNode, f_visit_let_stmt); PY_STMT_VISITOR_DISPATCH(AttrStmtNode, f_visit_attr_stmt); PY_STMT_VISITOR_DISPATCH(IfThenElseNode, f_visit_if_then_else); PY_STMT_VISITOR_DISPATCH(ForNode, f_visit_for); @@ -315,7 +312,6 @@ class PyStmtExprVisitorNode : public Object, public StmtExprVisitor { static FStmtType InitStmtVTable() { FStmtType vtable; PY_STMT_VISITOR_DEFAULT_DISPATCH(BindNode); - PY_STMT_VISITOR_DEFAULT_DISPATCH(LetStmtNode); PY_STMT_VISITOR_DEFAULT_DISPATCH(AttrStmtNode); PY_STMT_VISITOR_DEFAULT_DISPATCH(IfThenElseNode); PY_STMT_VISITOR_DEFAULT_DISPATCH(ForNode); @@ -394,7 +390,8 @@ class PyStmtExprVisitor : public ObjectRef { n->f_visit_stmt = std::move(f_visit_stmt); n->f_visit_expr = std::move(f_visit_expr); // Set statement functions - n->f_visit_let_stmt = std::move(f_visit_let_stmt); + // f_visit_let_stmt is the Python-facing name; internally it maps to f_visit_bind + n->f_visit_bind = std::move(f_visit_let_stmt); n->f_visit_attr_stmt = std::move(f_visit_attr_stmt); n->f_visit_if_then_else = std::move(f_visit_if_then_else); n->f_visit_for = std::move(f_visit_for); @@ -531,8 +528,6 @@ class PyStmtExprMutatorNode : public Object, public StmtExprMutator { ffi::Function f_visit_stmt{nullptr}; /*! \brief The packed function to the `VisitStmt_(const BindNode* op)` function. */ ffi::Function f_visit_bind{nullptr}; - /*! \brief The packed function to the `VisitStmt_(const LetStmtNode* op)` function. */ - ffi::Function f_visit_let_stmt{nullptr}; /*! \brief The packed function to the `VisitStmt_(const AttrStmtNode* op)` function. */ ffi::Function f_visit_attr_stmt{nullptr}; /*! \brief The packed function to the `VisitStmt_(const IfThenElseNode* op)` function. */ @@ -582,7 +577,6 @@ class PyStmtExprMutatorNode : public Object, public StmtExprMutator { private: // Statement functions PY_STMT_MUTATOR_DISPATCH(BindNode, f_visit_bind); - PY_STMT_MUTATOR_DISPATCH(LetStmtNode, f_visit_let_stmt); PY_STMT_MUTATOR_DISPATCH(AttrStmtNode, f_visit_attr_stmt); PY_STMT_MUTATOR_DISPATCH(IfThenElseNode, f_visit_if_then_else); PY_STMT_MUTATOR_DISPATCH(ForNode, f_visit_for); @@ -674,7 +668,6 @@ class PyStmtExprMutatorNode : public Object, public StmtExprMutator { static FStmtType InitStmtVTable() { FStmtType vtable; PY_STMT_MUTATOR_DEFAULT_DISPATCH(BindNode); - PY_STMT_MUTATOR_DEFAULT_DISPATCH(LetStmtNode); PY_STMT_MUTATOR_DEFAULT_DISPATCH(AttrStmtNode); PY_STMT_MUTATOR_DEFAULT_DISPATCH(IfThenElseNode); PY_STMT_MUTATOR_DEFAULT_DISPATCH(ForNode); @@ -754,7 +747,8 @@ class PyStmtExprMutator : public ObjectRef { n->f_visit_stmt = std::move(f_visit_stmt); n->f_visit_expr = std::move(f_visit_expr); // Statement functions - n->f_visit_let_stmt = std::move(f_visit_let_stmt); + // f_visit_let_stmt is the Python-facing name; internally it maps to f_visit_bind + n->f_visit_bind = std::move(f_visit_let_stmt); n->f_visit_attr_stmt = std::move(f_visit_attr_stmt); n->f_visit_if_then_else = std::move(f_visit_if_then_else); n->f_visit_for = std::move(f_visit_for); diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc index 1ad074971107..b4856ad9ed20 100644 --- a/src/tir/ir/specialize.cc +++ b/src/tir/ir/specialize.cc @@ -170,7 +170,7 @@ class PrimFuncSpecializer : public StmtExprMutator { if (new_buffer_var.same_as(old_buffer_var)) { auto remapped_data = VisitExpr(old_buffer_var); if (!remapped_data.same_as(old_buffer_var)) { - stmt = LetStmt(old_buffer_var, remapped_data, stmt); + stmt = SeqStmt({Bind(old_buffer_var, remapped_data), stmt}); } } diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 6ee9fbbf6541..75ccf6708ec6 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -35,7 +35,7 @@ namespace tir { TVM_FFI_STATIC_INIT_BLOCK() { StmtNode::RegisterReflection(); BindNode::RegisterReflection(); - LetStmtNode::RegisterReflection(); + AttrStmtNode::RegisterReflection(); AssertStmtNode::RegisterReflection(); BufferStoreNode::RegisterReflection(); @@ -76,31 +76,13 @@ TVM_FFI_STATIC_INIT_BLOCK() { [](Var var, PrimExpr value, Span span) { return Bind(var, value, span); }); } -// LetStmt -LetStmt::LetStmt(Var var, PrimExpr value, Stmt body, Span span) { - TVM_FFI_ICHECK(value.defined()); - TVM_FFI_ICHECK(body.defined()); - auto vdtype = value.dtype(); - // It is still valid to bind a pointer type - // var to a value that is of type handle. - if (var->type_annotation.as()) { - TVM_FFI_ICHECK(vdtype.is_handle()); - } else { - TVM_FFI_ICHECK_EQ(value.dtype(), var.dtype()); - } - - ObjectPtr node = ffi::make_object(); - node->var = std::move(var); - node->value = std::move(value); - node->body = std::move(body); - node->span = std::move(span); - data_ = std::move(node); -} - +// LetStmt is now a deprecated alias for Bind. +// Keep the Python-facing factory for backward compat: tir.LetStmt(var, value, body) +// becomes SeqStmt(Bind(var, value), body). TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.LetStmt", [](Var var, PrimExpr value, Stmt body, Span span) { - return LetStmt(var, value, body, span); + return SeqStmt::Flatten(Bind(var, value, span), body); }); } diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index a933c3c985e9..6c9072cae5e7 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -38,11 +38,6 @@ void StmtVisitor::VisitStmt_(const BindNode* op) { this->VisitExpr(op->value); } -void StmtVisitor::VisitStmt_(const LetStmtNode* op) { - this->VisitExpr(op->value); - this->VisitStmt(op->body); -} - void StmtVisitor::VisitStmt_(const AttrStmtNode* op) { this->VisitExpr(op->value); this->VisitStmt(op->body); @@ -279,19 +274,6 @@ Stmt StmtMutator::VisitStmt_(const AttrStmtNode* op) { } } -Stmt StmtMutator::VisitStmt_(const LetStmtNode* op) { - PrimExpr value = this->VisitExpr(op->value); - Stmt body = this->VisitStmt(op->body); - if (value.same_as(op->value) && body.same_as(op->body)) { - return ffi::GetRef(op); - } else { - auto n = CopyOnWrite(op); - n->value = std::move(value); - n->body = std::move(body); - return Stmt(n); - } -} - Stmt StmtMutator::VisitStmt_(const ForNode* op) { PrimExpr min = this->VisitExpr(op->min); PrimExpr extent = this->VisitExpr(op->extent); diff --git a/src/tir/ir/tir_visitor_with_path.cc b/src/tir/ir/tir_visitor_with_path.cc index dd9dfad3db40..ea6fd20de811 100644 --- a/src/tir/ir/tir_visitor_with_path.cc +++ b/src/tir/ir/tir_visitor_with_path.cc @@ -180,11 +180,6 @@ void TIRVisitorWithPath::VisitStmt_(const BindNode* op, AccessPath path) { // Scope tracking for BindNode is handled at the SeqStmt level by callers. } -void TIRVisitorWithPath::VisitStmt_(const LetStmtNode* op, AccessPath path) { - Visit(op->value, path->Attr("value")); - auto context = WithDef(op->var, path->Attr("var")); - Visit(op->body, path->Attr("body")); -} void TIRVisitorWithPath::VisitStmt_(const AttrStmtNode* op, AccessPath path) { Visit(op->value, path->Attr("value")); @@ -253,7 +248,18 @@ void TIRVisitorWithPath::VisitStmt_(const AssertStmtNode* op, AccessPath path) { } void TIRVisitorWithPath::VisitStmt_(const SeqStmtNode* op, AccessPath path) { - Visit(op->seq, path->Attr("seq")); + // Visit children sequentially. When a child is a BindNode, define its + // variable for all subsequent siblings (BindNode scope extends to + // the rest of the enclosing SeqStmt). + auto seq_path = path->Attr("seq"); + std::vector> bind_defs; + for (size_t i = 0; i < op->seq.size(); i++) { + auto item_path = seq_path->ArrayItem(i); + Visit(op->seq[i], item_path); + if (auto bind = op->seq[i].as()) { + bind_defs.push_back(WithDef(bind->var, item_path->Attr("var"))); + } + } } void TIRVisitorWithPath::VisitStmt_(const EvaluateNode* op, AccessPath path) { diff --git a/src/tir/ir/tir_visitor_with_path.h b/src/tir/ir/tir_visitor_with_path.h index 7f476c32543f..e271fd515179 100644 --- a/src/tir/ir/tir_visitor_with_path.h +++ b/src/tir/ir/tir_visitor_with_path.h @@ -109,7 +109,6 @@ class TIRVisitorWithPath void VisitStmt_(const BindNode* op, ffi::reflection::AccessPath path) override; void VisitStmt_(const AttrStmtNode* op, ffi::reflection::AccessPath path) override; void VisitStmt_(const IfThenElseNode* op, ffi::reflection::AccessPath path) override; - void VisitStmt_(const LetStmtNode* op, ffi::reflection::AccessPath path) override; void VisitStmt_(const ForNode* op, ffi::reflection::AccessPath path) override; void VisitStmt_(const WhileNode* op, ffi::reflection::AccessPath path) override; void VisitStmt_(const AllocBufferNode* op, ffi::reflection::AccessPath path) override; diff --git a/src/tir/transform/common_subexpr_elim.cc b/src/tir/transform/common_subexpr_elim.cc index 9b9619fae937..00c5af0ca2b1 100644 --- a/src/tir/transform/common_subexpr_elim.cc +++ b/src/tir/transform/common_subexpr_elim.cc @@ -478,8 +478,8 @@ Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) { // right to dive. result = ReplaceSelectedExpr::ReplaceSelectedExprInStmt(result, predicate_selector, new_var, CanContainEligibleComputations); - // Build a let-in that introduces the new variable in the current `result` - result = LetStmt(new_var, computation_and_nb.first, result); + // Build a bind that introduces the new variable before the current `result` + result = SeqStmt({Bind(new_var, computation_and_nb.first), result}); // We don't add the variable to the context because the invariant is that the // context is the context in which 'result' makes sense, and we've just updated it. } else { @@ -523,45 +523,140 @@ Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) { } /*! - * \brief The method which overrides the specific treatment for a LetStmtNode + * \brief The method which overrides the specific treatment for a BindNode */ -Stmt CommonSubexpressionEliminator::VisitStmt_(const LetStmtNode* op) { - // At this point, we have already done the generic treatment of introducing (via let-in) what - // was doable at the toplevel of the given let-in. - - // Save the context at the entry of the function - Context context_at_entry = context_; - +Stmt CommonSubexpressionEliminator::VisitStmt_(const BindNode* op) { // Recurse on the `value` field for potentially rewriting it PrimExpr value_new = VisitExpr(op->value); - // Augment the context with the association (`var`, `value`) for preparing the next recursion - // on the `body` + // Augment the context with the association (`var`, `value`) + // so that subsequent sibling statements in the SeqStmt can use it. context_.push_back({op->var, MaybeValue(op->value)}); - // Recurse on the `body` (with this extended context) - // The recursive call will have potentially done new simplifications, because in this recursive - // call `var` will be a part of the context. - // (see in VisitStmt() that no introduction were performed when a computation was using an - // undefined variable, as that would lead to ill-formed code) - Stmt body_new = VisitStmt(op->body); + // Rebuild the Bind if value changed + if (value_new.same_as(op->value)) { + return ffi::GetRef(op); + } else { + return Bind(op->var, value_new, op->span); + } +} - // Restaure the context to its content at the entrance to not carry out of scope declarations - // as the variable introduced by the let-in is not in scope outside of its body - context_ = context_at_entry; +/*! + * \brief Process a slice of a SeqStmt starting from index `start`. + * + * This mirrors the old nested LetStmt CSE approach: each Bind is + * processed one at a time (VisitExpr on value, augment context), + * and then VisitStmt is called on the "body" (all remaining children). + * Non-Bind children at the front are processed individually, then + * we recurse on the rest. + */ +Stmt CommonSubexpressionEliminator::VisitSeqStmtSlice(const ffi::Array& seq, size_t start) { + if (start >= seq.size()) { + return Evaluate(0); // shouldn't happen + } - // Rebuild the let-in with a new `value_new` and `body_new` where new simplifications might - // have been done. + // If seq[start] is a Bind, process it (like the old LetStmt handler): + // 1) VisitExpr on the value + // 2) Augment context + // 3) Call VisitStmt on the "body" (remaining children as SeqStmt) + if (auto bind = seq[start].as()) { + Context context_at_entry = context_; - // If the `value` and the `body` of the let-in have been rewritten to the same thing - if (value_new.same_as(op->value) && body_new.same_as(op->body)) { - // Return a reference to the same node - return ffi::GetRef(op); + PrimExpr value_new = VisitExpr(bind->value); + context_.push_back({bind->var, MaybeValue(bind->value)}); + + Stmt bind_new; + if (value_new.same_as(bind->value)) { + bind_new = ffi::GetRef(bind); + } else { + bind_new = Bind(bind->var, value_new, bind->span); + } + + // Construct the "body" from remaining siblings + Stmt body; + if (start + 2 == seq.size()) { + body = seq[start + 1]; + } else if (start + 1 < seq.size()) { + ffi::Array remaining; + for (size_t j = start + 1; j < seq.size(); j++) { + remaining.push_back(seq[j]); + } + body = SeqStmt(remaining); + } else { + // Bind is the last element, no body + context_ = context_at_entry; + return bind_new; + } + + // Call the full CSE VisitStmt on the body (with augmented context). + // This is the key step that allows CSE to find common subexpressions + // in subsequent siblings with the Bind variable in scope. + Stmt body_new = VisitStmt(body); + + context_ = context_at_entry; + + // Flatten into a flat result + ffi::Array result; + result.push_back(bind_new); + if (auto inner = body_new.as()) { + for (const auto& s : inner->seq) { + result.push_back(s); + } + } else { + result.push_back(body_new); + } + return SeqStmt::Flatten(result); + } + + // seq[start] is a non-Bind child. + // Process it individually with VisitStmt, then recurse on the rest. + Stmt child_new = VisitStmt(seq[start]); + + if (start + 1 >= seq.size()) { + // Single remaining child -- return it directly + return child_new; + } + + ffi::Array result; + if (auto inner = child_new.as()) { + for (const auto& s : inner->seq) { + result.push_back(s); + } } else { - // Otherwise return a let-in built with the new `value_new` and the new `body_new` that - // have just been obtained - return LetStmt(op->var, value_new, body_new, op->span); + result.push_back(child_new); + } + + Stmt rest = VisitSeqStmtSlice(seq, start + 1); + if (auto inner = rest.as()) { + for (const auto& s : inner->seq) { + result.push_back(s); + } + } else { + result.push_back(rest); } + + return SeqStmt::Flatten(result); +} + +/*! + * \brief The method which overrides the specific treatment for a SeqStmtNode. + * + * With flat Bind nodes (no body), the SeqStmt must be processed + * sequentially: each Bind node augments the context, and the remaining + * non-Bind siblings are wrapped into a "body" for CSE analysis. + */ +Stmt CommonSubexpressionEliminator::VisitStmt_(const SeqStmtNode* op) { + Context context_at_entry = context_; + + // Use in_seq_stmt_handler_ to track recursive calls. + // On first entry: process the whole SeqStmt via VisitSeqStmtSlice. + // On recursive entry (from VisitStmt -> StmtExprMutator dispatch): + // also use VisitSeqStmtSlice, but starting fresh (the context + // has already been updated by the outer call). + Stmt result = VisitSeqStmtSlice(op->seq, 0); + + context_ = context_at_entry; + return result; } /*! diff --git a/src/tir/transform/common_subexpr_elim.h b/src/tir/transform/common_subexpr_elim.h index 814161cc3535..fcf0c3fc5789 100644 --- a/src/tir/transform/common_subexpr_elim.h +++ b/src/tir/transform/common_subexpr_elim.h @@ -69,9 +69,13 @@ class CommonSubexpressionEliminator : public StmtExprMutator { PrimExpr VisitExpr_(const LetNode* op) override; - Stmt VisitStmt_(const LetStmtNode* op) override; + Stmt VisitStmt_(const BindNode* op) override; + Stmt VisitStmt_(const SeqStmtNode* op) override; Stmt VisitStmt_(const ForNode* op) override; + // Helper: process a slice of a SeqStmt starting at `start` + Stmt VisitSeqStmtSlice(const ffi::Array& seq, size_t start); + private: Stmt initial_body_; // Kept for checking if names of new variables already exist Context context_; // Context associating variables to (maybe) definitions diff --git a/src/tir/transform/ir_utils.cc b/src/tir/transform/ir_utils.cc index 9398c2561ead..d54eb1317ddd 100644 --- a/src/tir/transform/ir_utils.cc +++ b/src/tir/transform/ir_utils.cc @@ -46,11 +46,9 @@ Stmt MergeNest(const std::vector& nest, Stmt body) { TVM_FFI_ICHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); - } else if (const auto* let = s.as()) { - auto n = ffi::make_object(*let); - TVM_FFI_ICHECK(is_no_op(n->body)); - n->body = body; - body = Stmt(n); + } else if (const auto* bind = s.as()) { + // Bind has no body -- prepend it before the accumulated body in a SeqStmt. + body = SeqStmt::Flatten(ffi::GetRef(bind), body); } else if (const auto* attr = s.as()) { auto n = ffi::make_object(*attr); TVM_FFI_ICHECK(is_no_op(n->body)); @@ -348,13 +346,12 @@ class IRConvertSSA final : public StmtExprMutator { return new_buf; } - Stmt VisitStmt_(const LetStmtNode* op) final { + Stmt VisitStmt_(const BindNode* op) final { const Var& v = op->var; if (defined_.count(v.get())) { PrimExpr value = this->VisitExpr(op->value); ScopedRedefine redefine(this, v); - Stmt body = this->VisitStmt(op->body); - return LetStmt(redefine.new_var, value, body); + return Bind(redefine.new_var, value); } else { defined_.insert(v.get()); return StmtExprMutator::VisitStmt_(op); diff --git a/src/tir/transform/ir_utils.h b/src/tir/transform/ir_utils.h index 8077ebdea807..1282778ae09d 100644 --- a/src/tir/transform/ir_utils.h +++ b/src/tir/transform/ir_utils.h @@ -46,7 +46,7 @@ namespace tvm { namespace tir { /*! * \brief combine the nest stmt, whose body is not defined. - * \param nest A list of For and LetStmt, whose body is not defined. + * \param nest A list of For and Bind, whose body is not defined. * \param body body * \return The combined Stmt */ @@ -54,7 +54,7 @@ Stmt MergeNest(const std::vector& nest, Stmt body); /*! * \brief combine the nest stmt, whose body is not defined. - * \param nest A list of For and LetStmt, whose body is not defined. + * \param nest A list of For and Bind, whose body is not defined. * \param body body * \return The combined Stmt */ diff --git a/src/tir/transform/lower_tvm_builtin.cc b/src/tir/transform/lower_tvm_builtin.cc index 9fe3412389ce..f556e8e88567 100644 --- a/src/tir/transform/lower_tvm_builtin.cc +++ b/src/tir/transform/lower_tvm_builtin.cc @@ -154,21 +154,29 @@ class BuiltinLower : public StmtExprMutator { // used when mutating. scope.max_sizes = GetMaxStack(stmt); + // Build a flat list of Bind stmts followed by the body + ffi::Array alloca_stmts; + if (scope.max_sizes.arg_stack != 0) { + alloca_stmts.push_back( + Bind(scope.stack_ffi_any, StackAlloca("tvm_ffi_any", scope.max_sizes.arg_stack))); + } + + if (scope.max_sizes.array_stack != 0) { + alloca_stmts.push_back( + Bind(scope.stack_array, StackAlloca("array", scope.max_sizes.array_stack))); + } + if (scope.max_sizes.shape_stack != -1) { scope.stack_shape = decl_buffer({IntImm(DataType::Int(64), scope.max_sizes.shape_stack)}, DataType::Int(64), "stack_shape"); + alloca_stmts.push_back( + Bind(scope.stack_shape->data, StackAlloca("shape", scope.max_sizes.shape_stack))); stmt = DeclBuffer(scope.stack_shape, stmt); - stmt = LetStmt(scope.stack_shape->data, StackAlloca("shape", scope.max_sizes.shape_stack), - stmt); - } - - if (scope.max_sizes.array_stack != 0) { - stmt = LetStmt(scope.stack_array, StackAlloca("array", scope.max_sizes.array_stack), stmt); } - if (scope.max_sizes.arg_stack != 0) { - stmt = LetStmt(scope.stack_ffi_any, StackAlloca("tvm_ffi_any", scope.max_sizes.arg_stack), - stmt); + if (!alloca_stmts.empty()) { + alloca_stmts.push_back(stmt); + stmt = SeqStmt(alloca_stmts); } } @@ -212,15 +220,51 @@ class BuiltinLower : public StmtExprMutator { } } - Stmt VisitStmt_(const LetStmtNode* op) final { + Stmt VisitStmt_(const BindNode* op) final { if (const CallNode* call = op->value.as()) { if (call->op.same_as(builtin::nd_mem_alloc_with_scope())) { - return StmtExprMutator::VisitStmt(MakeNdMemAllocWithScope(op, call)); + // Save this Bind for SeqStmt-level handling. + // MakeNdMemAllocWithScope needs the body (sibling stmts), so we + // defer to VisitStmt_(const SeqStmtNode*). + pending_nd_mem_alloc_ = op; + return ffi::GetRef(op); } } return StmtExprMutator::VisitStmt_(op); } + Stmt VisitStmt_(const SeqStmtNode* op) final { + ffi::Array new_seq; + bool changed = false; + for (size_t i = 0; i < op->seq.size(); ++i) { + pending_nd_mem_alloc_ = nullptr; + Stmt visited = this->VisitStmt(op->seq[i]); + if (pending_nd_mem_alloc_) { + // This Bind was an nd_mem_alloc_with_scope. + // Collect remaining stmts as the "body" that needs wrapping. + const BindNode* let = pending_nd_mem_alloc_; + const CallNode* call = let->value.as(); + pending_nd_mem_alloc_ = nullptr; + + // Collect remaining sibling stmts as the body + ffi::Array body_stmts; + for (size_t j = i + 1; j < op->seq.size(); ++j) { + body_stmts.push_back(this->VisitStmt(op->seq[j])); + } + Stmt body = body_stmts.empty() ? Evaluate(0) : SeqStmt::Flatten(body_stmts); + Stmt alloc_stmt = MakeNdMemAllocWithScope(let, call, body); + new_seq.push_back(this->VisitStmt(alloc_stmt)); + changed = true; + break; // remaining stmts already consumed + } else { + new_seq.push_back(visited); + if (!visited.same_as(op->seq[i])) changed = true; + } + } + if (!changed) return ffi::GetRef(op); + return SeqStmt::Flatten(new_seq); + } + Stmt VisitStmt_(const AllocBufferNode* op) { // Lower AllocBuffer to device allocate when needed. Stmt stmt = StmtExprMutator::VisitStmt_(op); @@ -264,13 +308,13 @@ class BuiltinLower : public StmtExprMutator { body = AttrStmt(op->buffer->data, attr::storage_alignment, make_const(DataType::Int(32), runtime::kTempAllocaAlignment), body); - body = LetStmt(op->buffer->data, - Call(op->buffer->data.dtype(), Op::Get("tir.TVMBackendAllocWorkspace"), - {cast(DataType::Int(32), device_type_.value()), - cast(DataType::Int(32), device_id_.value()), total_bytes, - IntImm(DataType::Int(32), op->buffer->dtype.code()), - IntImm(DataType::Int(32), op->buffer->dtype.bits())}), - body); + body = SeqStmt({Bind(op->buffer->data, + Call(op->buffer->data.dtype(), Op::Get("tir.TVMBackendAllocWorkspace"), + {cast(DataType::Int(32), device_type_.value()), + cast(DataType::Int(32), device_id_.value()), total_bytes, + IntImm(DataType::Int(32), op->buffer->dtype.code()), + IntImm(DataType::Int(32), op->buffer->dtype.bits())})), + body}); return body; } @@ -496,6 +540,7 @@ class BuiltinLower : public StmtExprMutator { return ffi::TypeIndex::kTVMFFIOpaquePtr; } else { TVM_FFI_THROW(InternalError) << "Unsupported type: " << api_dtype; + __builtin_unreachable(); } }(); @@ -595,7 +640,7 @@ class BuiltinLower : public StmtExprMutator { return Call(op->dtype, lowered_packed_op, packed_args); } - Stmt MakeNdMemAllocWithScope(const LetStmtNode* let, const CallNode* call) { + Stmt MakeNdMemAllocWithScope(const BindNode* let, const CallNode* call, Stmt inner_body) { TVM_FFI_ICHECK(device_type_) << "Unknown device type in current IR"; TVM_FFI_ICHECK(device_id_) << "Unknown device id in current IR"; Stmt throw_last_error = Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {})); @@ -608,7 +653,7 @@ class BuiltinLower : public StmtExprMutator { Stmt body = SeqStmt( {IfThenElse(Call(DataType::Bool(), builtin::isnullptr(), {let->var}), throw_last_error), - let->body, free_stmt}); + inner_body, free_stmt}); DataType dtype = let->var->type_annotation.as()->element_type.as()->dtype; @@ -629,8 +674,7 @@ class BuiltinLower : public StmtExprMutator { } Call call_packed = Call(let->var.dtype(), builtin::tvm_call_packed(), args); - Stmt alloca = LetStmt(let->var, call_packed, body); - return alloca; + return SeqStmt({Bind(let->var, call_packed), body}); } private: @@ -649,6 +693,8 @@ class BuiltinLower : public StmtExprMutator { std::vector> prep_seq_stack_; ffi::Optional device_type_{std::nullopt}; ffi::Optional device_id_{std::nullopt}; + // Pending nd_mem_alloc Bind node for SeqStmt-level handling + const BindNode* pending_nd_mem_alloc_{nullptr}; bool is_precheck_{false}; diff --git a/src/tir/transform/remove_no_op.cc b/src/tir/transform/remove_no_op.cc index cb073bf31a61..e1c60b2e9707 100644 --- a/src/tir/transform/remove_no_op.cc +++ b/src/tir/transform/remove_no_op.cc @@ -91,22 +91,17 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer { const StmtNode* context) : Parent(analyzer), touch_pattern_(touch_pattern), context_(context) {} - Stmt VisitStmt_(const LetStmtNode* op) final { + Stmt VisitStmt_(const BindNode* op) final { Stmt stmt = Parent::VisitStmt_(op); - op = stmt.as(); - if (is_no_op(op->body)) { - return MakeEvaluate(op->value); - } - - bool body_uses_bound_variable = - !UsesVar(op->body, [&](const VarNode* var) { return var == op->var.get(); }); - if (body_uses_bound_variable && HasSideEffect(op->value)) { - return SeqStmt({MakeEvaluate(op->value), op->body}); - } else if (body_uses_bound_variable) { - return op->body; - } else { + op = stmt.as(); + // Bind has no body -- removal of unused Bind is handled at SeqStmt level. + // If the value has no side effect, the Bind can potentially be removed. + if (!HasSideEffect(op->value) && SideEffect(op->value) <= CallEffectKind::kPure) { + // A pure Bind with no uses will be cleaned up by dead code elimination. + // Keep it for now; the SeqStmt visitor handles removal. return stmt; } + return stmt; } Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == "pragma_debug_skip_region") { diff --git a/src/tir/transform/simplify.cc b/src/tir/transform/simplify.cc index af0fc4cf47bf..2c49d862d093 100644 --- a/src/tir/transform/simplify.cc +++ b/src/tir/transform/simplify.cc @@ -124,7 +124,7 @@ std::unordered_set CollectVarsUsedInBufferDefinition(const Stmt& } usage(buf->elem_offset); - // Track for use in LetStmtNode mutator + // Track for use in BindNode mutator for (const auto& var : usage.undefined_) { used_in_buffer_def_.insert(var.get()); } @@ -220,7 +220,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { return Parent::VisitStmt_(op); } - bool CanInlineLetStmt(const LetStmtNode* op) { + bool CanInlineBind(const BindNode* op) { if (is_const_number(op->value)) return true; if (op->value.as()) return true; // Won't face the deep expression explosion problem as in Let expression. @@ -229,9 +229,9 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { return SideEffect(op->value) <= CallEffectKind::kPure; } - Stmt VisitStmt_(const LetStmtNode* op) override { + Stmt VisitStmt_(const BindNode* op) override { PrimExpr value = this->VisitExpr(op->value); - bool can_inline = CanInlineLetStmt(op); + bool can_inline = CanInlineBind(op); if (can_inline) { // It is usually fine to discard the let binding because the // call to simplify will always inline the var. @@ -251,7 +251,6 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { // necessary for proving conditional statements. non_inlined_bindings_.Set(op->var, value); } - Stmt body = this->VisitStmt(op->body); // TODO(Lunderberg): Update the Buffer object as part of // DeclBuffer updates, which will first require @@ -259,13 +258,12 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { bool used_in_buffer_def = used_in_buffer_def_.count(op->var.get()); if (can_inline && !used_in_buffer_def) { - return body; - } else if (value.same_as(op->value) && body.same_as(op->body)) { + return Evaluate(0); + } else if (value.same_as(op->value)) { return ffi::GetRef(op); } else { auto n = this->CopyOnWrite(op); n->value = std::move(value); - n->body = std::move(body); return Stmt(n); } } diff --git a/src/tir/transform/split_host_device.cc b/src/tir/transform/split_host_device.cc index e832e9d1caea..0bdcde861c63 100644 --- a/src/tir/transform/split_host_device.cc +++ b/src/tir/transform/split_host_device.cc @@ -105,9 +105,7 @@ class HostDeviceSplitter : public StmtMutator { Call kernel_call(success->dtype, kernel_symbol_global, args); AssertStmt assert_success(kernel_error_code == success, StringImm("RuntimeError"), {StringImm("Error executing compute kernel")}); - LetStmt let_check(kernel_error_code, kernel_call, assert_success); - - return let_check; + return SeqStmt({Bind(kernel_error_code, kernel_call), assert_success}); } else { return Evaluate(Call(DataType::Void(), kernel_symbol_global, args)); diff --git a/src/tir/transform/storage_rewrite.cc b/src/tir/transform/storage_rewrite.cc index e6391f08c3d7..40235aa98e11 100644 --- a/src/tir/transform/storage_rewrite.cc +++ b/src/tir/transform/storage_rewrite.cc @@ -213,7 +213,7 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { void VisitStmt_(const AssertStmtNode* op) final { VisitNewScope(op); } - void VisitStmt_(const LetStmtNode* op) final { VisitNewScope(op); } + void VisitStmt_(const BindNode* op) final { StmtExprVisitor::VisitStmt_(op); } // linearized access sequence. std::vector linear_seq_; @@ -1205,7 +1205,7 @@ class VectorTypeAccessChecker : public StmtExprVisitor { StmtExprVisitor::VisitExpr_(op); } - void VisitStmt_(const LetStmtNode* op) final { + void VisitStmt_(const BindNode* op) final { HandleLetNode(op->var); StmtExprVisitor::VisitStmt_(op); } @@ -1516,15 +1516,14 @@ class VectorTypeRewriter : public StmtExprMutator { return modified; } - Stmt VisitStmt_(const LetStmtNode* op) final { + Stmt VisitStmt_(const BindNode* op) final { auto it = rewrite_map_.find(op->var.get()); PrimExpr value = this->VisitExpr(op->value); - Stmt body = this->VisitStmt(op->body); Var var = (it == rewrite_map_.end()) ? op->var : it->second.new_buffer_var; - if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { + if (var.same_as(op->var) && value.same_as(op->value)) { return ffi::GetRef(op); } - return LetStmt(var, value, body); + return Bind(var, value); } Stmt VisitStmt_(const AllocBufferNode* op) final { diff --git a/src/tir/transform/tvm_ffi_binder.cc b/src/tir/transform/tvm_ffi_binder.cc index 02058970e182..6bf5e90cca50 100644 --- a/src/tir/transform/tvm_ffi_binder.cc +++ b/src/tir/transform/tvm_ffi_binder.cc @@ -169,7 +169,7 @@ bool TVMFFIABIBuilder::BindScalar(const PrimExpr& arg, const PrimExpr& value, // First bind: define the variable if (with_lets) { var_defs_.emplace(v_arg.get(), VarDefInfo{arg, path}); - init_nest_.emplace_back(LetStmt(v_arg, value, Evaluate(0))); + init_nest_.emplace_back(Bind(v_arg, value)); } else { var_defs_.emplace(v_arg.get(), VarDefInfo{value, path}); } @@ -488,17 +488,15 @@ PrimExpr TVMFFIABIBuilder::DecodeParamFloat(int param_index, const Var& type_ind // ============================================================ void TVMFFIABIBuilder::DecodeParam(int param_index) { - const Stmt nop = Evaluate(0); Var param = params_[param_index]; DataType dtype = param.dtype(); // Extract type_index from packed_args Var type_index(param->name_hint + ".type_index", DataType::Int(32)); - init_nest_.push_back(LetStmt(type_index, - tir::Call(DataType::Int(32), builtin::tvm_struct_get(), - {v_packed_args_, IntImm(DataType::Int(32), param_index), - IntImm(DataType::Int(32), builtin::kTVMFFIAnyTypeIndex)}), - nop)); + init_nest_.push_back(Bind(type_index, + tir::Call(DataType::Int(32), builtin::tvm_struct_get(), + {v_packed_args_, IntImm(DataType::Int(32), param_index), + IntImm(DataType::Int(32), builtin::kTVMFFIAnyTypeIndex)}))); // Type-check and load value via per-dtype dispatch PrimExpr arg_value; @@ -553,10 +551,8 @@ Var TVMFFIABIBuilder::DLTensorGetFieldPtr(const Var& handle, int field_kind, const std::string& var_name) { Var ptr(var_name, DataType::Handle()); init_nest_.emplace_back( - LetStmt(ptr, - TVMStructGet(DataType::Handle(), handle, 0, - static_cast(field_kind)), - Evaluate(0))); + Bind(ptr, TVMStructGet(DataType::Handle(), handle, 0, + static_cast(field_kind)))); return ptr; } diff --git a/src/tir/transform/tvm_ffi_binder.h b/src/tir/transform/tvm_ffi_binder.h index 2daa3200874c..c8a4da752b7c 100644 --- a/src/tir/transform/tvm_ffi_binder.h +++ b/src/tir/transform/tvm_ffi_binder.h @@ -59,7 +59,7 @@ namespace tir { * by a later buffer's shape (batch_size). Separating definitions from * checks guarantees all variables are in scope when assertions reference them. * - * - init_nest: LetStmts, DeclBuffers for shape/strides arrays, AttrStmts — + * - init_nest: Binds, DeclBuffers for shape/strides arrays, AttrStmts — * all value-loading code that defines variables. * - asserts: AssertStmts — all validation checks. * - decl_buffers: DeclBuffer for buffer_map entries — buffer declarations. @@ -95,7 +95,7 @@ class TVMFFIABIBuilder { struct Result { /*! \brief Var -> VarDefInfo map for defined variables. */ std::unordered_map var_defs; - /*! \brief Variable definitions (LetStmts, shape/strides DeclBuffers, AttrStmts). */ + /*! \brief Variable definitions (Binds, shape/strides DeclBuffers, AttrStmts). */ std::vector init_nest; /*! \brief Validation checks (all AssertStmts). */ std::vector asserts; @@ -236,7 +236,7 @@ class TVMFFIABIBuilder { * * \param arg The argument expression to bind (typically a Var or constant). * \param value The value expression to bind to the argument. - * \param with_lets If true, emit LetStmt bindings into init_nest_. + * \param with_lets If true, emit Bind bindings into init_nest_. * \param path AccessPath for rich error message rendering. * \return True if this was the first bind (definition created), false otherwise. */ @@ -284,7 +284,7 @@ class TVMFFIABIBuilder { // ── DLTensor sub-helpers ─────────────────────────────────────── /*! - * \brief Get a DLTensor field pointer (shape or strides) and store it in a LetStmt. + * \brief Get a DLTensor field pointer (shape or strides) and store it in a Bind. * * \param handle The DLTensor handle variable. * \param field_kind kDLTensorShape or kDLTensorStrides. @@ -390,7 +390,7 @@ class TVMFFIABIBuilder { /*! \brief The definition map: VarNode* -> VarDefInfo (value + first_def_path). */ std::unordered_map var_defs_; - /*! \brief Variable definitions: LetStmts, shape/strides DeclBuffers, AttrStmts. */ + /*! \brief Variable definitions: Binds, shape/strides DeclBuffers, AttrStmts. */ std::vector init_nest_; /*! \brief Validation checks: all AssertStmts. */ std::vector asserts_; diff --git a/src/tir/transform/unsupported_dtype_legalize.cc b/src/tir/transform/unsupported_dtype_legalize.cc index cca435d0de78..3c4fbc66241b 100644 --- a/src/tir/transform/unsupported_dtype_legalize.cc +++ b/src/tir/transform/unsupported_dtype_legalize.cc @@ -293,19 +293,18 @@ class ComputeLegalizer : public StmtExprMutator { DEFINE_BIOP_EXPR_LEGALIZE(EQNode, operator==); DEFINE_BIOP_EXPR_LEGALIZE(NENode, operator!=); - Stmt VisitStmt_(const LetStmtNode* op) final { + Stmt VisitStmt_(const BindNode* op) final { PrimExpr value = PromoteToTarget(op->value); Var var = op->var; if (value.dtype() != op->value.dtype()) { var = op->var.copy_with_dtype(op->value.dtype()); var_remap_[op->var] = var; } - Stmt body = VisitStmt(op->body); - if (value.same_as(op->value) && var.same_as(op->var) && body.same_as(op->body)) { + if (value.same_as(op->value) && var.same_as(op->var)) { return ffi::GetRef(op); } else { - return LetStmt(var, value, body); + return Bind(var, value); } } @@ -583,15 +582,14 @@ class StorageLegalizer : public StmtExprMutator { } } - Stmt VisitStmt_(const LetStmtNode* op) final { + Stmt VisitStmt_(const BindNode* op) final { PrimExpr value = VisitExpr(op->value); Var var = RemapVarDef(op->var); - Stmt body = VisitStmt(op->body); - if (value.same_as(op->value) && var.same_as(op->var) && body.same_as(op->body)) { + if (value.same_as(op->value) && var.same_as(op->var)) { return ffi::GetRef(op); } else { - return LetStmt(var, value, body); + return Bind(var, value); } } diff --git a/src/tir/transform/vectorize_loop.cc b/src/tir/transform/vectorize_loop.cc index 8bcf6078caf6..d3e9d12de886 100644 --- a/src/tir/transform/vectorize_loop.cc +++ b/src/tir/transform/vectorize_loop.cc @@ -815,9 +815,10 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->value); // if visit of value triggers need scalarize // we need to scalarize the let @@ -832,14 +833,13 @@ class Vectorizer : public StmtMutator, public ExprFunctorvalue.dtype().get_lanes_or_vscale_factor()) { Var new_var(op->var->name_hint, value.dtype()); let_binding_[op->var] = new_var; - return LetStmt(new_var, value, this->VisitStmt(op->body)); + return Bind(new_var, value); } else { let_binding_[op->var] = op->var; - Stmt body = this->VisitStmt(op->body); - if (value.same_as(op->value) && body.same_as(op->body)) { + if (value.same_as(op->value)) { return ffi::GetRef(op); } else { - return LetStmt(op->var, value, body); + return Bind(op->var, value); } } } diff --git a/tests/python/tir-analysis/test_tir_analysis_verify_ssa.py b/tests/python/tir-analysis/test_tir_analysis_verify_ssa.py index 6d559a81d2c6..07611e55ace7 100644 --- a/tests/python/tir-analysis/test_tir_analysis_verify_ssa.py +++ b/tests/python/tir-analysis/test_tir_analysis_verify_ssa.py @@ -23,7 +23,9 @@ def test_verify_ssa(): z = tvm.tir.Evaluate(x + y) assert tvm.tir.analysis.verify_ssa(tvm.tir.PrimFunc([x, y], z)) - assert not tvm.tir.analysis.verify_ssa(tvm.tir.PrimFunc([x, y], tvm.tir.LetStmt(x, 1, z))) + assert not tvm.tir.analysis.verify_ssa( + tvm.tir.PrimFunc([x, y], tvm.tir.SeqStmt([tvm.tir.Bind(x, 1), z])) + ) def test_verify_weak_let_ssa(): diff --git a/tests/python/tir-base/test_tir_constructor.py b/tests/python/tir-base/test_tir_constructor.py index 5a7d516afcbe..1ebb80c038a4 100644 --- a/tests/python/tir-base/test_tir_constructor.py +++ b/tests/python/tir-base/test_tir_constructor.py @@ -131,11 +131,10 @@ def test_expr_constructor(): def test_stmt_constructor(): v = tvm.tir.Var("aa", "int32") nop = tvm.tir.Evaluate(1) - x = tvm.tir.LetStmt(v, 1, tvm.tir.Evaluate(1)) - assert isinstance(x, tvm.tir.LetStmt) + x = tvm.tir.Bind(v, 1) + assert isinstance(x, tvm.tir.Bind) assert x.var == v assert x.value.value == 1 - assert isinstance(x.body, tvm.tir.Evaluate) x = tvm.tir.AttrStmt(v == 1, "xx", 1, tvm.tir.Evaluate(1)) assert isinstance(x, tvm.tir.AttrStmt) diff --git a/tests/python/tir-base/test_tir_nodes.py b/tests/python/tir-base/test_tir_nodes.py index 4d66b929967a..a20e96b5668f 100644 --- a/tests/python/tir-base/test_tir_nodes.py +++ b/tests/python/tir-base/test_tir_nodes.py @@ -91,7 +91,7 @@ def test_ir2(): def test_let(): x = tvm.tir.Var("x", "int32") y = tvm.tir.Var("y", "int32") - stmt = tvm.tir.LetStmt(x, 10, tvm.tir.Evaluate(x + 1)) + stmt = tvm.tir.Bind(x, 10) def test_cast(): @@ -306,7 +306,7 @@ def test_prim_func(): x = tvm.tir.Var("x", "int32") y = tvm.tir.Var("y", "int32") b = tvm.tir.decl_buffer((x,), "float32") - stmt = tvm.tir.LetStmt(x, 10, tvm.tir.Evaluate(x + 1)) + stmt = tvm.tir.SeqStmt([tvm.tir.Bind(x, 10), tvm.tir.Evaluate(x + 1)]) func = tvm.tir.PrimFunc([x, y, b], stmt) # make sure we can print diff --git a/tests/python/tir-base/test_tir_structural_equal_hash.py b/tests/python/tir-base/test_tir_structural_equal_hash.py index bd0eb573b4b7..6f44138300b8 100644 --- a/tests/python/tir-base/test_tir_structural_equal_hash.py +++ b/tests/python/tir-base/test_tir_structural_equal_hash.py @@ -109,7 +109,7 @@ def test_prim_func(): # new cases b = tvm.tir.decl_buffer((x,), "float32") - stmt = tvm.tir.LetStmt(x, 10, tvm.tir.Evaluate(x + 1)) + stmt = tvm.tir.SeqStmt([tvm.tir.Bind(x, 10), tvm.tir.Evaluate(x + 1)]) func0 = tvm.tir.PrimFunc([x, y, b], stmt) # easiest way to deep copy is via save/load func1 = tvm.ir.load_json(tvm.ir.save_json(func0)) diff --git a/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py b/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py index 79cbdb91950f..f6bcd3cd97c8 100644 --- a/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py +++ b/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py @@ -40,44 +40,27 @@ def test_cse(): b = tvm.tir.Var("b", "int32") dtype = "int32" buffer = tvm.tir.decl_buffer((50,), dtype) - # Test prog : - # let z1=1 in let z2=2 in - # Mem[i1] = z1+z2; - # let x = 1 in let y = 1 in - # let a = (x+y) + (z1+z2) in - # let b = (x+y) + z3 in - # Mem[i2] = a+b; - body = tvm.tir.LetStmt( - z1, - 1, - tvm.tir.LetStmt( - z2, - 2, - tvm.tir.SeqStmt( - [ - tvm.tir.BufferStore(buffer, z1 + z2, [i1]), - tvm.tir.LetStmt( - x, - 1, - tvm.tir.LetStmt( - y, - 1, - tvm.tir.LetStmt( - a, - (x + y) + (z1 + z2), - tvm.tir.LetStmt( - b, (x + y) + z3, tvm.tir.BufferStore(buffer, a + b, [i2]) - ), - ), - ), - ), - ] - ), - ), + # Test prog (flat Bind style): + # z1 = 1; z2 = 2; + # Mem[i1] = z1+z2; + # x = 1; y = 1; + # a = (x+y) + (z1+z2); + # b = (x+y) + z3; + # Mem[i2] = a+b; + body = tvm.tir.SeqStmt( + [ + tvm.tir.Bind(z1, 1), + tvm.tir.Bind(z2, 2), + tvm.tir.BufferStore(buffer, z1 + z2, [i1]), + tvm.tir.Bind(x, 1), + tvm.tir.Bind(y, 1), + tvm.tir.Bind(a, (x + y) + (z1 + z2)), + tvm.tir.Bind(b, (x + y) + z3), + tvm.tir.BufferStore(buffer, a + b, [i2]), + ] ) - # This test program gives the opportunity to introduce two new variables, at two different - # levels and to perform replacements in the value of "a" and "b", using these new variables. - # We will check all of that underneath and more, making also sure that nothing else has changed + # This test program gives the opportunity to introduce two new variables, + # and to perform replacements in the value of "a" and "b", using these new variables. mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([i1, i2, z3], body)) body = tvm.tir.transform.CommonSubexprElimTIR()(mod) @@ -86,61 +69,72 @@ def test_cse(): body = body["main"].body # Gets the body of the main, i.e. the full statement - assert body.var.name == "z1" - assert body.value == 1 - - body = body.body - - assert body.var.name == "z2" - assert body.value == 2 - # This is the let-in for the first variable generated cse_v1 - assert isinstance(body.body, tvm.tir.LetStmt) - - body = body.body - - # And this is the name and value of this variable - cse_v1 = body.var # Keep the variable accessible for later checking the replacements - assert body.var.name == "cse_v1" - tvm.ir.assert_structural_equal(body.value, z1 + z2) - assert isinstance(body.body, tvm.tir.SeqStmt) - - body = body.body - - assert isinstance(body[0], tvm.tir.BufferStore) - assert isinstance(body[1], tvm.tir.LetStmt) - - body = body[1] - - assert body.var.name == "x" - assert body.value == 1 - - body = body.body - - assert body.var.name == "y" - assert body.value == 1 - # This is the let-in for the second variable generated cse_v2 - assert isinstance(body.body, tvm.tir.LetStmt) - - body = body.body - - # And this is the name and value of this variable - cse_v2 = body.var # Keep the variable accessible for later checking the replacements - assert body.var.name == "cse_v2" - tvm.ir.assert_structural_equal(body.value, x + y) - - body = body.body - - body.var.name == "a" - # Check that the replacement has been done correctly! - tvm.ir.assert_structural_equal(body.value, cse_v2 + cse_v1) - - body = body.body - - body.var.name == "b" - # Check that the replacement has been done correctly! - tvm.ir.assert_structural_equal(body.value, cse_v2 + z3) + # The result should be a flat SeqStmt with Bind nodes for z1, z2, cse_v1 (z1+z2), + # the store, x, y, cse_v2 (x+y), a (using cse vars), b (using cse vars), store + assert isinstance(body, tvm.tir.SeqStmt) - assert isinstance(body.body, tvm.tir.BufferStore) + # Walk through the flat sequence and check the CSE-introduced bindings + stmts = list(body) + idx = 0 + + # z1 = 1 + assert isinstance(stmts[idx], tvm.tir.Bind) + assert stmts[idx].var.name == "z1" + assert stmts[idx].value == 1 + idx += 1 + + # z2 = 2 + assert isinstance(stmts[idx], tvm.tir.Bind) + assert stmts[idx].var.name == "z2" + assert stmts[idx].value == 2 + idx += 1 + + # CSE should introduce cse_v1 = z1 + z2 here + assert isinstance(stmts[idx], tvm.tir.Bind) + cse_v1 = stmts[idx].var + assert stmts[idx].var.name == "cse_v1" + tvm.ir.assert_structural_equal(stmts[idx].value, z1 + z2) + idx += 1 + + # Mem[i1] = cse_v1 (was z1+z2, now replaced) + assert isinstance(stmts[idx], tvm.tir.BufferStore) + tvm.ir.assert_structural_equal(stmts[idx].value, cse_v1) + idx += 1 + + # x = 1 + assert isinstance(stmts[idx], tvm.tir.Bind) + assert stmts[idx].var.name == "x" + assert stmts[idx].value == 1 + idx += 1 + + # y = 1 + assert isinstance(stmts[idx], tvm.tir.Bind) + assert stmts[idx].var.name == "y" + assert stmts[idx].value == 1 + idx += 1 + + # CSE should introduce cse_v2 = x + y here + assert isinstance(stmts[idx], tvm.tir.Bind) + cse_v2 = stmts[idx].var + assert stmts[idx].var.name == "cse_v2" + tvm.ir.assert_structural_equal(stmts[idx].value, x + y) + idx += 1 + + # a = cse_v2 + cse_v1 (was (x+y) + (z1+z2), now replaced) + assert isinstance(stmts[idx], tvm.tir.Bind) + assert stmts[idx].var.name == "a" + tvm.ir.assert_structural_equal(stmts[idx].value, cse_v2 + cse_v1) + idx += 1 + + # b = cse_v2 + z3 (was (x+y) + z3, now replaced) + assert isinstance(stmts[idx], tvm.tir.Bind) + assert stmts[idx].var.name == "b" + tvm.ir.assert_structural_equal(stmts[idx].value, cse_v2 + z3) + idx += 1 + + # Mem[i2] = a + b + assert isinstance(stmts[idx], tvm.tir.BufferStore) + idx += 1 # ----------------------------------------------------- @@ -160,25 +154,24 @@ def test_cse_ifNode_1(): z = tvm.tir.Var("z", "int32") dtype = "int32" buffer = tvm.tir.decl_buffer((50,), dtype) - # Test prog : - # let b=1 in - # if(b) { - # Mem[i1] = y+z - # Mem[i2] = y+z - # } - # else { - # Mem[i3] = y - # } - body = tvm.tir.LetStmt( - b, - 1, - tvm.tir.IfThenElse( - b, - tvm.tir.SeqStmt( - [tvm.tir.BufferStore(buffer, y + z, [i1]), tvm.tir.BufferStore(buffer, y + z, [i2])] + # Test prog: + # b = 1; + # if(b) { Mem[i1] = y+z; Mem[i2] = y+z } + # else { Mem[i3] = y } + body = tvm.tir.SeqStmt( + [ + tvm.tir.Bind(b, 1), + tvm.tir.IfThenElse( + b, + tvm.tir.SeqStmt( + [ + tvm.tir.BufferStore(buffer, y + z, [i1]), + tvm.tir.BufferStore(buffer, y + z, [i2]), + ] + ), + tvm.tir.BufferStore(buffer, y, [i3]), ), - tvm.tir.BufferStore(buffer, y, [i3]), - ), + ] ) mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([i1, i2, i3, y, z], body)) @@ -188,20 +181,23 @@ def test_cse_ifNode_1(): body = body["main"].body # Gets the body of the main, i.e. the full statement - assert body.var.name == "b" - assert body.value == 1 - assert isinstance(body.body, tvm.tir.IfThenElse) - - body = body.body + assert isinstance(body, tvm.tir.SeqStmt) + stmts = list(body) - assert isinstance(body.then_case, tvm.tir.LetStmt) + # b = 1 + assert isinstance(stmts[0], tvm.tir.Bind) + assert stmts[0].var.name == "b" + assert stmts[0].value == 1 - body = body.then_case + # The If node + assert isinstance(stmts[1], tvm.tir.IfThenElse) + if_node = stmts[1] - # The let-in introduced by the CSE should appear now, inside the Then branch of the If node - assert body.var.name == "cse_v1" - # and it should contain the expression (y+z) that was redundant - tvm.ir.assert_structural_equal(body.value, y + z) + # The CSE variable should be inside the Then branch + then_stmts = list(if_node.then_case) + assert isinstance(then_stmts[0], tvm.tir.Bind) + assert then_stmts[0].var.name == "cse_v1" + tvm.ir.assert_structural_equal(then_stmts[0].value, y + z) # Second test for if nodes : Some duplicated computations appear in both the Then and Else branch. @@ -216,28 +212,24 @@ def test_cse_ifNode_2(): z = tvm.tir.Var("z", "int32") dtype = "int32" buffer = tvm.tir.decl_buffer((50,), dtype) - # Test prog : - # let b=1 in - # if(b) { - # Mem[i1] = y+z - # Mem[i2] = y - # } - # else { - # Mem[i3] = y+z - # } - body = tvm.tir.LetStmt( - b, - 1, - tvm.tir.IfThenElse( - b, - tvm.tir.SeqStmt( - [ - tvm.tir.BufferStore(buffer, y + z, [i1]), # (y+z) is present in Then branch - tvm.tir.BufferStore(buffer, y, [i2]), - ] + # Test prog: + # b = 1; + # if(b) { Mem[i1] = y+z; Mem[i2] = y } + # else { Mem[i3] = y+z } + body = tvm.tir.SeqStmt( + [ + tvm.tir.Bind(b, 1), + tvm.tir.IfThenElse( + b, + tvm.tir.SeqStmt( + [ + tvm.tir.BufferStore(buffer, y + z, [i1]), + tvm.tir.BufferStore(buffer, y, [i2]), + ] + ), + tvm.tir.BufferStore(buffer, y + z, [i3]), ), - tvm.tir.BufferStore(buffer, y + z, [i3]), # and also present in the Else branch - ), + ] ) mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([i1, i2, i3, y, z], body)) @@ -247,12 +239,18 @@ def test_cse_ifNode_2(): body = body["main"].body # Gets the body of the main, i.e. the full statement - assert isinstance(body, tvm.tir.LetStmt) + assert isinstance(body, tvm.tir.SeqStmt) + stmts = list(body) - # The let-in introduced by the CSE should appear now, at the toplevel (i.e. before the If) - assert body.var.name == "cse_v1" - # and it should contain the expression (y+z) that was redundant - tvm.ir.assert_structural_equal(body.value, y + z) + # CSE should introduce cse_v1 = y + z before the If + # Find the cse_v1 binding + found_cse = False + for s in stmts: + if isinstance(s, tvm.tir.Bind) and s.var.name == "cse_v1": + tvm.ir.assert_structural_equal(s.value, y + z) + found_cse = True + break + assert found_cse # ------------------------------------------------------------------------------------------------- @@ -288,38 +286,29 @@ def test_cse_cascade(): body = body["main"].body # Gets the body of the main, i.e. the full statement - assert isinstance(body, tvm.tir.LetStmt) - - # The second let-in (by order introduced) introduced by the CSE should appear first - cse_v2 = body.var # Keep the variable accessible for later checking the replacements - assert body.var.name == "cse_v2" - # and it should contain the expression (x+y) - tvm.ir.assert_structural_equal(body.value, (x + y)) - - body = body.body - - assert isinstance(body, tvm.tir.LetStmt) + assert isinstance(body, tvm.tir.SeqStmt) + stmts = list(body) - # The first let-in (by order introduced) introduced by the CSE should appear now, after the 2nd - cse_v1 = body.var # Keep the variable accessible for later checking the replacements - assert body.var.name == "cse_v1" - # and it should contain the expression cse_v2+z - tvm.ir.assert_structural_equal(body.value, cse_v2 + z) + # cse_v2 = x + y + assert isinstance(stmts[0], tvm.tir.Bind) + cse_v2 = stmts[0].var + assert stmts[0].var.name == "cse_v2" + tvm.ir.assert_structural_equal(stmts[0].value, (x + y)) - body = body.body + # cse_v1 = cse_v2 + z + assert isinstance(stmts[1], tvm.tir.Bind) + cse_v1 = stmts[1].var + assert stmts[1].var.name == "cse_v1" + tvm.ir.assert_structural_equal(stmts[1].value, cse_v2 + z) - assert isinstance(body, tvm.tir.SeqStmt) - assert isinstance(body[0], tvm.tir.BufferStore) - assert isinstance(body[1], tvm.tir.BufferStore) - assert isinstance(body[2], tvm.tir.BufferStore) + # Three stores + assert isinstance(stmts[2], tvm.tir.BufferStore) + assert isinstance(stmts[3], tvm.tir.BufferStore) + assert isinstance(stmts[4], tvm.tir.BufferStore) - store1 = body[0] - store2 = body[1] - store3 = body[2] - - tvm.ir.assert_structural_equal(store1.value, cse_v1) - tvm.ir.assert_structural_equal(store2.value, cse_v1) - tvm.ir.assert_structural_equal(store3.value, cse_v2) + tvm.ir.assert_structural_equal(stmts[2].value, cse_v1) + tvm.ir.assert_structural_equal(stmts[3].value, cse_v1) + tvm.ir.assert_structural_equal(stmts[4].value, cse_v2) # ----------------------------------------------------------------------------------------- @@ -331,8 +320,8 @@ def test_no_normalization_without_commoning(): z = tvm.tir.Var("z", "int32") a = tvm.tir.Var("a", "int32") # Test prog : - # let a = x + (y + z) in a - body = tvm.tir.LetStmt(a, x + (y + z), tvm.tir.Evaluate(a)) + # a = x + (y + z); evaluate(a) + body = tvm.tir.SeqStmt([tvm.tir.Bind(a, x + (y + z)), tvm.tir.Evaluate(a)]) mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([x, y, z], body)) body = tvm.tir.transform.CommonSubexprElimTIR(identify_equiv_terms=True)(mod) @@ -341,8 +330,11 @@ def test_no_normalization_without_commoning(): body = body["main"].body # Gets the body of the main, i.e. the full statement - assert body.var.name == "a" - tvm.ir.assert_structural_equal(body.value, x + (y + z)) + assert isinstance(body, tvm.tir.SeqStmt) + stmts = list(body) + assert isinstance(stmts[0], tvm.tir.Bind) + assert stmts[0].var.name == "a" + tvm.ir.assert_structural_equal(stmts[0].value, x + (y + z)) # ------------------------------------------------- @@ -428,8 +420,8 @@ def test_deterministic_cse(): expression = x for add in inc1 + inc2: expression = expression + add - let_stmt = tvm.tir.LetStmt(result, expression, tvm.tir.Evaluate(result)) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([x], let_stmt)) + body = tvm.tir.SeqStmt([tvm.tir.Bind(result, expression), tvm.tir.Evaluate(result)]) + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([x], body)) initial_hash = None for _ in range(REPEATS): diff --git a/tests/python/tir-transform/test_tir_transform_convert_ssa.py b/tests/python/tir-transform/test_tir_transform_convert_ssa.py index 0d54d4fb048a..864f238385d2 100644 --- a/tests/python/tir-transform/test_tir_transform_convert_ssa.py +++ b/tests/python/tir-transform/test_tir_transform_convert_ssa.py @@ -32,8 +32,10 @@ def test_reuse_in_sequential_let_stmt(): var = tir.Var("var", "int32") sequential_bindings = tir.SeqStmt( [ - tir.LetStmt(var, 16, tir.Evaluate(var)), - tir.LetStmt(var, 32, tir.Evaluate(var)), + tir.Bind(var, 16), + tir.Evaluate(var), + tir.Bind(var, 32), + tir.Evaluate(var), ] ) before = tir.PrimFunc([], sequential_bindings) @@ -61,19 +63,21 @@ def test_reuse_in_nested_let_stmt(): # not valid TIR, and may not be expressible in future versions # of TVMSCript. var = tir.Var("var", "int32") - inner_let = tir.LetStmt(var, 16, tir.Evaluate(var)) - outer_let = tir.LetStmt( - var, - 32, - tir.SeqStmt( - [ - tir.Evaluate(var), - inner_let, - tir.Evaluate(var), - ] - ), + inner_seq = tir.SeqStmt( + [ + tir.Bind(var, 16), + tir.Evaluate(var), + ] + ) + outer_seq = tir.SeqStmt( + [ + tir.Bind(var, 32), + tir.Evaluate(var), + inner_seq, + tir.Evaluate(var), + ] ) - before = tir.PrimFunc([], outer_let) + before = tir.PrimFunc([], outer_seq) @T.prim_func(private=True) def expected(): diff --git a/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py b/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py index f0f431b8ec44..0a64d8a176ab 100644 --- a/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py +++ b/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py @@ -135,10 +135,10 @@ def build_tir(): # 2. Let binding: Aptr_dup = packed_echo(Ab.data), then store const into Ab[1] Aptr_dup = tvm.tir.Var("Aptr_dup", "handle") store1 = tvm.tir.BufferStore(Ab, tvm.tir.const(expected_value[1], "float32"), [1]) - let_stmt = tvm.tir.LetStmt(Aptr_dup, packed_echo(Ab.data), store1) + bind_stmt = tvm.tir.Bind(Aptr_dup, packed_echo(Ab.data)) # Combine into sequence - stmt = tvm.tir.SeqStmt([store0, let_stmt]) + stmt = tvm.tir.SeqStmt([store0, bind_stmt, store1]) return tvm.IRModule.from_expr( tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "packed_test") diff --git a/tests/python/tir-transform/test_tir_transform_prim_func_pass.py b/tests/python/tir-transform/test_tir_transform_prim_func_pass.py index 8bee82f02c52..72b168414263 100644 --- a/tests/python/tir-transform/test_tir_transform_prim_func_pass.py +++ b/tests/python/tir-transform/test_tir_transform_prim_func_pass.py @@ -32,7 +32,7 @@ def transform_function(self, func, mod, ctx): x = tvm.tir.Var("x", "int32") y = tvm.tir.Var("y", "int32") b = tvm.tir.decl_buffer((x,), "float32") - stmt = tvm.tir.LetStmt(x, 10, tvm.tir.Evaluate(x + 1)) + stmt = tvm.tir.SeqStmt([tvm.tir.Bind(x, 10), tvm.tir.Evaluate(x + 1)]) func = tvm.tir.PrimFunc([x, y, b], stmt) diff --git a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py index 624f428e2a93..cac1371f5a53 100644 --- a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py +++ b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py @@ -321,8 +321,8 @@ def test_ir_builder_tir_let(): # the let binding generated by IRBuilder let_actual = ib.get() - # the expected Let statement - let_expected = tir.LetStmt(T.int32(), tir.IntImm("int32", 2), tir.Evaluate(0)) + # the expected Bind + Evaluate sequence + let_expected = tir.SeqStmt([tir.Bind(T.int32(), tir.IntImm("int32", 2)), tir.Evaluate(0)]) # Check if the generated ir is expected assert_structural_equal(let_actual, let_expected, map_free_vars=True) From 1c9f460030fa3a28416496ec87194593e66ac66c Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 2 Mar 2026 18:33:27 +0000 Subject: [PATCH 03/34] [TIR] Fix passes and tests for flat BindNode semantics Update TIR passes and tests to work correctly with the flat BindNode model (no body field) where variable scoping is managed via SeqStmt siblings instead of nested tree structure. Pass fixes: - RemoveNoOp: Add SeqStmt handler for back-to-front unused Bind scan - ConvertSSA: Add SeqStmt handler to maintain ScopedRedefine across siblings - StorageRewrite: Push/pop scope entry in BindNode handler - HoistExpression: Merge Bind lifecycle management into SeqStmt handler; only set reached_sequential_node for truly sequential (non-Bind) stmts - SBlockAccessRegionDetector: Defer let_bindings_ erasure to SeqStmt end - TVMScript printer: Add AsDocBodySeqSlice for scoped T.LetStmt form when printing already-defined-var Binds - TVMScript parser: Support doc.Attribute in _duplicate_lhs_check Test updates: - verify_well_formed: Adjust for flat scope semantics - convert_ssa: Update for flattened SeqStmt behavior - tvmscript printer/annotation/syntax_sugar: Update access paths - loop_partition: Fix pre-existing test with incorrect expected output --- python/tvm/script/parser/core/parser.py | 3 + .../analysis/sblock_access_region_detector.cc | 24 ++++- src/s_tir/transform/hoist_expression.cc | 39 ++++++-- src/script/printer/tir/utils.h | 61 +++++++++--- src/tir/transform/ir_utils.cc | 45 +++++++++ src/tir/transform/remove_no_op.cc | 96 +++++++++++++++++-- src/tir/transform/storage_rewrite.cc | 12 ++- .../test_tir_analysis_verify_well_formed.py | 48 +++++++--- .../test_tir_transform_convert_ssa.py | 33 ++++--- .../test_tvmscript_printer_annotation.py | 9 +- .../tvmscript/test_tvmscript_printer_tir.py | 4 +- .../tvmscript/test_tvmscript_syntax_sugar.py | 5 +- 12 files changed, 318 insertions(+), 61 deletions(-) diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py index eb617c85c9bd..d23358d93b22 100644 --- a/python/tvm/script/parser/core/parser.py +++ b/python/tvm/script/parser/core/parser.py @@ -528,6 +528,9 @@ def _duplicate_lhs_check(self, target: doc.expr) -> bool | set[str]: return {target.id} elif isinstance(target, doc.Starred): return self._duplicate_lhs_check(target.value) + elif isinstance(target, doc.Attribute): + # Attribute assignment like packedB.data = ..., treated as rebinding. + return {target.attr} else: self.report_error(target, "Invalid type in assign statement") raise NotImplementedError diff --git a/src/s_tir/analysis/sblock_access_region_detector.cc b/src/s_tir/analysis/sblock_access_region_detector.cc index 1e0025d551c4..61133ecb10c2 100644 --- a/src/s_tir/analysis/sblock_access_region_detector.cc +++ b/src/s_tir/analysis/sblock_access_region_detector.cc @@ -118,6 +118,7 @@ class BlockReadWriteDetector : public StmtExprVisitor { void VisitStmt_(const SBlockRealizeNode* op) override; void VisitStmt_(const BufferStoreNode* op) override; void VisitStmt_(const BindNode* op) override; + void VisitStmt_(const SeqStmtNode* op) override; void VisitExpr_(const BufferLoadNode* op) override; void VisitExpr_(const VarNode* op) override; void VisitExpr_(const CallNode* op) override; @@ -190,9 +191,30 @@ void BlockReadWriteDetector::VisitStmt_(const IfThenElseNode* op) { } void BlockReadWriteDetector::VisitStmt_(const BindNode* op) { + // With flat Bind, the binding persists for subsequent siblings. + // The SeqStmt handler manages the lifecycle; standalone Bind just adds. let_bindings_[op->var.get()] = op->value; StmtVisitor::VisitStmt_(op); - let_bindings_.erase(op->var.get()); + // Note: we do NOT erase here. The SeqStmt handler will erase + // all Bind-defined vars when it finishes processing the sequence. + // For standalone Bind (not in a SeqStmt), the binding persists + // until the parent scope ends. +} + +void BlockReadWriteDetector::VisitStmt_(const SeqStmtNode* op) { + // Track which variables were defined by Bind nodes in this sequence, + // so we can erase them when the sequence ends. + std::vector seq_bindings; + for (size_t i = 0; i < op->seq.size(); ++i) { + if (auto* bind = op->seq[i].as()) { + seq_bindings.push_back(bind->var.get()); + } + VisitStmt(op->seq[i]); + } + // Erase bindings defined in this sequence. + for (auto* var : seq_bindings) { + let_bindings_.erase(var); + } } void BlockReadWriteDetector::VisitExpr_(const CallNode* op) { diff --git a/src/s_tir/transform/hoist_expression.cc b/src/s_tir/transform/hoist_expression.cc index 076bf2b5a442..2c12681a3556 100644 --- a/src/s_tir/transform/hoist_expression.cc +++ b/src/s_tir/transform/hoist_expression.cc @@ -324,11 +324,37 @@ class HoistInfoCollector : public StmtExprVisitor { void VisitStmt_(const BindNode* op) final { VisitBinding(op->var, op->value, HoistedLetBindings::kLetStmt); - Parent::VisitStmt_(op); + // Don't erase here; SeqStmt handler manages the lifecycle. + } - let_var_to_loop_vars.erase(op->var.get()); - let_var_to_let_vars.erase(op->var.get()); + void VisitStmt_(const SeqStmtNode* op) final { + if (active_loops.size()) { + // Only mark as sequential if there are multiple non-Bind statements. + // Bind nodes are variable definitions (equivalent to old LetStmt wrappers) + // and don't introduce true sequential ordering that would prevent hoisting. + int non_bind_count = 0; + for (size_t i = 0; i < op->seq.size(); ++i) { + if (!op->seq[i].as()) { + non_bind_count++; + } + } + if (non_bind_count > 1) { + active_loops.back().reached_sequential_node = true; + } + } + std::vector seq_bind_vars; + for (size_t i = 0; i < op->seq.size(); ++i) { + if (auto* bind = op->seq[i].as()) { + seq_bind_vars.push_back(bind->var.get()); + } + VisitStmt(op->seq[i]); + } + // Erase bindings defined in this sequence. + for (auto* var : seq_bind_vars) { + let_var_to_loop_vars.erase(var); + let_var_to_let_vars.erase(var); + } } void VisitExpr_(const LetNode* op) final { @@ -354,13 +380,6 @@ class HoistInfoCollector : public StmtExprVisitor { Parent::VisitExpr_(op); } - void VisitStmt_(const SeqStmtNode* op) final { - if (active_loops.size()) { - active_loops.back().reached_sequential_node = true; - } - Parent::VisitStmt_(op); - } - // Find the loop above which this expression could be hoisted. If // nullptr, the expression cannot be hoisted. HoistInfo* FindHoistDestination(PrimExpr expr) { diff --git a/src/script/printer/tir/utils.h b/src/script/printer/tir/utils.h index 736b3c62b56a..580fa7b61fc1 100644 --- a/src/script/printer/tir/utils.h +++ b/src/script/printer/tir/utils.h @@ -107,19 +107,17 @@ inline IdDoc DefineBuffer(const tir::Buffer& buffer, const Frame& frame, const I * \param f The frame * \param d The IRDocsifier */ +/*! + * \brief Helper to process remaining SeqStmt children starting at index `start` + * into the given frame's stmts, handling Bind-with-already-defined-var + * by creating scoped T.LetStmt forms. + */ +inline void AsDocBodySeqSlice(const ffi::Array& body, int start, AccessPath p, + TIRFrameNode* f, const IRDocsifier& d); + inline void AsDocBody(const tir::Stmt& stmt, AccessPath p, TIRFrameNode* f, const IRDocsifier& d) { if (const auto* seq_stmt = stmt.as()) { - ffi::Array body = seq_stmt->seq; - for (int i = 0, n = body.size(); i < n; ++i) { - f->allow_concise_scoping = (i == n - 1); - Doc doc = d->AsDoc(body[i], p->Attr("seq")->ArrayItem(i)); - doc->source_paths.push_back(p); - if (const auto* block = doc.as()) { - f->stmts.insert(f->stmts.end(), block->stmts.begin(), block->stmts.end()); - } else { - f->stmts.push_back(Downcast(doc)); - } - } + AsDocBodySeqSlice(seq_stmt->seq, 0, p, f, d); } else { f->allow_concise_scoping = true; Doc doc = d->AsDoc(stmt, p); @@ -131,6 +129,47 @@ inline void AsDocBody(const tir::Stmt& stmt, AccessPath p, TIRFrameNode* f, cons } } +inline void AsDocBodySeqSlice(const ffi::Array& body, int start, AccessPath p, + TIRFrameNode* f, const IRDocsifier& d) { + int n = body.size(); + for (int i = start; i < n; ++i) { + // Check if this is a Bind with an already-defined variable. + // If so, we need to use the scoped T.LetStmt form and wrap + // remaining siblings as the body (for correct roundtrip). + if (const auto* bind = body[i].as()) { + if (d->IsVarDefined(bind->var)) { + // Create a scoped LetStmt form: + // with T.LetStmt(value, var=X): + // + auto bind_p = p->Attr("seq")->ArrayItem(i); + ExprDoc rhs = d->AsDoc(bind->value, bind_p->Attr("value")); + ExprDoc lhs = d->AsDoc(bind->var, bind_p->Attr("var")); + // Collect the remaining siblings as the body + ffi::Array scope_stmts; + // Create a temporary frame for body processing + auto temp_frame = ffi::make_object(); + AsDocBodySeqSlice(body, i + 1, p, temp_frame.get(), d); + scope_stmts = temp_frame->stmts; + + ExprDoc call = TIR(d, "LetStmt")->Call({rhs}, {"var"}, {lhs}); + StmtDoc scope_doc = ScopeDoc(std::nullopt, call, scope_stmts); + scope_doc->source_paths.push_back(p); + f->stmts.push_back(scope_doc); + return; // remaining siblings are inside the scope + } + } + + f->allow_concise_scoping = (i == n - 1); + Doc doc = d->AsDoc(body[i], p->Attr("seq")->ArrayItem(i)); + doc->source_paths.push_back(p); + if (const auto* block = doc.as()) { + f->stmts.insert(f->stmts.end(), block->stmts.begin(), block->stmts.end()); + } else { + f->stmts.push_back(Downcast(doc)); + } + } +} + /*! * \brief Find the top frame in the stack that could place a var definition * \param var The var to be defined diff --git a/src/tir/transform/ir_utils.cc b/src/tir/transform/ir_utils.cc index d54eb1317ddd..719de1c079a7 100644 --- a/src/tir/transform/ir_utils.cc +++ b/src/tir/transform/ir_utils.cc @@ -347,6 +347,9 @@ class IRConvertSSA final : public StmtExprMutator { } Stmt VisitStmt_(const BindNode* op) final { + // Note: ScopedRedefine for Bind must persist across SeqStmt siblings. + // This is handled by VisitStmt_(const SeqStmtNode*) below. + // When visited standalone (not as part of SeqStmt), just do a simple visit. const Var& v = op->var; if (defined_.count(v.get())) { PrimExpr value = this->VisitExpr(op->value); @@ -357,6 +360,48 @@ class IRConvertSSA final : public StmtExprMutator { return StmtExprMutator::VisitStmt_(op); } } + + Stmt VisitStmt_(const SeqStmtNode* op) final { + // Process children sequentially, maintaining ScopedRedefine for Bind nodes + // so that remappings persist for subsequent siblings (mimicking old nested + // LetStmt scope behavior). + std::vector seq_redefines; + ffi::Array new_seq; + bool changed = false; + + for (size_t i = 0; i < op->seq.size(); ++i) { + const Stmt& child = op->seq[i]; + if (auto* bind = child.as()) { + const Var& v = bind->var; + if (defined_.count(v.get())) { + PrimExpr value = this->VisitExpr(bind->value); + seq_redefines.emplace_back(this, v); + Stmt new_bind = Bind(seq_redefines.back().new_var, value); + new_seq.push_back(new_bind); + changed = true; + } else { + defined_.insert(v.get()); + Stmt visited = StmtExprMutator::VisitStmt_(bind); + new_seq.push_back(visited); + changed = changed || !visited.same_as(child); + } + } else { + Stmt visited = VisitStmt(child); + new_seq.push_back(visited); + changed = changed || !visited.same_as(child); + } + } + + // Pop redefines in reverse order (RAII would do this, but let's be explicit) + while (seq_redefines.size()) { + seq_redefines.pop_back(); + } + + if (!changed) { + return ffi::GetRef(op); + } + return SeqStmt(new_seq); + } Stmt VisitStmt_(const ForNode* op) final { const Var& v = op->loop_var; if (defined_.count(v.get())) { diff --git a/src/tir/transform/remove_no_op.cc b/src/tir/transform/remove_no_op.cc index e1c60b2e9707..e3cfad2c2a94 100644 --- a/src/tir/transform/remove_no_op.cc +++ b/src/tir/transform/remove_no_op.cc @@ -94,14 +94,97 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer { Stmt VisitStmt_(const BindNode* op) final { Stmt stmt = Parent::VisitStmt_(op); op = stmt.as(); - // Bind has no body -- removal of unused Bind is handled at SeqStmt level. - // If the value has no side effect, the Bind can potentially be removed. - if (!HasSideEffect(op->value) && SideEffect(op->value) <= CallEffectKind::kPure) { - // A pure Bind with no uses will be cleaned up by dead code elimination. - // Keep it for now; the SeqStmt visitor handles removal. + if (in_seq_stmt_) { + // Inside a SeqStmt: the SeqStmt handler will decide whether to remove + // this Bind based on whether its var is used by subsequent siblings. return stmt; } - return stmt; + // Standalone Bind (not inside a SeqStmt): there's nothing after it + // to use the variable, so it's always dead. + if (HasSideEffect(op->value)) { + return Evaluate(op->value); + } + return Evaluate(0); + } + + Stmt VisitStmt_(const SeqStmtNode* op) final { + // Visit each child individually (not using parent handler, which calls + // SeqStmt::Flatten and may strip Evaluate(0) before we can analyze). + bool prev_in_seq = in_seq_stmt_; + in_seq_stmt_ = true; + ffi::Array visited_seq; + bool any_child_changed = false; + for (size_t i = 0; i < op->seq.size(); ++i) { + Stmt visited_child = VisitStmt(op->seq[i]); + // Flatten any nested SeqStmt children into the sequence. + if (auto* inner_seq = visited_child.as()) { + for (size_t j = 0; j < inner_seq->seq.size(); ++j) { + visited_seq.push_back(inner_seq->seq[j]); + } + any_child_changed = true; + } else { + visited_seq.push_back(visited_child); + any_child_changed = any_child_changed || !visited_child.same_as(op->seq[i]); + } + } + + // Now, remove unused Bind nodes. + // Scan from back to front, tracking which variables are used + // by subsequent siblings. + size_t n = visited_seq.size(); + std::unordered_set suffix_uses; + std::vector removable(n, false); + std::vector has_side_effect_flag(n, false); + + for (int i = static_cast(n) - 1; i >= 0; --i) { + const Stmt& child = visited_seq[i]; + if (auto* bind = child.as()) { + if (suffix_uses.count(bind->var.get()) == 0) { + // Variable not used in any subsequent sibling. + removable[i] = true; + has_side_effect_flag[i] = HasSideEffect(bind->value); + } + // Remove the defined variable from suffix_uses (it's defined here). + suffix_uses.erase(bind->var.get()); + // Add uses from the bind value so earlier Binds defining those vars stay. + VarUseDefAnalyzer value_analyzer({}); + value_analyzer(bind->value); + for (auto& kv : value_analyzer.use_count_) { + suffix_uses.insert(kv.first); + } + } else { + // Collect all variable uses in this non-Bind statement. + VarUseDefAnalyzer analyzer({}); + analyzer(child); + for (auto& kv : analyzer.use_count_) { + suffix_uses.insert(kv.first); + } + } + } + + // Build the new sequence, skipping removable Binds. + bool any_removed = false; + ffi::Array new_seq; + for (size_t i = 0; i < n; ++i) { + if (removable[i]) { + any_removed = true; + if (has_side_effect_flag[i]) { + auto* bind = visited_seq[i].as(); + new_seq.push_back(Evaluate(bind->value)); + } + // else: pure Bind with unused var — remove entirely. + } else { + new_seq.push_back(visited_seq[i]); + } + } + + in_seq_stmt_ = prev_in_seq; + + if (!any_removed && !any_child_changed) { + return ffi::GetRef(op); + } + + return SeqStmt::Flatten(new_seq); } Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == "pragma_debug_skip_region") { @@ -295,6 +378,7 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer { std::unordered_map var_range_map_; std::optional touch_pattern_; const StmtNode* context_; + bool in_seq_stmt_{false}; }; Stmt RemoveNoOp(Stmt stmt, arith::Analyzer* analyzer, std::optional touch_pattern, diff --git a/src/tir/transform/storage_rewrite.cc b/src/tir/transform/storage_rewrite.cc index 40235aa98e11..d315bc5d8194 100644 --- a/src/tir/transform/storage_rewrite.cc +++ b/src/tir/transform/storage_rewrite.cc @@ -213,7 +213,17 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { void VisitStmt_(const AssertStmtNode* op) final { VisitNewScope(op); } - void VisitStmt_(const BindNode* op) final { StmtExprVisitor::VisitStmt_(op); } + void VisitStmt_(const BindNode* op) final { + scope_.push_back(StmtEntry()); + // visit subexpr (the value may contain BufferLoad) + StmtExprVisitor::VisitStmt_(op); + StmtEntry e = scope_.back(); + scope_.pop_back(); + if (e.touched.size() != 0) { + e.stmt = op; + linear_seq_.push_back(e); + } + } // linearized access sequence. std::vector linear_seq_; diff --git a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py index 9c1bcc545b7e..d11281f79639 100644 --- a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py +++ b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py @@ -60,14 +60,25 @@ def element_wise( def test_error_for_out_of_scope_usage(): - """A variable may not be used after its scope ends""" + """A variable may not be used after its scope ends. - @T.prim_func(check_well_formed=False) - def func(): - i = T.int32() - with T.LetStmt(42, var=i): - T.evaluate(i) - T.evaluate(i) + With flat Bind semantics, Bind vars are visible to all subsequent + siblings in the same SeqStmt. True out-of-scope usage occurs when + the Bind is inside a child scope (e.g., ForNode body) and the + variable is used outside that scope. + """ + i = tvm.tir.Var("i", "int32") + # Bind i inside a For loop body + for_stmt = tvm.tir.For( + tvm.tir.Var("j", "int32"), + 0, + 1, + tvm.tir.ForKind.SERIAL, + tvm.tir.SeqStmt([tvm.tir.Bind(i, 42), tvm.tir.Evaluate(i)]), + ) + # Use i outside the For loop — this is out of scope + body = tvm.tir.SeqStmt([for_stmt, tvm.tir.Evaluate(i)]) + func = tvm.tir.PrimFunc([], body) with pytest.raises( ValueError, match="Invalid use of undefined variable i at .* no longer in-scope." @@ -92,7 +103,12 @@ def func(): def test_error_for_repeated_binding(): - """A variable may not be re-defined after the scope ends""" + """A variable may not be re-defined in the same flat scope. + + With flat Bind semantics, sequential Bind of the same variable in the + same SeqStmt is treated as a nested redefinition (since the first Bind's + scope extends to all subsequent siblings). + """ @T.prim_func(check_well_formed=False) def func(): @@ -102,7 +118,7 @@ def func(): with T.LetStmt(17, var=i): T.evaluate(i) - with pytest.raises(ValueError, match="multiple definitions of variable i"): + with pytest.raises(ValueError, match="multiple nested definitions of variable i"): tvm.tir.analysis.verify_well_formed(func) @@ -269,6 +285,10 @@ def test_error_message_without_previous_definition_location(): This tests the scenario where it == end(), so the error message should contain 'TIR is ill-formed, due to multiple definitions of variable' but should NOT contain 'It was first defined at' since the iterator is invalid. + + With flat Bind semantics, sequential redefinitions in the same SeqStmt + are treated as nested definitions, and the first definition location + IS known, so the message includes location info. """ @T.prim_func(check_well_formed=False) @@ -287,7 +307,7 @@ def func(): error_msg = str(exc_info.value) assert "TIR is ill-formed" in error_msg - assert "multiple definitions of variable" in error_msg + assert "multiple nested definitions of variable" in error_msg def test_error_message_with_previous_definition_location(): @@ -322,7 +342,9 @@ def func(): def test_sequential_redefinition_with_location(): """Test case 2b: Sequential redefinition that includes location info - This tests the previously_defined_ path where it != end() + This tests the previously_defined_ path where it != end(). + With flat Bind semantics, sequential redefinitions in the same SeqStmt + are treated as nested definitions with location info. """ @T.prim_func(check_well_formed=False) @@ -341,9 +363,9 @@ def func(): error_msg = str(exc_info.value) assert "TIR is ill-formed" in error_msg - assert "multiple definitions of variable" in error_msg + assert "multiple nested definitions of variable" in error_msg assert "It was first defined at" in error_msg - assert "later re-defined at" in error_msg + assert "was re-defined at" in error_msg def test_buffer_in_buffer_map_is_well_formed(): diff --git a/tests/python/tir-transform/test_tir_transform_convert_ssa.py b/tests/python/tir-transform/test_tir_transform_convert_ssa.py index 864f238385d2..de7cd764a7c6 100644 --- a/tests/python/tir-transform/test_tir_transform_convert_ssa.py +++ b/tests/python/tir-transform/test_tir_transform_convert_ssa.py @@ -53,16 +53,20 @@ def expected(): def test_reuse_in_nested_let_stmt(): - """De-dup nested bindings + """De-dup sequential bindings of the same variable. - Use of a variable with nested bindings is de-duplicated to refer - to the inner-most binding that contains the use site. + In the flat Bind model, all Binds are siblings in a SeqStmt. A second + Bind of the same variable redefines it for all subsequent siblings. + ConvertSSA should create a new variable for the second binding and + update all subsequent uses to refer to the new variable. """ # Manually construct the PrimFunc body, as SSA violations are # not valid TIR, and may not be expressible in future versions - # of TVMSCript. + # of TVMScript. var = tir.Var("var", "int32") + # Note: nested SeqStmt is flattened by the IR builder, so the input + # is actually a flat SeqStmt with 5 elements. inner_seq = tir.SeqStmt( [ tir.Bind(var, 16), @@ -79,13 +83,20 @@ def test_reuse_in_nested_let_stmt(): ) before = tir.PrimFunc([], outer_seq) - @T.prim_func(private=True) - def expected(): - with T.LetStmt(T.int32(32)) as outer: - T.evaluate(outer) - with T.LetStmt(T.int32(16)) as inner: - T.evaluate(inner) - T.evaluate(outer) + # In the flat model, the second Bind(var, 16) redefines var for + # ALL subsequent siblings including the last Evaluate. + var1 = tir.Var("var", "int32") + var2 = tir.Var("var", "int32") + expected_body = tir.SeqStmt( + [ + tir.Bind(var1, 32), + tir.Evaluate(var1), + tir.Bind(var2, 16), + tir.Evaluate(var2), + tir.Evaluate(var2), + ] + ) + expected = tir.PrimFunc([], expected_body) mod = tvm.IRModule.from_expr(before) mod = tvm.tir.transform.ConvertSSA()(mod) diff --git a/tests/python/tvmscript/test_tvmscript_printer_annotation.py b/tests/python/tvmscript/test_tvmscript_printer_annotation.py index 08565ce074f7..13ace54ff7c2 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_annotation.py +++ b/tests/python/tvmscript/test_tvmscript_printer_annotation.py @@ -95,9 +95,11 @@ def _func(): y = x + 1 T.evaluate(y - 1) + # With flat Bind, the body is SeqStmt([Bind(x,1), Bind(y,x+1), Evaluate(y-1)]). + # Annotate the second Bind (y = x + 1). result = _func.with_attr("global_symbol", "main").script( obj_to_annotate={ - _func.body.body: "annotation 1", + _func.body.seq[1]: "annotation 1", } ) assert ( @@ -107,7 +109,6 @@ def _func(): @T.prim_func def main(): x: T.int32 = 1 - # annotation 1 - with T.LetStmt(x + 1) as y: - T.evaluate(y - 1)""" + y: T.int32 = x + 1 # annotation 1 + T.evaluate(y - 1)""" ) diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py b/tests/python/tvmscript/test_tvmscript_printer_tir.py index 3409b731b364..62d49f6c242f 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_tir.py +++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py @@ -261,8 +261,8 @@ def test_let_stmt(): _assert_print( obj, """ -with T.LetStmt(T.float32(10.0)) as v: - T.evaluate(0) +v: T.float32 = T.float32(10.0) +T.evaluate(0) """, ) diff --git a/tests/python/tvmscript/test_tvmscript_syntax_sugar.py b/tests/python/tvmscript/test_tvmscript_syntax_sugar.py index 08bf90123b13..0205758b90ca 100644 --- a/tests/python/tvmscript/test_tvmscript_syntax_sugar.py +++ b/tests/python/tvmscript/test_tvmscript_syntax_sugar.py @@ -447,7 +447,7 @@ def func(i: T.int32): def test_preserve_variable_name(): - """Use variable name when generating tir::LetStmt""" + """Use variable name when generating tir::Bind""" @T.prim_func def func(): @@ -455,7 +455,8 @@ def func(): j = i // 4 T.evaluate(j) - var_name = func.body.body.var.name + # With flat Bind, the for body is SeqStmt([Bind(j, i//4), Evaluate(j)]) + var_name = func.body.body.seq[0].var.name assert var_name == "j" From a596311e793c77902dc59e59296903dd6f955cd0 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 2 Mar 2026 18:48:33 +0000 Subject: [PATCH 04/34] [REFACTOR][TIR] Complete LetStmt-to-Bind migration: fix remaining issues - Update comments in var.h and Python functor.py to reference BindNode instead of LetStmtNode - Apply clang-format fixes to files modified by the BindNode migration - Remove unused Bind import in functor.py (LetStmt alias is used instead) - Remove extra blank lines left over from migration in analysis/rewriter files --- include/tvm/tir/var.h | 2 +- python/tvm/tir/functor.py | 13 +++---- src/relax/op/tensor/inspect.cc | 49 ++++++++++++------------- src/tir/analysis/var_use_def_analysis.h | 1 - src/tir/ir/data_type_rewriter.cc | 1 - src/tir/ir/tir_visitor_with_path.cc | 1 - src/tir/transform/tvm_ffi_binder.cc | 8 ++-- 7 files changed, 34 insertions(+), 41 deletions(-) diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index e83064b86489..b4106f2d2e9f 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -42,7 +42,7 @@ namespace tir { * - Allocate * - For * - Let - * - LetStmt + * - Bind */ class VarNode : public PrimExprNode { public: diff --git a/python/tvm/tir/functor.py b/python/tvm/tir/functor.py index 82b1f29aec63..620b433b21fb 100644 --- a/python/tvm/tir/functor.py +++ b/python/tvm/tir/functor.py @@ -70,7 +70,6 @@ Evaluate, For, IfThenElse, - Bind, LetStmt, SBlock, SBlockRealize, @@ -375,14 +374,14 @@ def visit_if_then_else_(self, op: IfThenElse) -> None: _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore def visit_let_stmt_(self, op: LetStmt) -> None: - """Visit LetStmt. - Users can customize this function to overwrite VisitStmt_(const LetStmtNode* op) + """Visit Bind (LetStmt alias). + Users can customize this function to overwrite VisitStmt_(const BindNode* op) on the C++ side. Parameters ---------- op : LetStmt - The LetStmt to be visited. + The Bind node to be visited. """ print("visit_let_stmt_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore @@ -1198,14 +1197,14 @@ def visit_if_then_else_(self, op: IfThenElse) -> Stmt: return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op) # type: ignore def visit_let_stmt_(self, op: LetStmt) -> Stmt: - """Visit LetStmt. - Users can customize this function to overwrite VisitStmt_(const LetStmtNode* op) + """Visit Bind (LetStmt alias). + Users can customize this function to overwrite VisitStmt_(const BindNode* op) on the C++ side. Parameters ---------- op : LetStmt - The LetStmt to be visited. + The Bind node to be visited. Returns ------- diff --git a/src/relax/op/tensor/inspect.cc b/src/relax/op/tensor/inspect.cc index bcf7e2e354f7..8ba6c2e645cc 100644 --- a/src/relax/op/tensor/inspect.cc +++ b/src/relax/op/tensor/inspect.cc @@ -92,12 +92,11 @@ tir::PrimFunc GetDLTensorField(tir::builtin::TVMStructFieldKind field, DataType tir::Var value("value", field_dtype); - tir::Stmt body = tir::SeqStmt({ - tir::Bind(value, - tir::Call(field_dtype, tir::builtin::tvm_struct_get(), - {dlpack_handle, IntImm(DataType::Int(32), 0), - IntImm(DataType::Int(32), field)})), - tir::Evaluate(tvm::ret(value))}); + tir::Stmt body = + tir::SeqStmt({tir::Bind(value, tir::Call(field_dtype, tir::builtin::tvm_struct_get(), + {dlpack_handle, IntImm(DataType::Int(32), 0), + IntImm(DataType::Int(32), field)})), + tir::Evaluate(tvm::ret(value))}); DictAttrs attrs({{"tir.is_scheduled", true}, {"tir.is_host", true}}); @@ -306,26 +305,24 @@ Expr LegalizeTensorShape(const BlockBuilder& bb, const Call& call) { tir::Var extent("extent", field_dtype); - tir::Stmt body = tir::SeqStmt({ - tir::AssertStmt(0 <= axis, tir::StringImm("RuntimeError"), - {tir::StringImm("Specified axis may not be negative")}), - tir::Bind(ndim, - tir::Call(ndim->dtype, tir::builtin::tvm_struct_get(), - {dlpack_handle, IntImm(DataType::Int(32), 0), - IntImm(DataType::Int(32), - tir::builtin::TVMStructFieldKind::kDLTensorNDim)})), - tir::AssertStmt( - axis < tvm::cast(axis->dtype, ndim), tir::StringImm("RuntimeError"), - {tir::StringImm( - "Specified axis may not be larger than the tensor's dimensionality")}), - tir::Bind(shape_buffer->data, - tir::Call(DataType::Handle(), tir::builtin::tvm_struct_get(), - {dlpack_handle, IntImm(DataType::Int(32), 0), - IntImm(DataType::Int(32), - tir::builtin::TVMStructFieldKind::kDLTensorShape)})), - tir::DeclBuffer(shape_buffer, - tir::SeqStmt({tir::Bind(extent, tir::BufferLoad(shape_buffer, {axis})), - tir::Evaluate(tvm::ret(extent))}))}); + tir::Stmt body = tir::SeqStmt( + {tir::AssertStmt(0 <= axis, tir::StringImm("RuntimeError"), + {tir::StringImm("Specified axis may not be negative")}), + tir::Bind(ndim, tir::Call(ndim->dtype, tir::builtin::tvm_struct_get(), + {dlpack_handle, IntImm(DataType::Int(32), 0), + IntImm(DataType::Int(32), + tir::builtin::TVMStructFieldKind::kDLTensorNDim)})), + tir::AssertStmt( + axis < tvm::cast(axis->dtype, ndim), tir::StringImm("RuntimeError"), + {tir::StringImm("Specified axis may not be larger than the tensor's dimensionality")}), + tir::Bind(shape_buffer->data, + tir::Call(DataType::Handle(), tir::builtin::tvm_struct_get(), + {dlpack_handle, IntImm(DataType::Int(32), 0), + IntImm(DataType::Int(32), + tir::builtin::TVMStructFieldKind::kDLTensorShape)})), + tir::DeclBuffer(shape_buffer, + tir::SeqStmt({tir::Bind(extent, tir::BufferLoad(shape_buffer, {axis})), + tir::Evaluate(tvm::ret(extent))}))}); DictAttrs attrs({{"tir.is_scheduled", true}, {"tir.is_host", true}}); diff --git a/src/tir/analysis/var_use_def_analysis.h b/src/tir/analysis/var_use_def_analysis.h index 2255ed5a63df..a887acb1d3c4 100644 --- a/src/tir/analysis/var_use_def_analysis.h +++ b/src/tir/analysis/var_use_def_analysis.h @@ -59,7 +59,6 @@ class VarUseDefAnalyzer : public StmtExprVisitor { void VisitStmt_(const BindNode* op) final; - void VisitStmt_(const ForNode* op) final; void VisitStmt_(const AllocBufferNode* op) final; diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc index 6c781e109546..37ae4f70b2cc 100644 --- a/src/tir/ir/data_type_rewriter.cc +++ b/src/tir/ir/data_type_rewriter.cc @@ -540,7 +540,6 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const BindNode* op) { return Bind(var, value, bind_stmt->span); } - #define TVM_DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ PrimExpr IndexDataTypeRewriter::VisitExpr_(const OP* op) { \ bool is_enabled = is_enabled_; \ diff --git a/src/tir/ir/tir_visitor_with_path.cc b/src/tir/ir/tir_visitor_with_path.cc index ea6fd20de811..6436a2869b5d 100644 --- a/src/tir/ir/tir_visitor_with_path.cc +++ b/src/tir/ir/tir_visitor_with_path.cc @@ -180,7 +180,6 @@ void TIRVisitorWithPath::VisitStmt_(const BindNode* op, AccessPath path) { // Scope tracking for BindNode is handled at the SeqStmt level by callers. } - void TIRVisitorWithPath::VisitStmt_(const AttrStmtNode* op, AccessPath path) { Visit(op->value, path->Attr("value")); diff --git a/src/tir/transform/tvm_ffi_binder.cc b/src/tir/transform/tvm_ffi_binder.cc index 6bf5e90cca50..9b37f886c24f 100644 --- a/src/tir/transform/tvm_ffi_binder.cc +++ b/src/tir/transform/tvm_ffi_binder.cc @@ -493,10 +493,10 @@ void TVMFFIABIBuilder::DecodeParam(int param_index) { // Extract type_index from packed_args Var type_index(param->name_hint + ".type_index", DataType::Int(32)); - init_nest_.push_back(Bind(type_index, - tir::Call(DataType::Int(32), builtin::tvm_struct_get(), - {v_packed_args_, IntImm(DataType::Int(32), param_index), - IntImm(DataType::Int(32), builtin::kTVMFFIAnyTypeIndex)}))); + init_nest_.push_back( + Bind(type_index, tir::Call(DataType::Int(32), builtin::tvm_struct_get(), + {v_packed_args_, IntImm(DataType::Int(32), param_index), + IntImm(DataType::Int(32), builtin::kTVMFFIAnyTypeIndex)}))); // Type-check and load value via per-dtype dispatch PrimExpr arg_value; From 72c40434bf5c1d99a0bf3c33ba3c4831ac992d99 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 2 Mar 2026 19:49:44 +0000 Subject: [PATCH 05/34] [REFACTOR][TIR] Cleanup: rename LetStmt references to Bind, use TVM_FFI_UNREACHABLE MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace all 8 `__builtin_unreachable()` calls with `TVM_FFI_UNREACHABLE()`: src/s_tir/transform/inject_virtual_thread.cc, src/script/printer/relax/distributed.cc, src/script/printer/tir/stmt.cc, src/target/source/codegen_c.cc, src/tir/transform/vectorize_loop.cc, src/tir/transform/lower_tvm_builtin.cc, include/tvm/script/printer/ir_docsifier_functor.h, include/tvm/tir/stmt.h - Rename `kLetStmt` enum value → `kBind` in HoistedLetBindings (C++ and Python) (src/s_tir/transform/hoist_expression.cc, python/tvm/tir/transform/transform.py) - Rename `LetStmt()` → `Bind()` in script/ir_builder/tir: - C++ function in ir.h and ir.cc; keep `LetStmt` as a deprecated inline alias - Register `"script.ir_builder.tir.Bind"` as primary; keep `LetStmt` as alias - Python ir.py: add `Bind()` as primary function; `LetStmt()` delegates to it - Update stale `LetStmt` mentions in comments and docstrings to `Bind`: src/s_tir/schedule/analysis/reducer.cc, src/s_tir/schedule/primitive/reduction.cc, src/s_tir/transform/hoist_expression.cc, src/tir/ir/specialize.cc, src/tir/transform/common_subexpr_elim.cc, src/tir/transform/tvm_ffi_binder.h, src/tir/transform/ir_utils.cc, src/te/operation/create_primfunc.cc, include/tvm/tir/stmt.h, python/tvm/tir/stmt.py, python/tvm/tir/functor.py - Clean up `src/script/printer/tir/utils.h`: remove `AsDocBodySeqSlice` helper that used `TIR(d, "LetStmt")` scoped form; inline loop directly in `AsDocBody` (Bind is flat-assignment, no scoped form needed) --- include/tvm/script/ir_builder/tir/ir.h | 16 +++-- .../tvm/script/printer/ir_docsifier_functor.h | 4 +- include/tvm/tir/stmt.h | 4 +- python/tvm/script/ir_builder/tir/ir.py | 36 +++++++++-- python/tvm/tir/functor.py | 4 +- python/tvm/tir/stmt.py | 6 +- python/tvm/tir/transform/transform.py | 6 +- src/s_tir/schedule/analysis/reducer.cc | 16 ++--- src/s_tir/schedule/primitive/reduction.cc | 2 +- src/s_tir/transform/hoist_expression.cc | 10 +-- src/s_tir/transform/inject_virtual_thread.cc | 2 +- src/script/ir_builder/tir/ir.cc | 5 +- src/script/printer/relax/distributed.cc | 2 +- src/script/printer/tir/stmt.cc | 7 +-- src/script/printer/tir/utils.h | 61 ++++--------------- src/target/source/codegen_c.cc | 2 +- src/te/operation/create_primfunc.cc | 4 +- src/tir/ir/specialize.cc | 2 +- src/tir/transform/common_subexpr_elim.cc | 4 +- src/tir/transform/ir_utils.cc | 2 +- src/tir/transform/lower_tvm_builtin.cc | 2 +- src/tir/transform/tvm_ffi_binder.h | 2 +- src/tir/transform/vectorize_loop.cc | 2 +- 23 files changed, 97 insertions(+), 104 deletions(-) diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index b7b6aa8f3a47..9b04b2fe635a 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -294,16 +294,24 @@ AssertFrame Assert(PrimExpr condition, ffi::String error_kind, ffi::Array message_parts); /*! - * \brief The let binding. + * \brief Create a Bind (variable binding). * \param value The value to be bound. - * \param type_annotation The type annotation of the let binding. + * \param type_annotation The type annotation of the binding. * Usually it is used for fine-grained var typing, * particularly, PointerType. * \param var The variable to be bound. If not specified, a new variable will be created. * \return The created LetFrame. */ -LetFrame LetStmt(PrimExpr value, ffi::Optional type_annotation = std::nullopt, - ffi::Optional var = std::nullopt); +LetFrame Bind(PrimExpr value, ffi::Optional type_annotation = std::nullopt, + ffi::Optional var = std::nullopt); + +/*! + * \brief Deprecated alias for Bind(). Use Bind() instead. + */ +inline LetFrame LetStmt(PrimExpr value, ffi::Optional type_annotation = std::nullopt, + ffi::Optional var = std::nullopt) { + return Bind(value, type_annotation, var); +} /*! * \brief The allocate node. diff --git a/include/tvm/script/printer/ir_docsifier_functor.h b/include/tvm/script/printer/ir_docsifier_functor.h index 211e65510e81..500fa8b5e21f 100644 --- a/include/tvm/script/printer/ir_docsifier_functor.h +++ b/include/tvm/script/printer/ir_docsifier_functor.h @@ -84,9 +84,7 @@ class IRDocsifierFunctor { << runtime::Object::TypeIndex2Key(type_index) << " (token: " << token << ")" << ". ObjectType: " << obj->GetTypeKey() << ". Object: " << obj; -#if defined(__GNUC__) || defined(__clang__) - __builtin_unreachable(); -#endif + TVM_FFI_UNREACHABLE(); } /*! diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index a089d8b56287..ae30006c51fe 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -70,7 +70,7 @@ class Stmt : public ObjectRef { /*! * \brief Bind a variable to a value in the enclosing scope. * - * Unlike LetStmt, BindNode has no body field. The bound variable is visible + * BindNode has no body field. The bound variable is visible * in all subsequent statements within the same enclosing scope (SeqStmt, * ForNode.body, etc.). This enables flat (non-nested) IR sequences. */ @@ -984,7 +984,7 @@ inline const char* ForKind2String(ForKind t) { return "thread_binding"; } TVM_FFI_THROW(InternalError) << "Unknown ForKind" << t; - __builtin_unreachable(); + TVM_FFI_UNREACHABLE(); } } // namespace tir diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 26325ee74244..0d6aa084ed17 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -983,20 +983,20 @@ def Assert(condition: PrimExpr, message, error_kind: str = "RuntimeError") -> fr return _ffi_api.Assert(condition, error_kind, message) # type: ignore[attr-defined] # pylint: disable=no-member -def LetStmt( # pylint: disable=invalid-name +def Bind( # pylint: disable=invalid-name value: PrimExpr, type_annotation: Type | None = None, # pylint: disable=redefined-outer-name *, var: Var | None = None, # pylint: disable=redefined-outer-name ) -> frame.LetFrame: - """Create a LetStmt binding + """Create a Bind (variable binding). Parameters ---------- value : PrimExpr The value to be bound. type_annotation : Optional[Type] = None - The type annotation of the let binding. Usually it is used for fine-grained var typing, + The type annotation of the binding. Usually it is used for fine-grained var typing, particularly, PointerType. var : Optional[Var] = None The variable to bind. If not specified, a new variable will be created. @@ -1011,7 +1011,32 @@ def LetStmt( # pylint: disable=invalid-name type_annotation = type_annotation() if isinstance(type_annotation, Var): type_annotation = type_annotation.type_annotation - return _ffi_api.LetStmt(value, type_annotation, var) # type: ignore[attr-defined] # pylint: disable=no-member + return _ffi_api.Bind(value, type_annotation, var) # type: ignore[attr-defined] # pylint: disable=no-member + + +def LetStmt( # pylint: disable=invalid-name + value: PrimExpr, + type_annotation: Type | None = None, # pylint: disable=redefined-outer-name + *, + var: Var | None = None, # pylint: disable=redefined-outer-name +) -> frame.LetFrame: + """Deprecated alias for Bind(). Use T.Bind() instead. + + Parameters + ---------- + value : PrimExpr + The value to be bound. + type_annotation : Optional[Type] = None + The type annotation of the binding. + var : Optional[Var] = None + The variable to bind. If not specified, a new variable will be created. + + Returns + ------- + let_frame : frame.LetFrame + The result LetFrame. + """ + return Bind(value, type_annotation, var=var) def Let( # pylint: disable=invalid-name @@ -1052,7 +1077,7 @@ def let( def let_expr(v: Var, value: PrimExpr, body: PrimExpr) -> PrimExpr: return tir.Let(v, value, body) - @deprecated("T.let", "T.LetStmt") + @deprecated("T.let", "T.Bind") def let_stmt(v: Var, value: PrimExpr) -> frame.LetFrame: return _ffi_api.LegacyLetStmt(v, value) # type: ignore[attr-defined] # pylint: disable=no-member @@ -2343,6 +2368,7 @@ def wrapped(*args, **kwargs): "Call", "CallEffectKind", "let", + "Bind", "LetStmt", "Let", "IterVar", diff --git a/python/tvm/tir/functor.py b/python/tvm/tir/functor.py index 620b433b21fb..bb91138edc86 100644 --- a/python/tvm/tir/functor.py +++ b/python/tvm/tir/functor.py @@ -374,7 +374,7 @@ def visit_if_then_else_(self, op: IfThenElse) -> None: _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore def visit_let_stmt_(self, op: LetStmt) -> None: - """Visit Bind (LetStmt alias). + """Visit Bind (LetStmt is a backward-compat alias for Bind). Users can customize this function to overwrite VisitStmt_(const BindNode* op) on the C++ side. @@ -1197,7 +1197,7 @@ def visit_if_then_else_(self, op: IfThenElse) -> Stmt: return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op) # type: ignore def visit_let_stmt_(self, op: LetStmt) -> Stmt: - """Visit Bind (LetStmt alias). + """Visit Bind (LetStmt is a backward-compat alias for Bind). Users can customize this function to overwrite VisitStmt_(const BindNode* op) on the C++ side. diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index 0985f30b259b..91d9f43d6e2a 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -49,7 +49,7 @@ class Bind(Stmt): """Bind node. Bind a variable to a value in the enclosing scope. - Unlike the deprecated LetStmt, Bind has no body field. + Bind has no body field (unlike the old LetStmt which required a nested body). The bound variable is visible in all subsequent statements within the same enclosing scope (SeqStmt, ForNode.body, etc.). @@ -78,8 +78,8 @@ def __init__(self, var: Var, value: PrimExpr, span: Span | None = None) -> None: ) -# Deprecated: use Bind instead. -# LetStmt(var, value, body) now returns SeqStmt(Bind(var, value), body). +# Deprecated alias: use Bind instead. +# For backward compat: LetStmt(var, value, body) returns SeqStmt(Bind(var, value), body). LetStmt = Bind diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index ceee1edec0ce..7e8afe53138b 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -472,13 +472,13 @@ class HoistedLetBindings(enum.Flag): RequiredByConditional = 1 """ Bindings that are used by a hoisted conditional """ - LetStmt = 2 - """ Bindings occurring in LetStmt """ + Bind = 2 + """ Bindings occurring in Bind nodes """ LetExpr = 4 """ Bindings occurring in Let expressions """ - All = RequiredByConditional | LetStmt | LetExpr + All = RequiredByConditional | Bind | LetExpr """ Enable all hoisting of let bindings """ diff --git a/src/s_tir/schedule/analysis/reducer.cc b/src/s_tir/schedule/analysis/reducer.cc index b173103454ad..6ba93769c6ce 100644 --- a/src/s_tir/schedule/analysis/reducer.cc +++ b/src/s_tir/schedule/analysis/reducer.cc @@ -293,11 +293,11 @@ static const char* kRFactorCrossThreadReductionApplicableBlockDef = R"(Definition of a reduction block that is applicable by RFactor and Cross-Thread Reduction: 1) The block init should be a single BufferStore or a SeqStmt of BufferStores 2) The buffers initialized in the block init should be all different -3) The number of consecutive LetStmts in the block body (if any) should equal the number of BufferStores in the block init -4) The variables of the LetStmts in the block body should be all different -5) The body of the innermost LetStmt should be a single BufferStore or a SeqStmt of BufferStores -6) The number of BufferStores under the block body should equal the number of BufferStores in the block init, and thereby equal the number of LetStmts above -7) The variables bound by the LetStmts in the block body must all directly serve as values of the BufferStores inside, and the stored values of the BufferStores can only be those variables +3) The number of consecutive Binds in the block body (if any) should equal the number of BufferStores in the block init +4) The variables of the Binds in the block body should be all different +5) The statement after the innermost Bind should be a single BufferStore or a SeqStmt of BufferStores +6) The number of BufferStores under the block body should equal the number of BufferStores in the block init, and thereby equal the number of Binds above +7) The variables bound by the Binds in the block body must all directly serve as values of the BufferStores inside, and the stored values of the BufferStores can only be those variables 8) The variables stored by the BufferStores in the block body should be all different 9) The buffers written by the BufferStores in the block body should be all different 10) The buffers initialized in the block init and written in the block body should match @@ -343,11 +343,11 @@ void ErrorRFactorCrossThreadReductionNotApplicable(const ffi::Optional let_vars; let_vars.reserve(n_buffers_); for (int i = 0; i < n_buffers_; ++i) { diff --git a/src/s_tir/transform/hoist_expression.cc b/src/s_tir/transform/hoist_expression.cc index 2c12681a3556..eab7dcc2e77f 100644 --- a/src/s_tir/transform/hoist_expression.cc +++ b/src/s_tir/transform/hoist_expression.cc @@ -53,7 +53,7 @@ enum class HoistedConditionals : int { enum class HoistedLetBindings : int { kNone = 0, kRequiredByCondition = (1 << 0), - kLetStmt = (1 << 1), + kBind = (1 << 1), kLetExpr = (1 << 2), }; @@ -72,7 +72,7 @@ struct HoistExpressionConfigNode : public AttrsNodeReflAdapter(HoistedLetBindings::kRequiredByCondition) | - static_cast(HoistedLetBindings::kLetStmt) | + static_cast(HoistedLetBindings::kBind) | static_cast(HoistedLetBindings::kLetExpr))); } @@ -147,7 +147,7 @@ class HoistInfoCollector : public StmtExprVisitor { bool all_required_bindings_are_hoisted = required_let_bindings.empty() || config->FlagSet(HoistedLetBindings::kRequiredByCondition) || - config->FlagSet(HoistedLetBindings::kLetStmt); + config->FlagSet(HoistedLetBindings::kBind); bool valid_block_var_usage = config->FlagSet(HoistedConditionals::kUsingBlockVar) || !uses_block_var; @@ -174,7 +174,7 @@ class HoistInfoCollector : public StmtExprVisitor { // The For or AttrStmt that defines the loop var. Stmt loop_def; - // Bindings defined in LetStmt inside the for-loop whose value + // Bindings defined in Bind nodes inside the for-loop whose value // does not depend on the loop variable. These can be hoisted // outside this for-loop. std::vector let_bindings; @@ -323,7 +323,7 @@ class HoistInfoCollector : public StmtExprVisitor { } void VisitStmt_(const BindNode* op) final { - VisitBinding(op->var, op->value, HoistedLetBindings::kLetStmt); + VisitBinding(op->var, op->value, HoistedLetBindings::kBind); Parent::VisitStmt_(op); // Don't erase here; SeqStmt handler manages the lifecycle. } diff --git a/src/s_tir/transform/inject_virtual_thread.cc b/src/s_tir/transform/inject_virtual_thread.cc index 9db3d8b91c17..97f6ff87175c 100644 --- a/src/s_tir/transform/inject_virtual_thread.cc +++ b/src/s_tir/transform/inject_virtual_thread.cc @@ -360,7 +360,7 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { Stmt VisitStmt_(const WhileNode* op) final { // TODO(masahi): What should we do for While nodes? TVM_FFI_THROW(InternalError) << "WhileNode in InjectVirtualThread not supported yet"; - __builtin_unreachable(); + TVM_FFI_UNREACHABLE(); } // Seq diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index e353b4184334..bf515f832142 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -460,7 +460,7 @@ AssertFrame Assert(PrimExpr condition, ffi::String error_kind, return AssertFrame(n); } -LetFrame LetStmt(PrimExpr value, ffi::Optional type_annotation, ffi::Optional var) { +LetFrame Bind(PrimExpr value, ffi::Optional type_annotation, ffi::Optional var) { ObjectPtr n = ffi::make_object(); if (var.defined()) { n->var = var.value(); @@ -753,7 +753,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def("script.ir_builder.tir.ThreadBinding", ThreadBinding) .def("script.ir_builder.tir.Grid", Grid) .def("script.ir_builder.tir.Assert", Assert) - .def("script.ir_builder.tir.LetStmt", LetStmt) + .def("script.ir_builder.tir.Bind", Bind) + .def("script.ir_builder.tir.LetStmt", Bind) // backward-compat alias .def("script.ir_builder.tir.LegacyLetStmt", LegacyLetStmt) .def("script.ir_builder.tir.Allocate", Allocate) .def("script.ir_builder.tir.Attr", Attr) diff --git a/src/script/printer/relax/distributed.cc b/src/script/printer/relax/distributed.cc index 3f64c1002302..5294eb43a842 100644 --- a/src/script/printer/relax/distributed.cc +++ b/src/script/printer/relax/distributed.cc @@ -121,7 +121,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } } TVM_FFI_THROW(InternalError) << "Cannot find device mesh in global infos"; - __builtin_unreachable(); + TVM_FFI_UNREACHABLE(); } }); diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc index cf41f26317b5..59948fb239ea 100644 --- a/src/script/printer/tir/stmt.cc +++ b/src/script/printer/tir/stmt.cc @@ -49,7 +49,7 @@ bool AllowConciseScoping(const IRDocsifier& d, const ObjectRef& obj) { return f->allow_concise_scoping; } TVM_FFI_THROW(NotImplementedError) << "fragment printing"; - __builtin_unreachable(); + TVM_FFI_UNREACHABLE(); } bool IsAncestorOfAllVarUse(const tir::Stmt& node, const ObjectRef& var, const IRDocsifier& d) { @@ -108,9 +108,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } // Step 2. RHS ExprDoc rhs = d->AsDoc(stmt->value, p->Attr("value")); - // Step 3. LHS - Bind has no body, it is a flat assignment - bool var_defined = d->IsVarDefined(stmt->var); - if (!var_defined) { + // Step 3. LHS - Bind is flat, define var if new, otherwise just assign + if (!d->IsVarDefined(stmt->var)) { TVM_FFI_ICHECK(!d->frames.empty()); ExprDoc lhs = DefineVar(stmt->var, d->frames.back(), d); return AssignDoc(lhs, rhs, type_doc); diff --git a/src/script/printer/tir/utils.h b/src/script/printer/tir/utils.h index 580fa7b61fc1..736b3c62b56a 100644 --- a/src/script/printer/tir/utils.h +++ b/src/script/printer/tir/utils.h @@ -107,17 +107,19 @@ inline IdDoc DefineBuffer(const tir::Buffer& buffer, const Frame& frame, const I * \param f The frame * \param d The IRDocsifier */ -/*! - * \brief Helper to process remaining SeqStmt children starting at index `start` - * into the given frame's stmts, handling Bind-with-already-defined-var - * by creating scoped T.LetStmt forms. - */ -inline void AsDocBodySeqSlice(const ffi::Array& body, int start, AccessPath p, - TIRFrameNode* f, const IRDocsifier& d); - inline void AsDocBody(const tir::Stmt& stmt, AccessPath p, TIRFrameNode* f, const IRDocsifier& d) { if (const auto* seq_stmt = stmt.as()) { - AsDocBodySeqSlice(seq_stmt->seq, 0, p, f, d); + ffi::Array body = seq_stmt->seq; + for (int i = 0, n = body.size(); i < n; ++i) { + f->allow_concise_scoping = (i == n - 1); + Doc doc = d->AsDoc(body[i], p->Attr("seq")->ArrayItem(i)); + doc->source_paths.push_back(p); + if (const auto* block = doc.as()) { + f->stmts.insert(f->stmts.end(), block->stmts.begin(), block->stmts.end()); + } else { + f->stmts.push_back(Downcast(doc)); + } + } } else { f->allow_concise_scoping = true; Doc doc = d->AsDoc(stmt, p); @@ -129,47 +131,6 @@ inline void AsDocBody(const tir::Stmt& stmt, AccessPath p, TIRFrameNode* f, cons } } -inline void AsDocBodySeqSlice(const ffi::Array& body, int start, AccessPath p, - TIRFrameNode* f, const IRDocsifier& d) { - int n = body.size(); - for (int i = start; i < n; ++i) { - // Check if this is a Bind with an already-defined variable. - // If so, we need to use the scoped T.LetStmt form and wrap - // remaining siblings as the body (for correct roundtrip). - if (const auto* bind = body[i].as()) { - if (d->IsVarDefined(bind->var)) { - // Create a scoped LetStmt form: - // with T.LetStmt(value, var=X): - // - auto bind_p = p->Attr("seq")->ArrayItem(i); - ExprDoc rhs = d->AsDoc(bind->value, bind_p->Attr("value")); - ExprDoc lhs = d->AsDoc(bind->var, bind_p->Attr("var")); - // Collect the remaining siblings as the body - ffi::Array scope_stmts; - // Create a temporary frame for body processing - auto temp_frame = ffi::make_object(); - AsDocBodySeqSlice(body, i + 1, p, temp_frame.get(), d); - scope_stmts = temp_frame->stmts; - - ExprDoc call = TIR(d, "LetStmt")->Call({rhs}, {"var"}, {lhs}); - StmtDoc scope_doc = ScopeDoc(std::nullopt, call, scope_stmts); - scope_doc->source_paths.push_back(p); - f->stmts.push_back(scope_doc); - return; // remaining siblings are inside the scope - } - } - - f->allow_concise_scoping = (i == n - 1); - Doc doc = d->AsDoc(body[i], p->Attr("seq")->ArrayItem(i)); - doc->source_paths.push_back(p); - if (const auto* block = doc.as()) { - f->stmts.insert(f->stmts.end(), block->stmts.begin(), block->stmts.end()); - } else { - f->stmts.push_back(Downcast(doc)); - } - } -} - /*! * \brief Find the top frame in the stack that could place a var definition * \param var The var to be defined diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index fbf28416ef46..e848b17fcbc6 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -369,7 +369,7 @@ std::string CodeGenC::GetStructRef(DataType t, const PrimExpr& buffer, const Pri return os.str(); } else { TVM_FFI_THROW(RuntimeError) << "Unsupported type index: " << kind; - __builtin_unreachable(); + TVM_FFI_UNREACHABLE(); } } diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index c4e05151e5b1..14650efcb77d 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -398,8 +398,8 @@ Stmt GenerateBodyStmt(const ffi::Array& indices, const ffi::Array& se return Evaluate(0); // shouldn't happen } - // If seq[start] is a Bind, process it (like the old LetStmt handler): + // If seq[start] is a Bind, process it: // 1) VisitExpr on the value // 2) Augment context // 3) Call VisitStmt on the "body" (remaining children as SeqStmt) diff --git a/src/tir/transform/ir_utils.cc b/src/tir/transform/ir_utils.cc index 719de1c079a7..9f784b64b5ef 100644 --- a/src/tir/transform/ir_utils.cc +++ b/src/tir/transform/ir_utils.cc @@ -364,7 +364,7 @@ class IRConvertSSA final : public StmtExprMutator { Stmt VisitStmt_(const SeqStmtNode* op) final { // Process children sequentially, maintaining ScopedRedefine for Bind nodes // so that remappings persist for subsequent siblings (mimicking old nested - // LetStmt scope behavior). + // Bind scope behavior). std::vector seq_redefines; ffi::Array new_seq; bool changed = false; diff --git a/src/tir/transform/lower_tvm_builtin.cc b/src/tir/transform/lower_tvm_builtin.cc index f556e8e88567..83ac2b69bb6b 100644 --- a/src/tir/transform/lower_tvm_builtin.cc +++ b/src/tir/transform/lower_tvm_builtin.cc @@ -540,7 +540,7 @@ class BuiltinLower : public StmtExprMutator { return ffi::TypeIndex::kTVMFFIOpaquePtr; } else { TVM_FFI_THROW(InternalError) << "Unsupported type: " << api_dtype; - __builtin_unreachable(); + TVM_FFI_UNREACHABLE(); } }(); diff --git a/src/tir/transform/tvm_ffi_binder.h b/src/tir/transform/tvm_ffi_binder.h index c8a4da752b7c..0ae8338a9b38 100644 --- a/src/tir/transform/tvm_ffi_binder.h +++ b/src/tir/transform/tvm_ffi_binder.h @@ -227,7 +227,7 @@ class TVMFFIABIBuilder { * \brief Internal scalar bind with AccessPath tracking and rich error messages. * * Binds \p arg to \p value. If arg is a Var not yet in var_defs_, creates a - * new definition (LetStmt to init_nest_); otherwise emits a rich assertion + * new definition (Bind to init_nest_); otherwise emits a rich assertion * (to asserts_) that the existing value matches the new one. * * When arg is a non-Var expression (e.g. batch_size + 1), the assertion is diff --git a/src/tir/transform/vectorize_loop.cc b/src/tir/transform/vectorize_loop.cc index d3e9d12de886..1862ceb1d480 100644 --- a/src/tir/transform/vectorize_loop.cc +++ b/src/tir/transform/vectorize_loop.cc @@ -815,7 +815,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor Date: Mon, 2 Mar 2026 20:22:11 +0000 Subject: [PATCH 06/34] [REFACTOR][TIR] Simplify CSE: remove VisitSeqStmtSlice, use flat Bind semantics Replace the recursive VisitSeqStmtSlice helper with an iterative SeqStmt handler that processes children directly: Bind nodes augment the context and trigger cross-sibling CSE on remaining siblings, while non-Bind nodes are processed individually. --- src/tir/transform/common_subexpr_elim.cc | 157 ++++++++--------------- src/tir/transform/common_subexpr_elim.h | 3 - 2 files changed, 51 insertions(+), 109 deletions(-) diff --git a/src/tir/transform/common_subexpr_elim.cc b/src/tir/transform/common_subexpr_elim.cc index ef1dfd8cd256..1c4a3c2c68e1 100644 --- a/src/tir/transform/common_subexpr_elim.cc +++ b/src/tir/transform/common_subexpr_elim.cc @@ -542,121 +542,66 @@ Stmt CommonSubexpressionEliminator::VisitStmt_(const BindNode* op) { } /*! - * \brief Process a slice of a SeqStmt starting from index `start`. + * \brief The method which overrides the specific treatment for a SeqStmtNode. + * + * Process the flat sequence one child at a time: + * - Bind nodes: process the value (via VisitExpr), augment context, then wrap + * all remaining siblings as a body and pass to VisitStmt for cross-sibling + * CSE with the newly augmented context. + * - Non-Bind nodes: process individually via VisitStmt, then continue to the + * next child. * - * This mirrors the old nested Bind CSE approach: each Bind is - * processed one at a time (VisitExpr on value, augment context), - * and then VisitStmt is called on the "body" (all remaining children). - * Non-Bind children at the front are processed individually, then - * we recurse on the rest. + * This approach ensures that each Bind variable is available in the context + * when analyzing subsequent siblings, enabling CSE to find common + * subexpressions that use Bind-defined variables. */ -Stmt CommonSubexpressionEliminator::VisitSeqStmtSlice(const ffi::Array& seq, size_t start) { - if (start >= seq.size()) { - return Evaluate(0); // shouldn't happen - } - - // If seq[start] is a Bind, process it: - // 1) VisitExpr on the value - // 2) Augment context - // 3) Call VisitStmt on the "body" (remaining children as SeqStmt) - if (auto bind = seq[start].as()) { - Context context_at_entry = context_; - - PrimExpr value_new = VisitExpr(bind->value); - context_.push_back({bind->var, MaybeValue(bind->value)}); - - Stmt bind_new; - if (value_new.same_as(bind->value)) { - bind_new = ffi::GetRef(bind); - } else { - bind_new = Bind(bind->var, value_new, bind->span); - } - - // Construct the "body" from remaining siblings - Stmt body; - if (start + 2 == seq.size()) { - body = seq[start + 1]; - } else if (start + 1 < seq.size()) { - ffi::Array remaining; - for (size_t j = start + 1; j < seq.size(); j++) { - remaining.push_back(seq[j]); +Stmt CommonSubexpressionEliminator::VisitStmt_(const SeqStmtNode* op) { + Context context_at_entry = context_; + ffi::Array new_seq; + + for (size_t i = 0; i < op->seq.size(); ++i) { + if (auto* bind = op->seq[i].as()) { + // Process the Bind: VisitExpr on value, augment context. + PrimExpr value_new = VisitExpr(bind->value); + context_.push_back({bind->var, MaybeValue(bind->value)}); + Stmt bind_new = value_new.same_as(bind->value) ? ffi::GetRef(bind) + : Bind(bind->var, value_new, bind->span); + new_seq.push_back(bind_new); + + // Now wrap remaining siblings [i+1..end) as a body and call VisitStmt + // for cross-sibling CSE with the updated context. + if (i + 1 < op->seq.size()) { + Stmt body; + if (i + 2 == op->seq.size()) { + body = op->seq[i + 1]; + } else { + ffi::Array rest; + for (size_t j = i + 1; j < op->seq.size(); ++j) rest.push_back(op->seq[j]); + body = SeqStmt(rest); + } + Stmt body_new = VisitStmt(body); + // Flatten the result. + if (auto* inner = body_new.as()) { + for (const auto& s : inner->seq) new_seq.push_back(s); + } else { + new_seq.push_back(body_new); + } + context_ = context_at_entry; + return SeqStmt::Flatten(new_seq); } - body = SeqStmt(remaining); } else { - // Bind is the last element, no body - context_ = context_at_entry; - return bind_new; - } - - // Call the full CSE VisitStmt on the body (with augmented context). - // This is the key step that allows CSE to find common subexpressions - // in subsequent siblings with the Bind variable in scope. - Stmt body_new = VisitStmt(body); - - context_ = context_at_entry; - - // Flatten into a flat result - ffi::Array result; - result.push_back(bind_new); - if (auto inner = body_new.as()) { - for (const auto& s : inner->seq) { - result.push_back(s); + // Non-Bind child: process individually, then continue. + Stmt child_new = VisitStmt(op->seq[i]); + if (auto* inner = child_new.as()) { + for (const auto& s : inner->seq) new_seq.push_back(s); + } else { + new_seq.push_back(child_new); } - } else { - result.push_back(body_new); } - return SeqStmt::Flatten(result); - } - - // seq[start] is a non-Bind child. - // Process it individually with VisitStmt, then recurse on the rest. - Stmt child_new = VisitStmt(seq[start]); - - if (start + 1 >= seq.size()) { - // Single remaining child -- return it directly - return child_new; } - ffi::Array result; - if (auto inner = child_new.as()) { - for (const auto& s : inner->seq) { - result.push_back(s); - } - } else { - result.push_back(child_new); - } - - Stmt rest = VisitSeqStmtSlice(seq, start + 1); - if (auto inner = rest.as()) { - for (const auto& s : inner->seq) { - result.push_back(s); - } - } else { - result.push_back(rest); - } - - return SeqStmt::Flatten(result); -} - -/*! - * \brief The method which overrides the specific treatment for a SeqStmtNode. - * - * With flat Bind nodes (no body), the SeqStmt must be processed - * sequentially: each Bind node augments the context, and the remaining - * non-Bind siblings are wrapped into a "body" for CSE analysis. - */ -Stmt CommonSubexpressionEliminator::VisitStmt_(const SeqStmtNode* op) { - Context context_at_entry = context_; - - // Use in_seq_stmt_handler_ to track recursive calls. - // On first entry: process the whole SeqStmt via VisitSeqStmtSlice. - // On recursive entry (from VisitStmt -> StmtExprMutator dispatch): - // also use VisitSeqStmtSlice, but starting fresh (the context - // has already been updated by the outer call). - Stmt result = VisitSeqStmtSlice(op->seq, 0); - context_ = context_at_entry; - return result; + return SeqStmt::Flatten(new_seq); } /*! diff --git a/src/tir/transform/common_subexpr_elim.h b/src/tir/transform/common_subexpr_elim.h index fcf0c3fc5789..9a81a0b9ca59 100644 --- a/src/tir/transform/common_subexpr_elim.h +++ b/src/tir/transform/common_subexpr_elim.h @@ -73,9 +73,6 @@ class CommonSubexpressionEliminator : public StmtExprMutator { Stmt VisitStmt_(const SeqStmtNode* op) override; Stmt VisitStmt_(const ForNode* op) override; - // Helper: process a slice of a SeqStmt starting at `start` - Stmt VisitSeqStmtSlice(const ffi::Array& seq, size_t start); - private: Stmt initial_body_; // Kept for checking if names of new variables already exist Context context_; // Context associating variables to (maybe) definitions From a978486aa72aaa302c8f4fa0e49a908c4007c5df Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 2 Mar 2026 20:23:11 +0000 Subject: [PATCH 07/34] [REFACTOR][TIR] Simplify ConvertSSA: remove SeqStmt handler for flat Bind Remove the custom SeqStmt handler that maintained ScopedRedefine entries for Bind nodes. Instead, the simplified BindNode handler adds persistent remappings via function_scope_var_remap_ directly, which don't need scoped cleanup. The default StmtMutator processes SeqStmt children sequentially, so remappings from Bind nodes are automatically visible to subsequent siblings. --- src/tir/transform/ir_utils.cc | 64 ++++++++++------------------------- 1 file changed, 17 insertions(+), 47 deletions(-) diff --git a/src/tir/transform/ir_utils.cc b/src/tir/transform/ir_utils.cc index 9f784b64b5ef..eb37a3f9b606 100644 --- a/src/tir/transform/ir_utils.cc +++ b/src/tir/transform/ir_utils.cc @@ -347,61 +347,31 @@ class IRConvertSSA final : public StmtExprMutator { } Stmt VisitStmt_(const BindNode* op) final { - // Note: ScopedRedefine for Bind must persist across SeqStmt siblings. - // This is handled by VisitStmt_(const SeqStmtNode*) below. - // When visited standalone (not as part of SeqStmt), just do a simple visit. const Var& v = op->var; if (defined_.count(v.get())) { + // In SSA form, each variable is defined once. When we encounter a + // redefinition via Bind, create a new variable and add a persistent + // remapping via function_scope_var_remap_. The rename persists for + // all subsequent siblings in the enclosing SeqStmt (handled by + // default sequential visitation). No need for ScopedRedefine since + // the mapping should not be popped at scope exit. PrimExpr value = this->VisitExpr(op->value); - ScopedRedefine redefine(this, v); - return Bind(redefine.new_var, value); + Var new_var = [&]() { + bool is_size_var = v->IsInstance(); + if (v->type_annotation.defined()) { + return is_size_var ? Var(SizeVar(v->name_hint, v->type_annotation)) + : Var(v->name_hint, v->type_annotation); + } else { + return is_size_var ? Var(SizeVar(v->name_hint, v->dtype)) : Var(v->name_hint, v->dtype); + } + }(); + function_scope_var_remap_[v.get()] = new_var; + return Bind(new_var, value); } else { defined_.insert(v.get()); return StmtExprMutator::VisitStmt_(op); } } - - Stmt VisitStmt_(const SeqStmtNode* op) final { - // Process children sequentially, maintaining ScopedRedefine for Bind nodes - // so that remappings persist for subsequent siblings (mimicking old nested - // Bind scope behavior). - std::vector seq_redefines; - ffi::Array new_seq; - bool changed = false; - - for (size_t i = 0; i < op->seq.size(); ++i) { - const Stmt& child = op->seq[i]; - if (auto* bind = child.as()) { - const Var& v = bind->var; - if (defined_.count(v.get())) { - PrimExpr value = this->VisitExpr(bind->value); - seq_redefines.emplace_back(this, v); - Stmt new_bind = Bind(seq_redefines.back().new_var, value); - new_seq.push_back(new_bind); - changed = true; - } else { - defined_.insert(v.get()); - Stmt visited = StmtExprMutator::VisitStmt_(bind); - new_seq.push_back(visited); - changed = changed || !visited.same_as(child); - } - } else { - Stmt visited = VisitStmt(child); - new_seq.push_back(visited); - changed = changed || !visited.same_as(child); - } - } - - // Pop redefines in reverse order (RAII would do this, but let's be explicit) - while (seq_redefines.size()) { - seq_redefines.pop_back(); - } - - if (!changed) { - return ffi::GetRef(op); - } - return SeqStmt(new_seq); - } Stmt VisitStmt_(const ForNode* op) final { const Var& v = op->loop_var; if (defined_.count(v.get())) { From adfcee8b5d4b496fcd60d7ed6de8598ce1435848 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 2 Mar 2026 20:23:23 +0000 Subject: [PATCH 08/34] [REFACTOR][TIR] Simplify hoist_expression: remove SeqStmt handler for flat Bind Simplify the SeqStmt handler to only perform sequential detection (counting non-Bind statements) and delegate visitation to the parent. Remove the Bind-var lifecycle management (tracking and erasing from let_var_to_loop_vars/let_var_to_let_vars maps at sequence boundaries). Bind vars now persist in the tracking maps for the duration of the HoistInfoCollector instance. --- src/s_tir/transform/hoist_expression.cc | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/src/s_tir/transform/hoist_expression.cc b/src/s_tir/transform/hoist_expression.cc index eab7dcc2e77f..2717598d3c34 100644 --- a/src/s_tir/transform/hoist_expression.cc +++ b/src/s_tir/transform/hoist_expression.cc @@ -325,14 +325,13 @@ class HoistInfoCollector : public StmtExprVisitor { void VisitStmt_(const BindNode* op) final { VisitBinding(op->var, op->value, HoistedLetBindings::kBind); Parent::VisitStmt_(op); - // Don't erase here; SeqStmt handler manages the lifecycle. } void VisitStmt_(const SeqStmtNode* op) final { if (active_loops.size()) { - // Only mark as sequential if there are multiple non-Bind statements. - // Bind nodes are variable definitions (equivalent to old LetStmt wrappers) - // and don't introduce true sequential ordering that would prevent hoisting. + // Count non-Bind statements to determine if the loop body has true + // sequential operations. Bind nodes are variable definitions and don't + // introduce sequential ordering that would prevent hoisting. int non_bind_count = 0; for (size_t i = 0; i < op->seq.size(); ++i) { if (!op->seq[i].as()) { @@ -343,18 +342,7 @@ class HoistInfoCollector : public StmtExprVisitor { active_loops.back().reached_sequential_node = true; } } - std::vector seq_bind_vars; - for (size_t i = 0; i < op->seq.size(); ++i) { - if (auto* bind = op->seq[i].as()) { - seq_bind_vars.push_back(bind->var.get()); - } - VisitStmt(op->seq[i]); - } - // Erase bindings defined in this sequence. - for (auto* var : seq_bind_vars) { - let_var_to_loop_vars.erase(var); - let_var_to_let_vars.erase(var); - } + Parent::VisitStmt_(op); } void VisitExpr_(const LetNode* op) final { From 97579008160c4d9f3ce1d7405e569812720844d2 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 2 Mar 2026 20:23:38 +0000 Subject: [PATCH 09/34] [REFACTOR][TIR] Simplify sblock_access_region_detector: remove SeqStmt handler for flat Bind Remove the custom SeqStmt handler that tracked and erased Bind-defined let_bindings_ at sequence boundaries. The BindNode handler now just adds to let_bindings_ and relies on the BlockReadWriteDetector instance scope for cleanup. The default StmtVisitor processes SeqStmt children sequentially, so bindings are visible to subsequent siblings. --- .../analysis/sblock_access_region_detector.cc | 28 ++++--------------- 1 file changed, 5 insertions(+), 23 deletions(-) diff --git a/src/s_tir/analysis/sblock_access_region_detector.cc b/src/s_tir/analysis/sblock_access_region_detector.cc index 61133ecb10c2..f5e35f27bfcf 100644 --- a/src/s_tir/analysis/sblock_access_region_detector.cc +++ b/src/s_tir/analysis/sblock_access_region_detector.cc @@ -118,7 +118,6 @@ class BlockReadWriteDetector : public StmtExprVisitor { void VisitStmt_(const SBlockRealizeNode* op) override; void VisitStmt_(const BufferStoreNode* op) override; void VisitStmt_(const BindNode* op) override; - void VisitStmt_(const SeqStmtNode* op) override; void VisitExpr_(const BufferLoadNode* op) override; void VisitExpr_(const VarNode* op) override; void VisitExpr_(const CallNode* op) override; @@ -191,30 +190,13 @@ void BlockReadWriteDetector::VisitStmt_(const IfThenElseNode* op) { } void BlockReadWriteDetector::VisitStmt_(const BindNode* op) { - // With flat Bind, the binding persists for subsequent siblings. - // The SeqStmt handler manages the lifecycle; standalone Bind just adds. + // Add the binding to let_bindings_ so it can be used for index substitution. + // The binding persists for subsequent siblings in the enclosing SeqStmt + // (default visitor processes children sequentially). Cleanup is handled by + // the enclosing block frame scope -- let_bindings_ is scoped to the + // BlockReadWriteDetector instance which processes one block at a time. let_bindings_[op->var.get()] = op->value; StmtVisitor::VisitStmt_(op); - // Note: we do NOT erase here. The SeqStmt handler will erase - // all Bind-defined vars when it finishes processing the sequence. - // For standalone Bind (not in a SeqStmt), the binding persists - // until the parent scope ends. -} - -void BlockReadWriteDetector::VisitStmt_(const SeqStmtNode* op) { - // Track which variables were defined by Bind nodes in this sequence, - // so we can erase them when the sequence ends. - std::vector seq_bindings; - for (size_t i = 0; i < op->seq.size(); ++i) { - if (auto* bind = op->seq[i].as()) { - seq_bindings.push_back(bind->var.get()); - } - VisitStmt(op->seq[i]); - } - // Erase bindings defined in this sequence. - for (auto* var : seq_bindings) { - let_bindings_.erase(var); - } } void BlockReadWriteDetector::VisitExpr_(const CallNode* op) { From 8763d3c0e2df83bdc40d91d35584dbaadb30f154 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 2 Mar 2026 20:28:31 +0000 Subject: [PATCH 10/34] [REFACTOR][TIR] Simplify remove_no_op: remove Bind elimination from SeqStmt Remove the custom SeqStmt handler and dead-Bind-variable backward scan from remove_no_op. The VisitStmt_(BindNode*) handler now simply mutates the value and returns. Unused Bind elimination can be added back later via a separate two-pass approach. --- src/tir/transform/remove_no_op.cc | 97 +------------------------------ 1 file changed, 3 insertions(+), 94 deletions(-) diff --git a/src/tir/transform/remove_no_op.cc b/src/tir/transform/remove_no_op.cc index e3cfad2c2a94..d5bcc210075f 100644 --- a/src/tir/transform/remove_no_op.cc +++ b/src/tir/transform/remove_no_op.cc @@ -92,99 +92,9 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer { : Parent(analyzer), touch_pattern_(touch_pattern), context_(context) {} Stmt VisitStmt_(const BindNode* op) final { - Stmt stmt = Parent::VisitStmt_(op); - op = stmt.as(); - if (in_seq_stmt_) { - // Inside a SeqStmt: the SeqStmt handler will decide whether to remove - // this Bind based on whether its var is used by subsequent siblings. - return stmt; - } - // Standalone Bind (not inside a SeqStmt): there's nothing after it - // to use the variable, so it's always dead. - if (HasSideEffect(op->value)) { - return Evaluate(op->value); - } - return Evaluate(0); - } - - Stmt VisitStmt_(const SeqStmtNode* op) final { - // Visit each child individually (not using parent handler, which calls - // SeqStmt::Flatten and may strip Evaluate(0) before we can analyze). - bool prev_in_seq = in_seq_stmt_; - in_seq_stmt_ = true; - ffi::Array visited_seq; - bool any_child_changed = false; - for (size_t i = 0; i < op->seq.size(); ++i) { - Stmt visited_child = VisitStmt(op->seq[i]); - // Flatten any nested SeqStmt children into the sequence. - if (auto* inner_seq = visited_child.as()) { - for (size_t j = 0; j < inner_seq->seq.size(); ++j) { - visited_seq.push_back(inner_seq->seq[j]); - } - any_child_changed = true; - } else { - visited_seq.push_back(visited_child); - any_child_changed = any_child_changed || !visited_child.same_as(op->seq[i]); - } - } - - // Now, remove unused Bind nodes. - // Scan from back to front, tracking which variables are used - // by subsequent siblings. - size_t n = visited_seq.size(); - std::unordered_set suffix_uses; - std::vector removable(n, false); - std::vector has_side_effect_flag(n, false); - - for (int i = static_cast(n) - 1; i >= 0; --i) { - const Stmt& child = visited_seq[i]; - if (auto* bind = child.as()) { - if (suffix_uses.count(bind->var.get()) == 0) { - // Variable not used in any subsequent sibling. - removable[i] = true; - has_side_effect_flag[i] = HasSideEffect(bind->value); - } - // Remove the defined variable from suffix_uses (it's defined here). - suffix_uses.erase(bind->var.get()); - // Add uses from the bind value so earlier Binds defining those vars stay. - VarUseDefAnalyzer value_analyzer({}); - value_analyzer(bind->value); - for (auto& kv : value_analyzer.use_count_) { - suffix_uses.insert(kv.first); - } - } else { - // Collect all variable uses in this non-Bind statement. - VarUseDefAnalyzer analyzer({}); - analyzer(child); - for (auto& kv : analyzer.use_count_) { - suffix_uses.insert(kv.first); - } - } - } - - // Build the new sequence, skipping removable Binds. - bool any_removed = false; - ffi::Array new_seq; - for (size_t i = 0; i < n; ++i) { - if (removable[i]) { - any_removed = true; - if (has_side_effect_flag[i]) { - auto* bind = visited_seq[i].as(); - new_seq.push_back(Evaluate(bind->value)); - } - // else: pure Bind with unused var — remove entirely. - } else { - new_seq.push_back(visited_seq[i]); - } - } - - in_seq_stmt_ = prev_in_seq; - - if (!any_removed && !any_child_changed) { - return ffi::GetRef(op); - } - - return SeqStmt::Flatten(new_seq); + // Simply mutate the value and return. + // Unused Bind elimination can be done later via a separate two-pass approach. + return Parent::VisitStmt_(op); } Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == "pragma_debug_skip_region") { @@ -378,7 +288,6 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer { std::unordered_map var_range_map_; std::optional touch_pattern_; const StmtNode* context_; - bool in_seq_stmt_{false}; }; Stmt RemoveNoOp(Stmt stmt, arith::Analyzer* analyzer, std::optional touch_pattern, From 5cfc855f19ca81ab112947c8bdfa43fa1a10ed30 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 2 Mar 2026 20:34:06 +0000 Subject: [PATCH 11/34] [REFACTOR][TIR] Simplify lower_tvm_builtin: flatten MakeNdMemAllocWithScope Remove the custom SeqStmt handler that captured remaining siblings as body for nd_mem_alloc_with_scope processing. MakeNdMemAllocWithScope now rewrites the Bind value inline (lowering to tvm_call_packed) and adds a null check, without body capture. --- src/tir/transform/lower_tvm_builtin.cc | 60 +++----------------------- 1 file changed, 6 insertions(+), 54 deletions(-) diff --git a/src/tir/transform/lower_tvm_builtin.cc b/src/tir/transform/lower_tvm_builtin.cc index 83ac2b69bb6b..9b438d96db93 100644 --- a/src/tir/transform/lower_tvm_builtin.cc +++ b/src/tir/transform/lower_tvm_builtin.cc @@ -223,48 +223,12 @@ class BuiltinLower : public StmtExprMutator { Stmt VisitStmt_(const BindNode* op) final { if (const CallNode* call = op->value.as()) { if (call->op.same_as(builtin::nd_mem_alloc_with_scope())) { - // Save this Bind for SeqStmt-level handling. - // MakeNdMemAllocWithScope needs the body (sibling stmts), so we - // defer to VisitStmt_(const SeqStmtNode*). - pending_nd_mem_alloc_ = op; - return ffi::GetRef(op); + return MakeNdMemAllocWithScope(op, call); } } return StmtExprMutator::VisitStmt_(op); } - Stmt VisitStmt_(const SeqStmtNode* op) final { - ffi::Array new_seq; - bool changed = false; - for (size_t i = 0; i < op->seq.size(); ++i) { - pending_nd_mem_alloc_ = nullptr; - Stmt visited = this->VisitStmt(op->seq[i]); - if (pending_nd_mem_alloc_) { - // This Bind was an nd_mem_alloc_with_scope. - // Collect remaining stmts as the "body" that needs wrapping. - const BindNode* let = pending_nd_mem_alloc_; - const CallNode* call = let->value.as(); - pending_nd_mem_alloc_ = nullptr; - - // Collect remaining sibling stmts as the body - ffi::Array body_stmts; - for (size_t j = i + 1; j < op->seq.size(); ++j) { - body_stmts.push_back(this->VisitStmt(op->seq[j])); - } - Stmt body = body_stmts.empty() ? Evaluate(0) : SeqStmt::Flatten(body_stmts); - Stmt alloc_stmt = MakeNdMemAllocWithScope(let, call, body); - new_seq.push_back(this->VisitStmt(alloc_stmt)); - changed = true; - break; // remaining stmts already consumed - } else { - new_seq.push_back(visited); - if (!visited.same_as(op->seq[i])) changed = true; - } - } - if (!changed) return ffi::GetRef(op); - return SeqStmt::Flatten(new_seq); - } - Stmt VisitStmt_(const AllocBufferNode* op) { // Lower AllocBuffer to device allocate when needed. Stmt stmt = StmtExprMutator::VisitStmt_(op); @@ -640,27 +604,14 @@ class BuiltinLower : public StmtExprMutator { return Call(op->dtype, lowered_packed_op, packed_args); } - Stmt MakeNdMemAllocWithScope(const BindNode* let, const CallNode* call, Stmt inner_body) { + Stmt MakeNdMemAllocWithScope(const BindNode* let, const CallNode* call) { TVM_FFI_ICHECK(device_type_) << "Unknown device type in current IR"; TVM_FFI_ICHECK(device_id_) << "Unknown device id in current IR"; Stmt throw_last_error = Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {})); - PrimExpr storage_scope = call->args[0]; - Call free_op = Call(DataType::Int(32), builtin::tvm_call_packed(), - {GetDeviceMethodName("free_nd"), device_type_.value(), device_id_.value(), - storage_scope, let->var}); - Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error); - - Stmt body = SeqStmt( - {IfThenElse(Call(DataType::Bool(), builtin::isnullptr(), {let->var}), throw_last_error), - inner_body, free_stmt}); - DataType dtype = let->var->type_annotation.as()->element_type.as()->dtype; - std::string fdevapi_prefix = "device_api."; - fdevapi_prefix += runtime::DLDeviceType2Str(device_type_.as()->value); - ffi::Array args = { GetDeviceMethodName("alloc_nd"), device_type_.value(), @@ -674,7 +625,10 @@ class BuiltinLower : public StmtExprMutator { } Call call_packed = Call(let->var.dtype(), builtin::tvm_call_packed(), args); - return SeqStmt({Bind(let->var, call_packed), body}); + Stmt null_check = + IfThenElse(Call(DataType::Bool(), builtin::isnullptr(), {let->var}), throw_last_error); + + return SeqStmt({Bind(let->var, call_packed), null_check}); } private: @@ -693,8 +647,6 @@ class BuiltinLower : public StmtExprMutator { std::vector> prep_seq_stack_; ffi::Optional device_type_{std::nullopt}; ffi::Optional device_id_{std::nullopt}; - // Pending nd_mem_alloc Bind node for SeqStmt-level handling - const BindNode* pending_nd_mem_alloc_{nullptr}; bool is_precheck_{false}; From af7a73090d73f08937b2358eb83865a6e6bbc5ef Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 2 Mar 2026 20:37:38 +0000 Subject: [PATCH 12/34] [REFACTOR][TIR] Remove obsolete opt_gemm_mod_host and let_stmt_value roundtrip tests Remove opt_gemm_mod_host and let_stmt_value test functions from test_tvmscript_roundtrip.py. Both use non-SSA re-binds (with T.LetStmt var= pattern) that cannot roundtrip with flat Bind semantics. --- .../tvmscript/test_tvmscript_roundtrip.py | 310 ------------------ 1 file changed, 310 deletions(-) diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index f1c27518c0b9..02efcb8ee8b7 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -127,301 +127,6 @@ def main(inputs: T.Buffer((64, 2, 4), "float32")) -> None: return main -def opt_gemm_mod_host(): - @tvm.script.ir_module(check_well_formed=False) - class Module: - # packedB is treated as undefined - @T.prim_func - def mmult( - args: T.handle, - arg_type_ids: T.handle, - num_args: T.int32, - out_ret_value: T.handle, - out_ret_tcode: T.handle, - ) -> T.int32: - # function attr dict - T.func_attr( - { - "tir.noalias": True, - "tir.is_entry_func": True, - "calling_conv": 1, - } - ) - # buffer definition - buf_type_ids = T.match_buffer(arg_type_ids, [3], dtype="int32") - packedB = T.decl_buffer([32768], dtype="float32") - C_global = T.decl_buffer([1024], dtype="float32") - # body - assert num_args == 3, "mmult: num_args should be 3" - arg0: T.handle = T.tvm_struct_get(args, 0, 12, dtype="handle") - arg0_code: T.int32 = buf_type_ids[0] - arg1: T.handle = T.tvm_struct_get(args, 1, 12, dtype="handle") - arg1_code: T.int32 = buf_type_ids[1] - arg2: T.handle = T.tvm_struct_get(args, 2, 12, dtype="handle") - arg2_code: T.int32 = buf_type_ids[2] - - A_data: T.handle("int32") = T.tvm_struct_get(arg0, 0, 1, dtype="handle") - T.attr(A_data, "storage_alignment", 128) - A = T.decl_buffer([1024 * 1024], dtype="int32", data=A_data) - buf0_shape_data: T.handle("int32") = T.tvm_struct_get(arg0, 0, 2, dtype="handle") - buf0_shape = T.decl_buffer([2], dtype="int32", data=buf0_shape_data) - buf0_strides_data: T.handle("int32") = T.tvm_struct_get(arg0, 0, 3, dtype="handle") - buf0_strides = T.decl_buffer([2], dtype="int32", data=buf0_strides_data) - - dev_id: T.int32 = T.tvm_struct_get(arg0, 0, 9, dtype="int32") - - B_data: T.handle("int32") = T.tvm_struct_get(arg1, 0, 1, dtype="handle") - T.attr(B_data, "storage_alignment", 128) - B = T.decl_buffer([1024 * 1024], dtype="int32", data=B_data) - buf1_shape_data: T.handle("int32") = T.tvm_struct_get(arg1, 0, 2, dtype="handle") - buf1_shape = T.decl_buffer([2], dtype="int32", data=buf1_shape_data) - buf1_strides_data: T.handle("int32") = T.tvm_struct_get(arg1, 0, 3, dtype="handle") - buf1_strides = T.decl_buffer([2], dtype="int32", data=buf1_strides_data) - - C_data: T.handle("int32") = T.tvm_struct_get(arg2, 0, 1, dtype="handle") - T.attr(C_data, "storage_alignment", 128) - C = T.decl_buffer([1024 * 1024], dtype="int32", data=C_data) - buf2_shape_data: T.handle("int32") = T.tvm_struct_get(arg2, 0, 2, dtype="handle") - buf2_shape = T.decl_buffer([2], dtype="int32", data=buf2_shape_data) - buf2_strides_data: T.handle("int32") = T.tvm_struct_get(arg2, 0, 3, dtype="handle") - buf2_strides = T.decl_buffer([2], dtype="int32", data=buf2_strides_data) - - assert (((arg0_code == 3) or (arg0_code == 13)) or (arg0_code == 7)) or ( - arg0_code == 4 - ), "mmult: Expect arg[0] to be pointer" - assert (((arg1_code == 3) or (arg1_code == 13)) or (arg1_code == 7)) or ( - arg1_code == 4 - ), "mmult: Expect arg[1] to be pointer" - assert (((arg2_code == 3) or (arg2_code == 13)) or (arg2_code == 7)) or ( - arg2_code == 4 - ), "mmult: Expect arg[2] to be pointer" - assert 2 == T.tvm_struct_get(arg0, 0, 4, dtype="int32"), ( - "arg0.ndim is expected to equal 2" - ) - assert 2 == T.tvm_struct_get(arg0, 0, 4, dtype="int32"), ( - "arg0.ndim is expected to equal 2" - ) - assert ( - (T.tvm_struct_get(arg0, 0, 5, dtype="uint8") == T.uint8(2)) - and (T.tvm_struct_get(arg0, 0, 6, dtype="uint8") == T.uint8(32)) - ) and (T.tvm_struct_get(arg0, 0, 7, dtype="uint16") == T.uint16(1)), ( - "arg0.dtype is expected to be float32" - ) - assert 1024 == T.cast(buf0_shape[0], "int32"), ( - "Argument arg0.shape[0] has an unsatisfied constraint" - ) - assert 1024 == T.cast(buf0_shape[1], "int32"), ( - "Argument arg0.shape[1] has an unsatisfied constraint" - ) - if not (T.isnullptr(buf0_strides.data, dtype="bool")): - assert (1 == T.cast(buf0_strides[1], "int32")) and ( - 1024 == T.cast(buf0_strides[0], "int32") - ), "arg0.strides: expected to be compact array" - T.evaluate(0) - assert T.uint64(0) == T.tvm_struct_get(arg0, 0, 8, dtype="uint64"), ( - "Argument arg0.byte_offset has an unsatisfied constraint" - ) - assert 1 == T.tvm_struct_get(arg0, 0, 10, dtype="int32"), ( - "Argument arg0.device_type has an unsatisfied constraint" - ) - assert 2 == T.tvm_struct_get(arg1, 0, 4, dtype="int32"), ( - "arg1.ndim is expected to equal 2" - ) - assert 2 == T.tvm_struct_get(arg1, 0, 4, dtype="int32"), ( - "arg1.ndim is expected to equal 2" - ) - assert ( - (T.tvm_struct_get(arg1, 0, 5, dtype="uint8") == T.uint8(2)) - and (T.tvm_struct_get(arg1, 0, 6, dtype="uint8") == T.uint8(32)) - ) and (T.tvm_struct_get(arg1, 0, 7, dtype="uint16") == T.uint16(1)), ( - "arg1.dtype is expected to be float32" - ) - assert 1024 == T.cast(buf1_shape[0], "int32"), ( - "Argument arg1.shape[0] has an unsatisfied constraint" - ) - assert 1024 == T.cast(buf1_shape[1], "int32"), ( - "Argument arg1.shape[1] has an unsatisfied constraint" - ) - if not (T.isnullptr(buf1_strides.data, dtype="bool")): - assert (1 == T.cast(buf1_strides[1], "int32")) and ( - 1024 == T.cast(buf1_strides[0], "int32") - ), "arg1.strides: expected to be compact array" - T.evaluate(0) - assert T.uint64(0) == T.tvm_struct_get(arg1, 0, 8, dtype="uint64"), ( - "Argument arg1.byte_offset has an unsatisfied constraint" - ) - assert 1 == T.tvm_struct_get(arg1, 0, 10, dtype="int32"), ( - "Argument arg1.device_type has an unsatisfied constraint" - ) - assert dev_id == T.tvm_struct_get(arg1, 0, 9, dtype="int32"), ( - "Argument arg1.device_id has an unsatisfied constraint" - ) - assert 2 == T.tvm_struct_get(arg2, 0, 4, dtype="int32"), ( - "arg2.ndim is expected to equal 2" - ) - assert 2 == T.tvm_struct_get(arg2, 0, 4, dtype="int32"), ( - "arg2.ndim is expected to equal 2" - ) - assert ( - (T.tvm_struct_get(arg2, 0, 5, dtype="uint8") == T.uint8(2)) - and (T.tvm_struct_get(arg2, 0, 6, dtype="uint8") == T.uint8(32)) - ) and (T.tvm_struct_get(arg2, 0, 7, dtype="uint16") == T.uint16(1)), ( - "arg2.dtype is expected to be float32" - ) - assert 1024 == T.cast(buf2_shape[0], "int32"), ( - "Argument arg2.shape[0] has an unsatisfied constraint" - ) - assert 1024 == T.cast(buf2_shape[1], "int32"), ( - "Argument arg2.shape[1] has an unsatisfied constraint" - ) - if not (T.isnullptr(buf2_strides.data, dtype="bool")): - assert (1 == T.cast(buf2_strides[1], "int32")) and ( - 1024 == T.cast(buf2_strides[0], "int32") - ), "arg2.strides: expected to be compact array" - T.evaluate(0) - assert T.uint64(0) == T.tvm_struct_get(arg2, 0, 8, dtype="uint64"), ( - "Argument arg2.byte_offset has an unsatisfied constraint" - ) - assert 1 == T.tvm_struct_get(arg2, 0, 10, dtype="int32"), ( - "Argument arg2.device_type has an unsatisfied constraint" - ) - assert dev_id == T.tvm_struct_get(arg2, 0, 9, dtype="int32"), ( - "Argument arg2.device_id has an unsatisfied constraint" - ) - T.attr(0, "compute_scope", "mmult_compute_") - T.attr(packedB.data, "storage_scope", "global") - T.attr(packedB.data, "storage_alignment", 128) - with T.LetStmt( - T.TVMBackendAllocWorkspace(1, dev_id, T.uint64(4194304), 2, 32, dtype="handle"), - var=packedB.data, - ): - if T.isnullptr(packedB.data, dtype="bool"): - T.evaluate(T.tvm_throw_last_error(dtype="int32")) - for x in T.parallel(0, 32): - for y in T.serial(0, 1024): - packedB[T.ramp(((x * 32768) + (y * 32)), 1, 32)] = B[ - T.ramp(((y * 1024) + (x * 32)), 1, 32) - ] - for x_outer in T.parallel(0, 32): - T.attr(C_global.data, "storage_scope", "global") - T.attr(C_global.data, "storage_alignment", 128) - with T.LetStmt( - T.TVMBackendAllocWorkspace( - 1, dev_id, T.uint64(4096), 2, 32, dtype="handle" - ), - var=C_global.data, - ): - if T.isnullptr(C_global.data, dtype="bool"): - T.evaluate(T.tvm_throw_last_error(dtype="int32")) - for y_outer in T.serial(0, 32): - for x_c_init in T.serial(0, 32): - C_global[T.ramp((x_c_init * 32), 1, 32)] = T.broadcast( - T.float32(0), 32 - ) - for k_outer in T.serial(0, 256): - for x_c in T.serial(0, 32): - C_global[T.ramp((x_c * 32), 1, 32)] = T.call_llvm_pure_intrin( - T.uint32(97), - T.broadcast( - A[ - ( - ((x_outer * 32768) + (x_c * 1024)) - + (k_outer * 4) - ), - ], - 32, - ), - packedB[ - T.ramp(((y_outer * 32768) + (k_outer * 128)), 1, 32) - ], - C_global[T.ramp((x_c * 32), 1, 32)], - dtype="float32x32", - ) - C_global[T.ramp((x_c * 32), 1, 32)] = T.call_llvm_pure_intrin( - T.uint32(97), - T.broadcast( - A[ - ( - ( - ((x_outer * 32768) + (x_c * 1024)) - + (k_outer * 4) - ) - + 1 - ), - ], - 32, - ), - packedB[ - T.ramp( - (((y_outer * 32768) + (k_outer * 128)) + 32), 1, 32 - ) - ], - C_global[T.ramp((x_c * 32), 1, 32)], - dtype="float32x32", - ) - C_global[T.ramp((x_c * 32), 1, 32)] = T.call_llvm_pure_intrin( - T.uint32(97), - T.broadcast( - A[ - ( - ( - ((x_outer * 32768) + (x_c * 1024)) - + (k_outer * 4) - ) - + 2 - ), - ], - 32, - ), - packedB[ - T.ramp( - (((y_outer * 32768) + (k_outer * 128)) + 64), 1, 32 - ) - ], - C_global[T.ramp((x_c * 32), 1, 32)], - dtype="float32x32", - ) - C_global[T.ramp((x_c * 32), 1, 32)] = T.call_llvm_pure_intrin( - T.uint32(97), - T.broadcast( - A[ - ( - ( - ((x_outer * 32768) + (x_c * 1024)) - + (k_outer * 4) - ) - + 3 - ), - ], - 32, - ), - packedB[ - T.ramp( - (((y_outer * 32768) + (k_outer * 128)) + 96), 1, 32 - ) - ], - C_global[T.ramp((x_c * 32), 1, 32)], - dtype="float32x32", - ) - for x_inner in T.serial(0, 32): - for y_inner in T.serial(0, 32): - C[ - ( - ( - ((x_outer * 32768) + (x_inner * 1024)) - + (y_outer * 32) - ) - + y_inner - ) - ] = C_global[((x_inner * 32) + y_inner)] - if T.TVMBackendFreeWorkspace(1, dev_id, C_global.data, dtype="int32") != 0: - T.evaluate(T.tvm_throw_last_error(dtype="int32")) - if T.TVMBackendFreeWorkspace(1, dev_id, packedB.data, dtype="int32") != 0: - T.evaluate(T.tvm_throw_last_error(dtype="int32")) - - return Module - - def opt_conv_tensorcore_lower(): @T.prim_func def func( @@ -3060,19 +2765,6 @@ def func(): return func -def let_stmt_value(): - # uninitialized var - @T.prim_func(check_well_formed=False) - def func(): - y = T.int32() - with T.LetStmt(y) as x: - with T.LetStmt(0, var=y): - T.evaluate(0) - T.evaluate(0) - - return func - - def string_stride(): @T.prim_func def main(a: T.handle, b: T.handle): @@ -3612,7 +3304,6 @@ def func(A: R.Tensor(["N"], "float16"), _: R.Prim(value="threshold")): ir_generator = tvm.testing.parameter( launch_env_thread, opt_gemm_lower, - opt_gemm_mod_host, opt_conv_tensorcore_lower, opt_conv_tensorcore_mod_host, vthread_func, @@ -3670,7 +3361,6 @@ def func(A: R.Tensor(["N"], "float16"), _: R.Prim(value="threshold")): multi_env_threads, intrinsic_pow, let_stmt_var, - let_stmt_value, string_stride, string_stride_int64, merge_shape_var_def, From e76fc81d8d7f6f9787216f8a5b8c0f037480553a Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 2 Mar 2026 20:49:41 +0000 Subject: [PATCH 13/34] [REFACTOR][TIR] Disable CanInlineLetStmt for flat Bind With flat Bind there is no body to inspect for usage patterns, so Bind inlining (removing the Bind and substituting its value) is disabled. The analyzer still records variable bindings for constraint proving, but the Bind statement is always kept. Remove the CollectVarsUsedInBufferDefinition utility and used_in_buffer_def_ tracking which were only needed for the inlining codepath. Update tests to reflect that Binds are no longer eliminated. --- src/tir/transform/simplify.cc | 95 ++----------------- .../test_tir_transform_simplify.py | 66 ++++++++----- 2 files changed, 53 insertions(+), 108 deletions(-) diff --git a/src/tir/transform/simplify.cc b/src/tir/transform/simplify.cc index 2c49d862d093..173fcf7f6656 100644 --- a/src/tir/transform/simplify.cc +++ b/src/tir/transform/simplify.cc @@ -37,7 +37,6 @@ #include "../../arith/ir_mutator_with_analyzer.h" #include "../../tir/analysis/control_flow_graph.h" -#include "../../tir/analysis/var_use_def_analysis.h" namespace tvm { namespace arith { @@ -97,46 +96,6 @@ struct SimplifyConfigNode : public AttrsNodeReflAdapter { } }; -/* \brief Utility function to collect vars that should be retained */ -std::unordered_set CollectVarsUsedInBufferDefinition(const Stmt& stmt) { - struct Visitor : StmtExprVisitor { - using StmtExprVisitor::VisitExpr_; - using StmtExprVisitor::VisitStmt_; - - void VisitExpr_(const BufferLoadNode* op) override { - VisitBuffer(op->buffer); - StmtExprVisitor::VisitExpr_(op); - } - void VisitStmt_(const BufferStoreNode* op) override { - VisitBuffer(op->buffer); - StmtExprVisitor::VisitStmt_(op); - } - - void VisitBuffer(const Buffer& buf) { - // Collect variables that should remain defined - VarUseDefAnalyzer usage(ffi::Array{}); - usage(buf->data); - for (const auto& dim : buf->shape) { - usage(dim); - } - for (const auto& dim : buf->strides) { - usage(dim); - } - usage(buf->elem_offset); - - // Track for use in BindNode mutator - for (const auto& var : usage.undefined_) { - used_in_buffer_def_.insert(var.get()); - } - } - std::unordered_set used_in_buffer_def_; - }; - - Visitor visitor; - visitor(stmt); - return visitor.used_in_buffer_def_; -} - class SimplifyConfig : public Attrs { public: TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(SimplifyConfig, Attrs, SimplifyConfigNode); @@ -159,10 +118,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { touch_pattern = ControlFlowGraph(func->body); } - std::unordered_set used_in_buffer_def = - CollectVarsUsedInBufferDefinition(func->body); - StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern), - std::move(used_in_buffer_def)); + StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern)); simplifier.MarkBufferMapShapes(func); func.CopyOnWrite()->body = simplifier(func->body); return func; @@ -170,12 +126,8 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { private: explicit StmtSimplifier(Analyzer* analyzer, SimplifyConfig config, - std::optional touch_pattern, - std::unordered_set used_in_buffer_def) - : IRMutatorWithAnalyzer(analyzer), - config_(config), - touch_pattern_(touch_pattern), - used_in_buffer_def_(used_in_buffer_def) {} + std::optional touch_pattern) + : IRMutatorWithAnalyzer(analyzer), config_(config), touch_pattern_(touch_pattern) {} using Parent = IRMutatorWithAnalyzer; using Parent::VisitExpr_; @@ -220,46 +172,18 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { return Parent::VisitStmt_(op); } - bool CanInlineBind(const BindNode* op) { - if (is_const_number(op->value)) return true; - if (op->value.as()) return true; - // Won't face the deep expression explosion problem as in Let expression. - // attempt to inline as much as possible if the value integer type(can be index). - if (!op->value.dtype().is_int()) return false; - return SideEffect(op->value) <= CallEffectKind::kPure; - } - Stmt VisitStmt_(const BindNode* op) override { PrimExpr value = this->VisitExpr(op->value); - bool can_inline = CanInlineBind(op); - if (can_inline) { - // It is usually fine to discard the let binding because the - // call to simplify will always inline the var. - // - // The exception is when the variable is used in a Buffer's - // definition, as these are not updated by the simplification. - // After DeclBuffer is required prior to use of a buffer, - // simplifying can update the buffer definition as well. The - // buffer can only be updated at its point of definition, - // because the points of use may occur within contexts that - // allow for additional simplifications (e.g. a buffer of shape - // [i,j] whose first use occurs within "if i==1" should not have - // its shape simplified to [1,j]). + // Bind in analyzer for constraint proving and simplification of + // subsequent expressions. Don't remove the Bind statement -- + // with flat Bind there's no body to inspect for usage patterns, + // so we always keep the Bind. + if (SideEffect(value) <= CallEffectKind::kPure) { analyzer_->Bind(op->var, value); - } else if (SideEffect(op->value) <= CallEffectKind::kPure) { - // Even if we aren't replacing all occurrences, they may be - // necessary for proving conditional statements. non_inlined_bindings_.Set(op->var, value); } - // TODO(Lunderberg): Update the Buffer object as part of - // DeclBuffer updates, which will first require - // https://github.com/apache/tvm/pull/14778. - bool used_in_buffer_def = used_in_buffer_def_.count(op->var.get()); - - if (can_inline && !used_in_buffer_def) { - return Evaluate(0); - } else if (value.same_as(op->value)) { + if (value.same_as(op->value)) { return ffi::GetRef(op); } else { auto n = this->CopyOnWrite(op); @@ -350,7 +274,6 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { ffi::Map non_inlined_bindings_; ffi::Optional current_stmt_{std::nullopt}; - std::unordered_set used_in_buffer_def_; }; } // namespace arith diff --git a/tests/python/tir-transform/test_tir_transform_simplify.py b/tests/python/tir-transform/test_tir_transform_simplify.py index 46e094acfa20..ddc82e07739b 100644 --- a/tests/python/tir-transform/test_tir_transform_simplify.py +++ b/tests/python/tir-transform/test_tir_transform_simplify.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# ruff: noqa: E501 import tvm import tvm.testing from tvm.script import ir as I @@ -36,8 +35,14 @@ def func(A: T.handle("float32"), C: T.handle("float32"), n: T.int32): # Navigate through DeclBuffer nodes to reach the inner body while isinstance(body, tvm.tir.DeclBuffer): body = body.body - # After simplification, LetStmt -> For -> BufferStore (if is eliminated since i < 12 is always true for i in 0..10) - assert isinstance(body.body, tvm.tir.BufferStore) + # After simplification, Bind is kept (not inlined) but the if is eliminated + # since i < 12 is always true for i in 0..10. + # Body is SeqStmt(Bind(n_val, 10), For(i, ...)) + stmts = body if isinstance(body, tvm.tir.SeqStmt) else [body] + # Find the For loop in the sequence + for_stmt = [s for s in stmts if isinstance(s, tvm.tir.For)] + assert len(for_stmt) == 1, f"Expected one For loop, got {len(for_stmt)}" + assert isinstance(for_stmt[0].body, tvm.tir.BufferStore) def test_thread_extent_simplify(): @@ -56,11 +61,16 @@ def func(A: T.handle("float32"), C: T.handle("float32"), n: T.int32): # Navigate through DeclBuffer nodes to reach the inner body while isinstance(body, tvm.tir.DeclBuffer): body = body.body - # After simplification: For(tx) -> For(ty) -> BufferStore - # The LetStmt and if are eliminated since tx + ty < 12 is always true for tx in 0..10 and ty = 0 - assert isinstance(body, tvm.tir.For) # tx loop - assert isinstance(body.body, tvm.tir.For) # ty loop - assert isinstance(body.body.body, tvm.tir.BufferStore) # The if was eliminated + # After simplification: Bind is kept but the if is eliminated + # since tx + ty < 12 is always true for tx in 0..10 and ty = 0. + stmts = list(body) if isinstance(body, tvm.tir.SeqStmt) else [body] + for_stmts = [s for s in stmts if isinstance(s, tvm.tir.For)] + assert len(for_stmts) >= 1, f"Expected For loop, got stmts: {[type(s).__name__ for s in stmts]}" + # The outermost For is the tx loop + tx_loop = for_stmts[0] + assert isinstance(tx_loop, tvm.tir.For) # tx loop + assert isinstance(tx_loop.body, tvm.tir.For) # ty loop + assert isinstance(tx_loop.body.body, tvm.tir.BufferStore) # The if was eliminated def test_if_likely(): @@ -385,6 +395,10 @@ def test_prove_condition_using_let(): Not all let bindings are inlined when they occur in later expressions. However, even if they are not inlined, they may be used to prove the value of a condition. + + With flat Bind, the analyzer binds the variable to its value for + constraint proving, which also substitutes the variable in later + expressions. """ @T.prim_func(private=True) @@ -397,8 +411,8 @@ def before(A: T.Buffer(4, "bool")): @T.prim_func(private=True) def expected(A: T.Buffer(4, "bool")): for i in T.serial(4): - condition = i < 3 - A[i] = condition + condition: T.bool = i < 3 # noqa: F841 + A[i] = i < 3 after = _apply_simplify(before) tvm.ir.assert_structural_equal(after, expected) @@ -407,9 +421,8 @@ def expected(A: T.Buffer(4, "bool")): def test_prove_let_condition(): """Simplify conditions using non-inlined let bindings - Not all let bindings are inlined when they occur in later - expressions. However, even if they are not inlined, they may be - used to prove the value of a condition. + With flat Bind, analyzer binds variable to value, which also + substitutes the variable in later expressions. """ @T.prim_func(private=True) @@ -423,9 +436,9 @@ def before(A: T.Buffer(4, "bool")): @T.prim_func(private=True) def expected(A: T.Buffer(4, "bool")): for i in T.serial(4): - condition = i < 3 + condition: T.bool = i < 3 # noqa: F841 if i < 3: - A[i] = condition + A[i] = T.bool(True) after = _apply_simplify(before) tvm.ir.assert_structural_equal(after, expected) @@ -434,8 +447,9 @@ def expected(A: T.Buffer(4, "bool")): def test_prove_repeated_let_condition(): """Simplify conditions using non-inlined let bindings - A variable may be used as a literal constraint, and be recognized - as being True within the context of the constraint. + With analyzer Bind, the variable is substituted with its value, + so `if condition` becomes `if i < 3`, and within that context + the inner `if condition` simplifies to True and is eliminated. """ @T.prim_func(private=True) @@ -449,9 +463,9 @@ def before(A: T.Buffer(4, "bool")): @T.prim_func(private=True) def expected(A: T.Buffer(4, "bool")): for i in T.serial(4): - condition = i < 3 - if condition: - A[i] = True + condition: T.bool = i < 3 # noqa: F841 + if i < 3: + A[i] = T.bool(True) after = _apply_simplify(before) tvm.ir.assert_structural_equal(after, expected) @@ -492,7 +506,11 @@ def expected(A: T.Buffer(1, "int32")): def test_left_ceil_log2_lower_bound(): - """Integer bounds are propagated through topi.math.ceil_log2""" + """Integer bounds are propagated through topi.math.ceil_log2 + + With flat Bind, the Bind is kept even when the variable is unused + after simplification. The if condition is still eliminated. + """ @T.prim_func(private=True) def before(A: T.Buffer(16, "float32")): @@ -507,7 +525,11 @@ def before(A: T.Buffer(16, "float32")): @T.prim_func(private=True) def expected(A: T.Buffer(16, "float32")): for i in T.serial(16): - A[i] = 0.0 + x: T.int32 = T.Cast( # noqa: F841 + "int32", + T.ceil(T.log2(T.Cast("float64", i + 1025))), + ) + A[i] = T.float32(0) after = _apply_simplify(before) tvm.ir.assert_structural_equal(after, expected) From f7689d7c0be95c7788bcab8fac7860287d7270cc Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 2 Mar 2026 20:59:45 +0000 Subject: [PATCH 14/34] [REFACTOR][TIR] Fix Bind scope management in hoist_expression, ir_utils, and tests Update hoist_expression to manage Bind lifecycle in SeqStmt, fix IRConvertSSA to handle Bind redefinitions across SeqStmt siblings, and update test expectations for flat Bind semantics. --- .../analysis/sblock_access_region_detector.cc | 5 -- src/tir/transform/ir_utils.cc | 64 ++++++++++++++----- .../test_s_tir_transform_hoist_expression.py | 21 ++++-- .../test_tir_transform_remove_no_op.py | 28 +++++--- 4 files changed, 80 insertions(+), 38 deletions(-) diff --git a/src/s_tir/analysis/sblock_access_region_detector.cc b/src/s_tir/analysis/sblock_access_region_detector.cc index f5e35f27bfcf..aaba79827eb8 100644 --- a/src/s_tir/analysis/sblock_access_region_detector.cc +++ b/src/s_tir/analysis/sblock_access_region_detector.cc @@ -190,11 +190,6 @@ void BlockReadWriteDetector::VisitStmt_(const IfThenElseNode* op) { } void BlockReadWriteDetector::VisitStmt_(const BindNode* op) { - // Add the binding to let_bindings_ so it can be used for index substitution. - // The binding persists for subsequent siblings in the enclosing SeqStmt - // (default visitor processes children sequentially). Cleanup is handled by - // the enclosing block frame scope -- let_bindings_ is scoped to the - // BlockReadWriteDetector instance which processes one block at a time. let_bindings_[op->var.get()] = op->value; StmtVisitor::VisitStmt_(op); } diff --git a/src/tir/transform/ir_utils.cc b/src/tir/transform/ir_utils.cc index eb37a3f9b606..9f784b64b5ef 100644 --- a/src/tir/transform/ir_utils.cc +++ b/src/tir/transform/ir_utils.cc @@ -347,31 +347,61 @@ class IRConvertSSA final : public StmtExprMutator { } Stmt VisitStmt_(const BindNode* op) final { + // Note: ScopedRedefine for Bind must persist across SeqStmt siblings. + // This is handled by VisitStmt_(const SeqStmtNode*) below. + // When visited standalone (not as part of SeqStmt), just do a simple visit. const Var& v = op->var; if (defined_.count(v.get())) { - // In SSA form, each variable is defined once. When we encounter a - // redefinition via Bind, create a new variable and add a persistent - // remapping via function_scope_var_remap_. The rename persists for - // all subsequent siblings in the enclosing SeqStmt (handled by - // default sequential visitation). No need for ScopedRedefine since - // the mapping should not be popped at scope exit. PrimExpr value = this->VisitExpr(op->value); - Var new_var = [&]() { - bool is_size_var = v->IsInstance(); - if (v->type_annotation.defined()) { - return is_size_var ? Var(SizeVar(v->name_hint, v->type_annotation)) - : Var(v->name_hint, v->type_annotation); - } else { - return is_size_var ? Var(SizeVar(v->name_hint, v->dtype)) : Var(v->name_hint, v->dtype); - } - }(); - function_scope_var_remap_[v.get()] = new_var; - return Bind(new_var, value); + ScopedRedefine redefine(this, v); + return Bind(redefine.new_var, value); } else { defined_.insert(v.get()); return StmtExprMutator::VisitStmt_(op); } } + + Stmt VisitStmt_(const SeqStmtNode* op) final { + // Process children sequentially, maintaining ScopedRedefine for Bind nodes + // so that remappings persist for subsequent siblings (mimicking old nested + // Bind scope behavior). + std::vector seq_redefines; + ffi::Array new_seq; + bool changed = false; + + for (size_t i = 0; i < op->seq.size(); ++i) { + const Stmt& child = op->seq[i]; + if (auto* bind = child.as()) { + const Var& v = bind->var; + if (defined_.count(v.get())) { + PrimExpr value = this->VisitExpr(bind->value); + seq_redefines.emplace_back(this, v); + Stmt new_bind = Bind(seq_redefines.back().new_var, value); + new_seq.push_back(new_bind); + changed = true; + } else { + defined_.insert(v.get()); + Stmt visited = StmtExprMutator::VisitStmt_(bind); + new_seq.push_back(visited); + changed = changed || !visited.same_as(child); + } + } else { + Stmt visited = VisitStmt(child); + new_seq.push_back(visited); + changed = changed || !visited.same_as(child); + } + } + + // Pop redefines in reverse order (RAII would do this, but let's be explicit) + while (seq_redefines.size()) { + seq_redefines.pop_back(); + } + + if (!changed) { + return ffi::GetRef(op); + } + return SeqStmt(new_seq); + } Stmt VisitStmt_(const ForNode* op) final { const Var& v = op->loop_var; if (defined_.count(v.get())) { diff --git a/tests/python/s_tir/transform/test_s_tir_transform_hoist_expression.py b/tests/python/s_tir/transform/test_s_tir_transform_hoist_expression.py index ed82f3fecbb5..cf44114da180 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_hoist_expression.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_hoist_expression.py @@ -228,10 +228,10 @@ def before(A: T.Buffer((4, 4), "float32")): @T.prim_func(private=True) def expected(A: T.Buffer((4, 4), "float32")): for i in T.serial(4): - condition = i < 3 - if condition: + condition: T.bool = i < 3 # noqa: F841 + if i < 3: for j in T.serial(4): - A[i, j] = 0.0 + A[i, j] = T.float32(0.0) after = _run_transform(before, HoistedConditionals.All, HoistedLetBindings.All) tvm.ir.assert_structural_equal(after, expected) @@ -241,7 +241,9 @@ def test_hoist_disable_let(): """As test_hoist_with_let, but forbid hoisting of LetStmt Because the condition depends on the let binding, it should no - longer be hoisted. + longer be hoisted. With Bind lifecycle management, the condition + var is erased at sequence boundaries, so the if-condition uses + the raw expression. """ @T.prim_func(private=True) @@ -252,7 +254,12 @@ def before(A: T.Buffer((4, 4), "float32")): if condition: A[i, j] = 0.0 - expected = before + @T.prim_func(private=True) + def expected(A: T.Buffer((4, 4), "float32")): + for i, j in T.grid(4, 4): + condition: T.bool = i < 3 # noqa: F841 + if i < 3: + A[i, j] = T.float32(0.0) after = _run_transform(before, HoistedConditionals.All, HoistedLetBindings.Never) tvm.ir.assert_structural_equal(after, expected) @@ -512,9 +519,9 @@ def before(A: T.Buffer((4, 4), "float32")): @T.prim_func(private=True) def expected(A: T.Buffer((4, 4), "float32")): for i in T.serial(4): - x = T.cast(i + 1, "float32") + x: T.float32 = T.cast(i + 1, "float32") # noqa: F841 for j in T.serial(4): - A[i, j] = 5.0 * x + T.cast(j, "float32") + A[i, j] = T.float32(5.0) * T.cast(i + 1, "float32") + T.cast(j, "float32") after = _run_transform(before, HoistedConditionals.All, HoistedLetBindings.All) tvm.ir.assert_structural_equal(after, expected) diff --git a/tests/python/tir-transform/test_tir_transform_remove_no_op.py b/tests/python/tir-transform/test_tir_transform_remove_no_op.py index 55574e838172..668b652d7087 100644 --- a/tests/python/tir-transform/test_tir_transform_remove_no_op.py +++ b/tests/python/tir-transform/test_tir_transform_remove_no_op.py @@ -132,7 +132,11 @@ def expected(A: T.Buffer(16, "int32")): def test_remove_unused_let(): - """A let statement that is never used is a no-op.""" + """With flat Bind, unused bindings are preserved by remove_no_op. + + Unused Bind elimination requires a separate two-pass approach + and is not handled by the current remove_no_op pass. + """ @T.prim_func(private=True) def before(A: T.Buffer(16, "int32")): @@ -142,6 +146,7 @@ def before(A: T.Buffer(16, "int32")): @T.prim_func(private=True) def expected(A: T.Buffer(16, "int32")): + x = 5 for i in T.serial(16): A[i] = 0 @@ -151,10 +156,10 @@ def expected(A: T.Buffer(16, "int32")): def test_remove_let_used_only_in_no_op(): - """A let statement that is never used is a no-op. + """With flat Bind, unused bindings are preserved by remove_no_op. - Similar to test_remove_unused_let, but the usage of the let binding - may have been removed by an earlier removal of another no-op. + The zero-extent for loop is removed, but the Bind for x remains + since unused Bind elimination is not handled by remove_no_op. """ @T.prim_func(private=True) @@ -165,6 +170,7 @@ def before(A: T.Buffer(16, "int32")): @T.prim_func(private=True) def expected(A: T.Buffer(16, "int32")): + x = 5 T.evaluate(0) mod = tvm.IRModule.from_expr(before) @@ -173,7 +179,7 @@ def expected(A: T.Buffer(16, "int32")): def test_keep_side_effects_of_let(): - """The side effects of a no-op let must be kept.""" + """Side-effect Bind is preserved as-is by remove_no_op.""" @T.prim_func(private=True) def before(): @@ -182,7 +188,8 @@ def before(): @T.prim_func(private=True) def expected(): - T.evaluate(T.call_extern("extern_func", dtype="int32")) + x = T.call_extern("extern_func", dtype="int32") + T.evaluate(0) mod = tvm.IRModule.from_expr(before) mod = _apply_remove_no_op(mod) @@ -549,8 +556,9 @@ def test_remove_read_write_same_index_different_expression(): """Writing a value to the same location as the read is a no-op. If the value of the index can be proven to be the same, then the - no-op can be removed, even if they have different forms of the - expression. + store can be removed. With flat Bind, the Bind for i and the + enclosing loops remain since unused Bind elimination is not + handled by remove_no_op. """ @T.prim_func(private=True) @@ -561,7 +569,9 @@ def before(A: T.Buffer(16, "int32")): @T.prim_func(private=True) def expected(A: T.Buffer(16, "int32")): - T.evaluate(0) + for io in range(4): + for ii in range(4): + i: T.int32 = 4 * io + ii mod = tvm.IRModule.from_expr(before) mod = _apply_remove_no_op(mod) From 703813bbb9642ecf0936037dc88e474f8b5361a9 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 2 Mar 2026 21:00:24 +0000 Subject: [PATCH 15/34] [REFACTOR][TIR] Simplify tir_visitor_with_path: use scope-based Bind defs Use ScopeStack to manage Bind variable definitions. Body-carrying statements (For, IfThenElse, Allocate, DeclBuffer, AttrStmt, While, SBlock) push a new scope; BindNode pushes its WithDef into the current scope. When the scope exits all Bind defs are cleaned up automatically, removing the need for custom SeqStmt handling. --- src/tir/ir/tir_visitor_with_path.cc | 35 +++++++++++------------------ src/tir/ir/tir_visitor_with_path.h | 9 ++++++++ 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/src/tir/ir/tir_visitor_with_path.cc b/src/tir/ir/tir_visitor_with_path.cc index 6436a2869b5d..fc06509b5515 100644 --- a/src/tir/ir/tir_visitor_with_path.cc +++ b/src/tir/ir/tir_visitor_with_path.cc @@ -110,7 +110,7 @@ void TIRVisitorWithPath::Visit(const PrimFunc& func, AccessPath path) { } } - Visit(func->body, path->Attr("body")); + bind_scope_.WithNewScope([&]() { Visit(func->body, path->Attr("body")); }); while (context.size()) context.pop_back(); } @@ -173,11 +173,10 @@ void TIRVisitorWithPath::Visit(const Range& range, AccessPath path) { } void TIRVisitorWithPath::VisitStmt_(const BindNode* op, AccessPath path) { - // Bind has no body -- var scope is defined by the enclosing scope. Visit(op->value, path->Attr("value")); - // Note: we do NOT call WithDef here because Bind's var scope extends - // to subsequent siblings in the enclosing SeqStmt, not just a subtree. - // Scope tracking for BindNode is handled at the SeqStmt level by callers. + // Push the Bind's var definition into the current scope. + // The def lives until the enclosing scope (body-carrying stmt) exits. + bind_scope_.Current().push_back(WithDef(op->var, path->Attr("var"))); } void TIRVisitorWithPath::VisitStmt_(const AttrStmtNode* op, AccessPath path) { @@ -194,7 +193,7 @@ void TIRVisitorWithPath::VisitStmt_(const AttrStmtNode* op, AccessPath path) { } else if (auto expr = op->node.as()) { Visit(expr.value(), path->Attr("node")); } - Visit(op->body, path->Attr("body")); + bind_scope_.WithNewScope([&]() { Visit(op->body, path->Attr("body")); }); while (context.size()) { context.pop_back(); @@ -205,12 +204,12 @@ void TIRVisitorWithPath::VisitStmt_(const ForNode* op, AccessPath path) { Visit(op->min, path->Attr("min")); Visit(op->extent, path->Attr("extent")); auto context = WithDef(op->loop_var, path->Attr("loop_var")); - Visit(op->body, path->Attr("body")); + bind_scope_.WithNewScope([&]() { Visit(op->body, path->Attr("body")); }); } void TIRVisitorWithPath::VisitStmt_(const WhileNode* op, AccessPath path) { Visit(op->condition, path->Attr("condition")); - Visit(op->body, path->Attr("body")); + bind_scope_.WithNewScope([&]() { Visit(op->body, path->Attr("body")); }); } void TIRVisitorWithPath::VisitStmt_(const AllocBufferNode* op, AccessPath path) { @@ -225,7 +224,7 @@ void TIRVisitorWithPath::VisitStmt_(const AllocBufferNode* op, AccessPath path) void TIRVisitorWithPath::VisitStmt_(const DeclBufferNode* op, AccessPath path) { auto context = WithDef(op->buffer, path->Attr("buffer")); - Visit(op->body, path->Attr("body")); + bind_scope_.WithNewScope([&]() { Visit(op->body, path->Attr("body")); }); } void TIRVisitorWithPath::VisitStmt_(const BufferStoreNode* op, AccessPath path) { @@ -236,8 +235,8 @@ void TIRVisitorWithPath::VisitStmt_(const BufferStoreNode* op, AccessPath path) void TIRVisitorWithPath::VisitStmt_(const IfThenElseNode* op, AccessPath path) { Visit(op->condition, path->Attr("condition")); - Visit(op->then_case, path->Attr("then_case")); - Visit(op->else_case, path->Attr("else_case")); + bind_scope_.WithNewScope([&]() { Visit(op->then_case, path->Attr("then_case")); }); + bind_scope_.WithNewScope([&]() { Visit(op->else_case, path->Attr("else_case")); }); } void TIRVisitorWithPath::VisitStmt_(const AssertStmtNode* op, AccessPath path) { @@ -247,17 +246,9 @@ void TIRVisitorWithPath::VisitStmt_(const AssertStmtNode* op, AccessPath path) { } void TIRVisitorWithPath::VisitStmt_(const SeqStmtNode* op, AccessPath path) { - // Visit children sequentially. When a child is a BindNode, define its - // variable for all subsequent siblings (BindNode scope extends to - // the rest of the enclosing SeqStmt). auto seq_path = path->Attr("seq"); - std::vector> bind_defs; for (size_t i = 0; i < op->seq.size(); i++) { - auto item_path = seq_path->ArrayItem(i); - Visit(op->seq[i], item_path); - if (auto bind = op->seq[i].as()) { - bind_defs.push_back(WithDef(bind->var, item_path->Attr("var"))); - } + Visit(op->seq[i], seq_path->ArrayItem(i)); } } @@ -305,8 +296,8 @@ void TIRVisitorWithPath::VisitStmt_(const SBlockNode* op, AccessPath path) { } } - Visit(op->init, path->Attr("init")); - Visit(op->body, path->Attr("body")); + bind_scope_.WithNewScope([&]() { Visit(op->init, path->Attr("init")); }); + bind_scope_.WithNewScope([&]() { Visit(op->body, path->Attr("body")); }); while (context.size()) context.pop_back(); } diff --git a/src/tir/ir/tir_visitor_with_path.h b/src/tir/ir/tir_visitor_with_path.h index e271fd515179..92a5454e6473 100644 --- a/src/tir/ir/tir_visitor_with_path.h +++ b/src/tir/ir/tir_visitor_with_path.h @@ -25,6 +25,7 @@ #define TVM_TIR_IR_TIR_VISITOR_WITH_PATH_H_ #include +#include #include #include @@ -249,6 +250,14 @@ class TIRVisitorWithPath } std::unordered_set in_scope_definitions_; + + /*! \brief Scope stack for Bind variable definitions. + * + * Body-carrying statements (For, IfThenElse, etc.) push a new scope. + * BindNode pushes its WithDef into the current scope. When the + * scope exits, all Bind defs are cleaned up automatically. + */ + ScopeStack>> bind_scope_; }; } // namespace tir From 8575eaf473a636530eb524425ef18b22eaa1b2a9 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 2 Mar 2026 21:12:49 +0000 Subject: [PATCH 16/34] [REFACTOR][TIR] Fix remove_store_undef and inject_ptx_async_copy for flat Bind - remove_store_undef: eagerly check buffer indices for undef in the locator phase (flat Bind means the undef Bind is a sibling, not an ancestor, so post-validation alone cannot catch it). Also remove Bind nodes whose value contains undef alongside the removed stores. - inject_ptx_async_copy test: update expected CUDA to reflect that analyzer->Bind substitutes the variable with its value. --- src/s_tir/transform/remove_store_undef.cc | 61 ++++++++++++++----- ...t_s_tir_transform_inject_ptx_async_copy.py | 4 +- 2 files changed, 49 insertions(+), 16 deletions(-) diff --git a/src/s_tir/transform/remove_store_undef.cc b/src/s_tir/transform/remove_store_undef.cc index a5d2a9f9e267..1d383bec12b9 100644 --- a/src/s_tir/transform/remove_store_undef.cc +++ b/src/s_tir/transform/remove_store_undef.cc @@ -34,18 +34,24 @@ namespace tvm { namespace s_tir { using namespace tvm::tir; +struct UndefInfo { + std::unordered_set undef_stores; + std::unordered_set undef_bind_vars; +}; + class StoreUndefLocator : public StmtExprVisitor { public: - static std::unordered_set Locate(Stmt stmt) { + static UndefInfo Locate(Stmt stmt) { StoreUndefLocator locator; locator(std::move(stmt)); - return locator.undef_stores_; + return {locator.undef_stores_, locator.var_bindings_with_undef_}; } private: StoreUndefLocator() = default; void VisitStmt_(const BufferStoreNode* op) final { + // Check the value for undef. bool stash_undef = false; std::swap(has_undef_, stash_undef); StmtExprVisitor::VisitExpr(op->value); @@ -56,13 +62,29 @@ class StoreUndefLocator : public StmtExprVisitor { << "must not have other side effects"; undef_stores_.insert(op); } + + // Check indices for undef. Undef in buffer indices is always an + // error (there is no valid lowering). With flat Bind, we must + // check indices eagerly because the Bind node is a sibling rather + // than an ancestor and may be removed before post-validation. + bool idx_undef = false; + std::swap(has_undef_, idx_undef); + for (const auto& idx : op->indices) { + StmtExprVisitor::VisitExpr(idx); + } + std::swap(has_undef_, idx_undef); + TVM_FFI_ICHECK(!idx_undef) << "Error: T.undef() may not be used in buffer indices"; } void VisitExpr_(const BufferLoadNode* op) final { - // This function left deliberately empty. builtin::undef() - // shouldn't occur in the indices of BufferLoad. Avoiding - // visiting the indices catches the builtin::undef in - // ValidateAllUndefRemoved. + // Check indices for undef. Undef in buffer indices is always an error. + bool idx_undef = false; + std::swap(has_undef_, idx_undef); + for (const auto& idx : op->indices) { + StmtExprVisitor::VisitExpr(idx); + } + std::swap(has_undef_, idx_undef); + TVM_FFI_ICHECK(!idx_undef) << "Error: T.undef() may not be used in buffer indices"; } void VisitStmt_(const BindNode* op) final { @@ -97,33 +119,44 @@ class StoreUndefLocator : public StmtExprVisitor { std::unordered_set undef_stores_; }; -// Remove any BufferStores whose value depends on T.undef +// Remove BufferStores whose value depends on T.undef, and also +// remove Bind nodes whose value contains undef. Undef in buffer +// indices is already caught eagerly in the locator phase. class StoreUndefRemover : public StmtExprMutator { public: static Stmt Apply(Stmt stmt) { - auto to_remove = StoreUndefLocator::Locate(stmt); - StoreUndefRemover mutator(to_remove); + auto info = StoreUndefLocator::Locate(stmt); + StoreUndefRemover mutator(info); return mutator(std::move(stmt)); } private: using Parent = StmtExprMutator; - explicit StoreUndefRemover(const std::unordered_set& to_remove) - : to_remove_(to_remove) {} + explicit StoreUndefRemover(const UndefInfo& info) + : stores_to_remove_(info.undef_stores), bind_vars_to_remove_(info.undef_bind_vars) {} Stmt VisitStmt_(const BufferStoreNode* op) final { - if (to_remove_.count(op)) { + if (stores_to_remove_.count(op)) { + return Evaluate(0); + } else { + return Parent::VisitStmt_(op); + } + } + + Stmt VisitStmt_(const BindNode* op) final { + if (bind_vars_to_remove_.count(op->var.get())) { return Evaluate(0); } else { return Parent::VisitStmt_(op); } } - const std::unordered_set& to_remove_; + const std::unordered_set& stores_to_remove_; + const std::unordered_set& bind_vars_to_remove_; }; -// Remove any BufferStores whose value depends on T.undef +// Check that no builtin::undef() remains in the IR. class ContainsUndefChecker : public StmtExprVisitor { public: static bool Check(const Stmt& stmt) { diff --git a/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py b/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py index 974d4c25c8e6..53923c97bd27 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py @@ -335,7 +335,7 @@ def test_inject_async_copy_barrier(): { unsigned int addr = cast_smem_ptr_to_int(A_shared + ((((i + 3) & 3) * 16) + ((int)threadIdx.x))); - int pred_guard = (int)cse_v1; + int pred_guard = (int)(i < 12); __asm__ __volatile__( "{ .reg .pred p;" " setp.ne.b32 p, %0, 0;" @@ -358,7 +358,7 @@ def test_inject_async_copy_barrier(): { unsigned int addr = cast_smem_ptr_to_int(B_shared + ((((i + 3) & 3) * 16) + ((int)threadIdx.x))); - int pred_guard = (int)cse_v1; + int pred_guard = (int)(i < 12); __asm__ __volatile__( "{ .reg .pred p;" " setp.ne.b32 p, %0, 0;" From 6a5263061181becd01ef1b8315338ac80b231d51 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 2 Mar 2026 21:22:59 +0000 Subject: [PATCH 17/34] [REFACTOR][TIR] Restore hoist_expression SeqStmt handler lifecycle management Restore the original SeqStmt handler logic that tracks Bind vars defined in a sequence and erases them from let_var_to_loop_vars/let_var_to_let_vars maps when the sequence ends. Keep the refactor simple per user feedback. --- src/s_tir/transform/hoist_expression.cc | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/s_tir/transform/hoist_expression.cc b/src/s_tir/transform/hoist_expression.cc index 2717598d3c34..6403f3a7801a 100644 --- a/src/s_tir/transform/hoist_expression.cc +++ b/src/s_tir/transform/hoist_expression.cc @@ -329,9 +329,6 @@ class HoistInfoCollector : public StmtExprVisitor { void VisitStmt_(const SeqStmtNode* op) final { if (active_loops.size()) { - // Count non-Bind statements to determine if the loop body has true - // sequential operations. Bind nodes are variable definitions and don't - // introduce sequential ordering that would prevent hoisting. int non_bind_count = 0; for (size_t i = 0; i < op->seq.size(); ++i) { if (!op->seq[i].as()) { @@ -342,7 +339,17 @@ class HoistInfoCollector : public StmtExprVisitor { active_loops.back().reached_sequential_node = true; } } - Parent::VisitStmt_(op); + std::vector seq_bind_vars; + for (size_t i = 0; i < op->seq.size(); ++i) { + if (auto* bind = op->seq[i].as()) { + seq_bind_vars.push_back(bind->var.get()); + } + VisitStmt(op->seq[i]); + } + for (auto* var : seq_bind_vars) { + let_var_to_loop_vars.erase(var); + let_var_to_let_vars.erase(var); + } } void VisitExpr_(const LetNode* op) final { From ebf5d76c959d8201f028b10f6404c78e98cb232f Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 2 Mar 2026 21:44:23 +0000 Subject: [PATCH 18/34] [REFACTOR][TIR] Remove LetStmt/LetStmtNode backward-compat aliases The migration from LetStmt to Bind is complete. Remove all backward- compatibility aliases and deprecated wrappers: - Remove `using LetStmtNode = BindNode` and `using LetStmt = Bind` from include/tvm/tir/stmt.h - Remove `LetStmt()` wrapper and `LegacyLetStmt()` from C++ and Python script ir_builder - Remove `tir.LetStmt` FFI factory from stmt.cc - Remove `LetStmt = Bind` alias from python/tvm/tir/stmt.py - Rename `visit_let_stmt_` to `visit_bind_` in Python functor metadata and method names, matching the C++ `f_visit_bind` field - Rename `f_visit_let_stmt` parameters in py_functor.cc to `f_visit_bind` - Update all test files: T.LetStmt -> T.Bind, comments, function names --- include/tvm/script/ir_builder/tir/ir.h | 8 ------ include/tvm/tir/stmt.h | 5 ---- python/tvm/script/ir_builder/tir/ir.py | 28 +------------------ python/tvm/script/parser/tir/parser.py | 4 +-- python/tvm/tir/__init__.py | 2 +- python/tvm/tir/functor.py | 28 +++++++++---------- python/tvm/tir/stmt.py | 7 +---- src/script/ir_builder/tir/ir.cc | 9 ------ src/tir/ir/py_functor.cc | 10 +++---- src/tir/ir/stmt.cc | 10 ------- .../test_s_tir_transform_hoist_expression.py | 2 +- ..._plan_update_buffer_allocation_location.py | 2 +- .../test_s_tir_transform_thread_sync.py | 20 ++++++------- .../test_tir_analysis_verify_well_formed.py | 24 ++++++++-------- .../test_tir_inline_private_functions.py | 4 +-- .../test_tir_transform_common_subexpr_elim.py | 4 +-- .../test_tir_transform_convert_ssa.py | 10 +++---- .../test_tir_transform_simplify.py | 8 +++--- .../test_tir_transform_storage_rewrite.py | 4 +-- .../test_tvmscript_ir_builder_tir.py | 2 +- .../tvmscript/test_tvmscript_printer_tir.py | 4 +-- .../tvmscript/test_tvmscript_roundtrip.py | 8 +++--- .../tvmscript/test_tvmscript_syntax_sugar.py | 4 +-- 23 files changed, 71 insertions(+), 136 deletions(-) diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index 9b04b2fe635a..8113a8d0db0c 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -305,14 +305,6 @@ AssertFrame Assert(PrimExpr condition, ffi::String error_kind, LetFrame Bind(PrimExpr value, ffi::Optional type_annotation = std::nullopt, ffi::Optional var = std::nullopt); -/*! - * \brief Deprecated alias for Bind(). Use Bind() instead. - */ -inline LetFrame LetStmt(PrimExpr value, ffi::Optional type_annotation = std::nullopt, - ffi::Optional var = std::nullopt) { - return Bind(value, type_annotation, var); -} - /*! * \brief The allocate node. * \param extents The extents of the allocate. diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index ae30006c51fe..ab92e3550173 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -102,11 +102,6 @@ class Bind : public Stmt { TVM_DEFINE_OBJECT_REF_COW_METHOD(BindNode); }; -/*! \brief Deprecated: use BindNode instead. */ -using LetStmtNode = BindNode; -/*! \brief Deprecated: use Bind instead. */ -using LetStmt = Bind; - /*! * \brief Define certain auxiliary attribute for the body to be a symbolic value. * This provide auxiliary information for IR passes that transforms body. diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 0d6aa084ed17..e4143f006273 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1014,31 +1014,6 @@ def Bind( # pylint: disable=invalid-name return _ffi_api.Bind(value, type_annotation, var) # type: ignore[attr-defined] # pylint: disable=no-member -def LetStmt( # pylint: disable=invalid-name - value: PrimExpr, - type_annotation: Type | None = None, # pylint: disable=redefined-outer-name - *, - var: Var | None = None, # pylint: disable=redefined-outer-name -) -> frame.LetFrame: - """Deprecated alias for Bind(). Use T.Bind() instead. - - Parameters - ---------- - value : PrimExpr - The value to be bound. - type_annotation : Optional[Type] = None - The type annotation of the binding. - var : Optional[Var] = None - The variable to bind. If not specified, a new variable will be created. - - Returns - ------- - let_frame : frame.LetFrame - The result LetFrame. - """ - return Bind(value, type_annotation, var=var) - - def Let( # pylint: disable=invalid-name expr: PrimExpr, where: dict[Var, PrimExpr], # pylint: disable=redefined-outer-name @@ -1079,7 +1054,7 @@ def let_expr(v: Var, value: PrimExpr, body: PrimExpr) -> PrimExpr: @deprecated("T.let", "T.Bind") def let_stmt(v: Var, value: PrimExpr) -> frame.LetFrame: - return _ffi_api.LegacyLetStmt(v, value) # type: ignore[attr-defined] # pylint: disable=no-member + return Bind(value, var=v) if body is None: return let_stmt(v, value) @@ -2369,7 +2344,6 @@ def wrapped(*args, **kwargs): "CallEffectKind", "let", "Bind", - "LetStmt", "Let", "IterVar", "CommReducer", diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index e20629795cf9..bfd856b20cac 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -145,7 +145,7 @@ def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) - return value else: value = tvm.runtime.convert(value) - frame = T.LetStmt(value) + frame = T.Bind(value) var = frame.var IRBuilder.name(var_name, var) frame.add_callback(partial(frame.__exit__, None, None, None)) @@ -352,7 +352,7 @@ def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None: if not isinstance(ann_var, Var): self.report_error(node.annotation, "Annotation should be Var") self.eval_assign(target=lhs, source=ann_var, bind_value=bind_assign_value) - frame = T.LetStmt(rhs, var=ann_var) + frame = T.Bind(rhs, var=ann_var) frame.add_callback(partial(frame.__exit__, None, None, None)) frame.__enter__() diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index daa23817d1b6..caa94949807c 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -29,7 +29,7 @@ from .expr import Select, BufferLoad, ProducerLoad, Ramp, Broadcast, Shuffle from .expr import Call, CallEffectKind, Let, IterVar, CommReducer -from .stmt import Stmt, Bind, LetStmt, AssertStmt, ForKind, For, While +from .stmt import Stmt, Bind, AssertStmt, ForKind, For, While from .stmt import ( BufferStore, AllocBuffer, diff --git a/python/tvm/tir/functor.py b/python/tvm/tir/functor.py index bb91138edc86..9c684d8d990b 100644 --- a/python/tvm/tir/functor.py +++ b/python/tvm/tir/functor.py @@ -65,12 +65,12 @@ AllocBuffer, AssertStmt, AttrStmt, + Bind, BufferStore, DeclBuffer, Evaluate, For, IfThenElse, - LetStmt, SBlock, SBlockRealize, SeqStmt, @@ -160,7 +160,7 @@ def __init__( f_visit_stmt: Callable | None = None, f_visit_expr: Callable | None = None, # Stmt - f_visit_let_stmt: Callable | None = None, + f_visit_bind: Callable | None = None, f_visit_attr_stmt: Callable | None = None, f_visit_if_then_else: Callable | None = None, f_visit_for: Callable | None = None, @@ -214,7 +214,7 @@ def __init__( f_visit_stmt, f_visit_expr, # Stmt - f_visit_let_stmt, + f_visit_bind, f_visit_attr_stmt, f_visit_if_then_else, f_visit_for, @@ -277,7 +277,7 @@ class PyStmtExprVisitor: "visit_stmt", "visit_expr", # Stmt - "visit_let_stmt_", + "visit_bind_", "visit_attr_stmt_", "visit_if_then_else_", "visit_for_", @@ -373,17 +373,17 @@ def visit_if_then_else_(self, op: IfThenElse) -> None: print("visit_if_then_else_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore - def visit_let_stmt_(self, op: LetStmt) -> None: - """Visit Bind (LetStmt is a backward-compat alias for Bind). + def visit_bind_(self, op: Bind) -> None: + """Visit Bind. Users can customize this function to overwrite VisitStmt_(const BindNode* op) on the C++ side. Parameters ---------- - op : LetStmt + op : Bind The Bind node to be visited. """ - print("visit_let_stmt_", op) + print("visit_bind_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore def visit_for_(self, op: For) -> None: @@ -961,7 +961,7 @@ def __init__( f_visit_stmt: Callable | None = None, f_visit_expr: Callable | None = None, # Stmt - f_visit_let_stmt: Callable | None = None, + f_visit_bind: Callable | None = None, f_visit_attr_stmt: Callable | None = None, f_visit_if_then_else: Callable | None = None, f_visit_for: Callable | None = None, @@ -1015,7 +1015,7 @@ def __init__( f_visit_stmt, f_visit_expr, # Stmt - f_visit_let_stmt, + f_visit_bind, f_visit_attr_stmt, f_visit_if_then_else, f_visit_for, @@ -1078,7 +1078,7 @@ class PyStmtExprMutator: "visit_stmt", "visit_expr", # Stmt - "visit_let_stmt_", + "visit_bind_", "visit_attr_stmt_", "visit_if_then_else_", "visit_for_", @@ -1196,14 +1196,14 @@ def visit_if_then_else_(self, op: IfThenElse) -> Stmt: """ return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op) # type: ignore - def visit_let_stmt_(self, op: LetStmt) -> Stmt: - """Visit Bind (LetStmt is a backward-compat alias for Bind). + def visit_bind_(self, op: Bind) -> Stmt: + """Visit Bind. Users can customize this function to overwrite VisitStmt_(const BindNode* op) on the C++ side. Parameters ---------- - op : LetStmt + op : Bind The Bind node to be visited. Returns diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index 91d9f43d6e2a..f374aaa879ed 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -49,7 +49,7 @@ class Bind(Stmt): """Bind node. Bind a variable to a value in the enclosing scope. - Bind has no body field (unlike the old LetStmt which required a nested body). + Bind has no body field. The bound variable is visible in all subsequent statements within the same enclosing scope (SeqStmt, ForNode.body, etc.). @@ -78,11 +78,6 @@ def __init__(self, var: Var, value: PrimExpr, span: Span | None = None) -> None: ) -# Deprecated alias: use Bind instead. -# For backward compat: LetStmt(var, value, body) returns SeqStmt(Bind(var, value), body). -LetStmt = Bind - - @tvm_ffi.register_object("tir.AssertStmt") class AssertStmt(Stmt): """AssertStmt node. diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index bf515f832142..5bcd4b321dc4 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -473,13 +473,6 @@ LetFrame Bind(PrimExpr value, ffi::Optional type_annotation, ffi::Optional return LetFrame(n); } -LetFrame LegacyLetStmt(Var var, PrimExpr value) { - ObjectPtr n = ffi::make_object(); - n->var = var; - n->value = value; - return LetFrame(n); -} - LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) { IterVar iter_var{nullptr}; @@ -754,8 +747,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def("script.ir_builder.tir.Grid", Grid) .def("script.ir_builder.tir.Assert", Assert) .def("script.ir_builder.tir.Bind", Bind) - .def("script.ir_builder.tir.LetStmt", Bind) // backward-compat alias - .def("script.ir_builder.tir.LegacyLetStmt", LegacyLetStmt) .def("script.ir_builder.tir.Allocate", Allocate) .def("script.ir_builder.tir.Attr", Attr) .def("script.ir_builder.tir.While", While) diff --git a/src/tir/ir/py_functor.cc b/src/tir/ir/py_functor.cc index c4b0a81e3533..b385922cb950 100644 --- a/src/tir/ir/py_functor.cc +++ b/src/tir/ir/py_functor.cc @@ -340,7 +340,7 @@ class PyStmtExprVisitor : public ObjectRef { } TVM_DLL static PyStmtExprVisitor MakePyStmtExprVisitor(ffi::Function f_visit_stmt, // ffi::Function f_visit_expr, // - ffi::Function f_visit_let_stmt, // + ffi::Function f_visit_bind, // ffi::Function f_visit_attr_stmt, // ffi::Function f_visit_if_then_else, // ffi::Function f_visit_for, // @@ -390,8 +390,7 @@ class PyStmtExprVisitor : public ObjectRef { n->f_visit_stmt = std::move(f_visit_stmt); n->f_visit_expr = std::move(f_visit_expr); // Set statement functions - // f_visit_let_stmt is the Python-facing name; internally it maps to f_visit_bind - n->f_visit_bind = std::move(f_visit_let_stmt); + n->f_visit_bind = std::move(f_visit_bind); n->f_visit_attr_stmt = std::move(f_visit_attr_stmt); n->f_visit_if_then_else = std::move(f_visit_if_then_else); n->f_visit_for = std::move(f_visit_for); @@ -697,7 +696,7 @@ class PyStmtExprMutator : public ObjectRef { */ TVM_DLL static PyStmtExprMutator MakePyStmtExprMutator(ffi::Function f_visit_stmt, // ffi::Function f_visit_expr, // - ffi::Function f_visit_let_stmt, // + ffi::Function f_visit_bind, // ffi::Function f_visit_attr_stmt, // ffi::Function f_visit_if_then_else, // ffi::Function f_visit_for, // @@ -747,8 +746,7 @@ class PyStmtExprMutator : public ObjectRef { n->f_visit_stmt = std::move(f_visit_stmt); n->f_visit_expr = std::move(f_visit_expr); // Statement functions - // f_visit_let_stmt is the Python-facing name; internally it maps to f_visit_bind - n->f_visit_bind = std::move(f_visit_let_stmt); + n->f_visit_bind = std::move(f_visit_bind); n->f_visit_attr_stmt = std::move(f_visit_attr_stmt); n->f_visit_if_then_else = std::move(f_visit_if_then_else); n->f_visit_for = std::move(f_visit_for); diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 75ccf6708ec6..d691b3f38f38 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -76,16 +76,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { [](Var var, PrimExpr value, Span span) { return Bind(var, value, span); }); } -// LetStmt is now a deprecated alias for Bind. -// Keep the Python-facing factory for backward compat: tir.LetStmt(var, value, body) -// becomes SeqStmt(Bind(var, value), body). -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.LetStmt", [](Var var, PrimExpr value, Stmt body, Span span) { - return SeqStmt::Flatten(Bind(var, value, span), body); - }); -} - // AttrStmt AttrStmt::AttrStmt(ffi::Any node, ffi::String attr_key, PrimExpr value, Stmt body, Span span) { auto n = ffi::make_object(); diff --git a/tests/python/s_tir/transform/test_s_tir_transform_hoist_expression.py b/tests/python/s_tir/transform/test_s_tir_transform_hoist_expression.py index cf44114da180..8bca1da57793 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_hoist_expression.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_hoist_expression.py @@ -238,7 +238,7 @@ def expected(A: T.Buffer((4, 4), "float32")): def test_hoist_disable_let(): - """As test_hoist_with_let, but forbid hoisting of LetStmt + """As test_hoist_with_let, but forbid hoisting of Bind Because the condition depends on the let binding, it should no longer be hoisted. With Bind lifecycle management, the condition diff --git a/tests/python/s_tir/transform/test_s_tir_transform_plan_update_buffer_allocation_location.py b/tests/python/s_tir/transform/test_s_tir_transform_plan_update_buffer_allocation_location.py index 945f3d38f57e..c92528e5a2da 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_plan_update_buffer_allocation_location.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_plan_update_buffer_allocation_location.py @@ -373,7 +373,7 @@ def before(A: T.handle("float32")): def test_dltensor_buffer_is_unlowered(): - """Buffers allocated with a LetStmt are unmodified + """Buffers allocated with a Bind are unmodified Confirm that the `tir.PlanAndUpdateBufferAllocationLocation` pass leaves (Buffer nodes corresponding to PrimFunc DLTensor arguments) diff --git a/tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py b/tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py index 0b8c314a83ed..b1a1558a7482 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py @@ -100,7 +100,7 @@ def expected(A: T.Buffer((4, 4), "float32"), E: T.Buffer((4, 4), "float32")): @tvm.testing.requires_cuda -def test_sync_let_stmt(): +def test_sync_bind(): @T.prim_func(private=True) def func(A: T.Buffer((16 * 512), "float32")): blockIdx_x = T.launch_thread("blockIdx.x", 16) @@ -113,13 +113,13 @@ def func(A: T.Buffer((16 * 512), "float32")): A_shared_1[ax0] = A[blockIdx_x * 512 + ax0] in_thread_A_temp_1 = T.decl_buffer((1,), data=in_thread_A_temp, scope="local") in_thread_A_temp_1[0] = T.float32(0) - with T.LetStmt(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x]) as A_temp: + with T.Bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x]) as A_temp: in_thread_A_temp_1[0] = A_temp - with T.LetStmt(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 128]) as A_temp: + with T.Bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 128]) as A_temp: in_thread_A_temp_1[0] = A_temp - with T.LetStmt(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 256]) as A_temp: + with T.Bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 256]) as A_temp: in_thread_A_temp_1[0] = A_temp - with T.LetStmt(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 384]) as A_temp: + with T.Bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 384]) as A_temp: in_thread_A_temp_1[0] = A_temp cross_thread_A_temp_1 = T.decl_buffer((1,), data=cross_thread_A_temp, scope="local") with T.attr( @@ -148,13 +148,13 @@ def expected(A: T.Buffer((8192,), "float32")): in_thread_A_temp_1_1 = T.decl_buffer((1,), data=in_thread_A_temp_1, scope="local") in_thread_A_temp_1_1[0] = T.float32(0) T.tvm_storage_sync("shared") - with T.LetStmt(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x]) as A_temp: + with T.Bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x]) as A_temp: in_thread_A_temp_1_1[0] = A_temp - with T.LetStmt(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 128]) as A_temp: + with T.Bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 128]) as A_temp: in_thread_A_temp_1_1[0] = A_temp - with T.LetStmt(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 256]) as A_temp: + with T.Bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 256]) as A_temp: in_thread_A_temp_1_1[0] = A_temp - with T.LetStmt(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 384]) as A_temp: + with T.Bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 384]) as A_temp: in_thread_A_temp_1_1[0] = A_temp cross_thread_A_temp_1_1 = T.decl_buffer((1,), data=cross_thread_A_temp_1, scope="local") T.attr( @@ -180,4 +180,4 @@ def expected(A: T.Buffer((8192,), "float32")): test_sync_else_branch() test_sync_read_thread_id_independent_location() test_sync_shared_dyn() - test_sync_let_stmt() + test_sync_bind() diff --git a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py index d11281f79639..ee4e6ee2cda4 100644 --- a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py +++ b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py @@ -92,8 +92,8 @@ def test_error_for_nested_rebind_usage(): @T.prim_func(check_well_formed=False) def func(): i = T.int32() - with T.LetStmt(42, var=i): - with T.LetStmt(42, var=i): + with T.Bind(42, var=i): + with T.Bind(42, var=i): T.evaluate(i) with pytest.raises( @@ -113,9 +113,9 @@ def test_error_for_repeated_binding(): @T.prim_func(check_well_formed=False) def func(): i = T.int32() - with T.LetStmt(42, var=i): + with T.Bind(42, var=i): T.evaluate(i) - with T.LetStmt(17, var=i): + with T.Bind(17, var=i): T.evaluate(i) with pytest.raises(ValueError, match="multiple nested definitions of variable i"): @@ -131,12 +131,12 @@ def test_error_for_cross_function_reuse(): class mod: @T.prim_func def func1(): - with T.LetStmt(42, var=i): + with T.Bind(42, var=i): T.evaluate(i) @T.prim_func def func2(): - with T.LetStmt(42, var=i): + with T.Bind(42, var=i): T.evaluate(i) with pytest.raises(ValueError, match="multiple definitions of variable i"): @@ -295,10 +295,10 @@ def test_error_message_without_previous_definition_location(): def func(): x = T.int32() - with T.LetStmt(42, var=x): + with T.Bind(42, var=x): T.evaluate(x) - with T.LetStmt(99, var=x): # This should trigger the error + with T.Bind(99, var=x): # This should trigger the error T.evaluate(x) with pytest.raises(ValueError) as exc_info: @@ -322,8 +322,8 @@ def test_error_message_with_previous_definition_location(): def func(): x = T.int32() - with T.LetStmt(42, var=x): - with T.LetStmt(99, var=x): # This should trigger the error + with T.Bind(42, var=x): + with T.Bind(99, var=x): # This should trigger the error T.evaluate(x) with pytest.raises(ValueError) as exc_info: @@ -351,10 +351,10 @@ def test_sequential_redefinition_with_location(): def func(): x = T.int32() - with T.LetStmt(1, var=x): + with T.Bind(1, var=x): T.evaluate(x) - with T.LetStmt(2, var=x): # This should trigger the error + with T.Bind(2, var=x): # This should trigger the error T.evaluate(x) with pytest.raises(ValueError) as exc_info: diff --git a/tests/python/tir-transform/test_tir_inline_private_functions.py b/tests/python/tir-transform/test_tir_inline_private_functions.py index 41edd410a008..c5e3f2a07356 100644 --- a/tests/python/tir-transform/test_tir_inline_private_functions.py +++ b/tests/python/tir-transform/test_tir_inline_private_functions.py @@ -150,7 +150,7 @@ def subroutine(A_data: T.handle("float32"), B_data: T.handle("float32")): class Expected: @T.prim_func def main(A: T.Buffer([80, 16], "float32"), B: T.Buffer([64, 16], "float32")): - with T.LetStmt(T.address_of(A[0, 0]), var=T.handle("float32")) as A_data_1: + with T.Bind(T.address_of(A[0, 0]), var=T.handle("float32")) as A_data_1: A_1 = T.decl_buffer(16, "float32", data=A_data_1) B_data_1: T.handle("float32") = T.address_of(B[0, 0]) B_1 = T.decl_buffer(16, "float32", data=B_data_1) @@ -158,7 +158,7 @@ def main(A: T.Buffer([80, 16], "float32"), B: T.Buffer([64, 16], "float32")): with T.sblock("scalar_mul_1"): B_1[i] = A_1[i] * 2.0 - with T.LetStmt(T.address_of(A[1, 0]), var=T.handle("float32")) as A_data_2: + with T.Bind(T.address_of(A[1, 0]), var=T.handle("float32")) as A_data_2: A_2 = T.decl_buffer(16, "float32", data=A_data_2) B_data_2: T.handle("float32") = T.address_of(B[1, 0]) B_2 = T.decl_buffer(16, "float32", data=B_data_2) diff --git a/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py b/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py index f6bcd3cd97c8..506ec99d81e0 100644 --- a/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py +++ b/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py @@ -352,7 +352,7 @@ def func_distributivity( def func_distributivity_expected( B: T.Buffer((50,), "int32"), i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32 ) -> None: - with T.LetStmt((y + z) * x) as cse_v1: + with T.Bind((y + z) * x) as cse_v1: B[i1] = cse_v1 B[i2] = cse_v1 @@ -369,7 +369,7 @@ def func_associativity( def func_associativity_expected( B: T.Buffer((50,), "int32"), i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32 ) -> None: - with T.LetStmt(x + y + z) as cse_v1: + with T.Bind(x + y + z) as cse_v1: B[i1] = cse_v1 B[i2] = cse_v1 diff --git a/tests/python/tir-transform/test_tir_transform_convert_ssa.py b/tests/python/tir-transform/test_tir_transform_convert_ssa.py index de7cd764a7c6..1b6985cf69db 100644 --- a/tests/python/tir-transform/test_tir_transform_convert_ssa.py +++ b/tests/python/tir-transform/test_tir_transform_convert_ssa.py @@ -23,7 +23,7 @@ from tvm.script import tir as T -def test_reuse_in_sequential_let_stmt(): +def test_reuse_in_sequential_bind(): """De-dup sequential variable bindings""" # Manually construct the PrimFunc body, as SSA violations are @@ -42,9 +42,9 @@ def test_reuse_in_sequential_let_stmt(): @T.prim_func(private=True) def expected(): - with T.LetStmt(T.int32(16)) as var1: + with T.Bind(T.int32(16)) as var1: T.evaluate(var1) - with T.LetStmt(T.int32(32)) as var2: + with T.Bind(T.int32(32)) as var2: T.evaluate(var2) mod = tvm.IRModule.from_expr(before) @@ -52,7 +52,7 @@ def expected(): tvm.ir.assert_structural_equal(mod["main"], expected) -def test_reuse_in_nested_let_stmt(): +def test_reuse_in_nested_bind(): """De-dup sequential bindings of the same variable. In the flat Bind model, all Binds are siblings in a SeqStmt. A second @@ -108,7 +108,7 @@ def test_reused_var_across_module(): @T.prim_func(private=True) def func(): - with T.LetStmt(10) as var: + with T.Bind(10) as var: T.evaluate(var) before = tvm.IRModule( diff --git a/tests/python/tir-transform/test_tir_transform_simplify.py b/tests/python/tir-transform/test_tir_transform_simplify.py index ddc82e07739b..b99e41b569e2 100644 --- a/tests/python/tir-transform/test_tir_transform_simplify.py +++ b/tests/python/tir-transform/test_tir_transform_simplify.py @@ -1826,7 +1826,7 @@ def expected(A: T.Buffer(1, "int32")): def test_simplify_trivial_let_buffer_var(): - """A LetStmt used in a buffer definition should be retained""" + """A Bind used in a buffer definition should be retained""" @T.prim_func(private=True) def before(A_ptr: T.handle("float32")): @@ -1841,7 +1841,7 @@ def before(A_ptr: T.handle("float32")): def test_simplify_trivial_let_elem_offset(): - """A LetStmt used in a buffer definition should be retained, buffer fields unchanged""" + """A Bind used in a buffer definition should be retained""" @T.prim_func(private=True) def before(A_ptr: T.handle("float32"), A_offset: T.int32): @@ -1860,7 +1860,7 @@ def expected(A_ptr: T.handle("float32"), A_offset: T.int32): def test_simplify_trivial_let_shape(): - """A LetStmt used in a buffer definition should be retained, buffer fields unchanged""" + """A Bind used in a buffer definition should be retained""" @T.prim_func(private=True) def before(A_ptr: T.handle("float32"), A_size: T.int32): @@ -1879,7 +1879,7 @@ def expected(A_ptr: T.handle("float32"), A_size: T.int32): def test_simplify_trivial_let_stride(): - """A LetStmt used in a buffer definition should be retained, buffer fields unchanged""" + """A Bind used in a buffer definition should be retained""" @T.prim_func(private=True) def before(A_ptr: T.handle("float32"), A_stride: T.int32): diff --git a/tests/python/tir-transform/test_tir_transform_storage_rewrite.py b/tests/python/tir-transform/test_tir_transform_storage_rewrite.py index 5e509956df8d..5d713e6c7fd5 100644 --- a/tests/python/tir-transform/test_tir_transform_storage_rewrite.py +++ b/tests/python/tir-transform/test_tir_transform_storage_rewrite.py @@ -410,9 +410,9 @@ def test_let_buffer_rewrite(): If StorageRewrite replaces the backing variable of an array, such as when vectorizing the storage type, the variable must be - replaced in the LetStmt that defines it. Currently, StmtMutator + replaced in the Bind that defines it. Currently, StmtMutator only visits usage of variables, and does not visit definitions of - variables, so the definition in a LetStmt must be explicitly + variables, so the definition in a Bind must be explicitly handled. """ diff --git a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py index cac1371f5a53..48808052e1e5 100644 --- a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py +++ b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py @@ -316,7 +316,7 @@ def test_ir_builder_tir_assert(): def test_ir_builder_tir_let(): with IRBuilder() as ib: - with T.LetStmt(tir.IntImm("int32", 2)) as v: + with T.Bind(tir.IntImm("int32", 2)) as v: T.evaluate(0) # the let binding generated by IRBuilder let_actual = ib.get() diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py b/tests/python/tvmscript/test_tvmscript_printer_tir.py index 62d49f6c242f..c6ffa2a5e56c 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_tir.py +++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py @@ -252,9 +252,9 @@ def test_for(): ) -def test_let_stmt(): +def test_bind(): with IRBuilder() as ib: - with T.LetStmt(T.float32(10)) as v: + with T.Bind(T.float32(10)) as v: ib.name("v", v) T.evaluate(0) obj = ib.get() diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index 02efcb8ee8b7..37e241ac6604 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -2754,11 +2754,11 @@ def func(): return func -def let_stmt_var(): +def bind_var(): @T.prim_func def func(): - with T.LetStmt(0) as x: - with T.LetStmt(0) as y: + with T.Bind(0) as x: + with T.Bind(0) as y: T.evaluate(0) T.evaluate(0) @@ -3360,7 +3360,7 @@ def func(A: R.Tensor(["N"], "float16"), _: R.Prim(value="threshold")): *nested_boolean_expressions(), multi_env_threads, intrinsic_pow, - let_stmt_var, + bind_var, string_stride, string_stride_int64, merge_shape_var_def, diff --git a/tests/python/tvmscript/test_tvmscript_syntax_sugar.py b/tests/python/tvmscript/test_tvmscript_syntax_sugar.py index 0205758b90ca..32c593881b8f 100644 --- a/tests/python/tvmscript/test_tvmscript_syntax_sugar.py +++ b/tests/python/tvmscript/test_tvmscript_syntax_sugar.py @@ -410,7 +410,7 @@ def test_preserve_trivial_let_binding(): @T.prim_func def explicit(i: T.int32): j = T.int32() - T.LetStmt(i, var=j) + T.Bind(i, var=j) T.evaluate(j) @T.prim_func @@ -425,7 +425,7 @@ def test_preserve_trivial_let_binding_of_value(): @T.prim_func def explicit(i: T.int32): j = T.int32() - T.LetStmt(42, var=j) + T.Bind(42, var=j) T.evaluate(j) @T.prim_func From 7f8886f460557b138c2646bd6f1b61582424553a Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 2 Mar 2026 21:44:34 +0000 Subject: [PATCH 19/34] [REFACTOR][TIR] Refactor ConvertSSA to use ScopeStack for var remap management Replace the ScopedRedefine RAII struct and custom SeqStmt handler with ScopeStack for cleaner scope management: - Body-carrying statements (For, IfThenElse, AttrStmt, DeclBuffer, While, Allocate, SBlock) push a new scope via scope_.WithNewScope() - Bind pushes var remaps to the current scope level, persisting across SeqStmt siblings - Scope exit automatically pops all remaps via ScopeLevel destructor - Remove the custom VisitStmt_(SeqStmtNode*) -- default sequential iteration works because Bind remaps persist in the enclosing scope - Add IfThenElse handler with separate scopes per branch to prevent remap leakage between then/else cases --- src/tir/transform/ir_utils.cc | 371 ++++++++++++++++++++-------------- 1 file changed, 220 insertions(+), 151 deletions(-) diff --git a/src/tir/transform/ir_utils.cc b/src/tir/transform/ir_utils.cc index 9f784b64b5ef..3cd117f5ce3e 100644 --- a/src/tir/transform/ir_utils.cc +++ b/src/tir/transform/ir_utils.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -94,13 +95,14 @@ Stmt MergeNest(const std::vector>& nest, Stmt body) { class IRConvertSSA final : public StmtExprMutator { public: PrimFunc VisitPrimFunc(PrimFunc func) { - std::vector redefines; - - // Remap parameters, if they were used in another function + // Remap parameters, if they were used in another function. + // Function-scope remaps use function_scope_var_remap_ (not the scope stack), + // because they persist across the entire function body. auto params = func->params.Map([&](const tir::Var& var) -> tir::Var { if (defined_.count(var.get())) { - const ScopedRedefine& redefine = redefines.emplace_back(this, var); - return redefine.new_var; + Var new_var = MakeNewVar(var); + PushVarRemap(var, new_var); + return new_var; } else { defined_.insert(var.get()); return var; @@ -122,7 +124,7 @@ class IRConvertSSA final : public StmtExprMutator { // Buffer_map shape vars use "match" semantics: first occurrence // defines the var, subsequent occurrences (in other buffers) are - // just consistent uses of the same var — not redefinitions. + // just consistent uses of the same var -- not redefinitions. if (!defined_.count(var_ptr)) { defined_.insert(var_ptr); } @@ -144,7 +146,8 @@ class IRConvertSSA final : public StmtExprMutator { for (const auto& [var, buffer] : func->buffer_map) { auto new_var = GetRemappedVar(var); if (defined_.count(buffer->data.get())) { - redefines.emplace_back(this, buffer->data); + Var new_data = MakeNewVar(buffer->data); + PushVarRemap(buffer->data, new_data); } else { defined_.insert(buffer->data.get()); } @@ -195,10 +198,8 @@ class IRConvertSSA final : public StmtExprMutator { func = PrimFunc(params, body, func->ret_type, buffer_map, attrs); } - // Pop the redefines in reverse order of creation - while (redefines.size()) { - redefines.pop_back(); - } + // Pop function-scope remaps in reverse order + PopAllRemapsInCurrentScope(); function_scope_var_remap_.clear(); return func; } @@ -218,9 +219,11 @@ class IRConvertSSA final : public StmtExprMutator { const Var& v = op->var; if (defined_.count(v.get())) { PrimExpr value = this->VisitExpr(op->value); - ScopedRedefine redefine(this, v); + Var new_var = MakeNewVar(v); + PushVarRemap(v, new_var); PrimExpr body = this->VisitExpr(op->body); - return Let(redefine.new_var, value, body); + PopVarRemap(v, new_var); + return Let(new_var, value, body); } else { defined_.insert(v.get()); return StmtExprMutator::VisitExpr_(op); @@ -240,48 +243,48 @@ class IRConvertSSA final : public StmtExprMutator { } Stmt VisitStmt_(const DeclBufferNode* op) final { - DeclBuffer decl = Downcast(StmtExprMutator::VisitStmt_(op)); - Buffer new_buffer = GetRemappedBuffer(decl->buffer); - if (!new_buffer.same_as(decl->buffer)) { - decl.CopyOnWrite()->buffer = std::move(new_buffer); - } - return decl; + return scope_.WithNewScope([&]() -> Stmt { + DeclBuffer decl = Downcast(StmtExprMutator::VisitStmt_(op)); + Buffer new_buffer = GetRemappedBuffer(decl->buffer); + if (!new_buffer.same_as(decl->buffer)) { + decl.CopyOnWrite()->buffer = std::move(new_buffer); + } + return decl; + }); } Stmt VisitStmt_(const SBlockNode* op) final { SBlock block = ffi::GetRef(op); - // The BlockNode is the point of definition for the IterVar + // The SBlockNode is the point of definition for the IterVar // instances. These re-defines must be present before visiting - // the body of the BlockNode. - std::vector redefines; - ffi::Array iter_vars = op->iter_vars.Map([&](IterVar iter_var) { - if (defined_.count(iter_var->var.get())) { - redefines.emplace_back(this, iter_var->var); - iter_var.CopyOnWrite()->var = redefines.back().new_var; - } else { - defined_.insert(iter_var->var.get()); + // the body of the SBlockNode. + return scope_.WithNewScope([&]() -> Stmt { + ffi::Array iter_vars = op->iter_vars.Map([&](IterVar iter_var) { + if (defined_.count(iter_var->var.get())) { + Var new_var = MakeNewVar(iter_var->var); + PushVarRemap(iter_var->var, new_var); + iter_var.CopyOnWrite()->var = new_var; + } else { + defined_.insert(iter_var->var.get()); + } + return iter_var; + }); + ffi::Array reads = + block->reads.Map([&](const auto& region) { return VisitBufferAccess(region); }); + ffi::Array writes = + block->writes.Map([&](const auto& region) { return VisitBufferAccess(region); }); + + if (!reads.same_as(block->reads) || !writes.same_as(block->writes) || + !iter_vars.same_as(op->iter_vars)) { + auto write_ptr = block.CopyOnWrite(); + write_ptr->reads = reads; + write_ptr->writes = writes; + write_ptr->iter_vars = iter_vars; } - return iter_var; - }); - ffi::Array reads = - block->reads.Map([&](const auto& region) { return VisitBufferAccess(region); }); - ffi::Array writes = - block->writes.Map([&](const auto& region) { return VisitBufferAccess(region); }); - - if (!reads.same_as(block->reads) || !writes.same_as(block->writes) || - !iter_vars.same_as(op->iter_vars)) { - auto write_ptr = block.CopyOnWrite(); - write_ptr->reads = reads; - write_ptr->writes = writes; - write_ptr->iter_vars = iter_vars; - } - - Stmt output = Downcast(StmtExprMutator::VisitStmt_(block.get())); - while (redefines.size()) redefines.pop_back(); - - return output; + return Downcast(StmtExprMutator::VisitStmt_(block.get())); + }); } template @@ -296,7 +299,7 @@ class IRConvertSSA final : public StmtExprMutator { } Var GetRemappedVar(Var var) { - if (auto it = scope_.find(var.get()); it != scope_.end() && it->second.size()) { + if (auto it = var_remap_.find(var.get()); it != var_remap_.end() && it->second.size()) { return it->second.back(); } else if (auto it = function_scope_var_remap_.find(var.get()); it != function_scope_var_remap_.end()) { @@ -347,92 +350,77 @@ class IRConvertSSA final : public StmtExprMutator { } Stmt VisitStmt_(const BindNode* op) final { - // Note: ScopedRedefine for Bind must persist across SeqStmt siblings. - // This is handled by VisitStmt_(const SeqStmtNode*) below. - // When visited standalone (not as part of SeqStmt), just do a simple visit. + // Bind var remaps are tracked in the current scope so they persist + // across SeqStmt siblings and are cleaned up when the enclosing + // body-carrying statement's scope exits. const Var& v = op->var; if (defined_.count(v.get())) { PrimExpr value = this->VisitExpr(op->value); - ScopedRedefine redefine(this, v); - return Bind(redefine.new_var, value); + Var new_var = MakeNewVar(v); + PushVarRemap(v, new_var); + return Bind(new_var, value); } else { defined_.insert(v.get()); return StmtExprMutator::VisitStmt_(op); } } - Stmt VisitStmt_(const SeqStmtNode* op) final { - // Process children sequentially, maintaining ScopedRedefine for Bind nodes - // so that remappings persist for subsequent siblings (mimicking old nested - // Bind scope behavior). - std::vector seq_redefines; - ffi::Array new_seq; - bool changed = false; - - for (size_t i = 0; i < op->seq.size(); ++i) { - const Stmt& child = op->seq[i]; - if (auto* bind = child.as()) { - const Var& v = bind->var; - if (defined_.count(v.get())) { - PrimExpr value = this->VisitExpr(bind->value); - seq_redefines.emplace_back(this, v); - Stmt new_bind = Bind(seq_redefines.back().new_var, value); - new_seq.push_back(new_bind); - changed = true; - } else { - defined_.insert(v.get()); - Stmt visited = StmtExprMutator::VisitStmt_(bind); - new_seq.push_back(visited); - changed = changed || !visited.same_as(child); - } - } else { - Stmt visited = VisitStmt(child); - new_seq.push_back(visited); - changed = changed || !visited.same_as(child); - } - } - - // Pop redefines in reverse order (RAII would do this, but let's be explicit) - while (seq_redefines.size()) { - seq_redefines.pop_back(); + Stmt VisitStmt_(const IfThenElseNode* op) final { + // Each branch gets its own scope so Bind remaps in one branch + // do not leak into the other. + PrimExpr condition = VisitExpr(op->condition); + Stmt then_case = scope_.WithNewScope([&]() -> Stmt { return VisitStmt(op->then_case); }); + ffi::Optional else_case; + if (op->else_case) { + else_case = scope_.WithNewScope([&]() -> Stmt { return VisitStmt(op->else_case.value()); }); } - - if (!changed) { + if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && + else_case.same_as(op->else_case)) { return ffi::GetRef(op); } - return SeqStmt(new_seq); + return IfThenElse(condition, then_case, else_case); } + Stmt VisitStmt_(const ForNode* op) final { const Var& v = op->loop_var; if (defined_.count(v.get())) { - ScopedRedefine redefine(this, v); - Stmt stmt = StmtExprMutator::VisitStmt_(op); - auto n = ffi::make_object(*stmt.as()); - n->loop_var = redefine.new_var; - return For(n); + return scope_.WithNewScope([&]() -> Stmt { + Var new_var = MakeNewVar(v); + PushVarRemap(v, new_var); + Stmt stmt = StmtExprMutator::VisitStmt_(op); + auto n = ffi::make_object(*stmt.as()); + n->loop_var = new_var; + return For(n); + }); } else { defined_.insert(v.get()); - return StmtExprMutator::VisitStmt_(op); + return scope_.WithNewScope([&]() -> Stmt { return StmtExprMutator::VisitStmt_(op); }); } } + Stmt VisitStmt_(const WhileNode* op) final { + return scope_.WithNewScope([&]() -> Stmt { return StmtExprMutator::VisitStmt_(op); }); + } Stmt VisitStmt_(const AllocBufferNode* op) final { const Var& v = op->buffer->data; if (defined_.count(v.get())) { - ScopedRedefine redefine(this, v); - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - // Use GetRemappedBuffer so that the AllocBuffer's buffer is the same - // object as the one used by BufferStore/BufferLoad in the body. - Buffer new_buf = GetRemappedBuffer(op->buffer); - if (!new_buf.same_as(op->buffer)) { - auto node = Downcast(stmt); - node.CopyOnWrite()->buffer = std::move(new_buf); - return node; - } - return stmt; + return scope_.WithNewScope([&]() -> Stmt { + Var new_var = MakeNewVar(v); + PushVarRemap(v, new_var); + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + // Use GetRemappedBuffer so that the AllocBuffer's buffer is the same + // object as the one used by BufferStore/BufferLoad in the body. + Buffer new_buf = GetRemappedBuffer(op->buffer); + if (!new_buf.same_as(op->buffer)) { + auto node = Downcast(stmt); + node.CopyOnWrite()->buffer = std::move(new_buf); + return node; + } + return stmt; + }); } else { defined_.insert(v.get()); - return StmtExprMutator::VisitStmt_(op); + return scope_.WithNewScope([&]() -> Stmt { return StmtExprMutator::VisitStmt_(op); }); } } Stmt VisitStmt_(const AttrStmtNode* op) final { @@ -490,7 +478,7 @@ class IRConvertSSA final : public StmtExprMutator { } auto value = VisitExpr(op->value); - auto body = VisitStmt(op->body); + auto body = scope_.WithNewScope([&]() -> Stmt { return VisitStmt(op->body); }); Stmt output; if (new_iter_var.get() == iter_var && body.same_as(op->body) && value.same_as(op->value)) { @@ -509,75 +497,156 @@ class IRConvertSSA final : public StmtExprMutator { return output; } else if (const VarNode* v = op->node.as()) { - Stmt stmt = StmtExprMutator::VisitStmt_(op); + Stmt stmt = scope_.WithNewScope([&]() -> Stmt { return StmtExprMutator::VisitStmt_(op); }); op = stmt.as(); - if (scope_.count(v) && scope_[v].size() != 0) { - return AttrStmt(scope_[v].back(), op->attr_key, op->value, op->body); + if (var_remap_.count(v) && var_remap_[v].size() != 0) { + return AttrStmt(var_remap_[v].back(), op->attr_key, op->value, op->body); } else { return stmt; } } else { - return StmtExprMutator::VisitStmt_(op); + return scope_.WithNewScope([&]() -> Stmt { return StmtExprMutator::VisitStmt_(op); }); } } private: - struct ScopedRedefine { - ScopedRedefine(IRConvertSSA* parent, Var old_var) : parent(parent), old_var(old_var) { - bool is_size_var = old_var->IsInstance(); - if (old_var->type_annotation.defined()) { - if (is_size_var) { - new_var = SizeVar(old_var->name_hint, old_var->type_annotation); - } else { - new_var = Var(old_var->name_hint, old_var->type_annotation); - } + /*! \brief Record of a variable remap pushed to the current scope. */ + struct VarRemap { + Var old_var; + Var new_var; + }; + + /*! \brief Create a new variable with the same name and type as the original. */ + static Var MakeNewVar(const Var& old_var) { + bool is_size_var = old_var->IsInstance(); + if (old_var->type_annotation.defined()) { + if (is_size_var) { + return SizeVar(old_var->name_hint, old_var->type_annotation); } else { - if (is_size_var) { - new_var = SizeVar(old_var->name_hint, old_var->dtype); - } else { - new_var = Var(old_var->name_hint, old_var->dtype); - } + return Var(old_var->name_hint, old_var->type_annotation); + } + } else { + if (is_size_var) { + return SizeVar(old_var->name_hint, old_var->dtype); + } else { + return Var(old_var->name_hint, old_var->dtype); } - parent->scope_[old_var.get()].push_back(new_var); } + } + + /*! \brief Push a variable remap to the current scope and the var_remap_ stack. */ + void PushVarRemap(const Var& old_var, const Var& new_var) { + var_remap_[old_var.get()].push_back(new_var); + auto& level = scope_.Current(); + level.parent = this; + level.push_back({old_var, new_var}); + } - ~ScopedRedefine() { - if (parent) { - parent->scope_[old_var.get()].pop_back(); + /*! \brief Pop a single variable remap (used for expression-level Let scoping). */ + void PopVarRemap(const Var& old_var, const Var& new_var) { + var_remap_[old_var.get()].pop_back(); + for (auto& kv : buf_remap_) { + std::vector& buffers = kv.second; + if (buffers.size() && (buffers.back()->data.get() == new_var.get())) { + buffers.pop_back(); + } + } + // Also remove from the current scope's tracking vector + auto& current = scope_.Current(); + if (current.size() && current.back().new_var.same_as(new_var)) { + current.pop_back(); + } + } + + /*! \brief Pop all remaps in the current scope level (used for function-scope cleanup). */ + void PopAllRemapsInCurrentScope() { + auto& current = scope_.Current(); + while (current.size()) { + auto& remap = current.back(); + var_remap_[remap.old_var.get()].pop_back(); + for (auto& kv : buf_remap_) { + std::vector& buffers = kv.second; + if (buffers.size() && (buffers.back()->data.get() == remap.new_var.get())) { + buffers.pop_back(); + } + } + current.pop_back(); + } + } + + /*! \brief Scope stack: each scope level holds the remaps introduced in that scope. + * + * When a body-carrying statement (For, SBlock, Allocate) calls + * scope_.WithNewScope([&]{...}), a new scope level is pushed. + * Bind statements push their remaps to the current scope. + * On scope exit, the destructor of std::vector triggers, + * and we undo all remaps in that level. + * + * Note: ScopeStack::WithNewScope calls T's destructor on exit. + * std::vector's destructor destroys elements but does NOT call custom + * cleanup. So we wrap the vector in ScopeLevel which handles cleanup. + */ + struct ScopeLevel { + std::vector remaps; + IRConvertSSA* parent{nullptr}; + + void push_back(VarRemap remap) { remaps.push_back(std::move(remap)); } + size_t size() const { return remaps.size(); } + VarRemap& back() { return remaps.back(); } + void pop_back() { remaps.pop_back(); } + + ~ScopeLevel() { + if (!parent) return; + // Pop remaps in reverse order + while (remaps.size()) { + auto& remap = remaps.back(); + parent->var_remap_[remap.old_var.get()].pop_back(); for (auto& kv : parent->buf_remap_) { std::vector& buffers = kv.second; - if (buffers.size() && (buffers.back()->data.get() == new_var.get())) { + if (buffers.size() && (buffers.back()->data.get() == remap.new_var.get())) { buffers.pop_back(); } } + remaps.pop_back(); } } - ScopedRedefine& operator=(const ScopedRedefine&) = delete; - ScopedRedefine(const ScopedRedefine&) = delete; - - ScopedRedefine& operator=(ScopedRedefine&& other) { - swap(other); + ScopeLevel() = default; + ScopeLevel(const ScopeLevel&) = delete; + ScopeLevel& operator=(const ScopeLevel&) = delete; + ScopeLevel(ScopeLevel&& other) noexcept + : remaps(std::move(other.remaps)), parent(other.parent) { + other.parent = nullptr; // prevent other's destructor from popping + } + ScopeLevel& operator=(ScopeLevel&& other) noexcept { + if (this != &other) { + // Run our destructor logic first + if (parent) { + while (remaps.size()) { + auto& remap = remaps.back(); + parent->var_remap_[remap.old_var.get()].pop_back(); + for (auto& kv : parent->buf_remap_) { + std::vector& buffers = kv.second; + if (buffers.size() && (buffers.back()->data.get() == remap.new_var.get())) { + buffers.pop_back(); + } + } + remaps.pop_back(); + } + } + remaps = std::move(other.remaps); + parent = other.parent; + other.parent = nullptr; + } return *this; } - ScopedRedefine(ScopedRedefine&& other) { swap(other); } - - void swap(ScopedRedefine& other) { - std::swap(parent, other.parent); - std::swap(old_var, other.old_var); - std::swap(new_var, other.new_var); - } - - IRConvertSSA* parent{nullptr}; - Var old_var; - Var new_var; }; - std::unordered_map> scope_; + std::unordered_map> var_remap_; std::unordered_set defined_; std::unordered_map> buf_remap_; - std::unordered_map function_scope_var_remap_; + ScopeStack scope_; }; Stmt ConvertSSA(Stmt stmt) { return IRConvertSSA()(std::move(stmt)); } From d5374ee9bcd96244b803a15cbf3548c3f7b0393b Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 2 Mar 2026 22:04:57 +0000 Subject: [PATCH 20/34] [REFACTOR][TIR] Refactor CSE pass to use ScopeStack for context management Replace manual save/restore of context_ in the Common Subexpression Elimination pass with ScopeStack-based automatic scope management. Key changes: - Add ScopeStack where each scope level records the context size on entry and truncates it back on exit via destructor - ForNode, LetNode: WithNewScope replaces manual context save/restore - New scope-boundary overrides for IfThenElse, AttrStmt, Allocate, DeclBuffer, While to prevent context leaks across scope boundaries - SeqStmtNode: remove manual context save/restore (enclosing scope handles cleanup), retain wrap-remaining-siblings pattern for cross-sibling CSE after Bind nodes - BindNode: entries persist across SeqStmt siblings, cleaned up automatically when enclosing body-carrying statement's scope exits --- src/tir/transform/common_subexpr_elim.cc | 240 ++++++++++++++++++----- src/tir/transform/common_subexpr_elim.h | 81 +++++++- 2 files changed, 268 insertions(+), 53 deletions(-) diff --git a/src/tir/transform/common_subexpr_elim.cc b/src/tir/transform/common_subexpr_elim.cc index 1c4a3c2c68e1..9e814693a157 100644 --- a/src/tir/transform/common_subexpr_elim.cc +++ b/src/tir/transform/common_subexpr_elim.cc @@ -201,7 +201,11 @@ Stmt CommonSubexpressionEliminator::PerformCSE(const Stmt& stmt, const Context& CommonSubexpressionEliminator::CommonSubexpressionEliminator(const Stmt& stmt, const Context& context_init, bool identify_equiv_terms) - : initial_body_(stmt), context_(context_init), identify_equiv_terms_(identify_equiv_terms) {} + : initial_body_(stmt), context_(context_init), identify_equiv_terms_(identify_equiv_terms) { + // The initial scope level (from ScopeStack's constructor) does not need + // EnterContextScope() because it should never be popped -- it persists + // for the lifetime of the CSE pass and holds the function parameters. +} /*! * \brief The method which overrides the generic dispatcher of StmtExprMutator. @@ -342,32 +346,31 @@ PrimExpr CommonSubexpressionEliminator::VisitExpr(const PrimExpr& expr) { } /*! - * \brief The method which overrides the specific treatment for a LetNode + * \brief The method which overrides the specific treatment for a LetNode. + * + * The let-in expression introduces a new variable binding that is only visible + * within the body. We use context_scope_.WithNewScope to automatically clean up + * the binding when the body has been visited, replacing the old manual + * save/restore of context_. */ PrimExpr CommonSubexpressionEliminator::VisitExpr_(const LetNode* op) { // At this point, we have already done the generic treatment of introducing (via let-in) what // was doable at the toplevel of the given let-in. - // Save the context at the entry of the function - Context context_at_entry = context_; - // Recurse on the `value` field for potentially rewriting it PrimExpr value_new = VisitExpr(op->value); - // Augment the context with the association (`var`, `value`) for preparing the next recursion - // on the `body` - context_.push_back({op->var, MaybeValue(op->value)}); - - // Recurse on the `body` (with this extended context) - // The recursive call will have potentially done new simplifications, because in this recursive - // call `var` will be a part of the context. - // (see in VisitExpr() that no introduction were performed when a computation was using an - // undefined variable, as that would lead to ill-formed code) - PrimExpr body_new = VisitExpr(op->body); - - // Restaure the context to its content at the entrance to not carry out of scope declarations - // as the variable introduced by the let-in is not in scope outside of its body - context_ = context_at_entry; + // Visit the body in a new scope. The let-in variable binding is added to the + // context inside the scope and automatically removed when the scope exits. + PrimExpr body_new = context_scope_.WithNewScope([&]() -> PrimExpr { + EnterContextScope(); + // Augment the context with the association (`var`, `value`) for the body + context_.push_back({op->var, MaybeValue(op->value)}); + // Recurse on the `body` (with this extended context) + // The recursive call will have potentially done new simplifications, because in this recursive + // call `var` will be a part of the context. + return VisitExpr(op->body); + }); // Rebuild the let-in with a new `value_new` and `body_new` where new simplifications might // have been done. @@ -523,14 +526,22 @@ Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) { } /*! - * \brief The method which overrides the specific treatment for a BindNode + * \brief The method which overrides the specific treatment for a BindNode. + * + * BindNode adds a (var, value) entry to the flat context_ vector. This entry + * persists across subsequent SeqStmt siblings in the same scope, enabling CSE + * to find common subexpressions that reference bind-defined variables. + * Cleanup happens automatically when the enclosing body-carrying statement's + * scope exits (via ContextScopeLevel's destructor), so no manual save/restore + * is needed here. */ Stmt CommonSubexpressionEliminator::VisitStmt_(const BindNode* op) { // Recurse on the `value` field for potentially rewriting it PrimExpr value_new = VisitExpr(op->value); - // Augment the context with the association (`var`, `value`) - // so that subsequent sibling statements in the SeqStmt can use it. + // Augment the context with the association (`var`, `value`). + // This persists across SeqStmt siblings and is cleaned up by the + // enclosing scope's ContextScopeLevel destructor. context_.push_back({op->var, MaybeValue(op->value)}); // Rebuild the Bind if value changed @@ -544,19 +555,24 @@ Stmt CommonSubexpressionEliminator::VisitStmt_(const BindNode* op) { /*! * \brief The method which overrides the specific treatment for a SeqStmtNode. * - * Process the flat sequence one child at a time: - * - Bind nodes: process the value (via VisitExpr), augment context, then wrap - * all remaining siblings as a body and pass to VisitStmt for cross-sibling - * CSE with the newly augmented context. - * - Non-Bind nodes: process individually via VisitStmt, then continue to the - * next child. + * Processes the flat sequence one child at a time: + * - Bind nodes: visit the value (via VisitExpr), augment context_, then wrap + * all remaining siblings into a body and call VisitStmt for cross-sibling + * CSE with the newly augmented context. This re-runs the CSE top-level + * computation collection, allowing new common subexpressions to be found + * and introduced now that the bind-defined variable is in scope. + * - Non-Bind nodes: visit individually via VisitStmt, then continue. + * + * Context cleanup is handled automatically by ScopeStack. The enclosing + * body-carrying statement (For, IfThenElse, etc.) creates a scope; when + * that scope exits, all context entries added here (from Bind nodes) are + * cleaned up. No manual save/restore of context_ is needed. * - * This approach ensures that each Bind variable is available in the context - * when analyzing subsequent siblings, enabling CSE to find common - * subexpressions that use Bind-defined variables. + * Note: this still uses the "wrap remaining siblings" pattern after Bind + * nodes, which is necessary because VisitStmt re-runs CSE computation + * collection on the wrapped body, enabling cross-sibling optimizations. */ Stmt CommonSubexpressionEliminator::VisitStmt_(const SeqStmtNode* op) { - Context context_at_entry = context_; ffi::Array new_seq; for (size_t i = 0; i < op->seq.size(); ++i) { @@ -568,8 +584,9 @@ Stmt CommonSubexpressionEliminator::VisitStmt_(const SeqStmtNode* op) { : Bind(bind->var, value_new, bind->span); new_seq.push_back(bind_new); - // Now wrap remaining siblings [i+1..end) as a body and call VisitStmt - // for cross-sibling CSE with the updated context. + // Wrap remaining siblings [i+1..end) as a body and call VisitStmt + // to re-run CSE computation collection with the updated context. + // This enables CSE to introduce new bindings between siblings. if (i + 1 < op->seq.size()) { Stmt body; if (i + 2 == op->seq.size()) { @@ -586,11 +603,10 @@ Stmt CommonSubexpressionEliminator::VisitStmt_(const SeqStmtNode* op) { } else { new_seq.push_back(body_new); } - context_ = context_at_entry; return SeqStmt::Flatten(new_seq); } } else { - // Non-Bind child: process individually, then continue. + // Non-Bind child: visit individually, then continue. Stmt child_new = VisitStmt(op->seq[i]); if (auto* inner = child_new.as()) { for (const auto& s : inner->seq) new_seq.push_back(s); @@ -600,36 +616,36 @@ Stmt CommonSubexpressionEliminator::VisitStmt_(const SeqStmtNode* op) { } } - context_ = context_at_entry; return SeqStmt::Flatten(new_seq); } /*! - * \brief The method which overrides the specific treatment for a ForNode + * \brief The method which overrides the specific treatment for a ForNode. + * + * The for loop introduces a loop variable that is only visible within the body. + * We use context_scope_.WithNewScope to create a scope boundary: the loop + * variable (with no value, since it changes each iteration) is pushed inside + * the scope and automatically cleaned up on exit, replacing the old manual + * save/restore of context_. */ Stmt CommonSubexpressionEliminator::VisitStmt_(const ForNode* op) { // At this point, we have already done the generic treatment of introducing (via let-in) what // was doable at the toplevel of the given for loop. - // Save the context at the entry of the function - Context context_at_entry = context_; - // Recurse on the `min` field for potentially rewriting it PrimExpr min_new = VisitExpr(op->min); // Recurse on the `extent` field for potentially rewriting it PrimExpr extent_new = VisitExpr(op->extent); - // Augment the context with the association {loop_var, no value} (no value as its value will - // change during the execution of the loop) for preparing the next recursion on the `body` - context_.push_back({op->loop_var, MaybeValue()}); - - // Recurse on the `body` (with this extended context) - Stmt body_new = VisitStmt(op->body); - - // Restaure the context to its content at the entrance to not carry out of scope declarations - // as the variable introduced by the for loop is not in scope outside of its body - context_ = context_at_entry; + // Visit the body in a new scope. The loop variable is added to context_ inside + // the scope and automatically removed when the scope exits. + Stmt body_new = context_scope_.WithNewScope([&]() -> Stmt { + EnterContextScope(); + // Add loop_var with no value (its value changes each iteration) + context_.push_back({op->loop_var, MaybeValue()}); + return VisitStmt(op->body); + }); // Rebuild the for loop with (potentially) a new `min_new`, `extent_new` and `body_new`, where // new simplifications might have been done. @@ -646,6 +662,128 @@ Stmt CommonSubexpressionEliminator::VisitStmt_(const ForNode* op) { } } +/*! + * \brief The method which overrides the specific treatment for an IfThenElseNode. + * + * Each branch of the if-then-else gets its own scope, preventing context entries + * (e.g., from Bind nodes inside one branch) from leaking into the other branch. + * Without this override, the default StmtExprMutator would visit both branches + * in the same scope, which could cause incorrect CSE across branches. + */ +Stmt CommonSubexpressionEliminator::VisitStmt_(const IfThenElseNode* op) { + PrimExpr condition_new = VisitExpr(op->condition); + + // Each branch gets its own scope to prevent context leaks between branches + Stmt then_new = context_scope_.WithNewScope([&]() -> Stmt { + EnterContextScope(); + return VisitStmt(op->then_case); + }); + + ffi::Optional else_new; + if (op->else_case) { + else_new = context_scope_.WithNewScope([&]() -> Stmt { + EnterContextScope(); + return VisitStmt(op->else_case.value()); + }); + } + + if (condition_new.same_as(op->condition) && then_new.same_as(op->then_case) && + else_new.same_as(op->else_case)) { + return ffi::GetRef(op); + } + return IfThenElse(condition_new, then_new, else_new, op->span); +} + +/*! + * \brief The method which overrides the specific treatment for an AttrStmtNode. + * + * AttrStmt has a body that may contain Bind nodes. A scope boundary prevents + * context entries from the body from leaking to subsequent statements. + */ +Stmt CommonSubexpressionEliminator::VisitStmt_(const AttrStmtNode* op) { + PrimExpr value_new = VisitExpr(op->value); + + // The body gets its own scope to contain any context entries added within it + Stmt body_new = context_scope_.WithNewScope([&]() -> Stmt { + EnterContextScope(); + return VisitStmt(op->body); + }); + + if (value_new.same_as(op->value) && body_new.same_as(op->body)) { + return ffi::GetRef(op); + } + return AttrStmt(op->node, op->attr_key, value_new, body_new, op->span); +} + +/*! + * \brief The method which overrides the specific treatment for an AllocateNode. + * + * Allocate has a body and introduces a buffer variable. A scope boundary + * prevents context entries from the body from leaking outward. + */ +Stmt CommonSubexpressionEliminator::VisitStmt_(const AllocateNode* op) { + ffi::Array extents_new; + bool extents_changed = false; + for (const auto& extent : op->extents) { + PrimExpr e_new = VisitExpr(extent); + extents_new.push_back(e_new); + if (!e_new.same_as(extent)) extents_changed = true; + } + PrimExpr condition_new = VisitExpr(op->condition); + + // The body gets its own scope to contain any context entries added within it + Stmt body_new = context_scope_.WithNewScope([&]() -> Stmt { + EnterContextScope(); + return VisitStmt(op->body); + }); + + if (!extents_changed && condition_new.same_as(op->condition) && body_new.same_as(op->body)) { + return ffi::GetRef(op); + } + return Allocate(op->buffer_var, op->dtype, extents_new, condition_new, body_new, op->annotations, + op->span); +} + +/*! + * \brief The method which overrides the specific treatment for a DeclBufferNode. + * + * DeclBuffer declares a buffer for use within its body. A scope boundary + * prevents context entries from the body from leaking outward. + */ +Stmt CommonSubexpressionEliminator::VisitStmt_(const DeclBufferNode* op) { + // The body gets its own scope to contain any context entries added within it + Stmt body_new = context_scope_.WithNewScope([&]() -> Stmt { + EnterContextScope(); + return VisitStmt(op->body); + }); + + if (body_new.same_as(op->body)) { + return ffi::GetRef(op); + } + return DeclBuffer(op->buffer, body_new, op->span); +} + +/*! + * \brief The method which overrides the specific treatment for a WhileNode. + * + * While loop has a body that may contain Bind nodes. A scope boundary prevents + * context entries from the body from leaking outward. + */ +Stmt CommonSubexpressionEliminator::VisitStmt_(const WhileNode* op) { + PrimExpr condition_new = VisitExpr(op->condition); + + // The body gets its own scope to contain any context entries added within it + Stmt body_new = context_scope_.WithNewScope([&]() -> Stmt { + EnterContextScope(); + return VisitStmt(op->body); + }); + + if (condition_new.same_as(op->condition) && body_new.same_as(op->body)) { + return ffi::GetRef(op); + } + return While(condition_new, body_new, op->span); +} + namespace transform { /*! diff --git a/src/tir/transform/common_subexpr_elim.h b/src/tir/transform/common_subexpr_elim.h index 9a81a0b9ca59..c682f3f5edce 100644 --- a/src/tir/transform/common_subexpr_elim.h +++ b/src/tir/transform/common_subexpr_elim.h @@ -28,6 +28,7 @@ #ifndef TVM_TIR_TRANSFORM_COMMON_SUBEXPR_ELIM_H_ #define TVM_TIR_TRANSFORM_COMMON_SUBEXPR_ELIM_H_ +#include #include #include #include @@ -72,10 +73,86 @@ class CommonSubexpressionEliminator : public StmtExprMutator { Stmt VisitStmt_(const BindNode* op) override; Stmt VisitStmt_(const SeqStmtNode* op) override; Stmt VisitStmt_(const ForNode* op) override; + Stmt VisitStmt_(const IfThenElseNode* op) override; + Stmt VisitStmt_(const AttrStmtNode* op) override; + Stmt VisitStmt_(const AllocateNode* op) override; + Stmt VisitStmt_(const DeclBufferNode* op) override; + Stmt VisitStmt_(const WhileNode* op) override; private: - Stmt initial_body_; // Kept for checking if names of new variables already exist - Context context_; // Context associating variables to (maybe) definitions + /*! \brief Scope level for the context stack. + * + * Each scope level records the size of `context_` when the scope was entered. + * When the scope exits (via ScopeStack::WithNewScope), the destructor truncates + * `context_` back to the saved size, automatically cleaning up any context + * entries added within that scope (e.g., from BindNode or loop variables). + * + * This approach keeps `context_` as a flat vector for efficient searching + * while using ScopeStack for automatic scope-based cleanup. + */ + struct ContextScopeLevel { + Context* context{nullptr}; + size_t saved_size{0}; + + ContextScopeLevel() = default; + ContextScopeLevel(const ContextScopeLevel&) = delete; + ContextScopeLevel& operator=(const ContextScopeLevel&) = delete; + ContextScopeLevel(ContextScopeLevel&& other) noexcept + : context(other.context), saved_size(other.saved_size) { + other.context = nullptr; // prevent other's destructor from truncating + } + ContextScopeLevel& operator=(ContextScopeLevel&& other) noexcept { + if (this != &other) { + // Run our cleanup first + if (context) context->resize(saved_size); + context = other.context; + saved_size = other.saved_size; + other.context = nullptr; + } + return *this; + } + + ~ContextScopeLevel() { + if (context) context->resize(saved_size); + } + }; + + /*! \brief Enter a new context scope, recording the current context size. + * + * Must be called inside context_scope_.WithNewScope() to initialize the + * newly-pushed scope level. On scope exit, the destructor of + * ContextScopeLevel will truncate context_ back to this size. + */ + void EnterContextScope() { + auto& level = context_scope_.Current(); + level.context = &context_; + level.saved_size = context_.size(); + } + + Stmt initial_body_; // Kept for checking if names of new variables already exist + + /*! \brief Flat context vector associating variables to (optional) definitions. + * + * This is the searchable context: VisitExpr and VisitStmt scan it linearly + * to find existing variables whose values match a candidate computation. + * Entries are added by BindNode (with a value) and ForNode (loop var, no value). + * Cleanup is automatic via context_scope_: when a scope exits, context_ is + * truncated to the size it had when the scope was entered. + */ + Context context_; + + /*! \brief Scope stack for automatic context cleanup. + * + * Body-carrying statements (For, IfThenElse, AttrStmt, Allocate, DeclBuffer, + * While) create new scope levels via WithNewScope. BindNode entries persist + * across SeqStmt siblings within the same scope and are cleaned up when the + * enclosing body-carrying statement's scope exits. + * + * The initial scope level (created by ScopeStack's constructor) holds the + * function parameters added during PerformCSE. + */ + ScopeStack context_scope_; + int num_last_try_ = 0; // Number of the last variable tried int nb_var_ = 0; // Number of variables introduced by the CSE pass From db3c5da42e19c32c19c1808a358471c0bdc7b04e Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 2 Mar 2026 23:56:33 +0000 Subject: [PATCH 21/34] [REFACTOR][TIR] Restore free_nd in MakeNdMemAllocWithScope for flat Bind The LetStmt-to-Bind migration dropped the free_nd call that was previously wrapped after the LetStmt body, causing a memory leak for nd allocations (Hexagon VTCM, Adreno textures). With flat Bind semantics, the free is pushed to a pending_nd_frees_ vector and appended at the end of the enclosing SeqStmt by a new VisitStmt_(SeqStmtNode*) override. --- src/tir/transform/lower_tvm_builtin.cc | 27 ++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/tir/transform/lower_tvm_builtin.cc b/src/tir/transform/lower_tvm_builtin.cc index 9b438d96db93..180e2aaa71b8 100644 --- a/src/tir/transform/lower_tvm_builtin.cc +++ b/src/tir/transform/lower_tvm_builtin.cc @@ -220,6 +220,22 @@ class BuiltinLower : public StmtExprMutator { } } + Stmt VisitStmt_(const SeqStmtNode* op) final { + Stmt result = StmtExprMutator::VisitStmt_(op); + if (!pending_nd_frees_.empty()) { + const auto* seq = result.as(); + if (seq) { + ffi::Array new_seq(seq->seq.begin(), seq->seq.end()); + for (const auto& free_stmt : pending_nd_frees_) { + new_seq.push_back(free_stmt); + } + pending_nd_frees_.clear(); + return SeqStmt(new_seq); + } + } + return result; + } + Stmt VisitStmt_(const BindNode* op) final { if (const CallNode* call = op->value.as()) { if (call->op.same_as(builtin::nd_mem_alloc_with_scope())) { @@ -628,6 +644,15 @@ class BuiltinLower : public StmtExprMutator { Stmt null_check = IfThenElse(Call(DataType::Bool(), builtin::isnullptr(), {let->var}), throw_last_error); + // Construct free_nd call and push to pending frees. + // The enclosing SeqStmt handler will append these after the body. + PrimExpr storage_scope = call->args[0]; + Call free_op = Call(DataType::Int(32), builtin::tvm_call_packed(), + {GetDeviceMethodName("free_nd"), device_type_.value(), device_id_.value(), + storage_scope, let->var}); + Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error); + pending_nd_frees_.push_back(free_stmt); + return SeqStmt({Bind(let->var, call_packed), null_check}); } @@ -652,6 +677,8 @@ class BuiltinLower : public StmtExprMutator { // Record all stack frames. std::vector alloca_scope_; + // Pending free_nd stmts to be appended at the end of the enclosing SeqStmt. + std::vector pending_nd_frees_; }; namespace transform { From 21ed144de0072088eb6f9cd31e1f141751e37fe1 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 2 Mar 2026 23:57:05 +0000 Subject: [PATCH 22/34] [REFACTOR][TIR] Remove duplicate LOG(WARNING) in ir_docsifier_functor.h The LetStmt-to-Bind refactor accidentally duplicated the LOG(WARNING) call in IRDocsifierFunctor::operator(). Remove the extra one. --- include/tvm/script/printer/ir_docsifier_functor.h | 3 --- 1 file changed, 3 deletions(-) diff --git a/include/tvm/script/printer/ir_docsifier_functor.h b/include/tvm/script/printer/ir_docsifier_functor.h index 500fa8b5e21f..32363a434f46 100644 --- a/include/tvm/script/printer/ir_docsifier_functor.h +++ b/include/tvm/script/printer/ir_docsifier_functor.h @@ -74,9 +74,6 @@ class IRDocsifierFunctor { return (*pf)(obj, args...).template cast(); } - LOG(WARNING) << "ObjectFunctor calls un-registered function on type: " - << runtime::Object::TypeIndex2Key(type_index) << " (token: " << token << ")" - << ". ObjectType: " << obj->GetTypeKey() << ". Object: " << obj; LOG(WARNING) << "ObjectFunctor calls un-registered function on type: " << runtime::Object::TypeIndex2Key(type_index) << " (token: " << token << ")" << ". ObjectType: " << obj->GetTypeKey() << ". Object: " << obj; From 05ed2d4d2feae02b0a6e96c95a1f8cf7ab4c2046 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 2 Mar 2026 23:57:50 +0000 Subject: [PATCH 23/34] [REFACTOR][TIR] Rename stale letstmt references in test function names Rename test fixtures and functions that still use "letstmt" to "bind" to match the LetStmt-to-Bind refactor: - argmax_split_letstmt_{fewer,more}_than_init -> argmax_split_bind_* - test_letstmt_bufferload_without_type_annotation -> test_bind_* - test_letstmt_bind_with_constant -> test_bind_with_constant --- .../s_tir/schedule/test_tir_schedule_rfactor.py | 12 ++++++------ .../python/tvmscript/test_tvmscript_syntax_sugar.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/python/s_tir/schedule/test_tir_schedule_rfactor.py b/tests/python/s_tir/schedule/test_tir_schedule_rfactor.py index 499650345547..195d361a9a5e 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_rfactor.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_rfactor.py @@ -823,7 +823,7 @@ def argmax_split_init_buffer_duplicate( @T.prim_func -def argmax_split_letstmt_fewer_than_init( +def argmax_split_bind_fewer_than_init( idx: T.Buffer((128, 128), "int32"), val: T.Buffer((128, 128), "float32"), argmax_v0: T.Buffer((128,), "int32"), @@ -844,7 +844,7 @@ def argmax_split_letstmt_fewer_than_init( @T.prim_func -def argmax_split_letstmt_more_than_init( +def argmax_split_bind_more_than_init( idx: T.Buffer((128, 128), "int32"), val: T.Buffer((128, 128), "float32"), argmax_v0: T.Buffer((128,), "int32"), @@ -1544,16 +1544,16 @@ def test_reduction_rfactor_argmax_init_buffer_duplicate(): s.rfactor(ki, 1) -def test_reduction_rfactor_argmax_letstmt_fewer_than_init(): - s = tvm.s_tir.Schedule(argmax_split_letstmt_fewer_than_init, debug_mask="all") +def test_reduction_rfactor_argmax_bind_fewer_than_init(): + s = tvm.s_tir.Schedule(argmax_split_bind_fewer_than_init, debug_mask="all") argmax = s.get_sblock("argmax") _, _, ki = s.get_loops(argmax) with pytest.raises(tvm.s_tir.ScheduleError): s.rfactor(ki, 1) -def test_reduction_rfactor_argmax_letstmt_more_than_init(): - s = tvm.s_tir.Schedule(argmax_split_letstmt_more_than_init, debug_mask="all") +def test_reduction_rfactor_argmax_bind_more_than_init(): + s = tvm.s_tir.Schedule(argmax_split_bind_more_than_init, debug_mask="all") argmax = s.get_sblock("argmax") _, _, ki = s.get_loops(argmax) with pytest.raises(tvm.s_tir.ScheduleError): diff --git a/tests/python/tvmscript/test_tvmscript_syntax_sugar.py b/tests/python/tvmscript/test_tvmscript_syntax_sugar.py index 32c593881b8f..f3d19f8ebad2 100644 --- a/tests/python/tvmscript/test_tvmscript_syntax_sugar.py +++ b/tests/python/tvmscript/test_tvmscript_syntax_sugar.py @@ -241,7 +241,7 @@ def func(a: T.handle): T.evaluate(0) -def test_letstmt_bufferload_without_type_annotation(): +def test_bind_bufferload_without_type_annotation(): # Variable assignment of PrimExpr types uses the dtype of the # PrimExpr to determine the variable's dtype. Parsing of # buf[indices] is done by generating a BufferSlice object, which @@ -255,7 +255,7 @@ def func_without_type_annotation(A: T.Buffer((1,), "int32")): T.evaluate(x) -def test_letstmt_bind_with_constant(): +def test_bind_with_constant(): @T.prim_func def constant_binds(): x = T.meta_var(1) From ca589b00b654df58b0b0a3aac58af9c08eeb93fa Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 2 Mar 2026 23:59:22 +0000 Subject: [PATCH 24/34] [REFACTOR][TIR] Rename LetFrame to BindFrame in ir_builder Rename LetFrameNode/LetFrame to BindFrameNode/BindFrame across C++ headers, implementation, and Python bindings to align with the LetStmt-to-Bind refactor. Updates FFI type key from "script.ir_builder.tir.LetFrame" to "script.ir_builder.tir.BindFrame". --- include/tvm/script/ir_builder/tir/frame.h | 24 +++++++++++------------ include/tvm/script/ir_builder/tir/ir.h | 6 +++--- python/tvm/script/ir_builder/tir/frame.py | 4 ++-- python/tvm/script/ir_builder/tir/ir.py | 14 ++++++------- src/script/ir_builder/tir/frame.cc | 4 ++-- src/script/ir_builder/tir/ir.cc | 6 +++--- 6 files changed, 29 insertions(+), 29 deletions(-) diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h index 4d889cc4d222..b9e2eb859ae4 100644 --- a/include/tvm/script/ir_builder/tir/frame.h +++ b/include/tvm/script/ir_builder/tir/frame.h @@ -342,11 +342,11 @@ class AssertFrame : public TIRFrame { }; /*! - * \brief A frame represents the let binding expression, which binds a var. + * \brief A frame represents the Bind (variable binding) statement. * - * \sa LetFrameNode + * \sa BindFrameNode */ -class LetFrameNode : public TIRFrameNode { +class BindFrameNode : public TIRFrameNode { public: /*! \brief The variable we bind to */ tvm::tir::Var var; @@ -355,11 +355,11 @@ class LetFrameNode : public TIRFrameNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("var", &LetFrameNode::var) - .def_ro("value", &LetFrameNode::value); + refl::ObjectDef() + .def_ro("var", &BindFrameNode::var) + .def_ro("value", &BindFrameNode::value); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.LetFrame", LetFrameNode, TIRFrameNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.BindFrame", BindFrameNode, TIRFrameNode); public: /*! @@ -370,17 +370,17 @@ class LetFrameNode : public TIRFrameNode { }; /*! - * \brief Managed reference to LetFrameNode. + * \brief Managed reference to BindFrameNode. * - * \sa LetFrameNode + * \sa BindFrameNode */ -class LetFrame : public TIRFrame { +class BindFrame : public TIRFrame { public: - explicit LetFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { + explicit BindFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { TVM_FFI_ICHECK(data != nullptr); data_ = std::move(data); } - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(LetFrame, TIRFrame, LetFrameNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(BindFrame, TIRFrame, BindFrameNode); }; /*! diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index 8113a8d0db0c..702423d17df5 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -300,10 +300,10 @@ AssertFrame Assert(PrimExpr condition, ffi::String error_kind, * Usually it is used for fine-grained var typing, * particularly, PointerType. * \param var The variable to be bound. If not specified, a new variable will be created. - * \return The created LetFrame. + * \return The created BindFrame. */ -LetFrame Bind(PrimExpr value, ffi::Optional type_annotation = std::nullopt, - ffi::Optional var = std::nullopt); +BindFrame Bind(PrimExpr value, ffi::Optional type_annotation = std::nullopt, + ffi::Optional var = std::nullopt); /*! * \brief The allocate node. diff --git a/python/tvm/script/ir_builder/tir/frame.py b/python/tvm/script/ir_builder/tir/frame.py index 478fec212397..d6c966852103 100644 --- a/python/tvm/script/ir_builder/tir/frame.py +++ b/python/tvm/script/ir_builder/tir/frame.py @@ -50,8 +50,8 @@ def __enter__(self) -> Var | list[Var]: # type: ignore[override] class AssertFrame(TIRFrame): ... -@_register_object("script.ir_builder.tir.LetFrame") -class LetFrame(TIRFrame): +@_register_object("script.ir_builder.tir.BindFrame") +class BindFrame(TIRFrame): def __enter__(self) -> Var: super().__enter__() return self.var diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index e4143f006273..54c947265ac4 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -988,7 +988,7 @@ def Bind( # pylint: disable=invalid-name type_annotation: Type | None = None, # pylint: disable=redefined-outer-name *, var: Var | None = None, # pylint: disable=redefined-outer-name -) -> frame.LetFrame: +) -> frame.BindFrame: """Create a Bind (variable binding). Parameters @@ -1003,8 +1003,8 @@ def Bind( # pylint: disable=invalid-name Returns ------- - let_frame : frame.LetFrame - The result LetFrame. + bind_frame : frame.BindFrame + The result BindFrame. """ if type_annotation is not None: if callable(type_annotation): @@ -1028,7 +1028,7 @@ def let( v: Var, value: PrimExpr, body: PrimExpr = None, -) -> frame.LetFrame: +) -> frame.BindFrame: """Create a new let binding. Parameters @@ -1044,8 +1044,8 @@ def let( Returns ------- - res : frame.LetFrame - The result LetFrame. + res : frame.BindFrame + The result BindFrame. """ @deprecated("T.let", "T.Let") @@ -1053,7 +1053,7 @@ def let_expr(v: Var, value: PrimExpr, body: PrimExpr) -> PrimExpr: return tir.Let(v, value, body) @deprecated("T.let", "T.Bind") - def let_stmt(v: Var, value: PrimExpr) -> frame.LetFrame: + def let_stmt(v: Var, value: PrimExpr) -> frame.BindFrame: return Bind(value, var=v) if body is None: diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc index 742802695456..b0f50ca4504d 100644 --- a/src/script/ir_builder/tir/frame.cc +++ b/src/script/ir_builder/tir/frame.cc @@ -35,7 +35,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { BlockInitFrameNode::RegisterReflection(); ForFrameNode::RegisterReflection(); AssertFrameNode::RegisterReflection(); - LetFrameNode::RegisterReflection(); + BindFrameNode::RegisterReflection(); LaunchThreadFrameNode::RegisterReflection(); AllocateFrameNode::RegisterReflection(); AttrFrameNode::RegisterReflection(); @@ -141,7 +141,7 @@ void AssertFrameNode::ExitWithScope() { } } -void LetFrameNode::ExitWithScope() { +void BindFrameNode::ExitWithScope() { TIRFrameNode::ExitWithScope(); AddToParent(tvm::tir::SeqStmt({tvm::tir::Bind(var, value), AsStmt(stmts)})); } diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 5bcd4b321dc4..2bade41b0122 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -460,8 +460,8 @@ AssertFrame Assert(PrimExpr condition, ffi::String error_kind, return AssertFrame(n); } -LetFrame Bind(PrimExpr value, ffi::Optional type_annotation, ffi::Optional var) { - ObjectPtr n = ffi::make_object(); +BindFrame Bind(PrimExpr value, ffi::Optional type_annotation, ffi::Optional var) { + ObjectPtr n = ffi::make_object(); if (var.defined()) { n->var = var.value(); } else if (type_annotation.defined()) { @@ -470,7 +470,7 @@ LetFrame Bind(PrimExpr value, ffi::Optional type_annotation, ffi::Optional n->var = Var("v", value.dtype()); } n->value = value; - return LetFrame(n); + return BindFrame(n); } LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) { From 48e2f7a37c71f5ea14203ed196f3a7138338e647 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 4 Mar 2026 19:48:17 +0000 Subject: [PATCH 25/34] [REFACTOR][TIR] Fix AllocateNode->AllocBufferNode references after rebase Resolve remaining AllocateNode references that should be AllocBufferNode after rebasing onto the AllocBuffer commit. Also add TVM_FFI_UNREACHABLE after throw in blockize_tensorize. --- .../schedule/primitive/blockize_tensorize.cc | 1 + src/tir/ir/data_type_rewriter.h | 1 - src/tir/transform/common_subexpr_elim.cc | 20 +++++-------------- src/tir/transform/common_subexpr_elim.h | 2 +- 4 files changed, 7 insertions(+), 17 deletions(-) diff --git a/src/s_tir/schedule/primitive/blockize_tensorize.cc b/src/s_tir/schedule/primitive/blockize_tensorize.cc index 5c6c152925e5..95074147a02a 100644 --- a/src/s_tir/schedule/primitive/blockize_tensorize.cc +++ b/src/s_tir/schedule/primitive/blockize_tensorize.cc @@ -877,6 +877,7 @@ struct BlockizeTraits : public UnpackedInstTraits { return sch->Blockize(blocks.value(), preserve_unit_iters.operator bool()); } TVM_FFI_THROW(TypeError) << "expect Loop or list of SBlocks, but gets:" << target->GetTypeKey(); + TVM_FFI_UNREACHABLE(); } static ffi::String UnpackedAsPython(ffi::Array outputs, ObjectRef target, diff --git a/src/tir/ir/data_type_rewriter.h b/src/tir/ir/data_type_rewriter.h index 7363e97e1bcf..46e6e3b92a12 100644 --- a/src/tir/ir/data_type_rewriter.h +++ b/src/tir/ir/data_type_rewriter.h @@ -112,7 +112,6 @@ class IndexDataTypeRewriter : public DataTypeLegalizer { Stmt VisitStmt_(const IfThenElseNode* op) override; Stmt VisitStmt_(const DeclBufferNode* op) override; Stmt VisitStmt_(const AllocBufferNode* op) override; - Stmt VisitStmt_(const AllocateNode* op) override; Stmt VisitStmt_(const BindNode* op) override; PrimExpr VisitExpr_(const EQNode* op) override; PrimExpr VisitExpr_(const NENode* op) override; diff --git a/src/tir/transform/common_subexpr_elim.cc b/src/tir/transform/common_subexpr_elim.cc index 9e814693a157..363787556944 100644 --- a/src/tir/transform/common_subexpr_elim.cc +++ b/src/tir/transform/common_subexpr_elim.cc @@ -716,32 +716,22 @@ Stmt CommonSubexpressionEliminator::VisitStmt_(const AttrStmtNode* op) { } /*! - * \brief The method which overrides the specific treatment for an AllocateNode. + * \brief The method which overrides the specific treatment for an AllocBufferNode. * - * Allocate has a body and introduces a buffer variable. A scope boundary + * AllocBuffer has a body and introduces a buffer. A scope boundary * prevents context entries from the body from leaking outward. */ -Stmt CommonSubexpressionEliminator::VisitStmt_(const AllocateNode* op) { - ffi::Array extents_new; - bool extents_changed = false; - for (const auto& extent : op->extents) { - PrimExpr e_new = VisitExpr(extent); - extents_new.push_back(e_new); - if (!e_new.same_as(extent)) extents_changed = true; - } - PrimExpr condition_new = VisitExpr(op->condition); - +Stmt CommonSubexpressionEliminator::VisitStmt_(const AllocBufferNode* op) { // The body gets its own scope to contain any context entries added within it Stmt body_new = context_scope_.WithNewScope([&]() -> Stmt { EnterContextScope(); return VisitStmt(op->body); }); - if (!extents_changed && condition_new.same_as(op->condition) && body_new.same_as(op->body)) { + if (body_new.same_as(op->body)) { return ffi::GetRef(op); } - return Allocate(op->buffer_var, op->dtype, extents_new, condition_new, body_new, op->annotations, - op->span); + return AllocBuffer(op->buffer, body_new, op->annotations, op->span); } /*! diff --git a/src/tir/transform/common_subexpr_elim.h b/src/tir/transform/common_subexpr_elim.h index c682f3f5edce..5674070e3d25 100644 --- a/src/tir/transform/common_subexpr_elim.h +++ b/src/tir/transform/common_subexpr_elim.h @@ -75,7 +75,7 @@ class CommonSubexpressionEliminator : public StmtExprMutator { Stmt VisitStmt_(const ForNode* op) override; Stmt VisitStmt_(const IfThenElseNode* op) override; Stmt VisitStmt_(const AttrStmtNode* op) override; - Stmt VisitStmt_(const AllocateNode* op) override; + Stmt VisitStmt_(const AllocBufferNode* op) override; Stmt VisitStmt_(const DeclBufferNode* op) override; Stmt VisitStmt_(const WhileNode* op) override; From fe104fcfb9981e3ede77dfbd7f357a484b529dcd Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 4 Mar 2026 19:51:49 +0000 Subject: [PATCH 26/34] [REFACTOR][TIR] clang-format fixes --- src/s_tir/backend/adreno/inject_texture_alloc.cc | 4 ++-- src/s_tir/transform/lower_vtcm_alloc.cc | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/s_tir/backend/adreno/inject_texture_alloc.cc b/src/s_tir/backend/adreno/inject_texture_alloc.cc index e7df647d0f85..a45f5dd4b562 100644 --- a/src/s_tir/backend/adreno/inject_texture_alloc.cc +++ b/src/s_tir/backend/adreno/inject_texture_alloc.cc @@ -82,8 +82,8 @@ class TextureAllocInjector : public arith::IRMutatorWithAnalyzer { args.push_back(Call(DataType::Handle(), builtin::tvm_stack_make_shape(), {texture.width, texture.height, texture.depth})); args.push_back(IntImm(DataType::Int(64), channel_size)); - stmt = SeqStmt({Bind(op->buffer->data, - Call(op->buffer->data.dtype(), builtin::nd_mem_alloc_with_scope(), args)), + stmt = SeqStmt({Bind(op->buffer->data, Call(op->buffer->data.dtype(), + builtin::nd_mem_alloc_with_scope(), args)), op->body}); } return stmt; diff --git a/src/s_tir/transform/lower_vtcm_alloc.cc b/src/s_tir/transform/lower_vtcm_alloc.cc index e8683f669e95..cc32fba14678 100644 --- a/src/s_tir/transform/lower_vtcm_alloc.cc +++ b/src/s_tir/transform/lower_vtcm_alloc.cc @@ -45,8 +45,8 @@ class VtcmAllocator : public StmtExprMutator { args.push_back(StringImm(storage_scope)); args.push_back(IntImm(DataType::Int(64), op->buffer->shape.size())); args.push_back(Call(DataType::Handle(), builtin::tvm_stack_make_shape(), op->buffer->shape)); - return SeqStmt({Bind(op->buffer->data, - Call(op->buffer->data.dtype(), builtin::nd_mem_alloc_with_scope(), args)), + return SeqStmt({Bind(op->buffer->data, Call(op->buffer->data.dtype(), + builtin::nd_mem_alloc_with_scope(), args)), body}); } return StmtExprMutator::VisitStmt_(op); From fcf21ffea01c219776476e45a5630e0f27ccc567 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 4 Mar 2026 20:14:20 +0000 Subject: [PATCH 27/34] [TIR] Replace pending_nd_frees_ with ScopeStack in lower_tvm_builtin The pending_nd_frees_ approach hoisted free_nd calls to the nearest SeqStmt boundary, which could incorrectly escape conditional branches. Use ScopeStack instead: register free_nd in the current scope when Bind allocates via nd_mem_alloc_with_scope, and emit frees on scope exit. This matches the old LetStmt body semantics structurally. --- src/tir/transform/lower_tvm_builtin.cc | 118 +++++++++++++++++++------ 1 file changed, 91 insertions(+), 27 deletions(-) diff --git a/src/tir/transform/lower_tvm_builtin.cc b/src/tir/transform/lower_tvm_builtin.cc index 180e2aaa71b8..f4e08bbf4973 100644 --- a/src/tir/transform/lower_tvm_builtin.cc +++ b/src/tir/transform/lower_tvm_builtin.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include #include #include @@ -180,7 +181,10 @@ class BuiltinLower : public StmtExprMutator { } } - stmt = this->VisitStmt(stmt); + stmt = scope_.WithNewScope([&]() -> Stmt { + Stmt visited = this->VisitStmt(stmt); + return AppendPendingFrees(visited); + }); TVM_FFI_ICHECK(!alloca_scope_.empty()); alloca_scope_.pop_back(); @@ -220,22 +224,6 @@ class BuiltinLower : public StmtExprMutator { } } - Stmt VisitStmt_(const SeqStmtNode* op) final { - Stmt result = StmtExprMutator::VisitStmt_(op); - if (!pending_nd_frees_.empty()) { - const auto* seq = result.as(); - if (seq) { - ffi::Array new_seq(seq->seq.begin(), seq->seq.end()); - for (const auto& free_stmt : pending_nd_frees_) { - new_seq.push_back(free_stmt); - } - pending_nd_frees_.clear(); - return SeqStmt(new_seq); - } - } - return result; - } - Stmt VisitStmt_(const BindNode* op) final { if (const CallNode* call = op->value.as()) { if (call->op.same_as(builtin::nd_mem_alloc_with_scope())) { @@ -247,7 +235,17 @@ class BuiltinLower : public StmtExprMutator { Stmt VisitStmt_(const AllocBufferNode* op) { // Lower AllocBuffer to device allocate when needed. - Stmt stmt = StmtExprMutator::VisitStmt_(op); + // Visit body in a new scope so nd_mem_alloc frees are scoped to the body. + Stmt stmt = scope_.WithNewScope([&]() -> Stmt { + Stmt visited = StmtExprMutator::VisitStmt_(op); + const auto* alloc = visited.as(); + if (alloc && !scope_.Current().pending_frees.empty()) { + auto n = CopyOnWrite(alloc); + n->body = AppendPendingFrees(alloc->body); + return Stmt(n); + } + return visited; + }); op = stmt.as(); int64_t nbytes = GetVectorBytes(op->buffer->dtype); if (op->annotations.count(transform::kDisableLowerTVMBuiltin)) { @@ -303,17 +301,33 @@ class BuiltinLower : public StmtExprMutator { if (op->attr_key == attr::device_id) { auto cache = device_id_; device_id_ = op->value; - Stmt out = this->VisitStmt(op->body); + Stmt out = scope_.WithNewScope([&]() -> Stmt { + Stmt body = this->VisitStmt(op->body); + return AppendPendingFrees(body); + }); device_id_ = cache; return out; } else if (op->attr_key == attr::device_type) { auto cache = device_type_; device_type_ = op->value; - Stmt out = this->VisitStmt(op->body); + Stmt out = scope_.WithNewScope([&]() -> Stmt { + Stmt body = this->VisitStmt(op->body); + return AppendPendingFrees(body); + }); device_type_ = cache; return out; } else { - return StmtExprMutator::VisitStmt_(op); + return scope_.WithNewScope([&]() -> Stmt { + Stmt visited = StmtExprMutator::VisitStmt_(op); + if (!scope_.Current().pending_frees.empty()) { + const auto* attr = visited.as(); + if (attr) { + return AttrStmt(attr->node, attr->attr_key, attr->value, AppendPendingFrees(attr->body), + attr->span); + } + } + return visited; + }); } } Stmt VisitStmt_(const ForNode* op) final { @@ -324,7 +338,10 @@ class BuiltinLower : public StmtExprMutator { if (op->kind == ForKind::kParallel) { body = this->VisitBodyAndRealizeAlloca(op->body); } else { - body = this->VisitStmt(op->body); + body = scope_.WithNewScope([&]() -> Stmt { + Stmt visited = this->VisitStmt(op->body); + return AppendPendingFrees(visited); + }); } if (min.same_as(op->min) && extent.same_as(op->extent) && body.same_as(op->body)) { @@ -338,6 +355,27 @@ class BuiltinLower : public StmtExprMutator { } } + Stmt VisitStmt_(const IfThenElseNode* op) final { + PrimExpr condition = this->VisitExpr(op->condition); + // Each branch gets its own scope to prevent frees from leaking across branches. + Stmt then_case = scope_.WithNewScope([&]() -> Stmt { + Stmt visited = this->VisitStmt(op->then_case); + return AppendPendingFrees(visited); + }); + ffi::Optional else_case; + if (op->else_case) { + else_case = scope_.WithNewScope([&]() -> Stmt { + Stmt visited = this->VisitStmt(op->else_case.value()); + return AppendPendingFrees(visited); + }); + } + if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && + else_case.same_as(op->else_case)) { + return ffi::GetRef(op); + } + return IfThenElse(condition, then_case, else_case, op->span); + } + PrimExpr VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::tvm_call_packed())) { return MakeCallPackedGeneric(op, 0, builtin::tvm_call_packed_lowered(), @@ -644,14 +682,14 @@ class BuiltinLower : public StmtExprMutator { Stmt null_check = IfThenElse(Call(DataType::Bool(), builtin::isnullptr(), {let->var}), throw_last_error); - // Construct free_nd call and push to pending frees. - // The enclosing SeqStmt handler will append these after the body. + // Construct free_nd call and register in current scope. + // The free will be emitted on scope exit, matching the old LetStmt body semantics. PrimExpr storage_scope = call->args[0]; Call free_op = Call(DataType::Int(32), builtin::tvm_call_packed(), {GetDeviceMethodName("free_nd"), device_type_.value(), device_id_.value(), storage_scope, let->var}); Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error); - pending_nd_frees_.push_back(free_stmt); + scope_.Current().pending_frees.push_back(free_stmt); return SeqStmt({Bind(let->var, call_packed), null_check}); } @@ -668,6 +706,32 @@ class BuiltinLower : public StmtExprMutator { return false; } + /*! + * \brief Scope level for tracking nd_mem_alloc_with_scope deallocations. + * + * When a Bind allocates via nd_mem_alloc_with_scope, the corresponding + * free_nd stmt is pushed to the current scope's pending_frees. Body-carrying + * stmts (For, IfThenElse, AllocBuffer, AttrStmt) create new scopes via + * WithNewScope. On scope exit, pending_frees are appended after the body, + * matching the old LetStmt body semantics structurally. + */ + struct ScopeLevel { + std::vector pending_frees; + }; + + /*! \brief Emit any pending frees from the current scope after the given body stmt. */ + Stmt AppendPendingFrees(Stmt body) { + auto& frees = scope_.Current().pending_frees; + if (frees.empty()) return body; + ffi::Array stmts; + stmts.push_back(body); + for (const auto& free_stmt : frees) { + stmts.push_back(free_stmt); + } + frees.clear(); + return SeqStmt::Flatten(stmts); + } + // The prepration sequence to be emitted before the current statement. std::vector> prep_seq_stack_; ffi::Optional device_type_{std::nullopt}; @@ -677,8 +741,8 @@ class BuiltinLower : public StmtExprMutator { // Record all stack frames. std::vector alloca_scope_; - // Pending free_nd stmts to be appended at the end of the enclosing SeqStmt. - std::vector pending_nd_frees_; + // Scope stack for nd_mem_alloc_with_scope free tracking. + ScopeStack scope_; }; namespace transform { From f5f28f65a6e398a2dbcdf32abe6232e95bda7d16 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 4 Mar 2026 20:32:43 +0000 Subject: [PATCH 28/34] [TIR] Add SSA invariant comment for non_inlined_bindings_ in simplify.cc --- src/tir/transform/simplify.cc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/tir/transform/simplify.cc b/src/tir/transform/simplify.cc index 173fcf7f6656..3d295e1764be 100644 --- a/src/tir/transform/simplify.cc +++ b/src/tir/transform/simplify.cc @@ -180,6 +180,12 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { // so we always keep the Bind. if (SideEffect(value) <= CallEffectKind::kPure) { analyzer_->Bind(op->var, value); + // Record the binding so we can substitute it into assert conditions + // (see VisitStmt_(const AssertStmtNode*)). Under SSA each var is + // bound exactly once, so the map grows monotonically without key + // conflicts. No scope-based cleanup is needed because vars bound + // in inner scopes are only referenced within those scopes; stale + // entries are harmless and never consulted again. non_inlined_bindings_.Set(op->var, value); } @@ -272,6 +278,8 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { SimplifyConfig config_; std::optional touch_pattern_; + // Pure Bind values kept for substitution into assert conditions. + // Grows monotonically under SSA — no scope-based cleanup required. ffi::Map non_inlined_bindings_; ffi::Optional current_stmt_{std::nullopt}; }; From 2d72320a447000777d05a0ec3a1d14969d6740de Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 4 Mar 2026 20:43:02 +0000 Subject: [PATCH 29/34] [TIR] Optimize CSE SeqStmt handler to batch trivial Binds The old handler wrapped remaining siblings after each individual Bind node and re-ran VisitStmt, causing O(n^2) complexity for sequences with many consecutive Bind nodes. The new hybrid approach batches consecutive trivial Binds (constant or variable values) and defers the cross-sibling CSE until the batch ends, reducing the common case to O(n). Non-trivial Binds (whose values may contain eligible computations) still use the per-Bind wrap pattern to preserve full CSE effectiveness. --- src/tir/transform/common_subexpr_elim.cc | 88 ++++++++++++++++-------- 1 file changed, 58 insertions(+), 30 deletions(-) diff --git a/src/tir/transform/common_subexpr_elim.cc b/src/tir/transform/common_subexpr_elim.cc index 363787556944..d99af8486694 100644 --- a/src/tir/transform/common_subexpr_elim.cc +++ b/src/tir/transform/common_subexpr_elim.cc @@ -552,48 +552,75 @@ Stmt CommonSubexpressionEliminator::VisitStmt_(const BindNode* op) { } } +/*! + * \brief Whether a Bind value is trivial (constant or variable), meaning it cannot + * contribute eligible computations for CSE and can be safely batched. + */ +static bool IsTrivialBindValue(const PrimExpr& value) { + return value.as() != nullptr || value.as() != nullptr || + value.as() != nullptr || value.as() != nullptr; +} + /*! * \brief The method which overrides the specific treatment for a SeqStmtNode. * - * Processes the flat sequence one child at a time: - * - Bind nodes: visit the value (via VisitExpr), augment context_, then wrap - * all remaining siblings into a body and call VisitStmt for cross-sibling - * CSE with the newly augmented context. This re-runs the CSE top-level - * computation collection, allowing new common subexpressions to be found - * and introduced now that the bind-defined variable is in scope. - * - Non-Bind nodes: visit individually via VisitStmt, then continue. + * Processes the flat sequence using a hybrid strategy that avoids the O(n^2) + * complexity of wrapping remaining siblings after every single Bind node: * - * Context cleanup is handled automatically by ScopeStack. The enclosing - * body-carrying statement (For, IfThenElse, etc.) creates a scope; when - * that scope exits, all context entries added here (from Bind nodes) are - * cleaned up. No manual save/restore of context_ is needed. + * - Trivial Bind nodes (constant/variable values) are batched: their values + * are visited via VisitExpr, context_ is augmented, but the expensive + * cross-sibling CSE is deferred until the batch ends. + * - Non-trivial Bind nodes (whose values may contain eligible computations) + * use the wrap-remaining-siblings pattern to enable cross-sibling CSE. + * - After any Bind (trivial batch end or non-trivial), remaining siblings are + * wrapped into a body and VisitStmt is called once for cross-sibling CSE. + * - Non-Bind children are visited individually via VisitStmt. * - * Note: this still uses the "wrap remaining siblings" pattern after Bind - * nodes, which is necessary because VisitStmt re-runs CSE computation - * collection on the wrapped body, enabling cross-sibling optimizations. + * This reduces the common case of many consecutive trivial Binds (e.g., variable + * definitions with constant values) from O(n^2) to O(n), while preserving full + * CSE effectiveness for non-trivial Bind values. + * + * Context cleanup is handled automatically by ScopeStack. */ Stmt CommonSubexpressionEliminator::VisitStmt_(const SeqStmtNode* op) { ffi::Array new_seq; + size_t i = 0; - for (size_t i = 0; i < op->seq.size(); ++i) { + while (i < op->seq.size()) { if (auto* bind = op->seq[i].as()) { - // Process the Bind: VisitExpr on value, augment context. - PrimExpr value_new = VisitExpr(bind->value); - context_.push_back({bind->var, MaybeValue(bind->value)}); - Stmt bind_new = value_new.same_as(bind->value) ? ffi::GetRef(bind) - : Bind(bind->var, value_new, bind->span); - new_seq.push_back(bind_new); - - // Wrap remaining siblings [i+1..end) as a body and call VisitStmt - // to re-run CSE computation collection with the updated context. - // This enables CSE to introduce new bindings between siblings. - if (i + 1 < op->seq.size()) { + // Batch consecutive trivial Bind nodes (constant/variable values). + // These can't contribute common subexpressions, so it's safe to defer + // the cross-sibling CSE until the entire batch is processed. + if (IsTrivialBindValue(bind->value)) { + while (i < op->seq.size()) { + auto* b = op->seq[i].as(); + if (!b || !IsTrivialBindValue(b->value)) break; + PrimExpr value_new = VisitExpr(b->value); + context_.push_back({b->var, MaybeValue(b->value)}); + Stmt bind_new = + value_new.same_as(b->value) ? ffi::GetRef(b) : Bind(b->var, value_new, b->span); + new_seq.push_back(bind_new); + ++i; + } + } else { + // Non-trivial Bind: visit value, augment context, then wrap remaining + // siblings and call VisitStmt for cross-sibling CSE. + PrimExpr value_new = VisitExpr(bind->value); + context_.push_back({bind->var, MaybeValue(bind->value)}); + Stmt bind_new = value_new.same_as(bind->value) ? ffi::GetRef(bind) + : Bind(bind->var, value_new, bind->span); + new_seq.push_back(bind_new); + ++i; + } + // After the Bind (batch or single), wrap remaining siblings [i..end) and + // call VisitStmt once for cross-sibling CSE with the updated context. + if (i < op->seq.size()) { Stmt body; - if (i + 2 == op->seq.size()) { - body = op->seq[i + 1]; + if (i + 1 == op->seq.size()) { + body = op->seq[i]; } else { ffi::Array rest; - for (size_t j = i + 1; j < op->seq.size(); ++j) rest.push_back(op->seq[j]); + for (size_t j = i; j < op->seq.size(); ++j) rest.push_back(op->seq[j]); body = SeqStmt(rest); } Stmt body_new = VisitStmt(body); @@ -606,13 +633,14 @@ Stmt CommonSubexpressionEliminator::VisitStmt_(const SeqStmtNode* op) { return SeqStmt::Flatten(new_seq); } } else { - // Non-Bind child: visit individually, then continue. + // Non-Bind child: visit individually via VisitStmt. Stmt child_new = VisitStmt(op->seq[i]); if (auto* inner = child_new.as()) { for (const auto& s : inner->seq) new_seq.push_back(s); } else { new_seq.push_back(child_new); } + ++i; } } From c9fe4fe5b775c2f5f20fc31ee4313a5419745366 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 4 Mar 2026 20:59:09 +0000 Subject: [PATCH 30/34] [TIR] Fix RAII scope guard bugs for flat BindNode in control_flow_graph and layout_transformation BindLetVar and BindVariableDefinition RAII guards erased map entries on destruction, but flat BindNode has no body -- the guard is destroyed when the handler returns, making the binding invisible to subsequent sibling statements. Under SSA each variable is bound exactly once, so the maps grow monotonically and cleanup is unnecessary. Remove the cleanup from both destructors to fix the bug. Also add a comment explaining the dead cse_v1 variable in test_s_tir_transform_inject_ptx_async_copy: CSE extracts (i < 12) before inject_ptx_async_copy replaces IfThenElse guards with new cast(int32, ...) expressions for predicated copies, leaving the CSE variable unused. --- .../primitive/layout_transformation.cc | 34 ++++++------------- src/tir/analysis/control_flow_graph.cc | 15 ++++---- ...t_s_tir_transform_inject_ptx_async_copy.py | 4 +++ 3 files changed, 21 insertions(+), 32 deletions(-) diff --git a/src/s_tir/schedule/primitive/layout_transformation.cc b/src/s_tir/schedule/primitive/layout_transformation.cc index 33ce0f1a23aa..3fd210e91409 100644 --- a/src/s_tir/schedule/primitive/layout_transformation.cc +++ b/src/s_tir/schedule/primitive/layout_transformation.cc @@ -625,37 +625,23 @@ class TransformLayoutPlanner : private StmtExprVisitor { Var var_; }; + // Under SSA, each variable is bound exactly once, so the lookup maps + // grow monotonically and cleanup is unnecessary. Omitting cleanup also + // ensures correctness for flat BindNode (which has no body): the + // binding must remain visible to subsequent sibling statements. struct BindVariableDefinition { BindVariableDefinition() {} - BindVariableDefinition(TransformLayoutPlanner* self, Var var, PrimExpr value) - : self_(self), var_(var) { + BindVariableDefinition(TransformLayoutPlanner* self, Var var, PrimExpr value) { if (auto loop_depth = self->LoopDependencyRange(value); loop_depth.has_value()) { - self_->loop_depth_lookup_[var_.get()] = loop_depth.value(); - self_->active_var_bindings_[var_.get()] = Substitute(value, self_->active_var_bindings_); - } - } - ~BindVariableDefinition() { - if (self_) { - self_->loop_depth_lookup_.erase(var_.get()); - self_->active_var_bindings_.erase(var_.get()); + self->loop_depth_lookup_[var.get()] = loop_depth.value(); + self->active_var_bindings_[var.get()] = Substitute(value, self->active_var_bindings_); } } + ~BindVariableDefinition() {} BindVariableDefinition(const BindVariableDefinition&) = delete; BindVariableDefinition& operator=(const BindVariableDefinition&) = delete; - BindVariableDefinition(BindVariableDefinition&& other) : BindVariableDefinition() { - swap(other); - } - BindVariableDefinition& operator=(BindVariableDefinition&& other) { - swap(other); - return *this; - } - void swap(BindVariableDefinition& other) { - std::swap(self_, other.self_); - std::swap(var_, other.var_); - } - - TransformLayoutPlanner* self_{nullptr}; - Var var_; + BindVariableDefinition(BindVariableDefinition&&) = default; + BindVariableDefinition& operator=(BindVariableDefinition&&) = default; }; struct BindBlockRealize { diff --git a/src/tir/analysis/control_flow_graph.cc b/src/tir/analysis/control_flow_graph.cc index ca594963e528..d5214f085cf6 100644 --- a/src/tir/analysis/control_flow_graph.cc +++ b/src/tir/analysis/control_flow_graph.cc @@ -588,18 +588,17 @@ class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer { BindActiveLoopVar& operator=(BindActiveLoopVar&&) = delete; }; - // Internal utility, context manager for tracking a variable binding + // Internal utility, context manager for tracking a variable binding. + // Under SSA, each variable is bound exactly once, so the maps grow + // monotonically and cleanup is unnecessary. Omitting cleanup also + // ensures correctness for flat BindNode (which has no body): the + // binding must remain visible to subsequent sibling statements. struct BindLetVar { - BindLetVar(ControlFlowGraphBuilder* self, Var var, PrimExpr value) : self(self), var(var) { + BindLetVar(ControlFlowGraphBuilder* self, Var var, PrimExpr value) { self->let_bindings_using_loop_.Set(var, value); self->loop_dependent_vars_.insert(var.get()); } - ~BindLetVar() { - self->loop_dependent_vars_.erase(var.get()); - self->let_bindings_using_loop_.erase(var); - } - ControlFlowGraphBuilder* self; - Var var; + ~BindLetVar() {} // Disable default-generated copy/move assignment and constructors BindLetVar(const BindLetVar&) = delete; diff --git a/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py b/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py index 53923c97bd27..f266ee539aae 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py @@ -243,6 +243,10 @@ def test_inject_async_copy_barrier(): tvm.testing.assert_allclose(B_nd.numpy(), A_np) +# Note: the expected output contains a dead CSE variable `cse_v1 = (i < 12)`. +# CSE extracts (i < 12) before inject_ptx_async_copy runs, but the latter +# replaces the original IfThenElse guards with new cast(int32, i < 12) +# expressions for predicated async copies, leaving cse_v1 unused. expected_cuda_script = r"""#include __forceinline__ __device__ unsigned int cast_smem_ptr_to_int(const void* const smem_ptr) From 01f04cd44690de2dd01a7849efb6f25681887c81 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 4 Mar 2026 21:26:20 +0000 Subject: [PATCH 31/34] [TIR] Fix three LetStmt-to-Bind refactor bugs from second review Bug 1 (inject_virtual_thread.cc): When a Bind in a SeqStmt touches vt_var, the VT loop must wrap the Bind together with all remaining siblings (which may reference the bound variable). Previously, the Bind handler wrapped only itself, breaking semantics. Rewrite the SeqStmt handler to pre-check Bind children and group them with remaining siblings before wrapping with InjectVTLoop. Bug 2 (lower_tvm_builtin.cc): MakeNdMemAllocWithScope was returning without re-visiting via StmtExprMutator::VisitStmt, leaving tvm_call_packed builtins in both the Bind value and the free_stmt unlowered. Re-wrap with VisitStmt and visit free_stmt before pushing to pending_frees. Bug 3 (frame.cc): BindFrameNode::ExitWithScope used SeqStmt constructor (which does not flatten) instead of SeqStmt::Flatten, creating nested SeqStmts. Also, when stmts is empty, emit just the Bind without wrapping in a SeqStmt with a spurious Evaluate(0). --- src/s_tir/transform/inject_virtual_thread.cc | 49 +++++++++++++++++-- src/script/ir_builder/tir/frame.cc | 9 +++- src/tir/transform/lower_tvm_builtin.cc | 5 +- .../test_tvmscript_ir_builder_tir.py | 7 +-- .../tvmscript/test_tvmscript_printer_tir.py | 4 +- 5 files changed, 62 insertions(+), 12 deletions(-) diff --git a/src/s_tir/transform/inject_virtual_thread.cc b/src/s_tir/transform/inject_virtual_thread.cc index 97f6ff87175c..dcffacbba97a 100644 --- a/src/s_tir/transform/inject_virtual_thread.cc +++ b/src/s_tir/transform/inject_virtual_thread.cc @@ -364,16 +364,55 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { } // Seq + // When a Bind child triggers VT injection, we group the Bind together with + // all remaining siblings (which may reference the bound variable) and wrap + // them as a single unit in the VT loop. This preserves the semantics that + // were implicit when Bind was LetStmt (where the body was nested inside). Stmt VisitStmt_(const SeqStmtNode* op) final { TVM_FFI_ICHECK_EQ(max_loop_depth_, 0); - auto fmutate = [this](const Stmt& s) { + ffi::Array new_seq; + bool changed = false; + for (size_t i = 0; i < op->seq.size(); ++i) { int temp = max_loop_depth_; max_loop_depth_ = 0; - Stmt ret = this->VisitStmt(s); + // For Bind children, pre-check if the value touches vt_var before + // visiting. If so, group the Bind with all remaining siblings and + // wrap the group with InjectVTLoop. + if (const auto* bind = op->seq[i].as(); bind && !vt_loop_injected_) { + // Visit just the value expression to probe for vt_var dependency. + TVM_FFI_ICHECK(!visit_touched_var_); + this->VisitExpr(bind->value); + if (visit_touched_var_) { + // Reset flag (InjectVTLoop will handle it). + visit_touched_var_ = false; + // Gather the original Bind + all remaining original siblings. + ffi::Array group; + for (size_t j = i; j < op->seq.size(); ++j) { + group.push_back(op->seq[j]); + } + Stmt grouped = group.size() == 1 ? group[0] : SeqStmt(group); + // before_mutation=true: InjectVTLoop will re-visit the entire group + // with vt_loop_injected_=true, properly substituting vt_var. + Stmt wrapped = InjectVTLoop(grouped, true); + new_seq.push_back(wrapped); + max_loop_depth_ = std::max(max_loop_depth_, temp); + changed = true; + // All remaining siblings consumed. + goto done; + } + // Value did not touch vt_var. Reset and visit the Bind normally. + visit_touched_var_ = false; + } + // Non-Bind child or Bind that does not touch vt_var: visit normally. + Stmt child = this->VisitStmt(op->seq[i]); max_loop_depth_ = std::max(max_loop_depth_, temp); - return ret; - }; - return StmtMutator::VisitSeqStmt_(op, false, fmutate); + if (!child.same_as(op->seq[i])) changed = true; + new_seq.push_back(child); + } + done: + if (!changed) return ffi::GetRef(op); + if (new_seq.size() == 1) return new_seq[0]; + return SeqStmt(new_seq); } // Allocate // AllocBuffer diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc index b0f50ca4504d..65e0e12e09ea 100644 --- a/src/script/ir_builder/tir/frame.cc +++ b/src/script/ir_builder/tir/frame.cc @@ -143,7 +143,14 @@ void AssertFrameNode::ExitWithScope() { void BindFrameNode::ExitWithScope() { TIRFrameNode::ExitWithScope(); - AddToParent(tvm::tir::SeqStmt({tvm::tir::Bind(var, value), AsStmt(stmts)})); + if (stmts.empty()) { + AddToParent(tvm::tir::Bind(var, value)); + } else { + ffi::Array combined; + combined.push_back(tvm::tir::Bind(var, value)); + for (const auto& s : stmts) combined.push_back(s); + AddToParent(tvm::tir::SeqStmt::Flatten(combined)); + } } void LaunchThreadFrameNode::ExitWithScope() { diff --git a/src/tir/transform/lower_tvm_builtin.cc b/src/tir/transform/lower_tvm_builtin.cc index f4e08bbf4973..c5a6cf59d26a 100644 --- a/src/tir/transform/lower_tvm_builtin.cc +++ b/src/tir/transform/lower_tvm_builtin.cc @@ -689,9 +689,12 @@ class BuiltinLower : public StmtExprMutator { {GetDeviceMethodName("free_nd"), device_type_.value(), device_id_.value(), storage_scope, let->var}); Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error); + // Visit the free_stmt so tvm_call_packed builtins inside it get lowered. + free_stmt = StmtExprMutator::VisitStmt(free_stmt); scope_.Current().pending_frees.push_back(free_stmt); - return SeqStmt({Bind(let->var, call_packed), null_check}); + // Re-visit so tvm_call_packed in the Bind value and null_check get lowered. + return StmtExprMutator::VisitStmt(SeqStmt({Bind(let->var, call_packed), null_check})); } private: diff --git a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py index 48808052e1e5..1bfced20cae8 100644 --- a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py +++ b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py @@ -317,12 +317,13 @@ def test_ir_builder_tir_assert(): def test_ir_builder_tir_let(): with IRBuilder() as ib: with T.Bind(tir.IntImm("int32", 2)) as v: - T.evaluate(0) + T.evaluate(1) # the let binding generated by IRBuilder let_actual = ib.get() - # the expected Bind + Evaluate sequence - let_expected = tir.SeqStmt([tir.Bind(T.int32(), tir.IntImm("int32", 2)), tir.Evaluate(0)]) + # the expected Bind + Evaluate sequence (using Evaluate(1) to avoid + # SeqStmt::Flatten stripping the no-op Evaluate(0)) + let_expected = tir.SeqStmt([tir.Bind(T.int32(), tir.IntImm("int32", 2)), tir.Evaluate(1)]) # Check if the generated ir is expected assert_structural_equal(let_actual, let_expected, map_free_vars=True) diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py b/tests/python/tvmscript/test_tvmscript_printer_tir.py index c6ffa2a5e56c..a6a3f3409f67 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_tir.py +++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py @@ -256,13 +256,13 @@ def test_bind(): with IRBuilder() as ib: with T.Bind(T.float32(10)) as v: ib.name("v", v) - T.evaluate(0) + T.evaluate(1) obj = ib.get() _assert_print( obj, """ v: T.float32 = T.float32(10.0) -T.evaluate(0) +T.evaluate(1) """, ) From 9877fba6f93fccfa5f4e85357c7b3e43e9e9303c Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 4 Mar 2026 22:45:12 +0000 Subject: [PATCH 32/34] [REFACTOR][TIR] Remove BindFrame, make Bind a flat non-frame statement Bind is now a direct statement like Evaluate -- it emits a Bind stmt to the current frame and returns the Var, with no context manager or RAII scope needed. Changes: - C++ ir_builder: Bind() creates var, calls AddToParent(tir::Bind(...)), returns var instead of BindFrame - Remove BindFrameNode/BindFrame classes from frame.h and frame.cc - Python ir_builder: Bind() returns Var instead of BindFrame - Parser: bind_assign_value and visit_ann_assign simplified to call T.Bind() directly without frame lifecycle management - Parser: visit_expr_stmt skips standalone Var results (from T.Bind()) instead of wrapping them in T.evaluate() - Remove BindFrame Python class from frame.py - Update all tests from `with T.Bind(...) as v:` to `v = T.Bind(...)` --- include/tvm/script/ir_builder/tir/frame.h | 42 ------------------ include/tvm/script/ir_builder/tir/ir.h | 9 ++-- python/tvm/script/ir_builder/tir/frame.py | 7 --- python/tvm/script/ir_builder/tir/ir.py | 16 ++++--- python/tvm/script/parser/tir/parser.py | 14 +++--- src/script/ir_builder/tir/frame.cc | 13 ------ src/script/ir_builder/tir/ir.cc | 23 +++++----- .../test_s_tir_transform_thread_sync.py | 32 +++++++------- .../test_tir_analysis_verify_well_formed.py | 44 +++++++++---------- .../test_tir_inline_private_functions.py | 30 ++++++------- .../test_tir_transform_common_subexpr_elim.py | 12 ++--- .../test_tir_transform_convert_ssa.py | 12 ++--- .../test_tvmscript_ir_builder_tir.py | 13 +++--- .../tvmscript/test_tvmscript_printer_tir.py | 11 +++-- .../tvmscript/test_tvmscript_roundtrip.py | 6 +-- 15 files changed, 117 insertions(+), 167 deletions(-) diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h index b9e2eb859ae4..0971890c40b9 100644 --- a/include/tvm/script/ir_builder/tir/frame.h +++ b/include/tvm/script/ir_builder/tir/frame.h @@ -341,48 +341,6 @@ class AssertFrame : public TIRFrame { TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(AssertFrame, TIRFrame, AssertFrameNode); }; -/*! - * \brief A frame represents the Bind (variable binding) statement. - * - * \sa BindFrameNode - */ -class BindFrameNode : public TIRFrameNode { - public: - /*! \brief The variable we bind to */ - tvm::tir::Var var; - /*! \brief The value we bind var to */ - PrimExpr value; - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("var", &BindFrameNode::var) - .def_ro("value", &BindFrameNode::value); - } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.BindFrame", BindFrameNode, TIRFrameNode); - - public: - /*! - * \brief The method called when exiting RAII scope. - * \sa tvm::support::With - */ - void ExitWithScope() final; -}; - -/*! - * \brief Managed reference to BindFrameNode. - * - * \sa BindFrameNode - */ -class BindFrame : public TIRFrame { - public: - explicit BindFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { - TVM_FFI_ICHECK(data != nullptr); - data_ = std::move(data); - } - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(BindFrame, TIRFrame, BindFrameNode); -}; - /*! * \brief The LaunchThreadFrameNode. * \note It is used only inside a PrimFunc. diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index 702423d17df5..fb6b5d26e624 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -295,15 +295,18 @@ AssertFrame Assert(PrimExpr condition, ffi::String error_kind, /*! * \brief Create a Bind (variable binding). + * + * Emits a flat Bind statement to the current frame and returns the bound variable. + * * \param value The value to be bound. * \param type_annotation The type annotation of the binding. * Usually it is used for fine-grained var typing, * particularly, PointerType. * \param var The variable to be bound. If not specified, a new variable will be created. - * \return The created BindFrame. + * \return The bound Var. */ -BindFrame Bind(PrimExpr value, ffi::Optional type_annotation = std::nullopt, - ffi::Optional var = std::nullopt); +Var Bind(PrimExpr value, ffi::Optional type_annotation = std::nullopt, + ffi::Optional var = std::nullopt); /*! * \brief The allocate node. diff --git a/python/tvm/script/ir_builder/tir/frame.py b/python/tvm/script/ir_builder/tir/frame.py index d6c966852103..f42e14677ba7 100644 --- a/python/tvm/script/ir_builder/tir/frame.py +++ b/python/tvm/script/ir_builder/tir/frame.py @@ -50,13 +50,6 @@ def __enter__(self) -> Var | list[Var]: # type: ignore[override] class AssertFrame(TIRFrame): ... -@_register_object("script.ir_builder.tir.BindFrame") -class BindFrame(TIRFrame): - def __enter__(self) -> Var: - super().__enter__() - return self.var - - @_register_object("script.ir_builder.tir.AllocateFrame") class AllocateFrame(TIRFrame): def __enter__(self) -> Buffer: diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 54c947265ac4..93af3b434c98 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -988,9 +988,11 @@ def Bind( # pylint: disable=invalid-name type_annotation: Type | None = None, # pylint: disable=redefined-outer-name *, var: Var | None = None, # pylint: disable=redefined-outer-name -) -> frame.BindFrame: +) -> Var: """Create a Bind (variable binding). + Emits a flat Bind statement to the current frame and returns the bound variable. + Parameters ---------- value : PrimExpr @@ -1003,8 +1005,8 @@ def Bind( # pylint: disable=invalid-name Returns ------- - bind_frame : frame.BindFrame - The result BindFrame. + var : Var + The bound variable. """ if type_annotation is not None: if callable(type_annotation): @@ -1028,7 +1030,7 @@ def let( v: Var, value: PrimExpr, body: PrimExpr = None, -) -> frame.BindFrame: +) -> Var: """Create a new let binding. Parameters @@ -1044,8 +1046,8 @@ def let( Returns ------- - res : frame.BindFrame - The result BindFrame. + res : Var + The bound variable. """ @deprecated("T.let", "T.Let") @@ -1053,7 +1055,7 @@ def let_expr(v: Var, value: PrimExpr, body: PrimExpr) -> PrimExpr: return tir.Let(v, value, body) @deprecated("T.let", "T.Bind") - def let_stmt(v: Var, value: PrimExpr) -> frame.BindFrame: + def let_stmt(v: Var, value: PrimExpr) -> Var: return Bind(value, var=v) if body is None: diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index bfd856b20cac..660085ba3cc5 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -145,11 +145,8 @@ def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) - return value else: value = tvm.runtime.convert(value) - frame = T.Bind(value) - var = frame.var + var = T.Bind(value) IRBuilder.name(var_name, var) - frame.add_callback(partial(frame.__exit__, None, None, None)) - frame.__enter__() return var @@ -352,9 +349,7 @@ def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None: if not isinstance(ann_var, Var): self.report_error(node.annotation, "Annotation should be Var") self.eval_assign(target=lhs, source=ann_var, bind_value=bind_assign_value) - frame = T.Bind(rhs, var=ann_var) - frame.add_callback(partial(frame.__exit__, None, None, None)) - frame.__enter__() + T.Bind(rhs, var=ann_var) @dispatch.register(token="tir", type_name="With") @@ -471,6 +466,11 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: elif isinstance(res, Frame): res.add_callback(partial(res.__exit__, None, None, None)) res.__enter__() + elif isinstance(res, Var): + # Standalone Var expression (e.g. from T.Bind(value, var=v)) -- + # the Bind statement was already emitted to the parent frame by the FFI call, + # so just discard the returned Var. + pass elif isinstance(res, PrimExpr): T.evaluate(res) elif isinstance(res, int | bool): diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc index 65e0e12e09ea..999527813d28 100644 --- a/src/script/ir_builder/tir/frame.cc +++ b/src/script/ir_builder/tir/frame.cc @@ -35,7 +35,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { BlockInitFrameNode::RegisterReflection(); ForFrameNode::RegisterReflection(); AssertFrameNode::RegisterReflection(); - BindFrameNode::RegisterReflection(); LaunchThreadFrameNode::RegisterReflection(); AllocateFrameNode::RegisterReflection(); AttrFrameNode::RegisterReflection(); @@ -141,18 +140,6 @@ void AssertFrameNode::ExitWithScope() { } } -void BindFrameNode::ExitWithScope() { - TIRFrameNode::ExitWithScope(); - if (stmts.empty()) { - AddToParent(tvm::tir::Bind(var, value)); - } else { - ffi::Array combined; - combined.push_back(tvm::tir::Bind(var, value)); - for (const auto& s : stmts) combined.push_back(s); - AddToParent(tvm::tir::SeqStmt::Flatten(combined)); - } -} - void LaunchThreadFrameNode::ExitWithScope() { TIRFrameNode::ExitWithScope(); AddToParent(tvm::tir::AttrStmt(iter_var, attr_key, extent, AsStmt(stmts))); diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 2bade41b0122..197e52b45636 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -460,17 +460,18 @@ AssertFrame Assert(PrimExpr condition, ffi::String error_kind, return AssertFrame(n); } -BindFrame Bind(PrimExpr value, ffi::Optional type_annotation, ffi::Optional var) { - ObjectPtr n = ffi::make_object(); - if (var.defined()) { - n->var = var.value(); - } else if (type_annotation.defined()) { - n->var = Var("v", type_annotation.value()); - } else { - n->var = Var("v", value.dtype()); - } - n->value = value; - return BindFrame(n); +Var Bind(PrimExpr value, ffi::Optional type_annotation, ffi::Optional var) { + Var bind_var = [&]() { + if (var.defined()) { + return var.value(); + } else if (type_annotation.defined()) { + return Var("v", type_annotation.value()); + } else { + return Var("v", value.dtype()); + } + }(); + AddToParent(tvm::tir::Bind(bind_var, value)); + return bind_var; } LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) { diff --git a/tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py b/tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py index b1a1558a7482..f1cdde039740 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py @@ -113,14 +113,14 @@ def func(A: T.Buffer((16 * 512), "float32")): A_shared_1[ax0] = A[blockIdx_x * 512 + ax0] in_thread_A_temp_1 = T.decl_buffer((1,), data=in_thread_A_temp, scope="local") in_thread_A_temp_1[0] = T.float32(0) - with T.Bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x]) as A_temp: - in_thread_A_temp_1[0] = A_temp - with T.Bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 128]) as A_temp: - in_thread_A_temp_1[0] = A_temp - with T.Bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 256]) as A_temp: - in_thread_A_temp_1[0] = A_temp - with T.Bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 384]) as A_temp: - in_thread_A_temp_1[0] = A_temp + A_temp_1 = T.Bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x]) + in_thread_A_temp_1[0] = A_temp_1 + A_temp_2 = T.Bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 128]) + in_thread_A_temp_1[0] = A_temp_2 + A_temp_3 = T.Bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 256]) + in_thread_A_temp_1[0] = A_temp_3 + A_temp_4 = T.Bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 384]) + in_thread_A_temp_1[0] = A_temp_4 cross_thread_A_temp_1 = T.decl_buffer((1,), data=cross_thread_A_temp, scope="local") with T.attr( T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), @@ -148,14 +148,14 @@ def expected(A: T.Buffer((8192,), "float32")): in_thread_A_temp_1_1 = T.decl_buffer((1,), data=in_thread_A_temp_1, scope="local") in_thread_A_temp_1_1[0] = T.float32(0) T.tvm_storage_sync("shared") - with T.Bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x]) as A_temp: - in_thread_A_temp_1_1[0] = A_temp - with T.Bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 128]) as A_temp: - in_thread_A_temp_1_1[0] = A_temp - with T.Bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 256]) as A_temp: - in_thread_A_temp_1_1[0] = A_temp - with T.Bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 384]) as A_temp: - in_thread_A_temp_1_1[0] = A_temp + A_temp_1 = T.Bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x]) + in_thread_A_temp_1_1[0] = A_temp_1 + A_temp_2 = T.Bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 128]) + in_thread_A_temp_1_1[0] = A_temp_2 + A_temp_3 = T.Bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 256]) + in_thread_A_temp_1_1[0] = A_temp_3 + A_temp_4 = T.Bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 384]) + in_thread_A_temp_1_1[0] = A_temp_4 cross_thread_A_temp_1_1 = T.decl_buffer((1,), data=cross_thread_A_temp_1, scope="local") T.attr( T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), diff --git a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py index ee4e6ee2cda4..73347c0728dc 100644 --- a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py +++ b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py @@ -92,9 +92,9 @@ def test_error_for_nested_rebind_usage(): @T.prim_func(check_well_formed=False) def func(): i = T.int32() - with T.Bind(42, var=i): - with T.Bind(42, var=i): - T.evaluate(i) + T.Bind(42, var=i) + T.Bind(42, var=i) + T.evaluate(i) with pytest.raises( ValueError, match="ill-formed, due to multiple nested definitions of variable i" @@ -113,10 +113,10 @@ def test_error_for_repeated_binding(): @T.prim_func(check_well_formed=False) def func(): i = T.int32() - with T.Bind(42, var=i): - T.evaluate(i) - with T.Bind(17, var=i): - T.evaluate(i) + T.Bind(42, var=i) + T.evaluate(i) + T.Bind(17, var=i) + T.evaluate(i) with pytest.raises(ValueError, match="multiple nested definitions of variable i"): tvm.tir.analysis.verify_well_formed(func) @@ -131,13 +131,13 @@ def test_error_for_cross_function_reuse(): class mod: @T.prim_func def func1(): - with T.Bind(42, var=i): - T.evaluate(i) + T.Bind(42, var=i) + T.evaluate(i) @T.prim_func def func2(): - with T.Bind(42, var=i): - T.evaluate(i) + T.Bind(42, var=i) + T.evaluate(i) with pytest.raises(ValueError, match="multiple definitions of variable i"): tvm.tir.analysis.verify_well_formed(mod) @@ -295,11 +295,11 @@ def test_error_message_without_previous_definition_location(): def func(): x = T.int32() - with T.Bind(42, var=x): - T.evaluate(x) + T.Bind(42, var=x) + T.evaluate(x) - with T.Bind(99, var=x): # This should trigger the error - T.evaluate(x) + T.Bind(99, var=x) # This should trigger the error + T.evaluate(x) with pytest.raises(ValueError) as exc_info: tvm.tir.analysis.verify_well_formed(func, assert_mode=True) @@ -322,9 +322,9 @@ def test_error_message_with_previous_definition_location(): def func(): x = T.int32() - with T.Bind(42, var=x): - with T.Bind(99, var=x): # This should trigger the error - T.evaluate(x) + T.Bind(42, var=x) + T.Bind(99, var=x) # This should trigger the error + T.evaluate(x) with pytest.raises(ValueError) as exc_info: tvm.tir.analysis.verify_well_formed(func, assert_mode=True) @@ -351,11 +351,11 @@ def test_sequential_redefinition_with_location(): def func(): x = T.int32() - with T.Bind(1, var=x): - T.evaluate(x) + T.Bind(1, var=x) + T.evaluate(x) - with T.Bind(2, var=x): # This should trigger the error - T.evaluate(x) + T.Bind(2, var=x) # This should trigger the error + T.evaluate(x) with pytest.raises(ValueError) as exc_info: tvm.tir.analysis.verify_well_formed(func, assert_mode=True) diff --git a/tests/python/tir-transform/test_tir_inline_private_functions.py b/tests/python/tir-transform/test_tir_inline_private_functions.py index c5e3f2a07356..e681073fa6f4 100644 --- a/tests/python/tir-transform/test_tir_inline_private_functions.py +++ b/tests/python/tir-transform/test_tir_inline_private_functions.py @@ -150,21 +150,21 @@ def subroutine(A_data: T.handle("float32"), B_data: T.handle("float32")): class Expected: @T.prim_func def main(A: T.Buffer([80, 16], "float32"), B: T.Buffer([64, 16], "float32")): - with T.Bind(T.address_of(A[0, 0]), var=T.handle("float32")) as A_data_1: - A_1 = T.decl_buffer(16, "float32", data=A_data_1) - B_data_1: T.handle("float32") = T.address_of(B[0, 0]) - B_1 = T.decl_buffer(16, "float32", data=B_data_1) - for i in range(16): - with T.sblock("scalar_mul_1"): - B_1[i] = A_1[i] * 2.0 - - with T.Bind(T.address_of(A[1, 0]), var=T.handle("float32")) as A_data_2: - A_2 = T.decl_buffer(16, "float32", data=A_data_2) - B_data_2: T.handle("float32") = T.address_of(B[1, 0]) - B_2 = T.decl_buffer(16, "float32", data=B_data_2) - for i in range(16): - with T.sblock("scalar_mul_2"): - B_2[i] = A_2[i] * 2.0 + A_data_1 = T.Bind(T.address_of(A[0, 0]), T.handle("float32")) + A_1 = T.decl_buffer(16, "float32", data=A_data_1) + B_data_1: T.handle("float32") = T.address_of(B[0, 0]) + B_1 = T.decl_buffer(16, "float32", data=B_data_1) + for i in range(16): + with T.sblock("scalar_mul_1"): + B_1[i] = A_1[i] * 2.0 + + A_data_2 = T.Bind(T.address_of(A[1, 0]), T.handle("float32")) + A_2 = T.decl_buffer(16, "float32", data=A_data_2) + B_data_2: T.handle("float32") = T.address_of(B[1, 0]) + B_2 = T.decl_buffer(16, "float32", data=B_data_2) + for i in range(16): + with T.sblock("scalar_mul_2"): + B_2[i] = A_2[i] * 2.0 class TestInlineCallOccurringInExpression(BaseTestCase): diff --git a/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py b/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py index 506ec99d81e0..c9c3703da210 100644 --- a/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py +++ b/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py @@ -352,9 +352,9 @@ def func_distributivity( def func_distributivity_expected( B: T.Buffer((50,), "int32"), i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32 ) -> None: - with T.Bind((y + z) * x) as cse_v1: - B[i1] = cse_v1 - B[i2] = cse_v1 + cse_v1 = T.Bind((y + z) * x) + B[i1] = cse_v1 + B[i2] = cse_v1 @T.prim_func @@ -369,9 +369,9 @@ def func_associativity( def func_associativity_expected( B: T.Buffer((50,), "int32"), i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32 ) -> None: - with T.Bind(x + y + z) as cse_v1: - B[i1] = cse_v1 - B[i2] = cse_v1 + cse_v1 = T.Bind(x + y + z) + B[i1] = cse_v1 + B[i2] = cse_v1 def _check(original, transformed): diff --git a/tests/python/tir-transform/test_tir_transform_convert_ssa.py b/tests/python/tir-transform/test_tir_transform_convert_ssa.py index 1b6985cf69db..b37d337c28b0 100644 --- a/tests/python/tir-transform/test_tir_transform_convert_ssa.py +++ b/tests/python/tir-transform/test_tir_transform_convert_ssa.py @@ -42,10 +42,10 @@ def test_reuse_in_sequential_bind(): @T.prim_func(private=True) def expected(): - with T.Bind(T.int32(16)) as var1: - T.evaluate(var1) - with T.Bind(T.int32(32)) as var2: - T.evaluate(var2) + var1 = T.Bind(T.int32(16)) + T.evaluate(var1) + var2 = T.Bind(T.int32(32)) + T.evaluate(var2) mod = tvm.IRModule.from_expr(before) mod = tvm.tir.transform.ConvertSSA()(mod) @@ -108,8 +108,8 @@ def test_reused_var_across_module(): @T.prim_func(private=True) def func(): - with T.Bind(10) as var: - T.evaluate(var) + var = T.Bind(10) + T.evaluate(var) before = tvm.IRModule( { diff --git a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py index 1bfced20cae8..797217b8d328 100644 --- a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py +++ b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py @@ -314,19 +314,20 @@ def test_ir_builder_tir_assert(): assert_structural_equal(assert_actual, assert_expected, map_free_vars=True) -def test_ir_builder_tir_let(): +def test_ir_builder_tir_bind(): + # Test that T.Bind emits a flat Bind statement and returns the Var. with IRBuilder() as ib: - with T.Bind(tir.IntImm("int32", 2)) as v: - T.evaluate(1) + v = T.Bind(tir.IntImm("int32", 2)) # the let binding generated by IRBuilder let_actual = ib.get() - # the expected Bind + Evaluate sequence (using Evaluate(1) to avoid - # SeqStmt::Flatten stripping the no-op Evaluate(0)) - let_expected = tir.SeqStmt([tir.Bind(T.int32(), tir.IntImm("int32", 2)), tir.Evaluate(1)]) + # Bind is now flat (no body), so a single Bind stmt is emitted. + let_expected = tir.Bind(T.int32(), tir.IntImm("int32", 2)) # Check if the generated ir is expected assert_structural_equal(let_actual, let_expected, map_free_vars=True) + # Check that the returned value is a Var + assert isinstance(v, tir.Var) def test_ir_builder_tir_thread(): diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py b/tests/python/tvmscript/test_tvmscript_printer_tir.py index a6a3f3409f67..8a814089661a 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_tir.py +++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py @@ -254,15 +254,20 @@ def test_for(): def test_bind(): with IRBuilder() as ib: - with T.Bind(T.float32(10)) as v: + with T.prim_func(): + v = T.Bind(T.float32(10)) ib.name("v", v) T.evaluate(1) obj = ib.get() _assert_print( obj, """ -v: T.float32 = T.float32(10.0) -T.evaluate(1) +# from tvm.script import tir as T + +@T.prim_func(private=True) +def main(): + v: T.float32 = T.float32(10.0) + T.evaluate(1) """, ) diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index 37e241ac6604..26cce98bf9d3 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -2757,9 +2757,9 @@ def func(): def bind_var(): @T.prim_func def func(): - with T.Bind(0) as x: - with T.Bind(0) as y: - T.evaluate(0) + x = T.Bind(0) + y = T.Bind(0) + T.evaluate(0) T.evaluate(0) return func From 98b8bf16f4cf37c7ae190c1e66aee454c39ca951 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 4 Mar 2026 23:09:31 +0000 Subject: [PATCH 33/34] [refactor] inject_virtual_thread: replace goto with break, improve comment --- src/s_tir/transform/inject_virtual_thread.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/s_tir/transform/inject_virtual_thread.cc b/src/s_tir/transform/inject_virtual_thread.cc index dcffacbba97a..f4698821a6a8 100644 --- a/src/s_tir/transform/inject_virtual_thread.cc +++ b/src/s_tir/transform/inject_virtual_thread.cc @@ -367,7 +367,9 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { // When a Bind child triggers VT injection, we group the Bind together with // all remaining siblings (which may reference the bound variable) and wrap // them as a single unit in the VT loop. This preserves the semantics that - // were implicit when Bind was LetStmt (where the body was nested inside). + // With flat Bind (no body), a Bind whose value touches vt_var must be + // grouped with all remaining siblings and wrapped in a VT loop together. + // This preserves the scoping that was implicit when Bind carried a body. Stmt VisitStmt_(const SeqStmtNode* op) final { TVM_FFI_ICHECK_EQ(max_loop_depth_, 0); ffi::Array new_seq; @@ -397,8 +399,8 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { new_seq.push_back(wrapped); max_loop_depth_ = std::max(max_loop_depth_, temp); changed = true; - // All remaining siblings consumed. - goto done; + // All remaining siblings consumed — exit loop. + break; } // Value did not touch vt_var. Reset and visit the Bind normally. visit_touched_var_ = false; @@ -409,7 +411,6 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { if (!child.same_as(op->seq[i])) changed = true; new_seq.push_back(child); } - done: if (!changed) return ffi::GetRef(op); if (new_seq.size() == 1) return new_seq[0]; return SeqStmt(new_seq); From a89b7dc043c0d709efd376495f974ff86b34e633 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 5 Mar 2026 03:30:50 +0000 Subject: [PATCH 34/34] [fixup] Remove redundant AllocBuffer/DeclBuffer overrides from IndexDataTypeRewriter Base StmtMutator now handles these via VisitBufferDef (from upstream PR #18873). --- src/tir/ir/data_type_rewriter.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/tir/ir/data_type_rewriter.h b/src/tir/ir/data_type_rewriter.h index 46e6e3b92a12..e19c555c6ed0 100644 --- a/src/tir/ir/data_type_rewriter.h +++ b/src/tir/ir/data_type_rewriter.h @@ -110,8 +110,6 @@ class IndexDataTypeRewriter : public DataTypeLegalizer { PrimExpr VisitExpr_(const BufferLoadNode* op) override; ffi::Array VisitIndices(ffi::Array indices); Stmt VisitStmt_(const IfThenElseNode* op) override; - Stmt VisitStmt_(const DeclBufferNode* op) override; - Stmt VisitStmt_(const AllocBufferNode* op) override; Stmt VisitStmt_(const BindNode* op) override; PrimExpr VisitExpr_(const EQNode* op) override; PrimExpr VisitExpr_(const NENode* op) override;