diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h index 4d889cc4d222..0971890c40b9 100644 --- a/include/tvm/script/ir_builder/tir/frame.h +++ b/include/tvm/script/ir_builder/tir/frame.h @@ -341,48 +341,6 @@ class AssertFrame : public TIRFrame { TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(AssertFrame, TIRFrame, AssertFrameNode); }; -/*! - * \brief A frame represents the let binding expression, which binds a var. - * - * \sa LetFrameNode - */ -class LetFrameNode : public TIRFrameNode { - public: - /*! \brief The variable we bind to */ - tvm::tir::Var var; - /*! \brief The value we bind var to */ - PrimExpr value; - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("var", &LetFrameNode::var) - .def_ro("value", &LetFrameNode::value); - } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.LetFrame", LetFrameNode, TIRFrameNode); - - public: - /*! - * \brief The method called when exiting RAII scope. - * \sa tvm::support::With - */ - void ExitWithScope() final; -}; - -/*! - * \brief Managed reference to LetFrameNode. - * - * \sa LetFrameNode - */ -class LetFrame : public TIRFrame { - public: - explicit LetFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { - TVM_FFI_ICHECK(data != nullptr); - data_ = std::move(data); - } - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(LetFrame, TIRFrame, LetFrameNode); -}; - /*! * \brief The LaunchThreadFrameNode. * \note It is used only inside a PrimFunc. diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index b7b6aa8f3a47..fb6b5d26e624 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -294,16 +294,19 @@ AssertFrame Assert(PrimExpr condition, ffi::String error_kind, ffi::Array message_parts); /*! - * \brief The let binding. + * \brief Create a Bind (variable binding). + * + * Emits a flat Bind statement to the current frame and returns the bound variable. + * * \param value The value to be bound. - * \param type_annotation The type annotation of the let binding. + * \param type_annotation The type annotation of the binding. * Usually it is used for fine-grained var typing, * particularly, PointerType. * \param var The variable to be bound. If not specified, a new variable will be created. - * \return The created LetFrame. + * \return The bound Var. */ -LetFrame LetStmt(PrimExpr value, ffi::Optional type_annotation = std::nullopt, - ffi::Optional var = std::nullopt); +Var Bind(PrimExpr value, ffi::Optional type_annotation = std::nullopt, + ffi::Optional var = std::nullopt); /*! * \brief The allocate node. diff --git a/include/tvm/script/printer/ir_docsifier_functor.h b/include/tvm/script/printer/ir_docsifier_functor.h index 2cc1782d9240..32363a434f46 100644 --- a/include/tvm/script/printer/ir_docsifier_functor.h +++ b/include/tvm/script/printer/ir_docsifier_functor.h @@ -81,6 +81,7 @@ class IRDocsifierFunctor { << runtime::Object::TypeIndex2Key(type_index) << " (token: " << token << ")" << ". ObjectType: " << obj->GetTypeKey() << ". Object: " << obj; + TVM_FFI_UNREACHABLE(); } /*! diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 0ded1e977fa2..ab92e3550173 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -68,37 +68,38 @@ class Stmt : public ObjectRef { }; /*! - * \brief Let binding, bind var to value, then run body. + * \brief Bind a variable to a value in the enclosing scope. + * + * BindNode has no body field. The bound variable is visible + * in all subsequent statements within the same enclosing scope (SeqStmt, + * ForNode.body, etc.). This enables flat (non-nested) IR sequences. */ -class LetStmtNode : public StmtNode { +class BindNode : public StmtNode { public: - /*! \brief The variable. */ + /*! \brief The variable being bound. */ Var var; - /*! \brief The value to be bound. */ + /*! \brief The value to bind to the variable. */ PrimExpr value; - /*! \brief The body block. */ - Stmt body; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("var", &LetStmtNode::var, refl::AttachFieldFlag::SEqHashDef()) - .def_ro("value", &LetStmtNode::value) - .def_ro("body", &LetStmtNode::body); + refl::ObjectDef() + .def_ro("var", &BindNode::var, refl::AttachFieldFlag::SEqHashDef()) + .def_ro("value", &BindNode::value); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.LetStmt", LetStmtNode, StmtNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Bind", BindNode, StmtNode); }; /*! - * \brief Managed reference to LetStmtNode. - * \sa LetStmtNode + * \brief Managed reference to BindNode. + * \sa BindNode */ -class LetStmt : public Stmt { +class Bind : public Stmt { public: - TVM_DLL LetStmt(Var var, PrimExpr value, Stmt body, Span span = Span()); + TVM_DLL Bind(Var var, PrimExpr value, Span span = Span()); - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(LetStmt, Stmt, LetStmtNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(LetStmtNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Bind, Stmt, BindNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(BindNode); }; /*! @@ -978,6 +979,7 @@ inline const char* ForKind2String(ForKind t) { return "thread_binding"; } TVM_FFI_THROW(InternalError) << "Unknown ForKind" << t; + TVM_FFI_UNREACHABLE(); } } // namespace tir diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index e86c6bb125dd..d99cdfb84e59 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -84,7 +84,7 @@ class StmtFunctor { return vtable(n, this, std::forward(args)...); } // Functions that can be overriden by subclass - virtual R VisitStmt_(const LetStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const BindNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const AttrStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const IfThenElseNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const ForNode* op, Args... args) STMT_FUNCTOR_DEFAULT; @@ -106,7 +106,7 @@ class StmtFunctor { // initialize the vtable. static FType InitVTable() { FType vtable; - IR_STMT_FUNCTOR_DISPATCH(LetStmtNode); + IR_STMT_FUNCTOR_DISPATCH(BindNode); IR_STMT_FUNCTOR_DISPATCH(AttrStmtNode); IR_STMT_FUNCTOR_DISPATCH(IfThenElseNode); IR_STMT_FUNCTOR_DISPATCH(ForNode); @@ -159,9 +159,9 @@ class TVM_DLL StmtVisitor : protected StmtFunctor { */ virtual void VisitBufferUse(const Buffer& buffer); // statement visitor + void VisitStmt_(const BindNode* op) override; void VisitStmt_(const AttrStmtNode* op) override; void VisitStmt_(const IfThenElseNode* op) override; - void VisitStmt_(const LetStmtNode* op) override; void VisitStmt_(const ForNode* op) override; void VisitStmt_(const WhileNode* op) override; void VisitStmt_(const AllocBufferNode* op) override; @@ -273,9 +273,9 @@ class TVM_DLL StmtMutator : protected StmtFunctor { */ virtual Buffer VisitBufferUse(const Buffer& buffer); // statement visitor + Stmt VisitStmt_(const BindNode* op) override; Stmt VisitStmt_(const AttrStmtNode* op) override; Stmt VisitStmt_(const IfThenElseNode* op) override; - Stmt VisitStmt_(const LetStmtNode* op) override; Stmt VisitStmt_(const ForNode* op) override; Stmt VisitStmt_(const WhileNode* op) override; Stmt VisitStmt_(const AllocBufferNode* op) override; diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index e83064b86489..b4106f2d2e9f 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -42,7 +42,7 @@ namespace tir { * - Allocate * - For * - Let - * - LetStmt + * - Bind */ class VarNode : public PrimExprNode { public: diff --git a/python/tvm/script/ir_builder/tir/frame.py b/python/tvm/script/ir_builder/tir/frame.py index 478fec212397..f42e14677ba7 100644 --- a/python/tvm/script/ir_builder/tir/frame.py +++ b/python/tvm/script/ir_builder/tir/frame.py @@ -50,13 +50,6 @@ def __enter__(self) -> Var | list[Var]: # type: ignore[override] class AssertFrame(TIRFrame): ... -@_register_object("script.ir_builder.tir.LetFrame") -class LetFrame(TIRFrame): - def __enter__(self) -> Var: - super().__enter__() - return self.var - - @_register_object("script.ir_builder.tir.AllocateFrame") class AllocateFrame(TIRFrame): def __enter__(self) -> Buffer: diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 26325ee74244..93af3b434c98 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -983,35 +983,37 @@ def Assert(condition: PrimExpr, message, error_kind: str = "RuntimeError") -> fr return _ffi_api.Assert(condition, error_kind, message) # type: ignore[attr-defined] # pylint: disable=no-member -def LetStmt( # pylint: disable=invalid-name +def Bind( # pylint: disable=invalid-name value: PrimExpr, type_annotation: Type | None = None, # pylint: disable=redefined-outer-name *, var: Var | None = None, # pylint: disable=redefined-outer-name -) -> frame.LetFrame: - """Create a LetStmt binding +) -> Var: + """Create a Bind (variable binding). + + Emits a flat Bind statement to the current frame and returns the bound variable. Parameters ---------- value : PrimExpr The value to be bound. type_annotation : Optional[Type] = None - The type annotation of the let binding. Usually it is used for fine-grained var typing, + The type annotation of the binding. Usually it is used for fine-grained var typing, particularly, PointerType. var : Optional[Var] = None The variable to bind. If not specified, a new variable will be created. Returns ------- - let_frame : frame.LetFrame - The result LetFrame. + var : Var + The bound variable. """ if type_annotation is not None: if callable(type_annotation): type_annotation = type_annotation() if isinstance(type_annotation, Var): type_annotation = type_annotation.type_annotation - return _ffi_api.LetStmt(value, type_annotation, var) # type: ignore[attr-defined] # pylint: disable=no-member + return _ffi_api.Bind(value, type_annotation, var) # type: ignore[attr-defined] # pylint: disable=no-member def Let( # pylint: disable=invalid-name @@ -1028,7 +1030,7 @@ def let( v: Var, value: PrimExpr, body: PrimExpr = None, -) -> frame.LetFrame: +) -> Var: """Create a new let binding. Parameters @@ -1044,17 +1046,17 @@ def let( Returns ------- - res : frame.LetFrame - The result LetFrame. + res : Var + The bound variable. """ @deprecated("T.let", "T.Let") def let_expr(v: Var, value: PrimExpr, body: PrimExpr) -> PrimExpr: return tir.Let(v, value, body) - @deprecated("T.let", "T.LetStmt") - def let_stmt(v: Var, value: PrimExpr) -> frame.LetFrame: - return _ffi_api.LegacyLetStmt(v, value) # type: ignore[attr-defined] # pylint: disable=no-member + @deprecated("T.let", "T.Bind") + def let_stmt(v: Var, value: PrimExpr) -> Var: + return Bind(value, var=v) if body is None: return let_stmt(v, value) @@ -2343,7 +2345,7 @@ def wrapped(*args, **kwargs): "Call", "CallEffectKind", "let", - "LetStmt", + "Bind", "Let", "IterVar", "CommReducer", diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py index eb617c85c9bd..d23358d93b22 100644 --- a/python/tvm/script/parser/core/parser.py +++ b/python/tvm/script/parser/core/parser.py @@ -528,6 +528,9 @@ def _duplicate_lhs_check(self, target: doc.expr) -> bool | set[str]: return {target.id} elif isinstance(target, doc.Starred): return self._duplicate_lhs_check(target.value) + elif isinstance(target, doc.Attribute): + # Attribute assignment like packedB.data = ..., treated as rebinding. + return {target.attr} else: self.report_error(target, "Invalid type in assign statement") raise NotImplementedError diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index e20629795cf9..660085ba3cc5 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -145,11 +145,8 @@ def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) - return value else: value = tvm.runtime.convert(value) - frame = T.LetStmt(value) - var = frame.var + var = T.Bind(value) IRBuilder.name(var_name, var) - frame.add_callback(partial(frame.__exit__, None, None, None)) - frame.__enter__() return var @@ -352,9 +349,7 @@ def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None: if not isinstance(ann_var, Var): self.report_error(node.annotation, "Annotation should be Var") self.eval_assign(target=lhs, source=ann_var, bind_value=bind_assign_value) - frame = T.LetStmt(rhs, var=ann_var) - frame.add_callback(partial(frame.__exit__, None, None, None)) - frame.__enter__() + T.Bind(rhs, var=ann_var) @dispatch.register(token="tir", type_name="With") @@ -471,6 +466,11 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: elif isinstance(res, Frame): res.add_callback(partial(res.__exit__, None, None, None)) res.__enter__() + elif isinstance(res, Var): + # Standalone Var expression (e.g. from T.Bind(value, var=v)) -- + # the Bind statement was already emitted to the parent frame by the FFI call, + # so just discard the returned Var. + pass elif isinstance(res, PrimExpr): T.evaluate(res) elif isinstance(res, int | bool): diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index c645dccda3b8..caa94949807c 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -29,7 +29,7 @@ from .expr import Select, BufferLoad, ProducerLoad, Ramp, Broadcast, Shuffle from .expr import Call, CallEffectKind, Let, IterVar, CommReducer -from .stmt import Stmt, LetStmt, AssertStmt, ForKind, For, While +from .stmt import Stmt, Bind, AssertStmt, ForKind, For, While from .stmt import ( BufferStore, AllocBuffer, diff --git a/python/tvm/tir/functor.py b/python/tvm/tir/functor.py index 5b19fd19b2c9..9c684d8d990b 100644 --- a/python/tvm/tir/functor.py +++ b/python/tvm/tir/functor.py @@ -65,12 +65,12 @@ AllocBuffer, AssertStmt, AttrStmt, + Bind, BufferStore, DeclBuffer, Evaluate, For, IfThenElse, - LetStmt, SBlock, SBlockRealize, SeqStmt, @@ -160,7 +160,7 @@ def __init__( f_visit_stmt: Callable | None = None, f_visit_expr: Callable | None = None, # Stmt - f_visit_let_stmt: Callable | None = None, + f_visit_bind: Callable | None = None, f_visit_attr_stmt: Callable | None = None, f_visit_if_then_else: Callable | None = None, f_visit_for: Callable | None = None, @@ -214,7 +214,7 @@ def __init__( f_visit_stmt, f_visit_expr, # Stmt - f_visit_let_stmt, + f_visit_bind, f_visit_attr_stmt, f_visit_if_then_else, f_visit_for, @@ -277,7 +277,7 @@ class PyStmtExprVisitor: "visit_stmt", "visit_expr", # Stmt - "visit_let_stmt_", + "visit_bind_", "visit_attr_stmt_", "visit_if_then_else_", "visit_for_", @@ -373,17 +373,17 @@ def visit_if_then_else_(self, op: IfThenElse) -> None: print("visit_if_then_else_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore - def visit_let_stmt_(self, op: LetStmt) -> None: - """Visit LetStmt. - Users can customize this function to overwrite VisitStmt_(const LetStmtNode* op) + def visit_bind_(self, op: Bind) -> None: + """Visit Bind. + Users can customize this function to overwrite VisitStmt_(const BindNode* op) on the C++ side. Parameters ---------- - op : LetStmt - The LetStmt to be visited. + op : Bind + The Bind node to be visited. """ - print("visit_let_stmt_", op) + print("visit_bind_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore def visit_for_(self, op: For) -> None: @@ -961,7 +961,7 @@ def __init__( f_visit_stmt: Callable | None = None, f_visit_expr: Callable | None = None, # Stmt - f_visit_let_stmt: Callable | None = None, + f_visit_bind: Callable | None = None, f_visit_attr_stmt: Callable | None = None, f_visit_if_then_else: Callable | None = None, f_visit_for: Callable | None = None, @@ -1015,7 +1015,7 @@ def __init__( f_visit_stmt, f_visit_expr, # Stmt - f_visit_let_stmt, + f_visit_bind, f_visit_attr_stmt, f_visit_if_then_else, f_visit_for, @@ -1078,7 +1078,7 @@ class PyStmtExprMutator: "visit_stmt", "visit_expr", # Stmt - "visit_let_stmt_", + "visit_bind_", "visit_attr_stmt_", "visit_if_then_else_", "visit_for_", @@ -1196,15 +1196,15 @@ def visit_if_then_else_(self, op: IfThenElse) -> Stmt: """ return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op) # type: ignore - def visit_let_stmt_(self, op: LetStmt) -> Stmt: - """Visit LetStmt. - Users can customize this function to overwrite VisitStmt_(const LetStmtNode* op) + def visit_bind_(self, op: Bind) -> Stmt: + """Visit Bind. + Users can customize this function to overwrite VisitStmt_(const BindNode* op) on the C++ side. Parameters ---------- - op : LetStmt - The LetStmt to be visited. + op : Bind + The Bind node to be visited. Returns ------- diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index cc945ba4fcec..f374aaa879ed 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -44,9 +44,14 @@ class Stmt(Object, Scriptable): """Base class of all the statements.""" -@tvm_ffi.register_object("tir.LetStmt") -class LetStmt(Stmt): - """LetStmt node. +@tvm_ffi.register_object("tir.Bind") +class Bind(Stmt): + """Bind node. + + Bind a variable to a value in the enclosing scope. + Bind has no body field. + The bound variable is visible in all subsequent statements + within the same enclosing scope (SeqStmt, ForNode.body, etc.). Parameters ---------- @@ -54,10 +59,7 @@ class LetStmt(Stmt): The variable in the binding. value : PrimExpr - The value in to be bound. - - body : Stmt - The body statement. + The value to be bound. span : Optional[Span] The location of the stmt in the source code. @@ -65,15 +67,13 @@ class LetStmt(Stmt): var: Var value: PrimExpr - body: Stmt span: Span | None - def __init__(self, var: Var, value: PrimExpr, body: Stmt, span: Span | None = None) -> None: + def __init__(self, var: Var, value: PrimExpr, span: Span | None = None) -> None: self.__init_handle_by_constructor__( - _ffi_api.LetStmt, + _ffi_api.Bind, var, value, - body, span, # type: ignore ) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index ceee1edec0ce..7e8afe53138b 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -472,13 +472,13 @@ class HoistedLetBindings(enum.Flag): RequiredByConditional = 1 """ Bindings that are used by a hoisted conditional """ - LetStmt = 2 - """ Bindings occurring in LetStmt """ + Bind = 2 + """ Bindings occurring in Bind nodes """ LetExpr = 4 """ Bindings occurring in Let expressions """ - All = RequiredByConditional | LetStmt | LetExpr + All = RequiredByConditional | Bind | LetExpr """ Enable all hoisting of let bindings """ diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc index f6c8db016132..11caef56850b 100644 --- a/src/arith/ir_mutator_with_analyzer.cc +++ b/src/arith/ir_mutator_with_analyzer.cc @@ -80,20 +80,16 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const SBlockNode* op) { }); } -Stmt IRMutatorWithAnalyzer::VisitStmt_(const LetStmtNode* op) { +Stmt IRMutatorWithAnalyzer::VisitStmt_(const BindNode* op) { PrimExpr value = this->VisitExpr(op->value); if (SideEffect(value) <= CallEffectKind::kPure) { analyzer_->Bind(op->var, value); } - // We keep the let-binding here - // as sub-class may or maynot choose to replace it. - Stmt body = this->VisitStmt(op->body); - if (value.same_as(op->value) && body.same_as(op->body)) { + if (value.same_as(op->value)) { return ffi::GetRef(op); } else { auto n = this->CopyOnWrite(op); n->value = std::move(value); - n->body = std::move(body); return Stmt(n); } } diff --git a/src/arith/ir_mutator_with_analyzer.h b/src/arith/ir_mutator_with_analyzer.h index 8810a8f78f62..0f03fef7d25e 100644 --- a/src/arith/ir_mutator_with_analyzer.h +++ b/src/arith/ir_mutator_with_analyzer.h @@ -54,7 +54,7 @@ class IRMutatorWithAnalyzer : public tir::StmtExprMutator { // override functions that need to populate the context information. tir::Stmt VisitStmt_(const tir::ForNode* op) override; tir::Stmt VisitStmt_(const tir::SBlockNode* op) override; - tir::Stmt VisitStmt_(const tir::LetStmtNode* op) override; + tir::Stmt VisitStmt_(const tir::BindNode* op) override; tir::Stmt VisitStmt_(const tir::IfThenElseNode* op) override; tir::Stmt VisitStmt_(const tir::AttrStmtNode* op) override; tir::Stmt VisitStmt_(const tir::AssertStmtNode* op) override; diff --git a/src/arith/ir_visitor_with_analyzer.cc b/src/arith/ir_visitor_with_analyzer.cc index 736e148d7a31..e5041b159f8d 100644 --- a/src/arith/ir_visitor_with_analyzer.cc +++ b/src/arith/ir_visitor_with_analyzer.cc @@ -48,10 +48,9 @@ void IRVisitorWithAnalyzer::VisitStmt_(const SBlockNode* op) { }); } -void IRVisitorWithAnalyzer::VisitStmt_(const LetStmtNode* op) { +void IRVisitorWithAnalyzer::VisitStmt_(const BindNode* op) { this->VisitExpr(op->value); analyzer_.Bind(op->var, op->value); - this->VisitStmt(op->body); } void IRVisitorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) { diff --git a/src/arith/ir_visitor_with_analyzer.h b/src/arith/ir_visitor_with_analyzer.h index f0553a1c428c..a5455659d0fe 100644 --- a/src/arith/ir_visitor_with_analyzer.h +++ b/src/arith/ir_visitor_with_analyzer.h @@ -43,7 +43,7 @@ class IRVisitorWithAnalyzer : public tir::StmtExprVisitor { void VisitStmt_(const tir::ForNode* op); void VisitStmt_(const tir::SBlockNode* op); - void VisitStmt_(const tir::LetStmtNode* op); + void VisitStmt_(const tir::BindNode* op); void VisitStmt_(const tir::IfThenElseNode* op); void VisitStmt_(const tir::AttrStmtNode* op); void VisitStmt_(const tir::AssertStmtNode* op); diff --git a/src/relax/op/tensor/inspect.cc b/src/relax/op/tensor/inspect.cc index 0130ee1063d9..8ba6c2e645cc 100644 --- a/src/relax/op/tensor/inspect.cc +++ b/src/relax/op/tensor/inspect.cc @@ -92,11 +92,11 @@ tir::PrimFunc GetDLTensorField(tir::builtin::TVMStructFieldKind field, DataType tir::Var value("value", field_dtype); - tir::LetStmt body( - value, - tir::Call(field_dtype, tir::builtin::tvm_struct_get(), - {dlpack_handle, IntImm(DataType::Int(32), 0), IntImm(DataType::Int(32), field)}), - tir::Evaluate(tvm::ret(value))); + tir::Stmt body = + tir::SeqStmt({tir::Bind(value, tir::Call(field_dtype, tir::builtin::tvm_struct_get(), + {dlpack_handle, IntImm(DataType::Int(32), 0), + IntImm(DataType::Int(32), field)})), + tir::Evaluate(tvm::ret(value))}); DictAttrs attrs({{"tir.is_scheduled", true}, {"tir.is_host", true}}); @@ -305,33 +305,24 @@ Expr LegalizeTensorShape(const BlockBuilder& bb, const Call& call) { tir::Var extent("extent", field_dtype); - tir::Stmt body = tir::Evaluate(tvm::ret(extent)); - - body = tir::LetStmt(extent, tir::BufferLoad(shape_buffer, {axis}), body); - body = tir::DeclBuffer(shape_buffer, body); - body = tir::LetStmt( - shape_buffer->data, - tir::Call(DataType::Handle(), tir::builtin::tvm_struct_get(), - {dlpack_handle, IntImm(DataType::Int(32), 0), - IntImm(DataType::Int(32), tir::builtin::TVMStructFieldKind::kDLTensorShape)}), - body); - - body = tir::SeqStmt( - {tir::AssertStmt( + tir::Stmt body = tir::SeqStmt( + {tir::AssertStmt(0 <= axis, tir::StringImm("RuntimeError"), + {tir::StringImm("Specified axis may not be negative")}), + tir::Bind(ndim, tir::Call(ndim->dtype, tir::builtin::tvm_struct_get(), + {dlpack_handle, IntImm(DataType::Int(32), 0), + IntImm(DataType::Int(32), + tir::builtin::TVMStructFieldKind::kDLTensorNDim)})), + tir::AssertStmt( axis < tvm::cast(axis->dtype, ndim), tir::StringImm("RuntimeError"), {tir::StringImm("Specified axis may not be larger than the tensor's dimensionality")}), - body}); - - body = tir::LetStmt( - ndim, - tir::Call(ndim->dtype, tir::builtin::tvm_struct_get(), - {dlpack_handle, IntImm(DataType::Int(32), 0), - IntImm(DataType::Int(32), tir::builtin::TVMStructFieldKind::kDLTensorNDim)}), - body); - - body = tir::SeqStmt({tir::AssertStmt(0 <= axis, tir::StringImm("RuntimeError"), - {tir::StringImm("Specified axis may not be negative")}), - body}); + tir::Bind(shape_buffer->data, + tir::Call(DataType::Handle(), tir::builtin::tvm_struct_get(), + {dlpack_handle, IntImm(DataType::Int(32), 0), + IntImm(DataType::Int(32), + tir::builtin::TVMStructFieldKind::kDLTensorShape)})), + tir::DeclBuffer(shape_buffer, + tir::SeqStmt({tir::Bind(extent, tir::BufferLoad(shape_buffer, {axis})), + tir::Evaluate(tvm::ret(extent))}))}); DictAttrs attrs({{"tir.is_scheduled", true}, {"tir.is_host", true}}); diff --git a/src/s_tir/analysis/estimate_flops.cc b/src/s_tir/analysis/estimate_flops.cc index 08bebea0d3ba..6cd405ebbb27 100644 --- a/src/s_tir/analysis/estimate_flops.cc +++ b/src/s_tir/analysis/estimate_flops.cc @@ -182,9 +182,8 @@ class FlopEstimator : private ExprFunctor, return result; } - TResult VisitStmt_(const LetStmtNode* let) override { + TResult VisitStmt_(const BindNode* let) override { TResult value = VisitExpr(let->value); - value += VisitStmt(let->body); return value; } diff --git a/src/s_tir/analysis/sblock_access_region_detector.cc b/src/s_tir/analysis/sblock_access_region_detector.cc index 22c0ed5ad920..aaba79827eb8 100644 --- a/src/s_tir/analysis/sblock_access_region_detector.cc +++ b/src/s_tir/analysis/sblock_access_region_detector.cc @@ -117,7 +117,7 @@ class BlockReadWriteDetector : public StmtExprVisitor { void VisitStmt_(const IfThenElseNode* op) override; void VisitStmt_(const SBlockRealizeNode* op) override; void VisitStmt_(const BufferStoreNode* op) override; - void VisitStmt_(const LetStmtNode* op) override; + void VisitStmt_(const BindNode* op) override; void VisitExpr_(const BufferLoadNode* op) override; void VisitExpr_(const VarNode* op) override; void VisitExpr_(const CallNode* op) override; @@ -189,10 +189,9 @@ void BlockReadWriteDetector::VisitStmt_(const IfThenElseNode* op) { } } -void BlockReadWriteDetector::VisitStmt_(const LetStmtNode* op) { +void BlockReadWriteDetector::VisitStmt_(const BindNode* op) { let_bindings_[op->var.get()] = op->value; StmtVisitor::VisitStmt_(op); - let_bindings_.erase(op->var.get()); } void BlockReadWriteDetector::VisitExpr_(const CallNode* op) { diff --git a/src/s_tir/backend/adreno/inject_texture_alloc.cc b/src/s_tir/backend/adreno/inject_texture_alloc.cc index 400f3edc97d3..a45f5dd4b562 100644 --- a/src/s_tir/backend/adreno/inject_texture_alloc.cc +++ b/src/s_tir/backend/adreno/inject_texture_alloc.cc @@ -82,9 +82,9 @@ class TextureAllocInjector : public arith::IRMutatorWithAnalyzer { args.push_back(Call(DataType::Handle(), builtin::tvm_stack_make_shape(), {texture.width, texture.height, texture.depth})); args.push_back(IntImm(DataType::Int(64), channel_size)); - stmt = LetStmt(op->buffer->data, - Call(op->buffer->data.dtype(), builtin::nd_mem_alloc_with_scope(), args), - op->body); + stmt = SeqStmt({Bind(op->buffer->data, Call(op->buffer->data.dtype(), + builtin::nd_mem_alloc_with_scope(), args)), + op->body}); } return stmt; } diff --git a/src/s_tir/schedule/analysis/reducer.cc b/src/s_tir/schedule/analysis/reducer.cc index 7559a5bfb9d7..6ba93769c6ce 100644 --- a/src/s_tir/schedule/analysis/reducer.cc +++ b/src/s_tir/schedule/analysis/reducer.cc @@ -293,11 +293,11 @@ static const char* kRFactorCrossThreadReductionApplicableBlockDef = R"(Definition of a reduction block that is applicable by RFactor and Cross-Thread Reduction: 1) The block init should be a single BufferStore or a SeqStmt of BufferStores 2) The buffers initialized in the block init should be all different -3) The number of consecutive LetStmts in the block body (if any) should equal the number of BufferStores in the block init -4) The variables of the LetStmts in the block body should be all different -5) The body of the innermost LetStmt should be a single BufferStore or a SeqStmt of BufferStores -6) The number of BufferStores under the block body should equal the number of BufferStores in the block init, and thereby equal the number of LetStmts above -7) The variables bound by the LetStmts in the block body must all directly serve as values of the BufferStores inside, and the stored values of the BufferStores can only be those variables +3) The number of consecutive Binds in the block body (if any) should equal the number of BufferStores in the block init +4) The variables of the Binds in the block body should be all different +5) The statement after the innermost Bind should be a single BufferStore or a SeqStmt of BufferStores +6) The number of BufferStores under the block body should equal the number of BufferStores in the block init, and thereby equal the number of Binds above +7) The variables bound by the Binds in the block body must all directly serve as values of the BufferStores inside, and the stored values of the BufferStores can only be those variables 8) The variables stored by the BufferStores in the block body should be all different 9) The buffers written by the BufferStores in the block body should be all different 10) The buffers initialized in the block init and written in the block body should match @@ -343,18 +343,18 @@ void ErrorRFactorCrossThreadReductionNotApplicable(const ffi::Optional& self, SBlock block, - const LetStmtNode* let, int n_buffers, + const ffi::Array& stmts, int n_buffers, ffi::Array* updates, std::unordered_map* buf2index) { std::unordered_map var2index; @@ -363,37 +363,35 @@ void ExtractReductionUpdates(const ffi::Optional& self, SBlock bl updates->resize(n_buffers); // Step 1. - // - Extract the BufferStore values from the LetStmts. - // - Construct the mapping from let variables to the index. + // - Extract the Bind values from the sequence. + // - Construct the mapping from bind variables to the index. + // The first n_buffers stmts should be Bind nodes. for (int i = 0; i < n_buffers; ++i) { - if (let == nullptr) { + if (i >= static_cast(stmts.size())) { ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/3); } - - let_values.push_back(let->value); - auto insert_result = var2index.insert(std::make_pair(let->var.get(), i)); + const auto* bind = stmts[i].as(); + if (bind == nullptr) { + ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/3); + } + let_values.push_back(bind->value); + auto insert_result = var2index.insert(std::make_pair(bind->var.get(), i)); if (!insert_result.second) { ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/4); } - if (i != n_buffers - 1) { - let = let->body.as(); - } } - // There should be no more LetStmt. - if (let->body->IsInstance()) { + // There should be no more Bind after the first n_buffers. + if (n_buffers < static_cast(stmts.size()) && stmts[n_buffers]->IsInstance()) { ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/3); } - // Now `let` is expected to be the innermost LetStmt, whose body should either be a SeqStmt or - // a BufferStore - const auto* p_seq = let->body.as(); - const auto* p_buf_store = let->body.as(); - if (p_seq == nullptr && p_buf_store == nullptr) { - ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/5); + // The remaining stmts after the Bind nodes should be BufferStores. + // Collect them into a sequence. + ffi::Array seq; + for (int i = n_buffers; i < static_cast(stmts.size()); ++i) { + seq.push_back(stmts[i]); } - ffi::Array seq = - p_seq != nullptr ? p_seq->seq : ffi::Array{ffi::GetRef(p_buf_store)}; if (static_cast(seq.size()) != n_buffers) { ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/6); } @@ -460,9 +458,10 @@ std::pair, ffi::Array> GetInitValuesAndUpdates if (const auto* update = block->body.as()) { updates.push_back(ffi::GetRef(update)); buf2index[update->buffer.get()] = 0; + } else if (const auto* seq = block->body.as()) { + ExtractReductionUpdates(self, block, seq->seq, n_buffers, &updates, &buf2index); } else { - const auto* let = block->body.as(); - ExtractReductionUpdates(self, block, let, n_buffers, &updates, &buf2index); + ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/3); } TVM_FFI_ICHECK_EQ(updates.size(), n_buffers); diff --git a/src/s_tir/schedule/primitive/blockize_tensorize.cc b/src/s_tir/schedule/primitive/blockize_tensorize.cc index 5c6c152925e5..95074147a02a 100644 --- a/src/s_tir/schedule/primitive/blockize_tensorize.cc +++ b/src/s_tir/schedule/primitive/blockize_tensorize.cc @@ -877,6 +877,7 @@ struct BlockizeTraits : public UnpackedInstTraits { return sch->Blockize(blocks.value(), preserve_unit_iters.operator bool()); } TVM_FFI_THROW(TypeError) << "expect Loop or list of SBlocks, but gets:" << target->GetTypeKey(); + TVM_FFI_UNREACHABLE(); } static ffi::String UnpackedAsPython(ffi::Array outputs, ObjectRef target, diff --git a/src/s_tir/schedule/primitive/layout_transformation.cc b/src/s_tir/schedule/primitive/layout_transformation.cc index d20e002603ed..3fd210e91409 100644 --- a/src/s_tir/schedule/primitive/layout_transformation.cc +++ b/src/s_tir/schedule/primitive/layout_transformation.cc @@ -131,7 +131,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { StmtExprVisitor::VisitStmt_(op); } - void VisitStmt_(const LetStmtNode* op) override { + void VisitStmt_(const BindNode* op) override { BindVariableDefinition context(this, op->var, op->value); StmtExprVisitor::VisitStmt_(op); } @@ -625,37 +625,23 @@ class TransformLayoutPlanner : private StmtExprVisitor { Var var_; }; + // Under SSA, each variable is bound exactly once, so the lookup maps + // grow monotonically and cleanup is unnecessary. Omitting cleanup also + // ensures correctness for flat BindNode (which has no body): the + // binding must remain visible to subsequent sibling statements. struct BindVariableDefinition { BindVariableDefinition() {} - BindVariableDefinition(TransformLayoutPlanner* self, Var var, PrimExpr value) - : self_(self), var_(var) { + BindVariableDefinition(TransformLayoutPlanner* self, Var var, PrimExpr value) { if (auto loop_depth = self->LoopDependencyRange(value); loop_depth.has_value()) { - self_->loop_depth_lookup_[var_.get()] = loop_depth.value(); - self_->active_var_bindings_[var_.get()] = Substitute(value, self_->active_var_bindings_); - } - } - ~BindVariableDefinition() { - if (self_) { - self_->loop_depth_lookup_.erase(var_.get()); - self_->active_var_bindings_.erase(var_.get()); + self->loop_depth_lookup_[var.get()] = loop_depth.value(); + self->active_var_bindings_[var.get()] = Substitute(value, self->active_var_bindings_); } } + ~BindVariableDefinition() {} BindVariableDefinition(const BindVariableDefinition&) = delete; BindVariableDefinition& operator=(const BindVariableDefinition&) = delete; - BindVariableDefinition(BindVariableDefinition&& other) : BindVariableDefinition() { - swap(other); - } - BindVariableDefinition& operator=(BindVariableDefinition&& other) { - swap(other); - return *this; - } - void swap(BindVariableDefinition& other) { - std::swap(self_, other.self_); - std::swap(var_, other.var_); - } - - TransformLayoutPlanner* self_{nullptr}; - Var var_; + BindVariableDefinition(BindVariableDefinition&&) = default; + BindVariableDefinition& operator=(BindVariableDefinition&&) = default; }; struct BindBlockRealize { diff --git a/src/s_tir/schedule/primitive/reduction.cc b/src/s_tir/schedule/primitive/reduction.cc index 087eabbce812..8a7056f4896b 100644 --- a/src/s_tir/schedule/primitive/reduction.cc +++ b/src/s_tir/schedule/primitive/reduction.cc @@ -736,7 +736,7 @@ class BaseBlockCreator { } // Case 3. In case the reduction is for multiple buffers, we should create the reduction with - // LetStmt so that the reduction execution generates correct results. + // Bind nodes so that the reduction execution generates correct results. ffi::Array let_vars; let_vars.reserve(n_buffers_); for (int i = 0; i < n_buffers_; ++i) { @@ -744,11 +744,14 @@ class BaseBlockCreator { let_vars.push_back(var); buf_stores.push_back(BufferStore(update_buffers_[i], var, update_indices_[i])); } - Stmt body = SeqStmt(buf_stores); - for (int i = n_buffers_ - 1; i >= 0; --i) { - body = LetStmt(let_vars[i], stored_values[i], std::move(body)); + ffi::Array stmts; + for (int i = 0; i < n_buffers_; ++i) { + stmts.push_back(tir::Bind(let_vars[i], stored_values[i])); + } + for (const auto& store : buf_stores) { + stmts.push_back(store); } - return body; + return SeqStmt(stmts); } ffi::Optional CreateBlockInit(bool has_reduce_iter) { diff --git a/src/s_tir/transform/compact_buffer_region.cc b/src/s_tir/transform/compact_buffer_region.cc index 9b2edde22189..ff811e3aa842 100644 --- a/src/s_tir/transform/compact_buffer_region.cc +++ b/src/s_tir/transform/compact_buffer_region.cc @@ -169,16 +169,12 @@ class BufferAccessRegionCollector : public StmtExprVisitor { ancestor_iters_.pop_back(); } - void VisitStmt_(const LetStmtNode* op) final { + void VisitStmt_(const BindNode* op) final { StmtExprVisitor::VisitExpr(op->value); if (arith::IsIndexType(op->value->dtype)) { dom_analyzer_.Bind(op->var, op->value); dom_map_.emplace(op->var.get(), arith::IntSet::SinglePoint(op->value)); } - StmtExprVisitor::VisitStmt(op->body); - if (arith::IsIndexType(op->value->dtype)) { - dom_map_.erase(op->var.get()); - } } void VisitExpr_(const LetNode* op) final { diff --git a/src/s_tir/transform/hoist_expression.cc b/src/s_tir/transform/hoist_expression.cc index 7858ff0e14dd..6403f3a7801a 100644 --- a/src/s_tir/transform/hoist_expression.cc +++ b/src/s_tir/transform/hoist_expression.cc @@ -53,7 +53,7 @@ enum class HoistedConditionals : int { enum class HoistedLetBindings : int { kNone = 0, kRequiredByCondition = (1 << 0), - kLetStmt = (1 << 1), + kBind = (1 << 1), kLetExpr = (1 << 2), }; @@ -72,7 +72,7 @@ struct HoistExpressionConfigNode : public AttrsNodeReflAdapter(HoistedLetBindings::kRequiredByCondition) | - static_cast(HoistedLetBindings::kLetStmt) | + static_cast(HoistedLetBindings::kBind) | static_cast(HoistedLetBindings::kLetExpr))); } @@ -147,7 +147,7 @@ class HoistInfoCollector : public StmtExprVisitor { bool all_required_bindings_are_hoisted = required_let_bindings.empty() || config->FlagSet(HoistedLetBindings::kRequiredByCondition) || - config->FlagSet(HoistedLetBindings::kLetStmt); + config->FlagSet(HoistedLetBindings::kBind); bool valid_block_var_usage = config->FlagSet(HoistedConditionals::kUsingBlockVar) || !uses_block_var; @@ -174,7 +174,7 @@ class HoistInfoCollector : public StmtExprVisitor { // The For or AttrStmt that defines the loop var. Stmt loop_def; - // Bindings defined in LetStmt inside the for-loop whose value + // Bindings defined in Bind nodes inside the for-loop whose value // does not depend on the loop variable. These can be hoisted // outside this for-loop. std::vector let_bindings; @@ -322,13 +322,34 @@ class HoistInfoCollector : public StmtExprVisitor { let_var_to_let_vars[var.get()] = std::move(let_bindings_used); } - void VisitStmt_(const LetStmtNode* op) final { - VisitBinding(op->var, op->value, HoistedLetBindings::kLetStmt); - + void VisitStmt_(const BindNode* op) final { + VisitBinding(op->var, op->value, HoistedLetBindings::kBind); Parent::VisitStmt_(op); + } - let_var_to_loop_vars.erase(op->var.get()); - let_var_to_let_vars.erase(op->var.get()); + void VisitStmt_(const SeqStmtNode* op) final { + if (active_loops.size()) { + int non_bind_count = 0; + for (size_t i = 0; i < op->seq.size(); ++i) { + if (!op->seq[i].as()) { + non_bind_count++; + } + } + if (non_bind_count > 1) { + active_loops.back().reached_sequential_node = true; + } + } + std::vector seq_bind_vars; + for (size_t i = 0; i < op->seq.size(); ++i) { + if (auto* bind = op->seq[i].as()) { + seq_bind_vars.push_back(bind->var.get()); + } + VisitStmt(op->seq[i]); + } + for (auto* var : seq_bind_vars) { + let_var_to_loop_vars.erase(var); + let_var_to_let_vars.erase(var); + } } void VisitExpr_(const LetNode* op) final { @@ -354,13 +375,6 @@ class HoistInfoCollector : public StmtExprVisitor { Parent::VisitExpr_(op); } - void VisitStmt_(const SeqStmtNode* op) final { - if (active_loops.size()) { - active_loops.back().reached_sequential_node = true; - } - Parent::VisitStmt_(op); - } - // Find the loop above which this expression could be hoisted. If // nullptr, the expression cannot be hoisted. HoistInfo* FindHoistDestination(PrimExpr expr) { @@ -482,9 +496,16 @@ class ExpressionHoister : public arith::IRMutatorWithAnalyzer { } } } - for (auto let_it = info.let_bindings.rbegin(); let_it != info.let_bindings.rend(); let_it++) { - if (hoisted_let_bindings.count(let_it->var.get())) { - stmt = LetStmt(let_it->var, let_it->value, stmt); + { + ffi::Array binds; + for (auto let_it = info.let_bindings.begin(); let_it != info.let_bindings.end(); let_it++) { + if (hoisted_let_bindings.count(let_it->var.get())) { + binds.push_back(Bind(let_it->var, let_it->value)); + } + } + if (!binds.empty()) { + binds.push_back(stmt); + stmt = SeqStmt(binds); } } @@ -511,9 +532,10 @@ class ExpressionHoister : public arith::IRMutatorWithAnalyzer { } } - Stmt VisitStmt_(const LetStmtNode* op) final { + Stmt VisitStmt_(const BindNode* op) final { if (hoisted_let_bindings.count(op->var.get())) { - return this->VisitStmt(op->body); + // The binding was hoisted; remove it from this location. + return Evaluate(0); } else { return Parent::VisitStmt_(op); } diff --git a/src/s_tir/transform/inject_virtual_thread.cc b/src/s_tir/transform/inject_virtual_thread.cc index d914fa81cd20..f4698821a6a8 100644 --- a/src/s_tir/transform/inject_virtual_thread.cc +++ b/src/s_tir/transform/inject_virtual_thread.cc @@ -99,11 +99,10 @@ class ExprTouched final : public StmtExprVisitor { // Analyze if the buffers are invariant to value of var class VarTouchedAnalysis : public StmtVisitor { public: - void VisitStmt_(const LetStmtNode* op) final { + void VisitStmt_(const BindNode* op) final { ExprTouched tc(touched_var_, false); tc(op->value); Record(op->var.get(), tc); - this->VisitStmt(op->body); } void VisitStmt_(const BufferStoreNode* op) final { @@ -299,18 +298,17 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { } } } - // LetStmt - Stmt VisitStmt_(const LetStmtNode* op) final { + // Bind + Stmt VisitStmt_(const BindNode* op) final { PrimExpr value = this->VisitExpr(op->value); if (visit_touched_var_ && !vt_loop_injected_) { return InjectVTLoop(ffi::GetRef(op), true); } visit_touched_var_ = false; - Stmt body = this->VisitStmt(op->body); - if (value.same_as(op->value) && body.same_as(op->body)) { + if (value.same_as(op->value)) { return ffi::GetRef(op); } else { - return LetStmt(op->var, value, body); + return Bind(op->var, value); } } // For @@ -362,19 +360,60 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { Stmt VisitStmt_(const WhileNode* op) final { // TODO(masahi): What should we do for While nodes? TVM_FFI_THROW(InternalError) << "WhileNode in InjectVirtualThread not supported yet"; + TVM_FFI_UNREACHABLE(); } // Seq + // When a Bind child triggers VT injection, we group the Bind together with + // all remaining siblings (which may reference the bound variable) and wrap + // them as a single unit in the VT loop. This preserves the semantics that + // With flat Bind (no body), a Bind whose value touches vt_var must be + // grouped with all remaining siblings and wrapped in a VT loop together. + // This preserves the scoping that was implicit when Bind carried a body. Stmt VisitStmt_(const SeqStmtNode* op) final { TVM_FFI_ICHECK_EQ(max_loop_depth_, 0); - auto fmutate = [this](const Stmt& s) { + ffi::Array new_seq; + bool changed = false; + for (size_t i = 0; i < op->seq.size(); ++i) { int temp = max_loop_depth_; max_loop_depth_ = 0; - Stmt ret = this->VisitStmt(s); + // For Bind children, pre-check if the value touches vt_var before + // visiting. If so, group the Bind with all remaining siblings and + // wrap the group with InjectVTLoop. + if (const auto* bind = op->seq[i].as(); bind && !vt_loop_injected_) { + // Visit just the value expression to probe for vt_var dependency. + TVM_FFI_ICHECK(!visit_touched_var_); + this->VisitExpr(bind->value); + if (visit_touched_var_) { + // Reset flag (InjectVTLoop will handle it). + visit_touched_var_ = false; + // Gather the original Bind + all remaining original siblings. + ffi::Array group; + for (size_t j = i; j < op->seq.size(); ++j) { + group.push_back(op->seq[j]); + } + Stmt grouped = group.size() == 1 ? group[0] : SeqStmt(group); + // before_mutation=true: InjectVTLoop will re-visit the entire group + // with vt_loop_injected_=true, properly substituting vt_var. + Stmt wrapped = InjectVTLoop(grouped, true); + new_seq.push_back(wrapped); + max_loop_depth_ = std::max(max_loop_depth_, temp); + changed = true; + // All remaining siblings consumed — exit loop. + break; + } + // Value did not touch vt_var. Reset and visit the Bind normally. + visit_touched_var_ = false; + } + // Non-Bind child or Bind that does not touch vt_var: visit normally. + Stmt child = this->VisitStmt(op->seq[i]); max_loop_depth_ = std::max(max_loop_depth_, temp); - return ret; - }; - return StmtMutator::VisitSeqStmt_(op, false, fmutate); + if (!child.same_as(op->seq[i])) changed = true; + new_seq.push_back(child); + } + if (!changed) return ffi::GetRef(op); + if (new_seq.size() == 1) return new_seq[0]; + return SeqStmt(new_seq); } // Allocate // AllocBuffer diff --git a/src/s_tir/transform/lower_thread_allreduce.cc b/src/s_tir/transform/lower_thread_allreduce.cc index f5764cbef8a3..13a8c3359f3a 100644 --- a/src/s_tir/transform/lower_thread_allreduce.cc +++ b/src/s_tir/transform/lower_thread_allreduce.cc @@ -645,10 +645,14 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { fstore({in_warp_local_vars.begin(), in_warp_local_vars.end()})); in_let_statement.emplace_back(SyncThread("warp")); - Stmt body = SeqStmt::Flatten(in_let_statement); + ffi::Array bind_stmts; for (size_t i = 0; i < size; i++) { - body = LetStmt(in_warp_local_vars[i], loads[i], body); + bind_stmts.push_back(Bind(in_warp_local_vars[i], loads[i])); } + for (const auto& s : in_let_statement) { + bind_stmts.push_back(s); + } + Stmt body = SeqStmt::Flatten(bind_stmts); in_warp_seq.push_back(body); } diff --git a/src/s_tir/transform/lower_vtcm_alloc.cc b/src/s_tir/transform/lower_vtcm_alloc.cc index d086f0d45e86..cc32fba14678 100644 --- a/src/s_tir/transform/lower_vtcm_alloc.cc +++ b/src/s_tir/transform/lower_vtcm_alloc.cc @@ -45,9 +45,9 @@ class VtcmAllocator : public StmtExprMutator { args.push_back(StringImm(storage_scope)); args.push_back(IntImm(DataType::Int(64), op->buffer->shape.size())); args.push_back(Call(DataType::Handle(), builtin::tvm_stack_make_shape(), op->buffer->shape)); - return LetStmt(op->buffer->data, - Call(op->buffer->data.dtype(), builtin::nd_mem_alloc_with_scope(), args), - body); + return SeqStmt({Bind(op->buffer->data, Call(op->buffer->data.dtype(), + builtin::nd_mem_alloc_with_scope(), args)), + body}); } return StmtExprMutator::VisitStmt_(op); } diff --git a/src/s_tir/transform/profile_instrumentation.cc b/src/s_tir/transform/profile_instrumentation.cc index c0e1c483cbe2..07cb8c72bf0d 100644 --- a/src/s_tir/transform/profile_instrumentation.cc +++ b/src/s_tir/transform/profile_instrumentation.cc @@ -135,9 +135,17 @@ class LoopAnalyzer : public StmtExprVisitor { loop_info.height = height; loops[f] = loop_info; return height + 1; - } else if (stmt->IsInstance()) { - const LetStmtNode* n = stmt.as(); - return TraverseLoop(n->body, parent_depth, has_parallel); + } else if (stmt->IsInstance()) { + // Bind has no body; skip it and return 0 (not a loop). + return 0; + } else if (stmt->IsInstance()) { + // For flat sequences, traverse children looking for loops. + const SeqStmtNode* seq = stmt.as(); + unsigned max_height = 0; + for (const auto& s : seq->seq) { + max_height = std::max(max_height, TraverseLoop(s, parent_depth, has_parallel)); + } + return max_height; } else if (stmt->IsInstance()) { const AttrStmtNode* n = stmt.as(); return TraverseLoop(n->body, parent_depth, has_parallel); diff --git a/src/s_tir/transform/remove_store_undef.cc b/src/s_tir/transform/remove_store_undef.cc index 54c91fb7f016..1d383bec12b9 100644 --- a/src/s_tir/transform/remove_store_undef.cc +++ b/src/s_tir/transform/remove_store_undef.cc @@ -34,18 +34,24 @@ namespace tvm { namespace s_tir { using namespace tvm::tir; +struct UndefInfo { + std::unordered_set undef_stores; + std::unordered_set undef_bind_vars; +}; + class StoreUndefLocator : public StmtExprVisitor { public: - static std::unordered_set Locate(Stmt stmt) { + static UndefInfo Locate(Stmt stmt) { StoreUndefLocator locator; locator(std::move(stmt)); - return locator.undef_stores_; + return {locator.undef_stores_, locator.var_bindings_with_undef_}; } private: StoreUndefLocator() = default; void VisitStmt_(const BufferStoreNode* op) final { + // Check the value for undef. bool stash_undef = false; std::swap(has_undef_, stash_undef); StmtExprVisitor::VisitExpr(op->value); @@ -56,16 +62,32 @@ class StoreUndefLocator : public StmtExprVisitor { << "must not have other side effects"; undef_stores_.insert(op); } + + // Check indices for undef. Undef in buffer indices is always an + // error (there is no valid lowering). With flat Bind, we must + // check indices eagerly because the Bind node is a sibling rather + // than an ancestor and may be removed before post-validation. + bool idx_undef = false; + std::swap(has_undef_, idx_undef); + for (const auto& idx : op->indices) { + StmtExprVisitor::VisitExpr(idx); + } + std::swap(has_undef_, idx_undef); + TVM_FFI_ICHECK(!idx_undef) << "Error: T.undef() may not be used in buffer indices"; } void VisitExpr_(const BufferLoadNode* op) final { - // This function left deliberately empty. builtin::undef() - // shouldn't occur in the indices of BufferLoad. Avoiding - // visiting the indices catches the builtin::undef in - // ValidateAllUndefRemoved. + // Check indices for undef. Undef in buffer indices is always an error. + bool idx_undef = false; + std::swap(has_undef_, idx_undef); + for (const auto& idx : op->indices) { + StmtExprVisitor::VisitExpr(idx); + } + std::swap(has_undef_, idx_undef); + TVM_FFI_ICHECK(!idx_undef) << "Error: T.undef() may not be used in buffer indices"; } - void VisitStmt_(const LetStmtNode* op) final { + void VisitStmt_(const BindNode* op) final { bool stash_undef = false; std::swap(has_undef_, stash_undef); StmtExprVisitor::VisitExpr(op->value); @@ -76,8 +98,6 @@ class StoreUndefLocator : public StmtExprVisitor { << "must not have other side effects"; var_bindings_with_undef_.insert(op->var.get()); } - - StmtExprVisitor::VisitStmt(op->body); } void VisitExpr_(const VarNode* op) final { @@ -99,33 +119,44 @@ class StoreUndefLocator : public StmtExprVisitor { std::unordered_set undef_stores_; }; -// Remove any BufferStores whose value depends on T.undef +// Remove BufferStores whose value depends on T.undef, and also +// remove Bind nodes whose value contains undef. Undef in buffer +// indices is already caught eagerly in the locator phase. class StoreUndefRemover : public StmtExprMutator { public: static Stmt Apply(Stmt stmt) { - auto to_remove = StoreUndefLocator::Locate(stmt); - StoreUndefRemover mutator(to_remove); + auto info = StoreUndefLocator::Locate(stmt); + StoreUndefRemover mutator(info); return mutator(std::move(stmt)); } private: using Parent = StmtExprMutator; - explicit StoreUndefRemover(const std::unordered_set& to_remove) - : to_remove_(to_remove) {} + explicit StoreUndefRemover(const UndefInfo& info) + : stores_to_remove_(info.undef_stores), bind_vars_to_remove_(info.undef_bind_vars) {} Stmt VisitStmt_(const BufferStoreNode* op) final { - if (to_remove_.count(op)) { + if (stores_to_remove_.count(op)) { + return Evaluate(0); + } else { + return Parent::VisitStmt_(op); + } + } + + Stmt VisitStmt_(const BindNode* op) final { + if (bind_vars_to_remove_.count(op->var.get())) { return Evaluate(0); } else { return Parent::VisitStmt_(op); } } - const std::unordered_set& to_remove_; + const std::unordered_set& stores_to_remove_; + const std::unordered_set& bind_vars_to_remove_; }; -// Remove any BufferStores whose value depends on T.undef +// Check that no builtin::undef() remains in the IR. class ContainsUndefChecker : public StmtExprVisitor { public: static bool Check(const Stmt& stmt) { diff --git a/src/s_tir/transform/renew_defs.cc b/src/s_tir/transform/renew_defs.cc index 224fccbadbf8..f34fb3cd856d 100644 --- a/src/s_tir/transform/renew_defs.cc +++ b/src/s_tir/transform/renew_defs.cc @@ -99,7 +99,7 @@ class RenewDefMutator : public StmtExprMutator { } private: - STMT_REGENERATE_VAR_DEF(LetStmtNode, var); + STMT_REGENERATE_VAR_DEF(BindNode, var); STMT_REGENERATE_VAR_DEF(ForNode, loop_var); // Override VisitBufferDef to create fresh buffer copies at definition sites diff --git a/src/s_tir/transform/storage_access.cc b/src/s_tir/transform/storage_access.cc index dae797486da0..ad0c9a5bd84d 100644 --- a/src/s_tir/transform/storage_access.cc +++ b/src/s_tir/transform/storage_access.cc @@ -95,7 +95,7 @@ void StorageAccessVisitor::VisitStmt_(const EvaluateNode* op) { allow_append_ = false; } -void StorageAccessVisitor::VisitStmt_(const LetStmtNode* op) { +void StorageAccessVisitor::VisitStmt_(const BindNode* op) { allow_append_ = true; TVM_FFI_ICHECK_EQ(curr_stmt_.access.size(), 0U); curr_stmt_.stmt = op; @@ -105,8 +105,6 @@ void StorageAccessVisitor::VisitStmt_(const LetStmtNode* op) { // clear access entry. curr_stmt_.access.clear(); allow_append_ = false; - // traverse body block - this->VisitStmt(op->body); } void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) { diff --git a/src/s_tir/transform/storage_access.h b/src/s_tir/transform/storage_access.h index 848d8edcf546..7635996acb7a 100644 --- a/src/s_tir/transform/storage_access.h +++ b/src/s_tir/transform/storage_access.h @@ -85,7 +85,7 @@ class StorageAccessVisitor : public StmtExprVisitor { void VisitExpr_(const BufferLoadNode* op) final; void VisitStmt_(const BufferStoreNode* op) final; void VisitStmt_(const EvaluateNode* op) final; - void VisitStmt_(const LetStmtNode* op) final; + void VisitStmt_(const BindNode* op) final; void VisitStmt_(const AttrStmtNode* op) final; void VisitStmt_(const ForNode* op) final; void VisitStmt_(const IfThenElseNode* op) final; diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc index 349374952109..999527813d28 100644 --- a/src/script/ir_builder/tir/frame.cc +++ b/src/script/ir_builder/tir/frame.cc @@ -35,7 +35,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { BlockInitFrameNode::RegisterReflection(); ForFrameNode::RegisterReflection(); AssertFrameNode::RegisterReflection(); - LetFrameNode::RegisterReflection(); LaunchThreadFrameNode::RegisterReflection(); AllocateFrameNode::RegisterReflection(); AttrFrameNode::RegisterReflection(); @@ -141,11 +140,6 @@ void AssertFrameNode::ExitWithScope() { } } -void LetFrameNode::ExitWithScope() { - TIRFrameNode::ExitWithScope(); - AddToParent(tvm::tir::LetStmt(var, value, AsStmt(stmts))); -} - void LaunchThreadFrameNode::ExitWithScope() { TIRFrameNode::ExitWithScope(); AddToParent(tvm::tir::AttrStmt(iter_var, attr_key, extent, AsStmt(stmts))); diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index e353b4184334..197e52b45636 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -460,24 +460,18 @@ AssertFrame Assert(PrimExpr condition, ffi::String error_kind, return AssertFrame(n); } -LetFrame LetStmt(PrimExpr value, ffi::Optional type_annotation, ffi::Optional var) { - ObjectPtr n = ffi::make_object(); - if (var.defined()) { - n->var = var.value(); - } else if (type_annotation.defined()) { - n->var = Var("v", type_annotation.value()); - } else { - n->var = Var("v", value.dtype()); - } - n->value = value; - return LetFrame(n); -} - -LetFrame LegacyLetStmt(Var var, PrimExpr value) { - ObjectPtr n = ffi::make_object(); - n->var = var; - n->value = value; - return LetFrame(n); +Var Bind(PrimExpr value, ffi::Optional type_annotation, ffi::Optional var) { + Var bind_var = [&]() { + if (var.defined()) { + return var.value(); + } else if (type_annotation.defined()) { + return Var("v", type_annotation.value()); + } else { + return Var("v", value.dtype()); + } + }(); + AddToParent(tvm::tir::Bind(bind_var, value)); + return bind_var; } LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) { @@ -753,8 +747,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def("script.ir_builder.tir.ThreadBinding", ThreadBinding) .def("script.ir_builder.tir.Grid", Grid) .def("script.ir_builder.tir.Assert", Assert) - .def("script.ir_builder.tir.LetStmt", LetStmt) - .def("script.ir_builder.tir.LegacyLetStmt", LegacyLetStmt) + .def("script.ir_builder.tir.Bind", Bind) .def("script.ir_builder.tir.Allocate", Allocate) .def("script.ir_builder.tir.Attr", Attr) .def("script.ir_builder.tir.While", While) diff --git a/src/script/printer/relax/distributed.cc b/src/script/printer/relax/distributed.cc index 51fae05bf626..5294eb43a842 100644 --- a/src/script/printer/relax/distributed.cc +++ b/src/script/printer/relax/distributed.cc @@ -121,6 +121,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } } TVM_FFI_THROW(InternalError) << "Cannot find device mesh in global infos"; + TVM_FFI_UNREACHABLE(); } }); diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc index 647f12c8ff7b..59948fb239ea 100644 --- a/src/script/printer/tir/stmt.cc +++ b/src/script/printer/tir/stmt.cc @@ -49,6 +49,7 @@ bool AllowConciseScoping(const IRDocsifier& d, const ObjectRef& obj) { return f->allow_concise_scoping; } TVM_FFI_THROW(NotImplementedError) << "fragment printing"; + TVM_FFI_UNREACHABLE(); } bool IsAncestorOfAllVarUse(const tir::Stmt& node, const ObjectRef& var, const IRDocsifier& d) { @@ -96,8 +97,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::LetStmt stmt, AccessPath p, IRDocsifier d) -> Doc { - bool concise = AllowConciseScoping(d, stmt); + .set_dispatch("", [](tir::Bind stmt, AccessPath p, IRDocsifier d) -> Doc { // Step 1. Type annotation ffi::Optional type_doc = d->AsDoc(stmt->var->type_annotation, // p->Attr("var")->Attr("type_annotation")); @@ -108,25 +108,14 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } // Step 2. RHS ExprDoc rhs = d->AsDoc(stmt->value, p->Attr("value")); - // Step 3. LHS and body - With f(d, stmt); - ffi::Array* stmts = &(*f)->stmts; - bool var_defined = d->IsVarDefined(stmt->var); - if (!var_defined) { - DefineVar(stmt->var, *f, d); - } - ExprDoc lhs = d->AsDoc(stmt->var, p->Attr("var")); - AsDocBody(stmt->body, p->Attr("body"), f->get(), d); - // Step 4. Dispatch - if (var_defined) { - return ScopeDoc(std::nullopt, TIR(d, "LetStmt")->Call({rhs}, {"var"}, {lhs}), *stmts); - } else if (concise) { - stmts->insert(stmts->begin(), AssignDoc(lhs, rhs, type_doc)); - return StmtBlockDoc(*stmts); - } else if (type_doc.defined() && !stmt->var->type_annotation->IsInstance()) { - return ScopeDoc(lhs, TIR(d, "LetStmt")->Call({rhs, type_doc.value()}), *stmts); + // Step 3. LHS - Bind is flat, define var if new, otherwise just assign + if (!d->IsVarDefined(stmt->var)) { + TVM_FFI_ICHECK(!d->frames.empty()); + ExprDoc lhs = DefineVar(stmt->var, d->frames.back(), d); + return AssignDoc(lhs, rhs, type_doc); } else { - return ScopeDoc(lhs, TIR(d, "LetStmt")->Call({rhs}), *stmts); + ExprDoc lhs = d->AsDoc(stmt->var, p->Attr("var")); + return AssignDoc(lhs, rhs, std::nullopt); } }); @@ -259,7 +248,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return DoConciseScoping(lhs, rhs.value(), &(*f)->stmts, concise); }); -TVM_SCRIPT_REPR(tir::LetStmtNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::BindNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::AttrStmtNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::AssertStmtNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::WhileNode, ReprPrintTIR); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 74317f8f2a33..44b5af7e82f3 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -2047,7 +2047,7 @@ void CodeGenLLVM::VisitStmt_(const AssertStmtNode* op) { // Constraint scoping is handled by ScopeStack in analysis passes. } -void CodeGenLLVM::VisitStmt_(const LetStmtNode* op) { +void CodeGenLLVM::VisitStmt_(const BindNode* op) { EmitDebugLocation(op); const VarNode* v = op->var.get(); TVM_FFI_ICHECK(!var_map_.count(v)); @@ -2081,7 +2081,6 @@ void CodeGenLLVM::VisitStmt_(const LetStmtNode* op) { alloc_storage_info_[v].alignment); } AddDebugInformation(value, op->var); - this->VisitStmt(op->body); } void CodeGenLLVM::VisitStmt_(const SeqStmtNode* op) { diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 08a2b07ec707..3fdbbec86fa9 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -227,7 +227,7 @@ class CodeGenLLVM : public ExprFunctor, void VisitStmt_(const AllocBufferNode* op) override; void VisitStmt_(const AttrStmtNode* op) override; void VisitStmt_(const AssertStmtNode* op) override; - void VisitStmt_(const LetStmtNode* op) override; + void VisitStmt_(const BindNode* op) override; void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const EvaluateNode* op) override; void VisitStmt_(const DeclBufferNode* op) override; diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 7fb7e0e382cb..e848b17fcbc6 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -369,6 +369,7 @@ std::string CodeGenC::GetStructRef(DataType t, const PrimExpr& buffer, const Pri return os.str(); } else { TVM_FFI_THROW(RuntimeError) << "Unsupported type index: " << kind; + TVM_FFI_UNREACHABLE(); } } @@ -1028,7 +1029,7 @@ void CodeGenC::VisitExpr_(const SelectNode* op, std::ostream& os) { // NOLINT(* os << ")"; } -void CodeGenC::VisitStmt_(const LetStmtNode* op) { +void CodeGenC::VisitStmt_(const BindNode* op) { std::string value = PrintExpr(op->value); if (print_ssa_form_) { TVM_FFI_ICHECK(!var_idmap_.count(op->var.get())); @@ -1045,7 +1046,6 @@ void CodeGenC::VisitStmt_(const LetStmtNode* op) { this->stream << ' ' << AllocVarID(op->var.get()) << " = " << value << ";\n"; } } - PrintStmt(op->body); } void CodeGenC::VisitStmt_(const AllocBufferNode* op) { diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index b683bf9105e0..2142c0aa2fc5 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -187,7 +187,7 @@ class CodeGenC : public ExprFunctor, void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const StringImmNode* op, std::ostream& os) override; // NOLINT(*) // statment - void VisitStmt_(const LetStmtNode* op) override; + void VisitStmt_(const BindNode* op) override; void VisitStmt_(const BufferStoreNode* op) override; void VisitStmt_(const ForNode* op) override; void VisitStmt_(const WhileNode* op) override; diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index 86006e89ae84..423c7e5850b7 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -569,7 +569,7 @@ void CodeGenWebGPU::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // } } -void CodeGenWebGPU::VisitStmt_(const LetStmtNode* op) { +void CodeGenWebGPU::VisitStmt_(const BindNode* op) { // use ssa form. if (print_ssa_form_) { std::string value = PrintExpr(op->value); @@ -582,7 +582,6 @@ void CodeGenWebGPU::VisitStmt_(const LetStmtNode* op) { PrintType(op->var.dtype(), this->stream); this->stream << " = " << value << ";\n"; } - PrintStmt(op->body); } void CodeGenWebGPU::VisitStmt_(const BufferStoreNode* op) { diff --git a/src/target/source/codegen_webgpu.h b/src/target/source/codegen_webgpu.h index d90d719e8d38..f53d090e586a 100644 --- a/src/target/source/codegen_webgpu.h +++ b/src/target/source/codegen_webgpu.h @@ -73,7 +73,7 @@ class CodeGenWebGPU final : public CodeGenC { void VisitExpr_(const IntImmNode* op, std::ostream& os) final; // NOLINT(*) // stmt printing - void VisitStmt_(const LetStmtNode* op) final; + void VisitStmt_(const BindNode* op) final; void VisitStmt_(const BufferStoreNode* op) final; void VisitStmt_(const ForNode* op) final; void VisitStmt_(const AllocBufferNode* op) final; diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index fe424066cbce..97fc60a00741 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -888,12 +888,11 @@ void CodeGenSPIRV::VisitStmt_(const AssertStmtNode* op) { // AssertStmt is a leaf — no body to visit. } -void CodeGenSPIRV::VisitStmt_(const LetStmtNode* op) { +void CodeGenSPIRV::VisitStmt_(const BindNode* op) { TVM_FFI_ICHECK(!var_map_.count(op->var.get())); TVM_FFI_ICHECK(!op->var.dtype().is_handle()); var_map_[op->var.get()] = MakeValue(op->value); analyzer_->Bind(op->var, op->value); - this->VisitStmt(op->body); } void CodeGenSPIRV::VisitStmt_(const SeqStmtNode* op) { diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index 70f88128cf8d..8daac154ecd9 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -112,7 +112,7 @@ class CodeGenSPIRV : public ExprFunctor, void VisitStmt_(const AllocBufferNode* op) override; void VisitStmt_(const AttrStmtNode* op) override; void VisitStmt_(const AssertStmtNode* op) override; - void VisitStmt_(const LetStmtNode* op) override; + void VisitStmt_(const BindNode* op) override; void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const EvaluateNode* op) override; diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 3d2536e423eb..14650efcb77d 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -398,8 +398,8 @@ Stmt GenerateBodyStmt(const ffi::Array& indices, const ffi::Array& indices, const ffi::Array 1) { - // When there are multiple buffers, we wrap the body with LetStmts. - for (int i = n_buffers - 1; i >= 0; --i) { + // When there are multiple buffers, we wrap the body with Bind stmts. + ffi::Array bind_stmts; + for (int i = 0; i < n_buffers; ++i) { PrimExpr value = f_transform_and_remap(reduce->combiner.get()->operator()(lhs, rhs)[i]); - body = LetStmt(temp_vars[i], std::move(value), std::move(body)); + bind_stmts.push_back(Bind(temp_vars[i], std::move(value))); } + bind_stmts.push_back(body); + body = SeqStmt(bind_stmts); } } else { // Case 2. Data parallel compute diff --git a/src/tir/analysis/control_flow_graph.cc b/src/tir/analysis/control_flow_graph.cc index 8d7a13b0b5c7..d5214f085cf6 100644 --- a/src/tir/analysis/control_flow_graph.cc +++ b/src/tir/analysis/control_flow_graph.cc @@ -332,7 +332,7 @@ class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer { Parent::VisitExpr_(op); } - void VisitStmt_(const LetStmtNode* op) override { + void VisitStmt_(const BindNode* op) override { std::optional binding; if (UsesLoopVar(op->value)) { binding.emplace(this, op->var, op->value); @@ -588,18 +588,17 @@ class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer { BindActiveLoopVar& operator=(BindActiveLoopVar&&) = delete; }; - // Internal utility, context manager for tracking a variable binding + // Internal utility, context manager for tracking a variable binding. + // Under SSA, each variable is bound exactly once, so the maps grow + // monotonically and cleanup is unnecessary. Omitting cleanup also + // ensures correctness for flat BindNode (which has no body): the + // binding must remain visible to subsequent sibling statements. struct BindLetVar { - BindLetVar(ControlFlowGraphBuilder* self, Var var, PrimExpr value) : self(self), var(var) { + BindLetVar(ControlFlowGraphBuilder* self, Var var, PrimExpr value) { self->let_bindings_using_loop_.Set(var, value); self->loop_dependent_vars_.insert(var.get()); } - ~BindLetVar() { - self->loop_dependent_vars_.erase(var.get()); - self->let_bindings_using_loop_.erase(var); - } - ControlFlowGraphBuilder* self; - Var var; + ~BindLetVar() {} // Disable default-generated copy/move assignment and constructors BindLetVar(const BindLetVar&) = delete; diff --git a/src/tir/analysis/var_use_def_analysis.cc b/src/tir/analysis/var_use_def_analysis.cc index b2236e28cee5..6951e25f8c99 100644 --- a/src/tir/analysis/var_use_def_analysis.cc +++ b/src/tir/analysis/var_use_def_analysis.cc @@ -54,7 +54,7 @@ void VarUseDefAnalyzer::VisitStmt_(const AttrStmtNode* op) { } } -void VarUseDefAnalyzer::VisitStmt_(const LetStmtNode* op) { +void VarUseDefAnalyzer::VisitStmt_(const BindNode* op) { this->HandleDef(op->var); StmtExprVisitor::VisitStmt_(op); } diff --git a/src/tir/analysis/var_use_def_analysis.h b/src/tir/analysis/var_use_def_analysis.h index 7196fb2e8fde..a887acb1d3c4 100644 --- a/src/tir/analysis/var_use_def_analysis.h +++ b/src/tir/analysis/var_use_def_analysis.h @@ -57,7 +57,7 @@ class VarUseDefAnalyzer : public StmtExprVisitor { std::unordered_map let_binding_; void VisitStmt_(const AttrStmtNode* op) final; - void VisitStmt_(const LetStmtNode* op) final; + void VisitStmt_(const BindNode* op) final; void VisitStmt_(const ForNode* op) final; diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index 19bc55bf64ec..35f682519c2a 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -72,7 +72,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { void VisitStmt(const Stmt& n) final { StmtExprVisitor::VisitStmt(n); } - void VisitStmt_(const LetStmtNode* op) final { + void VisitStmt_(const BindNode* op) final { // Book keep definitions defs_[op->var.get()] = op->value; return StmtExprVisitor::VisitStmt_(op); diff --git a/src/tir/analysis/verify_ssa.cc b/src/tir/analysis/verify_ssa.cc index e78e1cb58b69..b8fb99d701e8 100644 --- a/src/tir/analysis/verify_ssa.cc +++ b/src/tir/analysis/verify_ssa.cc @@ -67,7 +67,7 @@ class SSAVerifier final : public StmtExprVisitor { StmtExprVisitor::VisitExpr_(op); } - void VisitStmt_(const LetStmtNode* op) final { + void VisitStmt_(const BindNode* op) final { MarkDef(op->var, op->value); StmtExprVisitor::VisitStmt_(op); } diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc index f7c508dac9f1..37ae4f70b2cc 100644 --- a/src/tir/ir/data_type_rewriter.cc +++ b/src/tir/ir/data_type_rewriter.cc @@ -140,7 +140,7 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const LetNode* op) { } } -Stmt DataTypeLegalizer::VisitStmt_(const LetStmtNode* op) { +Stmt DataTypeLegalizer::VisitStmt_(const BindNode* op) { PrimExpr value = this->VisitExpr(op->value); Var var = op->var; @@ -149,12 +149,10 @@ Stmt DataTypeLegalizer::VisitStmt_(const LetStmtNode* op) { var_remap_[op->var.get()] = var; } - Stmt new_body = this->VisitStmt(op->body); - - if (value.same_as(op->value) && new_body.same_as(op->body)) { + if (value.same_as(op->value) && var.same_as(op->var)) { return ffi::GetRef(op); } else { - return LetStmt(var, value, new_body, op->span); + return Bind(var, value, op->span); } } @@ -528,19 +526,18 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const ForNode* op) { } } -Stmt IndexDataTypeRewriter::VisitStmt_(const LetStmtNode* op) { - LetStmt let_stmt = Downcast(DataTypeLegalizer::VisitStmt_(op)); - if (var_remap_.find(let_stmt->var.get()) == var_remap_.end()) { - return let_stmt; +Stmt IndexDataTypeRewriter::VisitStmt_(const BindNode* op) { + Bind bind_stmt = Downcast(DataTypeLegalizer::VisitStmt_(op)); + if (var_remap_.find(bind_stmt->var.get()) == var_remap_.end()) { + return bind_stmt; } bool is_enabled = is_enabled_; is_enabled_ = true; PrimExpr value = VisitExpr(op->value); - Var var = var_remap_[let_stmt->var.get()]; + Var var = var_remap_[bind_stmt->var.get()]; is_enabled_ = is_enabled; TVM_FFI_ICHECK(value.dtype() == var.dtype()); - // No need to re-visit body - return LetStmt(var, value, let_stmt->body, let_stmt->span); + return Bind(var, value, bind_stmt->span); } #define TVM_DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ diff --git a/src/tir/ir/data_type_rewriter.h b/src/tir/ir/data_type_rewriter.h index e886777096bd..e19c555c6ed0 100644 --- a/src/tir/ir/data_type_rewriter.h +++ b/src/tir/ir/data_type_rewriter.h @@ -53,7 +53,7 @@ class DataTypeLegalizer : public StmtExprMutator { Stmt VisitStmt_(const AttrStmtNode* op) override; Stmt VisitStmt_(const SBlockRealizeNode* op) override; Stmt VisitStmt_(const SBlockNode* op) override; - Stmt VisitStmt_(const LetStmtNode* op) override; + Stmt VisitStmt_(const BindNode* op) override; PrimExpr VisitExpr_(const VarNode* op) override; PrimExpr VisitExpr_(const SelectNode* op) override; PrimExpr VisitExpr_(const RampNode* op) override; @@ -110,7 +110,7 @@ class IndexDataTypeRewriter : public DataTypeLegalizer { PrimExpr VisitExpr_(const BufferLoadNode* op) override; ffi::Array VisitIndices(ffi::Array indices); Stmt VisitStmt_(const IfThenElseNode* op) override; - Stmt VisitStmt_(const LetStmtNode* op) override; + Stmt VisitStmt_(const BindNode* op) override; PrimExpr VisitExpr_(const EQNode* op) override; PrimExpr VisitExpr_(const NENode* op) override; PrimExpr VisitExpr_(const LTNode* op) override; diff --git a/src/tir/ir/py_functor.cc b/src/tir/ir/py_functor.cc index 867a740bcaa1..b385922cb950 100644 --- a/src/tir/ir/py_functor.cc +++ b/src/tir/ir/py_functor.cc @@ -170,13 +170,13 @@ class PyStmtExprVisitorNode : public Object, public StmtExprVisitor { // Statement functions /*! \brief The packed function to the `VisitStmt(const Stmt& stmt)` function. */ ffi::Function f_visit_stmt{nullptr}; - /*! \brief The packed function to the `VisitStmt_(const LetStmtNode* op)` function. */ + /*! \brief The packed function to the `VisitStmt_(const BindNode* op)` function. */ + ffi::Function f_visit_bind{nullptr}; + /*! \brief The packed function to the `VisitStmt_(const AttrStmtNode* op)` function. */ ffi::Function f_visit_attr_stmt{nullptr}; /*! \brief The packed function to the `VisitStmt_(const IfThenElseNode* op)` function. */ ffi::Function f_visit_if_then_else{nullptr}; // NOLINT(readability/braces) /*! \brief The packed function to the `VisitStmt_(const ForNode* op)` function. */ - ffi::Function f_visit_let_stmt{nullptr}; - /*! \brief The packed function to the `VisitStmt_(const AttrStmtNode* op)` function. */ ffi::Function f_visit_for{nullptr}; /*! \brief The packed function to the `VisitStmt_(const WhileNode* op)` function. */ ffi::Function f_visit_while{nullptr}; @@ -220,7 +220,7 @@ class PyStmtExprVisitorNode : public Object, public StmtExprVisitor { private: // Statement functions - PY_STMT_VISITOR_DISPATCH(LetStmtNode, f_visit_let_stmt); + PY_STMT_VISITOR_DISPATCH(BindNode, f_visit_bind); PY_STMT_VISITOR_DISPATCH(AttrStmtNode, f_visit_attr_stmt); PY_STMT_VISITOR_DISPATCH(IfThenElseNode, f_visit_if_then_else); PY_STMT_VISITOR_DISPATCH(ForNode, f_visit_for); @@ -311,7 +311,7 @@ class PyStmtExprVisitorNode : public Object, public StmtExprVisitor { static FStmtType InitStmtVTable() { FStmtType vtable; - PY_STMT_VISITOR_DEFAULT_DISPATCH(LetStmtNode); + PY_STMT_VISITOR_DEFAULT_DISPATCH(BindNode); PY_STMT_VISITOR_DEFAULT_DISPATCH(AttrStmtNode); PY_STMT_VISITOR_DEFAULT_DISPATCH(IfThenElseNode); PY_STMT_VISITOR_DEFAULT_DISPATCH(ForNode); @@ -340,7 +340,7 @@ class PyStmtExprVisitor : public ObjectRef { } TVM_DLL static PyStmtExprVisitor MakePyStmtExprVisitor(ffi::Function f_visit_stmt, // ffi::Function f_visit_expr, // - ffi::Function f_visit_let_stmt, // + ffi::Function f_visit_bind, // ffi::Function f_visit_attr_stmt, // ffi::Function f_visit_if_then_else, // ffi::Function f_visit_for, // @@ -390,7 +390,7 @@ class PyStmtExprVisitor : public ObjectRef { n->f_visit_stmt = std::move(f_visit_stmt); n->f_visit_expr = std::move(f_visit_expr); // Set statement functions - n->f_visit_let_stmt = std::move(f_visit_let_stmt); + n->f_visit_bind = std::move(f_visit_bind); n->f_visit_attr_stmt = std::move(f_visit_attr_stmt); n->f_visit_if_then_else = std::move(f_visit_if_then_else); n->f_visit_for = std::move(f_visit_for); @@ -525,8 +525,8 @@ class PyStmtExprMutatorNode : public Object, public StmtExprMutator { // Statement functions /*! \brief The packed function to the `VisitStmt(const Stmt& stmt)` function. */ ffi::Function f_visit_stmt{nullptr}; - /*! \brief The packed function to the `VisitStmt_(const LetStmtNode* op)` function. */ - ffi::Function f_visit_let_stmt{nullptr}; + /*! \brief The packed function to the `VisitStmt_(const BindNode* op)` function. */ + ffi::Function f_visit_bind{nullptr}; /*! \brief The packed function to the `VisitStmt_(const AttrStmtNode* op)` function. */ ffi::Function f_visit_attr_stmt{nullptr}; /*! \brief The packed function to the `VisitStmt_(const IfThenElseNode* op)` function. */ @@ -575,7 +575,7 @@ class PyStmtExprMutatorNode : public Object, public StmtExprMutator { private: // Statement functions - PY_STMT_MUTATOR_DISPATCH(LetStmtNode, f_visit_let_stmt); + PY_STMT_MUTATOR_DISPATCH(BindNode, f_visit_bind); PY_STMT_MUTATOR_DISPATCH(AttrStmtNode, f_visit_attr_stmt); PY_STMT_MUTATOR_DISPATCH(IfThenElseNode, f_visit_if_then_else); PY_STMT_MUTATOR_DISPATCH(ForNode, f_visit_for); @@ -666,7 +666,7 @@ class PyStmtExprMutatorNode : public Object, public StmtExprMutator { static FStmtType InitStmtVTable() { FStmtType vtable; - PY_STMT_MUTATOR_DEFAULT_DISPATCH(LetStmtNode); + PY_STMT_MUTATOR_DEFAULT_DISPATCH(BindNode); PY_STMT_MUTATOR_DEFAULT_DISPATCH(AttrStmtNode); PY_STMT_MUTATOR_DEFAULT_DISPATCH(IfThenElseNode); PY_STMT_MUTATOR_DEFAULT_DISPATCH(ForNode); @@ -696,7 +696,7 @@ class PyStmtExprMutator : public ObjectRef { */ TVM_DLL static PyStmtExprMutator MakePyStmtExprMutator(ffi::Function f_visit_stmt, // ffi::Function f_visit_expr, // - ffi::Function f_visit_let_stmt, // + ffi::Function f_visit_bind, // ffi::Function f_visit_attr_stmt, // ffi::Function f_visit_if_then_else, // ffi::Function f_visit_for, // @@ -746,7 +746,7 @@ class PyStmtExprMutator : public ObjectRef { n->f_visit_stmt = std::move(f_visit_stmt); n->f_visit_expr = std::move(f_visit_expr); // Statement functions - n->f_visit_let_stmt = std::move(f_visit_let_stmt); + n->f_visit_bind = std::move(f_visit_bind); n->f_visit_attr_stmt = std::move(f_visit_attr_stmt); n->f_visit_if_then_else = std::move(f_visit_if_then_else); n->f_visit_for = std::move(f_visit_for); diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc index 1ad074971107..b3e9f795e4f9 100644 --- a/src/tir/ir/specialize.cc +++ b/src/tir/ir/specialize.cc @@ -154,7 +154,7 @@ class PrimFuncSpecializer : public StmtExprMutator { // If the buffer variable is being remapped to an expression, we // still need a tir::Var to be used as a the buffer variable. - // Therefore, generate a LetStmt that will provide a tir::Var for + // Therefore, generate a Bind that will provide a tir::Var for // the buffer to use. // // This step is only required when a buffer definition is using a @@ -170,7 +170,7 @@ class PrimFuncSpecializer : public StmtExprMutator { if (new_buffer_var.same_as(old_buffer_var)) { auto remapped_data = VisitExpr(old_buffer_var); if (!remapped_data.same_as(old_buffer_var)) { - stmt = LetStmt(old_buffer_var, remapped_data, stmt); + stmt = SeqStmt({Bind(old_buffer_var, remapped_data), stmt}); } } diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 6e0fec885fe6..d691b3f38f38 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -34,7 +34,8 @@ namespace tir { TVM_FFI_STATIC_INIT_BLOCK() { StmtNode::RegisterReflection(); - LetStmtNode::RegisterReflection(); + BindNode::RegisterReflection(); + AttrStmtNode::RegisterReflection(); AssertStmtNode::RegisterReflection(); BufferStoreNode::RegisterReflection(); @@ -51,32 +52,28 @@ TVM_FFI_STATIC_INIT_BLOCK() { SBlockRealizeNode::RegisterReflection(); } -// LetStmt -LetStmt::LetStmt(Var var, PrimExpr value, Stmt body, Span span) { +// Bind +Bind::Bind(Var var, PrimExpr value, Span span) { TVM_FFI_ICHECK(value.defined()); - TVM_FFI_ICHECK(body.defined()); auto vdtype = value.dtype(); - // It is still valid to bind a pointer type - // var to a value that is of type handle. + // It is still valid to bind a pointer type var to a value that is of type handle. if (var->type_annotation.as()) { TVM_FFI_ICHECK(vdtype.is_handle()); } else { TVM_FFI_ICHECK_EQ(value.dtype(), var.dtype()); } - ObjectPtr node = ffi::make_object(); + ObjectPtr node = ffi::make_object(); node->var = std::move(var); node->value = std::move(value); - node->body = std::move(body); node->span = std::move(span); data_ = std::move(node); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.LetStmt", [](Var var, PrimExpr value, Stmt body, Span span) { - return LetStmt(var, value, body, span); - }); + refl::GlobalDef().def("tir.Bind", + [](Var var, PrimExpr value, Span span) { return Bind(var, value, span); }); } // AttrStmt diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index ff79c374db44..6c9072cae5e7 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -33,9 +33,9 @@ namespace tvm { namespace tir { -void StmtVisitor::VisitStmt_(const LetStmtNode* op) { +void StmtVisitor::VisitStmt_(const BindNode* op) { + // Bind has no body -- only visit the value expression. this->VisitExpr(op->value); - this->VisitStmt(op->body); } void StmtVisitor::VisitStmt_(const AttrStmtNode* op) { @@ -249,20 +249,19 @@ class StmtMutator::Internal { } }; -Stmt StmtMutator::VisitStmt_(const AttrStmtNode* op) { +Stmt StmtMutator::VisitStmt_(const BindNode* op) { + // Bind has no body -- only mutate the value expression. PrimExpr value = this->VisitExpr(op->value); - Stmt body = this->VisitStmt(op->body); - if (value.same_as(op->value) && body.same_as(op->body)) { + if (value.same_as(op->value)) { return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->value = std::move(value); - n->body = std::move(body); return Stmt(n); } } -Stmt StmtMutator::VisitStmt_(const LetStmtNode* op) { +Stmt StmtMutator::VisitStmt_(const AttrStmtNode* op) { PrimExpr value = this->VisitExpr(op->value); Stmt body = this->VisitStmt(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { diff --git a/src/tir/ir/tir_visitor_with_path.cc b/src/tir/ir/tir_visitor_with_path.cc index 5436e73d57c0..fc06509b5515 100644 --- a/src/tir/ir/tir_visitor_with_path.cc +++ b/src/tir/ir/tir_visitor_with_path.cc @@ -110,7 +110,7 @@ void TIRVisitorWithPath::Visit(const PrimFunc& func, AccessPath path) { } } - Visit(func->body, path->Attr("body")); + bind_scope_.WithNewScope([&]() { Visit(func->body, path->Attr("body")); }); while (context.size()) context.pop_back(); } @@ -172,10 +172,11 @@ void TIRVisitorWithPath::Visit(const Range& range, AccessPath path) { Visit(range->extent, path->Attr("extent")); } -void TIRVisitorWithPath::VisitStmt_(const LetStmtNode* op, AccessPath path) { +void TIRVisitorWithPath::VisitStmt_(const BindNode* op, AccessPath path) { Visit(op->value, path->Attr("value")); - auto context = WithDef(op->var, path->Attr("var")); - Visit(op->body, path->Attr("body")); + // Push the Bind's var definition into the current scope. + // The def lives until the enclosing scope (body-carrying stmt) exits. + bind_scope_.Current().push_back(WithDef(op->var, path->Attr("var"))); } void TIRVisitorWithPath::VisitStmt_(const AttrStmtNode* op, AccessPath path) { @@ -192,7 +193,7 @@ void TIRVisitorWithPath::VisitStmt_(const AttrStmtNode* op, AccessPath path) { } else if (auto expr = op->node.as()) { Visit(expr.value(), path->Attr("node")); } - Visit(op->body, path->Attr("body")); + bind_scope_.WithNewScope([&]() { Visit(op->body, path->Attr("body")); }); while (context.size()) { context.pop_back(); @@ -203,12 +204,12 @@ void TIRVisitorWithPath::VisitStmt_(const ForNode* op, AccessPath path) { Visit(op->min, path->Attr("min")); Visit(op->extent, path->Attr("extent")); auto context = WithDef(op->loop_var, path->Attr("loop_var")); - Visit(op->body, path->Attr("body")); + bind_scope_.WithNewScope([&]() { Visit(op->body, path->Attr("body")); }); } void TIRVisitorWithPath::VisitStmt_(const WhileNode* op, AccessPath path) { Visit(op->condition, path->Attr("condition")); - Visit(op->body, path->Attr("body")); + bind_scope_.WithNewScope([&]() { Visit(op->body, path->Attr("body")); }); } void TIRVisitorWithPath::VisitStmt_(const AllocBufferNode* op, AccessPath path) { @@ -223,7 +224,7 @@ void TIRVisitorWithPath::VisitStmt_(const AllocBufferNode* op, AccessPath path) void TIRVisitorWithPath::VisitStmt_(const DeclBufferNode* op, AccessPath path) { auto context = WithDef(op->buffer, path->Attr("buffer")); - Visit(op->body, path->Attr("body")); + bind_scope_.WithNewScope([&]() { Visit(op->body, path->Attr("body")); }); } void TIRVisitorWithPath::VisitStmt_(const BufferStoreNode* op, AccessPath path) { @@ -234,8 +235,8 @@ void TIRVisitorWithPath::VisitStmt_(const BufferStoreNode* op, AccessPath path) void TIRVisitorWithPath::VisitStmt_(const IfThenElseNode* op, AccessPath path) { Visit(op->condition, path->Attr("condition")); - Visit(op->then_case, path->Attr("then_case")); - Visit(op->else_case, path->Attr("else_case")); + bind_scope_.WithNewScope([&]() { Visit(op->then_case, path->Attr("then_case")); }); + bind_scope_.WithNewScope([&]() { Visit(op->else_case, path->Attr("else_case")); }); } void TIRVisitorWithPath::VisitStmt_(const AssertStmtNode* op, AccessPath path) { @@ -245,7 +246,10 @@ void TIRVisitorWithPath::VisitStmt_(const AssertStmtNode* op, AccessPath path) { } void TIRVisitorWithPath::VisitStmt_(const SeqStmtNode* op, AccessPath path) { - Visit(op->seq, path->Attr("seq")); + auto seq_path = path->Attr("seq"); + for (size_t i = 0; i < op->seq.size(); i++) { + Visit(op->seq[i], seq_path->ArrayItem(i)); + } } void TIRVisitorWithPath::VisitStmt_(const EvaluateNode* op, AccessPath path) { @@ -292,8 +296,8 @@ void TIRVisitorWithPath::VisitStmt_(const SBlockNode* op, AccessPath path) { } } - Visit(op->init, path->Attr("init")); - Visit(op->body, path->Attr("body")); + bind_scope_.WithNewScope([&]() { Visit(op->init, path->Attr("init")); }); + bind_scope_.WithNewScope([&]() { Visit(op->body, path->Attr("body")); }); while (context.size()) context.pop_back(); } diff --git a/src/tir/ir/tir_visitor_with_path.h b/src/tir/ir/tir_visitor_with_path.h index f5189ae61cee..92a5454e6473 100644 --- a/src/tir/ir/tir_visitor_with_path.h +++ b/src/tir/ir/tir_visitor_with_path.h @@ -25,6 +25,7 @@ #define TVM_TIR_IR_TIR_VISITOR_WITH_PATH_H_ #include +#include #include #include @@ -106,9 +107,9 @@ class TIRVisitorWithPath } using StmtFunctor::VisitStmt; + void VisitStmt_(const BindNode* op, ffi::reflection::AccessPath path) override; void VisitStmt_(const AttrStmtNode* op, ffi::reflection::AccessPath path) override; void VisitStmt_(const IfThenElseNode* op, ffi::reflection::AccessPath path) override; - void VisitStmt_(const LetStmtNode* op, ffi::reflection::AccessPath path) override; void VisitStmt_(const ForNode* op, ffi::reflection::AccessPath path) override; void VisitStmt_(const WhileNode* op, ffi::reflection::AccessPath path) override; void VisitStmt_(const AllocBufferNode* op, ffi::reflection::AccessPath path) override; @@ -249,6 +250,14 @@ class TIRVisitorWithPath } std::unordered_set in_scope_definitions_; + + /*! \brief Scope stack for Bind variable definitions. + * + * Body-carrying statements (For, IfThenElse, etc.) push a new scope. + * BindNode pushes its WithDef into the current scope. When the + * scope exits, all Bind defs are cleaned up automatically. + */ + ScopeStack>> bind_scope_; }; } // namespace tir diff --git a/src/tir/transform/common_subexpr_elim.cc b/src/tir/transform/common_subexpr_elim.cc index 9b9619fae937..d99af8486694 100644 --- a/src/tir/transform/common_subexpr_elim.cc +++ b/src/tir/transform/common_subexpr_elim.cc @@ -201,7 +201,11 @@ Stmt CommonSubexpressionEliminator::PerformCSE(const Stmt& stmt, const Context& CommonSubexpressionEliminator::CommonSubexpressionEliminator(const Stmt& stmt, const Context& context_init, bool identify_equiv_terms) - : initial_body_(stmt), context_(context_init), identify_equiv_terms_(identify_equiv_terms) {} + : initial_body_(stmt), context_(context_init), identify_equiv_terms_(identify_equiv_terms) { + // The initial scope level (from ScopeStack's constructor) does not need + // EnterContextScope() because it should never be popped -- it persists + // for the lifetime of the CSE pass and holds the function parameters. +} /*! * \brief The method which overrides the generic dispatcher of StmtExprMutator. @@ -342,32 +346,31 @@ PrimExpr CommonSubexpressionEliminator::VisitExpr(const PrimExpr& expr) { } /*! - * \brief The method which overrides the specific treatment for a LetNode + * \brief The method which overrides the specific treatment for a LetNode. + * + * The let-in expression introduces a new variable binding that is only visible + * within the body. We use context_scope_.WithNewScope to automatically clean up + * the binding when the body has been visited, replacing the old manual + * save/restore of context_. */ PrimExpr CommonSubexpressionEliminator::VisitExpr_(const LetNode* op) { // At this point, we have already done the generic treatment of introducing (via let-in) what // was doable at the toplevel of the given let-in. - // Save the context at the entry of the function - Context context_at_entry = context_; - // Recurse on the `value` field for potentially rewriting it PrimExpr value_new = VisitExpr(op->value); - // Augment the context with the association (`var`, `value`) for preparing the next recursion - // on the `body` - context_.push_back({op->var, MaybeValue(op->value)}); - - // Recurse on the `body` (with this extended context) - // The recursive call will have potentially done new simplifications, because in this recursive - // call `var` will be a part of the context. - // (see in VisitExpr() that no introduction were performed when a computation was using an - // undefined variable, as that would lead to ill-formed code) - PrimExpr body_new = VisitExpr(op->body); - - // Restaure the context to its content at the entrance to not carry out of scope declarations - // as the variable introduced by the let-in is not in scope outside of its body - context_ = context_at_entry; + // Visit the body in a new scope. The let-in variable binding is added to the + // context inside the scope and automatically removed when the scope exits. + PrimExpr body_new = context_scope_.WithNewScope([&]() -> PrimExpr { + EnterContextScope(); + // Augment the context with the association (`var`, `value`) for the body + context_.push_back({op->var, MaybeValue(op->value)}); + // Recurse on the `body` (with this extended context) + // The recursive call will have potentially done new simplifications, because in this recursive + // call `var` will be a part of the context. + return VisitExpr(op->body); + }); // Rebuild the let-in with a new `value_new` and `body_new` where new simplifications might // have been done. @@ -478,8 +481,8 @@ Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) { // right to dive. result = ReplaceSelectedExpr::ReplaceSelectedExprInStmt(result, predicate_selector, new_var, CanContainEligibleComputations); - // Build a let-in that introduces the new variable in the current `result` - result = LetStmt(new_var, computation_and_nb.first, result); + // Build a bind that introduces the new variable before the current `result` + result = SeqStmt({Bind(new_var, computation_and_nb.first), result}); // We don't add the variable to the context because the invariant is that the // context is the context in which 'result' makes sense, and we've just updated it. } else { @@ -523,73 +526,154 @@ Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) { } /*! - * \brief The method which overrides the specific treatment for a LetStmtNode + * \brief The method which overrides the specific treatment for a BindNode. + * + * BindNode adds a (var, value) entry to the flat context_ vector. This entry + * persists across subsequent SeqStmt siblings in the same scope, enabling CSE + * to find common subexpressions that reference bind-defined variables. + * Cleanup happens automatically when the enclosing body-carrying statement's + * scope exits (via ContextScopeLevel's destructor), so no manual save/restore + * is needed here. */ -Stmt CommonSubexpressionEliminator::VisitStmt_(const LetStmtNode* op) { - // At this point, we have already done the generic treatment of introducing (via let-in) what - // was doable at the toplevel of the given let-in. - - // Save the context at the entry of the function - Context context_at_entry = context_; - +Stmt CommonSubexpressionEliminator::VisitStmt_(const BindNode* op) { // Recurse on the `value` field for potentially rewriting it PrimExpr value_new = VisitExpr(op->value); - // Augment the context with the association (`var`, `value`) for preparing the next recursion - // on the `body` + // Augment the context with the association (`var`, `value`). + // This persists across SeqStmt siblings and is cleaned up by the + // enclosing scope's ContextScopeLevel destructor. context_.push_back({op->var, MaybeValue(op->value)}); - // Recurse on the `body` (with this extended context) - // The recursive call will have potentially done new simplifications, because in this recursive - // call `var` will be a part of the context. - // (see in VisitStmt() that no introduction were performed when a computation was using an - // undefined variable, as that would lead to ill-formed code) - Stmt body_new = VisitStmt(op->body); - - // Restaure the context to its content at the entrance to not carry out of scope declarations - // as the variable introduced by the let-in is not in scope outside of its body - context_ = context_at_entry; - - // Rebuild the let-in with a new `value_new` and `body_new` where new simplifications might - // have been done. - - // If the `value` and the `body` of the let-in have been rewritten to the same thing - if (value_new.same_as(op->value) && body_new.same_as(op->body)) { - // Return a reference to the same node + // Rebuild the Bind if value changed + if (value_new.same_as(op->value)) { return ffi::GetRef(op); } else { - // Otherwise return a let-in built with the new `value_new` and the new `body_new` that - // have just been obtained - return LetStmt(op->var, value_new, body_new, op->span); + return Bind(op->var, value_new, op->span); } } /*! - * \brief The method which overrides the specific treatment for a ForNode + * \brief Whether a Bind value is trivial (constant or variable), meaning it cannot + * contribute eligible computations for CSE and can be safely batched. + */ +static bool IsTrivialBindValue(const PrimExpr& value) { + return value.as() != nullptr || value.as() != nullptr || + value.as() != nullptr || value.as() != nullptr; +} + +/*! + * \brief The method which overrides the specific treatment for a SeqStmtNode. + * + * Processes the flat sequence using a hybrid strategy that avoids the O(n^2) + * complexity of wrapping remaining siblings after every single Bind node: + * + * - Trivial Bind nodes (constant/variable values) are batched: their values + * are visited via VisitExpr, context_ is augmented, but the expensive + * cross-sibling CSE is deferred until the batch ends. + * - Non-trivial Bind nodes (whose values may contain eligible computations) + * use the wrap-remaining-siblings pattern to enable cross-sibling CSE. + * - After any Bind (trivial batch end or non-trivial), remaining siblings are + * wrapped into a body and VisitStmt is called once for cross-sibling CSE. + * - Non-Bind children are visited individually via VisitStmt. + * + * This reduces the common case of many consecutive trivial Binds (e.g., variable + * definitions with constant values) from O(n^2) to O(n), while preserving full + * CSE effectiveness for non-trivial Bind values. + * + * Context cleanup is handled automatically by ScopeStack. + */ +Stmt CommonSubexpressionEliminator::VisitStmt_(const SeqStmtNode* op) { + ffi::Array new_seq; + size_t i = 0; + + while (i < op->seq.size()) { + if (auto* bind = op->seq[i].as()) { + // Batch consecutive trivial Bind nodes (constant/variable values). + // These can't contribute common subexpressions, so it's safe to defer + // the cross-sibling CSE until the entire batch is processed. + if (IsTrivialBindValue(bind->value)) { + while (i < op->seq.size()) { + auto* b = op->seq[i].as(); + if (!b || !IsTrivialBindValue(b->value)) break; + PrimExpr value_new = VisitExpr(b->value); + context_.push_back({b->var, MaybeValue(b->value)}); + Stmt bind_new = + value_new.same_as(b->value) ? ffi::GetRef(b) : Bind(b->var, value_new, b->span); + new_seq.push_back(bind_new); + ++i; + } + } else { + // Non-trivial Bind: visit value, augment context, then wrap remaining + // siblings and call VisitStmt for cross-sibling CSE. + PrimExpr value_new = VisitExpr(bind->value); + context_.push_back({bind->var, MaybeValue(bind->value)}); + Stmt bind_new = value_new.same_as(bind->value) ? ffi::GetRef(bind) + : Bind(bind->var, value_new, bind->span); + new_seq.push_back(bind_new); + ++i; + } + // After the Bind (batch or single), wrap remaining siblings [i..end) and + // call VisitStmt once for cross-sibling CSE with the updated context. + if (i < op->seq.size()) { + Stmt body; + if (i + 1 == op->seq.size()) { + body = op->seq[i]; + } else { + ffi::Array rest; + for (size_t j = i; j < op->seq.size(); ++j) rest.push_back(op->seq[j]); + body = SeqStmt(rest); + } + Stmt body_new = VisitStmt(body); + // Flatten the result. + if (auto* inner = body_new.as()) { + for (const auto& s : inner->seq) new_seq.push_back(s); + } else { + new_seq.push_back(body_new); + } + return SeqStmt::Flatten(new_seq); + } + } else { + // Non-Bind child: visit individually via VisitStmt. + Stmt child_new = VisitStmt(op->seq[i]); + if (auto* inner = child_new.as()) { + for (const auto& s : inner->seq) new_seq.push_back(s); + } else { + new_seq.push_back(child_new); + } + ++i; + } + } + + return SeqStmt::Flatten(new_seq); +} + +/*! + * \brief The method which overrides the specific treatment for a ForNode. + * + * The for loop introduces a loop variable that is only visible within the body. + * We use context_scope_.WithNewScope to create a scope boundary: the loop + * variable (with no value, since it changes each iteration) is pushed inside + * the scope and automatically cleaned up on exit, replacing the old manual + * save/restore of context_. */ Stmt CommonSubexpressionEliminator::VisitStmt_(const ForNode* op) { // At this point, we have already done the generic treatment of introducing (via let-in) what // was doable at the toplevel of the given for loop. - // Save the context at the entry of the function - Context context_at_entry = context_; - // Recurse on the `min` field for potentially rewriting it PrimExpr min_new = VisitExpr(op->min); // Recurse on the `extent` field for potentially rewriting it PrimExpr extent_new = VisitExpr(op->extent); - // Augment the context with the association {loop_var, no value} (no value as its value will - // change during the execution of the loop) for preparing the next recursion on the `body` - context_.push_back({op->loop_var, MaybeValue()}); - - // Recurse on the `body` (with this extended context) - Stmt body_new = VisitStmt(op->body); - - // Restaure the context to its content at the entrance to not carry out of scope declarations - // as the variable introduced by the for loop is not in scope outside of its body - context_ = context_at_entry; + // Visit the body in a new scope. The loop variable is added to context_ inside + // the scope and automatically removed when the scope exits. + Stmt body_new = context_scope_.WithNewScope([&]() -> Stmt { + EnterContextScope(); + // Add loop_var with no value (its value changes each iteration) + context_.push_back({op->loop_var, MaybeValue()}); + return VisitStmt(op->body); + }); // Rebuild the for loop with (potentially) a new `min_new`, `extent_new` and `body_new`, where // new simplifications might have been done. @@ -606,6 +690,118 @@ Stmt CommonSubexpressionEliminator::VisitStmt_(const ForNode* op) { } } +/*! + * \brief The method which overrides the specific treatment for an IfThenElseNode. + * + * Each branch of the if-then-else gets its own scope, preventing context entries + * (e.g., from Bind nodes inside one branch) from leaking into the other branch. + * Without this override, the default StmtExprMutator would visit both branches + * in the same scope, which could cause incorrect CSE across branches. + */ +Stmt CommonSubexpressionEliminator::VisitStmt_(const IfThenElseNode* op) { + PrimExpr condition_new = VisitExpr(op->condition); + + // Each branch gets its own scope to prevent context leaks between branches + Stmt then_new = context_scope_.WithNewScope([&]() -> Stmt { + EnterContextScope(); + return VisitStmt(op->then_case); + }); + + ffi::Optional else_new; + if (op->else_case) { + else_new = context_scope_.WithNewScope([&]() -> Stmt { + EnterContextScope(); + return VisitStmt(op->else_case.value()); + }); + } + + if (condition_new.same_as(op->condition) && then_new.same_as(op->then_case) && + else_new.same_as(op->else_case)) { + return ffi::GetRef(op); + } + return IfThenElse(condition_new, then_new, else_new, op->span); +} + +/*! + * \brief The method which overrides the specific treatment for an AttrStmtNode. + * + * AttrStmt has a body that may contain Bind nodes. A scope boundary prevents + * context entries from the body from leaking to subsequent statements. + */ +Stmt CommonSubexpressionEliminator::VisitStmt_(const AttrStmtNode* op) { + PrimExpr value_new = VisitExpr(op->value); + + // The body gets its own scope to contain any context entries added within it + Stmt body_new = context_scope_.WithNewScope([&]() -> Stmt { + EnterContextScope(); + return VisitStmt(op->body); + }); + + if (value_new.same_as(op->value) && body_new.same_as(op->body)) { + return ffi::GetRef(op); + } + return AttrStmt(op->node, op->attr_key, value_new, body_new, op->span); +} + +/*! + * \brief The method which overrides the specific treatment for an AllocBufferNode. + * + * AllocBuffer has a body and introduces a buffer. A scope boundary + * prevents context entries from the body from leaking outward. + */ +Stmt CommonSubexpressionEliminator::VisitStmt_(const AllocBufferNode* op) { + // The body gets its own scope to contain any context entries added within it + Stmt body_new = context_scope_.WithNewScope([&]() -> Stmt { + EnterContextScope(); + return VisitStmt(op->body); + }); + + if (body_new.same_as(op->body)) { + return ffi::GetRef(op); + } + return AllocBuffer(op->buffer, body_new, op->annotations, op->span); +} + +/*! + * \brief The method which overrides the specific treatment for a DeclBufferNode. + * + * DeclBuffer declares a buffer for use within its body. A scope boundary + * prevents context entries from the body from leaking outward. + */ +Stmt CommonSubexpressionEliminator::VisitStmt_(const DeclBufferNode* op) { + // The body gets its own scope to contain any context entries added within it + Stmt body_new = context_scope_.WithNewScope([&]() -> Stmt { + EnterContextScope(); + return VisitStmt(op->body); + }); + + if (body_new.same_as(op->body)) { + return ffi::GetRef(op); + } + return DeclBuffer(op->buffer, body_new, op->span); +} + +/*! + * \brief The method which overrides the specific treatment for a WhileNode. + * + * While loop has a body that may contain Bind nodes. A scope boundary prevents + * context entries from the body from leaking outward. + */ +Stmt CommonSubexpressionEliminator::VisitStmt_(const WhileNode* op) { + PrimExpr condition_new = VisitExpr(op->condition); + + // The body gets its own scope to contain any context entries added within it + Stmt body_new = context_scope_.WithNewScope([&]() -> Stmt { + EnterContextScope(); + return VisitStmt(op->body); + }); + + if (condition_new.same_as(op->condition) && body_new.same_as(op->body)) { + return ffi::GetRef(op); + } + return While(condition_new, body_new, op->span); +} + namespace transform { /*! diff --git a/src/tir/transform/common_subexpr_elim.h b/src/tir/transform/common_subexpr_elim.h index 814161cc3535..5674070e3d25 100644 --- a/src/tir/transform/common_subexpr_elim.h +++ b/src/tir/transform/common_subexpr_elim.h @@ -28,6 +28,7 @@ #ifndef TVM_TIR_TRANSFORM_COMMON_SUBEXPR_ELIM_H_ #define TVM_TIR_TRANSFORM_COMMON_SUBEXPR_ELIM_H_ +#include #include #include #include @@ -69,12 +70,89 @@ class CommonSubexpressionEliminator : public StmtExprMutator { PrimExpr VisitExpr_(const LetNode* op) override; - Stmt VisitStmt_(const LetStmtNode* op) override; + Stmt VisitStmt_(const BindNode* op) override; + Stmt VisitStmt_(const SeqStmtNode* op) override; Stmt VisitStmt_(const ForNode* op) override; + Stmt VisitStmt_(const IfThenElseNode* op) override; + Stmt VisitStmt_(const AttrStmtNode* op) override; + Stmt VisitStmt_(const AllocBufferNode* op) override; + Stmt VisitStmt_(const DeclBufferNode* op) override; + Stmt VisitStmt_(const WhileNode* op) override; private: - Stmt initial_body_; // Kept for checking if names of new variables already exist - Context context_; // Context associating variables to (maybe) definitions + /*! \brief Scope level for the context stack. + * + * Each scope level records the size of `context_` when the scope was entered. + * When the scope exits (via ScopeStack::WithNewScope), the destructor truncates + * `context_` back to the saved size, automatically cleaning up any context + * entries added within that scope (e.g., from BindNode or loop variables). + * + * This approach keeps `context_` as a flat vector for efficient searching + * while using ScopeStack for automatic scope-based cleanup. + */ + struct ContextScopeLevel { + Context* context{nullptr}; + size_t saved_size{0}; + + ContextScopeLevel() = default; + ContextScopeLevel(const ContextScopeLevel&) = delete; + ContextScopeLevel& operator=(const ContextScopeLevel&) = delete; + ContextScopeLevel(ContextScopeLevel&& other) noexcept + : context(other.context), saved_size(other.saved_size) { + other.context = nullptr; // prevent other's destructor from truncating + } + ContextScopeLevel& operator=(ContextScopeLevel&& other) noexcept { + if (this != &other) { + // Run our cleanup first + if (context) context->resize(saved_size); + context = other.context; + saved_size = other.saved_size; + other.context = nullptr; + } + return *this; + } + + ~ContextScopeLevel() { + if (context) context->resize(saved_size); + } + }; + + /*! \brief Enter a new context scope, recording the current context size. + * + * Must be called inside context_scope_.WithNewScope() to initialize the + * newly-pushed scope level. On scope exit, the destructor of + * ContextScopeLevel will truncate context_ back to this size. + */ + void EnterContextScope() { + auto& level = context_scope_.Current(); + level.context = &context_; + level.saved_size = context_.size(); + } + + Stmt initial_body_; // Kept for checking if names of new variables already exist + + /*! \brief Flat context vector associating variables to (optional) definitions. + * + * This is the searchable context: VisitExpr and VisitStmt scan it linearly + * to find existing variables whose values match a candidate computation. + * Entries are added by BindNode (with a value) and ForNode (loop var, no value). + * Cleanup is automatic via context_scope_: when a scope exits, context_ is + * truncated to the size it had when the scope was entered. + */ + Context context_; + + /*! \brief Scope stack for automatic context cleanup. + * + * Body-carrying statements (For, IfThenElse, AttrStmt, Allocate, DeclBuffer, + * While) create new scope levels via WithNewScope. BindNode entries persist + * across SeqStmt siblings within the same scope and are cleaned up when the + * enclosing body-carrying statement's scope exits. + * + * The initial scope level (created by ScopeStack's constructor) holds the + * function parameters added during PerformCSE. + */ + ScopeStack context_scope_; + int num_last_try_ = 0; // Number of the last variable tried int nb_var_ = 0; // Number of variables introduced by the CSE pass diff --git a/src/tir/transform/ir_utils.cc b/src/tir/transform/ir_utils.cc index 9398c2561ead..3cd117f5ce3e 100644 --- a/src/tir/transform/ir_utils.cc +++ b/src/tir/transform/ir_utils.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -46,11 +47,9 @@ Stmt MergeNest(const std::vector& nest, Stmt body) { TVM_FFI_ICHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); - } else if (const auto* let = s.as()) { - auto n = ffi::make_object(*let); - TVM_FFI_ICHECK(is_no_op(n->body)); - n->body = body; - body = Stmt(n); + } else if (const auto* bind = s.as()) { + // Bind has no body -- prepend it before the accumulated body in a SeqStmt. + body = SeqStmt::Flatten(ffi::GetRef(bind), body); } else if (const auto* attr = s.as()) { auto n = ffi::make_object(*attr); TVM_FFI_ICHECK(is_no_op(n->body)); @@ -96,13 +95,14 @@ Stmt MergeNest(const std::vector>& nest, Stmt body) { class IRConvertSSA final : public StmtExprMutator { public: PrimFunc VisitPrimFunc(PrimFunc func) { - std::vector redefines; - - // Remap parameters, if they were used in another function + // Remap parameters, if they were used in another function. + // Function-scope remaps use function_scope_var_remap_ (not the scope stack), + // because they persist across the entire function body. auto params = func->params.Map([&](const tir::Var& var) -> tir::Var { if (defined_.count(var.get())) { - const ScopedRedefine& redefine = redefines.emplace_back(this, var); - return redefine.new_var; + Var new_var = MakeNewVar(var); + PushVarRemap(var, new_var); + return new_var; } else { defined_.insert(var.get()); return var; @@ -124,7 +124,7 @@ class IRConvertSSA final : public StmtExprMutator { // Buffer_map shape vars use "match" semantics: first occurrence // defines the var, subsequent occurrences (in other buffers) are - // just consistent uses of the same var — not redefinitions. + // just consistent uses of the same var -- not redefinitions. if (!defined_.count(var_ptr)) { defined_.insert(var_ptr); } @@ -146,7 +146,8 @@ class IRConvertSSA final : public StmtExprMutator { for (const auto& [var, buffer] : func->buffer_map) { auto new_var = GetRemappedVar(var); if (defined_.count(buffer->data.get())) { - redefines.emplace_back(this, buffer->data); + Var new_data = MakeNewVar(buffer->data); + PushVarRemap(buffer->data, new_data); } else { defined_.insert(buffer->data.get()); } @@ -197,10 +198,8 @@ class IRConvertSSA final : public StmtExprMutator { func = PrimFunc(params, body, func->ret_type, buffer_map, attrs); } - // Pop the redefines in reverse order of creation - while (redefines.size()) { - redefines.pop_back(); - } + // Pop function-scope remaps in reverse order + PopAllRemapsInCurrentScope(); function_scope_var_remap_.clear(); return func; } @@ -220,9 +219,11 @@ class IRConvertSSA final : public StmtExprMutator { const Var& v = op->var; if (defined_.count(v.get())) { PrimExpr value = this->VisitExpr(op->value); - ScopedRedefine redefine(this, v); + Var new_var = MakeNewVar(v); + PushVarRemap(v, new_var); PrimExpr body = this->VisitExpr(op->body); - return Let(redefine.new_var, value, body); + PopVarRemap(v, new_var); + return Let(new_var, value, body); } else { defined_.insert(v.get()); return StmtExprMutator::VisitExpr_(op); @@ -242,48 +243,48 @@ class IRConvertSSA final : public StmtExprMutator { } Stmt VisitStmt_(const DeclBufferNode* op) final { - DeclBuffer decl = Downcast(StmtExprMutator::VisitStmt_(op)); - Buffer new_buffer = GetRemappedBuffer(decl->buffer); - if (!new_buffer.same_as(decl->buffer)) { - decl.CopyOnWrite()->buffer = std::move(new_buffer); - } - return decl; + return scope_.WithNewScope([&]() -> Stmt { + DeclBuffer decl = Downcast(StmtExprMutator::VisitStmt_(op)); + Buffer new_buffer = GetRemappedBuffer(decl->buffer); + if (!new_buffer.same_as(decl->buffer)) { + decl.CopyOnWrite()->buffer = std::move(new_buffer); + } + return decl; + }); } Stmt VisitStmt_(const SBlockNode* op) final { SBlock block = ffi::GetRef(op); - // The BlockNode is the point of definition for the IterVar + // The SBlockNode is the point of definition for the IterVar // instances. These re-defines must be present before visiting - // the body of the BlockNode. - std::vector redefines; - ffi::Array iter_vars = op->iter_vars.Map([&](IterVar iter_var) { - if (defined_.count(iter_var->var.get())) { - redefines.emplace_back(this, iter_var->var); - iter_var.CopyOnWrite()->var = redefines.back().new_var; - } else { - defined_.insert(iter_var->var.get()); + // the body of the SBlockNode. + return scope_.WithNewScope([&]() -> Stmt { + ffi::Array iter_vars = op->iter_vars.Map([&](IterVar iter_var) { + if (defined_.count(iter_var->var.get())) { + Var new_var = MakeNewVar(iter_var->var); + PushVarRemap(iter_var->var, new_var); + iter_var.CopyOnWrite()->var = new_var; + } else { + defined_.insert(iter_var->var.get()); + } + return iter_var; + }); + ffi::Array reads = + block->reads.Map([&](const auto& region) { return VisitBufferAccess(region); }); + ffi::Array writes = + block->writes.Map([&](const auto& region) { return VisitBufferAccess(region); }); + + if (!reads.same_as(block->reads) || !writes.same_as(block->writes) || + !iter_vars.same_as(op->iter_vars)) { + auto write_ptr = block.CopyOnWrite(); + write_ptr->reads = reads; + write_ptr->writes = writes; + write_ptr->iter_vars = iter_vars; } - return iter_var; - }); - ffi::Array reads = - block->reads.Map([&](const auto& region) { return VisitBufferAccess(region); }); - ffi::Array writes = - block->writes.Map([&](const auto& region) { return VisitBufferAccess(region); }); - if (!reads.same_as(block->reads) || !writes.same_as(block->writes) || - !iter_vars.same_as(op->iter_vars)) { - auto write_ptr = block.CopyOnWrite(); - write_ptr->reads = reads; - write_ptr->writes = writes; - write_ptr->iter_vars = iter_vars; - } - - Stmt output = Downcast(StmtExprMutator::VisitStmt_(block.get())); - - while (redefines.size()) redefines.pop_back(); - - return output; + return Downcast(StmtExprMutator::VisitStmt_(block.get())); + }); } template @@ -298,7 +299,7 @@ class IRConvertSSA final : public StmtExprMutator { } Var GetRemappedVar(Var var) { - if (auto it = scope_.find(var.get()); it != scope_.end() && it->second.size()) { + if (auto it = var_remap_.find(var.get()); it != var_remap_.end() && it->second.size()) { return it->second.back(); } else if (auto it = function_scope_var_remap_.find(var.get()); it != function_scope_var_remap_.end()) { @@ -348,49 +349,78 @@ class IRConvertSSA final : public StmtExprMutator { return new_buf; } - Stmt VisitStmt_(const LetStmtNode* op) final { + Stmt VisitStmt_(const BindNode* op) final { + // Bind var remaps are tracked in the current scope so they persist + // across SeqStmt siblings and are cleaned up when the enclosing + // body-carrying statement's scope exits. const Var& v = op->var; if (defined_.count(v.get())) { PrimExpr value = this->VisitExpr(op->value); - ScopedRedefine redefine(this, v); - Stmt body = this->VisitStmt(op->body); - return LetStmt(redefine.new_var, value, body); + Var new_var = MakeNewVar(v); + PushVarRemap(v, new_var); + return Bind(new_var, value); } else { defined_.insert(v.get()); return StmtExprMutator::VisitStmt_(op); } } + + Stmt VisitStmt_(const IfThenElseNode* op) final { + // Each branch gets its own scope so Bind remaps in one branch + // do not leak into the other. + PrimExpr condition = VisitExpr(op->condition); + Stmt then_case = scope_.WithNewScope([&]() -> Stmt { return VisitStmt(op->then_case); }); + ffi::Optional else_case; + if (op->else_case) { + else_case = scope_.WithNewScope([&]() -> Stmt { return VisitStmt(op->else_case.value()); }); + } + if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && + else_case.same_as(op->else_case)) { + return ffi::GetRef(op); + } + return IfThenElse(condition, then_case, else_case); + } + Stmt VisitStmt_(const ForNode* op) final { const Var& v = op->loop_var; if (defined_.count(v.get())) { - ScopedRedefine redefine(this, v); - Stmt stmt = StmtExprMutator::VisitStmt_(op); - auto n = ffi::make_object(*stmt.as()); - n->loop_var = redefine.new_var; - return For(n); + return scope_.WithNewScope([&]() -> Stmt { + Var new_var = MakeNewVar(v); + PushVarRemap(v, new_var); + Stmt stmt = StmtExprMutator::VisitStmt_(op); + auto n = ffi::make_object(*stmt.as()); + n->loop_var = new_var; + return For(n); + }); } else { defined_.insert(v.get()); - return StmtExprMutator::VisitStmt_(op); + return scope_.WithNewScope([&]() -> Stmt { return StmtExprMutator::VisitStmt_(op); }); } } + Stmt VisitStmt_(const WhileNode* op) final { + return scope_.WithNewScope([&]() -> Stmt { return StmtExprMutator::VisitStmt_(op); }); + } Stmt VisitStmt_(const AllocBufferNode* op) final { const Var& v = op->buffer->data; if (defined_.count(v.get())) { - ScopedRedefine redefine(this, v); - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - // Use GetRemappedBuffer so that the AllocBuffer's buffer is the same - // object as the one used by BufferStore/BufferLoad in the body. - Buffer new_buf = GetRemappedBuffer(op->buffer); - if (!new_buf.same_as(op->buffer)) { - auto node = Downcast(stmt); - node.CopyOnWrite()->buffer = std::move(new_buf); - return node; - } - return stmt; + return scope_.WithNewScope([&]() -> Stmt { + Var new_var = MakeNewVar(v); + PushVarRemap(v, new_var); + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + // Use GetRemappedBuffer so that the AllocBuffer's buffer is the same + // object as the one used by BufferStore/BufferLoad in the body. + Buffer new_buf = GetRemappedBuffer(op->buffer); + if (!new_buf.same_as(op->buffer)) { + auto node = Downcast(stmt); + node.CopyOnWrite()->buffer = std::move(new_buf); + return node; + } + return stmt; + }); } else { defined_.insert(v.get()); - return StmtExprMutator::VisitStmt_(op); + return scope_.WithNewScope([&]() -> Stmt { return StmtExprMutator::VisitStmt_(op); }); } } Stmt VisitStmt_(const AttrStmtNode* op) final { @@ -448,7 +478,7 @@ class IRConvertSSA final : public StmtExprMutator { } auto value = VisitExpr(op->value); - auto body = VisitStmt(op->body); + auto body = scope_.WithNewScope([&]() -> Stmt { return VisitStmt(op->body); }); Stmt output; if (new_iter_var.get() == iter_var && body.same_as(op->body) && value.same_as(op->value)) { @@ -467,75 +497,156 @@ class IRConvertSSA final : public StmtExprMutator { return output; } else if (const VarNode* v = op->node.as()) { - Stmt stmt = StmtExprMutator::VisitStmt_(op); + Stmt stmt = scope_.WithNewScope([&]() -> Stmt { return StmtExprMutator::VisitStmt_(op); }); op = stmt.as(); - if (scope_.count(v) && scope_[v].size() != 0) { - return AttrStmt(scope_[v].back(), op->attr_key, op->value, op->body); + if (var_remap_.count(v) && var_remap_[v].size() != 0) { + return AttrStmt(var_remap_[v].back(), op->attr_key, op->value, op->body); } else { return stmt; } } else { - return StmtExprMutator::VisitStmt_(op); + return scope_.WithNewScope([&]() -> Stmt { return StmtExprMutator::VisitStmt_(op); }); } } private: - struct ScopedRedefine { - ScopedRedefine(IRConvertSSA* parent, Var old_var) : parent(parent), old_var(old_var) { - bool is_size_var = old_var->IsInstance(); - if (old_var->type_annotation.defined()) { - if (is_size_var) { - new_var = SizeVar(old_var->name_hint, old_var->type_annotation); - } else { - new_var = Var(old_var->name_hint, old_var->type_annotation); - } + /*! \brief Record of a variable remap pushed to the current scope. */ + struct VarRemap { + Var old_var; + Var new_var; + }; + + /*! \brief Create a new variable with the same name and type as the original. */ + static Var MakeNewVar(const Var& old_var) { + bool is_size_var = old_var->IsInstance(); + if (old_var->type_annotation.defined()) { + if (is_size_var) { + return SizeVar(old_var->name_hint, old_var->type_annotation); } else { - if (is_size_var) { - new_var = SizeVar(old_var->name_hint, old_var->dtype); - } else { - new_var = Var(old_var->name_hint, old_var->dtype); + return Var(old_var->name_hint, old_var->type_annotation); + } + } else { + if (is_size_var) { + return SizeVar(old_var->name_hint, old_var->dtype); + } else { + return Var(old_var->name_hint, old_var->dtype); + } + } + } + + /*! \brief Push a variable remap to the current scope and the var_remap_ stack. */ + void PushVarRemap(const Var& old_var, const Var& new_var) { + var_remap_[old_var.get()].push_back(new_var); + auto& level = scope_.Current(); + level.parent = this; + level.push_back({old_var, new_var}); + } + + /*! \brief Pop a single variable remap (used for expression-level Let scoping). */ + void PopVarRemap(const Var& old_var, const Var& new_var) { + var_remap_[old_var.get()].pop_back(); + for (auto& kv : buf_remap_) { + std::vector& buffers = kv.second; + if (buffers.size() && (buffers.back()->data.get() == new_var.get())) { + buffers.pop_back(); + } + } + // Also remove from the current scope's tracking vector + auto& current = scope_.Current(); + if (current.size() && current.back().new_var.same_as(new_var)) { + current.pop_back(); + } + } + + /*! \brief Pop all remaps in the current scope level (used for function-scope cleanup). */ + void PopAllRemapsInCurrentScope() { + auto& current = scope_.Current(); + while (current.size()) { + auto& remap = current.back(); + var_remap_[remap.old_var.get()].pop_back(); + for (auto& kv : buf_remap_) { + std::vector& buffers = kv.second; + if (buffers.size() && (buffers.back()->data.get() == remap.new_var.get())) { + buffers.pop_back(); } } - parent->scope_[old_var.get()].push_back(new_var); + current.pop_back(); } + } + + /*! \brief Scope stack: each scope level holds the remaps introduced in that scope. + * + * When a body-carrying statement (For, SBlock, Allocate) calls + * scope_.WithNewScope([&]{...}), a new scope level is pushed. + * Bind statements push their remaps to the current scope. + * On scope exit, the destructor of std::vector triggers, + * and we undo all remaps in that level. + * + * Note: ScopeStack::WithNewScope calls T's destructor on exit. + * std::vector's destructor destroys elements but does NOT call custom + * cleanup. So we wrap the vector in ScopeLevel which handles cleanup. + */ + struct ScopeLevel { + std::vector remaps; + IRConvertSSA* parent{nullptr}; - ~ScopedRedefine() { - if (parent) { - parent->scope_[old_var.get()].pop_back(); + void push_back(VarRemap remap) { remaps.push_back(std::move(remap)); } + size_t size() const { return remaps.size(); } + VarRemap& back() { return remaps.back(); } + void pop_back() { remaps.pop_back(); } + + ~ScopeLevel() { + if (!parent) return; + // Pop remaps in reverse order + while (remaps.size()) { + auto& remap = remaps.back(); + parent->var_remap_[remap.old_var.get()].pop_back(); for (auto& kv : parent->buf_remap_) { std::vector& buffers = kv.second; - if (buffers.size() && (buffers.back()->data.get() == new_var.get())) { + if (buffers.size() && (buffers.back()->data.get() == remap.new_var.get())) { buffers.pop_back(); } } + remaps.pop_back(); } } - ScopedRedefine& operator=(const ScopedRedefine&) = delete; - ScopedRedefine(const ScopedRedefine&) = delete; - - ScopedRedefine& operator=(ScopedRedefine&& other) { - swap(other); - return *this; + ScopeLevel() = default; + ScopeLevel(const ScopeLevel&) = delete; + ScopeLevel& operator=(const ScopeLevel&) = delete; + ScopeLevel(ScopeLevel&& other) noexcept + : remaps(std::move(other.remaps)), parent(other.parent) { + other.parent = nullptr; // prevent other's destructor from popping } - ScopedRedefine(ScopedRedefine&& other) { swap(other); } - - void swap(ScopedRedefine& other) { - std::swap(parent, other.parent); - std::swap(old_var, other.old_var); - std::swap(new_var, other.new_var); + ScopeLevel& operator=(ScopeLevel&& other) noexcept { + if (this != &other) { + // Run our destructor logic first + if (parent) { + while (remaps.size()) { + auto& remap = remaps.back(); + parent->var_remap_[remap.old_var.get()].pop_back(); + for (auto& kv : parent->buf_remap_) { + std::vector& buffers = kv.second; + if (buffers.size() && (buffers.back()->data.get() == remap.new_var.get())) { + buffers.pop_back(); + } + } + remaps.pop_back(); + } + } + remaps = std::move(other.remaps); + parent = other.parent; + other.parent = nullptr; + } + return *this; } - - IRConvertSSA* parent{nullptr}; - Var old_var; - Var new_var; }; - std::unordered_map> scope_; + std::unordered_map> var_remap_; std::unordered_set defined_; std::unordered_map> buf_remap_; - std::unordered_map function_scope_var_remap_; + ScopeStack scope_; }; Stmt ConvertSSA(Stmt stmt) { return IRConvertSSA()(std::move(stmt)); } diff --git a/src/tir/transform/ir_utils.h b/src/tir/transform/ir_utils.h index 8077ebdea807..1282778ae09d 100644 --- a/src/tir/transform/ir_utils.h +++ b/src/tir/transform/ir_utils.h @@ -46,7 +46,7 @@ namespace tvm { namespace tir { /*! * \brief combine the nest stmt, whose body is not defined. - * \param nest A list of For and LetStmt, whose body is not defined. + * \param nest A list of For and Bind, whose body is not defined. * \param body body * \return The combined Stmt */ @@ -54,7 +54,7 @@ Stmt MergeNest(const std::vector& nest, Stmt body); /*! * \brief combine the nest stmt, whose body is not defined. - * \param nest A list of For and LetStmt, whose body is not defined. + * \param nest A list of For and Bind, whose body is not defined. * \param body body * \return The combined Stmt */ diff --git a/src/tir/transform/lower_tvm_builtin.cc b/src/tir/transform/lower_tvm_builtin.cc index 9fe3412389ce..c5a6cf59d26a 100644 --- a/src/tir/transform/lower_tvm_builtin.cc +++ b/src/tir/transform/lower_tvm_builtin.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include #include #include @@ -154,25 +155,36 @@ class BuiltinLower : public StmtExprMutator { // used when mutating. scope.max_sizes = GetMaxStack(stmt); + // Build a flat list of Bind stmts followed by the body + ffi::Array alloca_stmts; + if (scope.max_sizes.arg_stack != 0) { + alloca_stmts.push_back( + Bind(scope.stack_ffi_any, StackAlloca("tvm_ffi_any", scope.max_sizes.arg_stack))); + } + + if (scope.max_sizes.array_stack != 0) { + alloca_stmts.push_back( + Bind(scope.stack_array, StackAlloca("array", scope.max_sizes.array_stack))); + } + if (scope.max_sizes.shape_stack != -1) { scope.stack_shape = decl_buffer({IntImm(DataType::Int(64), scope.max_sizes.shape_stack)}, DataType::Int(64), "stack_shape"); + alloca_stmts.push_back( + Bind(scope.stack_shape->data, StackAlloca("shape", scope.max_sizes.shape_stack))); stmt = DeclBuffer(scope.stack_shape, stmt); - stmt = LetStmt(scope.stack_shape->data, StackAlloca("shape", scope.max_sizes.shape_stack), - stmt); } - if (scope.max_sizes.array_stack != 0) { - stmt = LetStmt(scope.stack_array, StackAlloca("array", scope.max_sizes.array_stack), stmt); - } - - if (scope.max_sizes.arg_stack != 0) { - stmt = LetStmt(scope.stack_ffi_any, StackAlloca("tvm_ffi_any", scope.max_sizes.arg_stack), - stmt); + if (!alloca_stmts.empty()) { + alloca_stmts.push_back(stmt); + stmt = SeqStmt(alloca_stmts); } } - stmt = this->VisitStmt(stmt); + stmt = scope_.WithNewScope([&]() -> Stmt { + Stmt visited = this->VisitStmt(stmt); + return AppendPendingFrees(visited); + }); TVM_FFI_ICHECK(!alloca_scope_.empty()); alloca_scope_.pop_back(); @@ -212,10 +224,10 @@ class BuiltinLower : public StmtExprMutator { } } - Stmt VisitStmt_(const LetStmtNode* op) final { + Stmt VisitStmt_(const BindNode* op) final { if (const CallNode* call = op->value.as()) { if (call->op.same_as(builtin::nd_mem_alloc_with_scope())) { - return StmtExprMutator::VisitStmt(MakeNdMemAllocWithScope(op, call)); + return MakeNdMemAllocWithScope(op, call); } } return StmtExprMutator::VisitStmt_(op); @@ -223,7 +235,17 @@ class BuiltinLower : public StmtExprMutator { Stmt VisitStmt_(const AllocBufferNode* op) { // Lower AllocBuffer to device allocate when needed. - Stmt stmt = StmtExprMutator::VisitStmt_(op); + // Visit body in a new scope so nd_mem_alloc frees are scoped to the body. + Stmt stmt = scope_.WithNewScope([&]() -> Stmt { + Stmt visited = StmtExprMutator::VisitStmt_(op); + const auto* alloc = visited.as(); + if (alloc && !scope_.Current().pending_frees.empty()) { + auto n = CopyOnWrite(alloc); + n->body = AppendPendingFrees(alloc->body); + return Stmt(n); + } + return visited; + }); op = stmt.as(); int64_t nbytes = GetVectorBytes(op->buffer->dtype); if (op->annotations.count(transform::kDisableLowerTVMBuiltin)) { @@ -264,13 +286,13 @@ class BuiltinLower : public StmtExprMutator { body = AttrStmt(op->buffer->data, attr::storage_alignment, make_const(DataType::Int(32), runtime::kTempAllocaAlignment), body); - body = LetStmt(op->buffer->data, - Call(op->buffer->data.dtype(), Op::Get("tir.TVMBackendAllocWorkspace"), - {cast(DataType::Int(32), device_type_.value()), - cast(DataType::Int(32), device_id_.value()), total_bytes, - IntImm(DataType::Int(32), op->buffer->dtype.code()), - IntImm(DataType::Int(32), op->buffer->dtype.bits())}), - body); + body = SeqStmt({Bind(op->buffer->data, + Call(op->buffer->data.dtype(), Op::Get("tir.TVMBackendAllocWorkspace"), + {cast(DataType::Int(32), device_type_.value()), + cast(DataType::Int(32), device_id_.value()), total_bytes, + IntImm(DataType::Int(32), op->buffer->dtype.code()), + IntImm(DataType::Int(32), op->buffer->dtype.bits())})), + body}); return body; } @@ -279,17 +301,33 @@ class BuiltinLower : public StmtExprMutator { if (op->attr_key == attr::device_id) { auto cache = device_id_; device_id_ = op->value; - Stmt out = this->VisitStmt(op->body); + Stmt out = scope_.WithNewScope([&]() -> Stmt { + Stmt body = this->VisitStmt(op->body); + return AppendPendingFrees(body); + }); device_id_ = cache; return out; } else if (op->attr_key == attr::device_type) { auto cache = device_type_; device_type_ = op->value; - Stmt out = this->VisitStmt(op->body); + Stmt out = scope_.WithNewScope([&]() -> Stmt { + Stmt body = this->VisitStmt(op->body); + return AppendPendingFrees(body); + }); device_type_ = cache; return out; } else { - return StmtExprMutator::VisitStmt_(op); + return scope_.WithNewScope([&]() -> Stmt { + Stmt visited = StmtExprMutator::VisitStmt_(op); + if (!scope_.Current().pending_frees.empty()) { + const auto* attr = visited.as(); + if (attr) { + return AttrStmt(attr->node, attr->attr_key, attr->value, AppendPendingFrees(attr->body), + attr->span); + } + } + return visited; + }); } } Stmt VisitStmt_(const ForNode* op) final { @@ -300,7 +338,10 @@ class BuiltinLower : public StmtExprMutator { if (op->kind == ForKind::kParallel) { body = this->VisitBodyAndRealizeAlloca(op->body); } else { - body = this->VisitStmt(op->body); + body = scope_.WithNewScope([&]() -> Stmt { + Stmt visited = this->VisitStmt(op->body); + return AppendPendingFrees(visited); + }); } if (min.same_as(op->min) && extent.same_as(op->extent) && body.same_as(op->body)) { @@ -314,6 +355,27 @@ class BuiltinLower : public StmtExprMutator { } } + Stmt VisitStmt_(const IfThenElseNode* op) final { + PrimExpr condition = this->VisitExpr(op->condition); + // Each branch gets its own scope to prevent frees from leaking across branches. + Stmt then_case = scope_.WithNewScope([&]() -> Stmt { + Stmt visited = this->VisitStmt(op->then_case); + return AppendPendingFrees(visited); + }); + ffi::Optional else_case; + if (op->else_case) { + else_case = scope_.WithNewScope([&]() -> Stmt { + Stmt visited = this->VisitStmt(op->else_case.value()); + return AppendPendingFrees(visited); + }); + } + if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && + else_case.same_as(op->else_case)) { + return ffi::GetRef(op); + } + return IfThenElse(condition, then_case, else_case, op->span); + } + PrimExpr VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::tvm_call_packed())) { return MakeCallPackedGeneric(op, 0, builtin::tvm_call_packed_lowered(), @@ -496,6 +558,7 @@ class BuiltinLower : public StmtExprMutator { return ffi::TypeIndex::kTVMFFIOpaquePtr; } else { TVM_FFI_THROW(InternalError) << "Unsupported type: " << api_dtype; + TVM_FFI_UNREACHABLE(); } }(); @@ -595,27 +658,14 @@ class BuiltinLower : public StmtExprMutator { return Call(op->dtype, lowered_packed_op, packed_args); } - Stmt MakeNdMemAllocWithScope(const LetStmtNode* let, const CallNode* call) { + Stmt MakeNdMemAllocWithScope(const BindNode* let, const CallNode* call) { TVM_FFI_ICHECK(device_type_) << "Unknown device type in current IR"; TVM_FFI_ICHECK(device_id_) << "Unknown device id in current IR"; Stmt throw_last_error = Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {})); - PrimExpr storage_scope = call->args[0]; - Call free_op = Call(DataType::Int(32), builtin::tvm_call_packed(), - {GetDeviceMethodName("free_nd"), device_type_.value(), device_id_.value(), - storage_scope, let->var}); - Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error); - - Stmt body = SeqStmt( - {IfThenElse(Call(DataType::Bool(), builtin::isnullptr(), {let->var}), throw_last_error), - let->body, free_stmt}); - DataType dtype = let->var->type_annotation.as()->element_type.as()->dtype; - std::string fdevapi_prefix = "device_api."; - fdevapi_prefix += runtime::DLDeviceType2Str(device_type_.as()->value); - ffi::Array args = { GetDeviceMethodName("alloc_nd"), device_type_.value(), @@ -629,8 +679,22 @@ class BuiltinLower : public StmtExprMutator { } Call call_packed = Call(let->var.dtype(), builtin::tvm_call_packed(), args); - Stmt alloca = LetStmt(let->var, call_packed, body); - return alloca; + Stmt null_check = + IfThenElse(Call(DataType::Bool(), builtin::isnullptr(), {let->var}), throw_last_error); + + // Construct free_nd call and register in current scope. + // The free will be emitted on scope exit, matching the old LetStmt body semantics. + PrimExpr storage_scope = call->args[0]; + Call free_op = Call(DataType::Int(32), builtin::tvm_call_packed(), + {GetDeviceMethodName("free_nd"), device_type_.value(), device_id_.value(), + storage_scope, let->var}); + Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error); + // Visit the free_stmt so tvm_call_packed builtins inside it get lowered. + free_stmt = StmtExprMutator::VisitStmt(free_stmt); + scope_.Current().pending_frees.push_back(free_stmt); + + // Re-visit so tvm_call_packed in the Bind value and null_check get lowered. + return StmtExprMutator::VisitStmt(SeqStmt({Bind(let->var, call_packed), null_check})); } private: @@ -645,6 +709,32 @@ class BuiltinLower : public StmtExprMutator { return false; } + /*! + * \brief Scope level for tracking nd_mem_alloc_with_scope deallocations. + * + * When a Bind allocates via nd_mem_alloc_with_scope, the corresponding + * free_nd stmt is pushed to the current scope's pending_frees. Body-carrying + * stmts (For, IfThenElse, AllocBuffer, AttrStmt) create new scopes via + * WithNewScope. On scope exit, pending_frees are appended after the body, + * matching the old LetStmt body semantics structurally. + */ + struct ScopeLevel { + std::vector pending_frees; + }; + + /*! \brief Emit any pending frees from the current scope after the given body stmt. */ + Stmt AppendPendingFrees(Stmt body) { + auto& frees = scope_.Current().pending_frees; + if (frees.empty()) return body; + ffi::Array stmts; + stmts.push_back(body); + for (const auto& free_stmt : frees) { + stmts.push_back(free_stmt); + } + frees.clear(); + return SeqStmt::Flatten(stmts); + } + // The prepration sequence to be emitted before the current statement. std::vector> prep_seq_stack_; ffi::Optional device_type_{std::nullopt}; @@ -654,6 +744,8 @@ class BuiltinLower : public StmtExprMutator { // Record all stack frames. std::vector alloca_scope_; + // Scope stack for nd_mem_alloc_with_scope free tracking. + ScopeStack scope_; }; namespace transform { diff --git a/src/tir/transform/remove_no_op.cc b/src/tir/transform/remove_no_op.cc index cb073bf31a61..d5bcc210075f 100644 --- a/src/tir/transform/remove_no_op.cc +++ b/src/tir/transform/remove_no_op.cc @@ -91,22 +91,10 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer { const StmtNode* context) : Parent(analyzer), touch_pattern_(touch_pattern), context_(context) {} - Stmt VisitStmt_(const LetStmtNode* op) final { - Stmt stmt = Parent::VisitStmt_(op); - op = stmt.as(); - if (is_no_op(op->body)) { - return MakeEvaluate(op->value); - } - - bool body_uses_bound_variable = - !UsesVar(op->body, [&](const VarNode* var) { return var == op->var.get(); }); - if (body_uses_bound_variable && HasSideEffect(op->value)) { - return SeqStmt({MakeEvaluate(op->value), op->body}); - } else if (body_uses_bound_variable) { - return op->body; - } else { - return stmt; - } + Stmt VisitStmt_(const BindNode* op) final { + // Simply mutate the value and return. + // Unused Bind elimination can be done later via a separate two-pass approach. + return Parent::VisitStmt_(op); } Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == "pragma_debug_skip_region") { diff --git a/src/tir/transform/simplify.cc b/src/tir/transform/simplify.cc index af0fc4cf47bf..3d295e1764be 100644 --- a/src/tir/transform/simplify.cc +++ b/src/tir/transform/simplify.cc @@ -37,7 +37,6 @@ #include "../../arith/ir_mutator_with_analyzer.h" #include "../../tir/analysis/control_flow_graph.h" -#include "../../tir/analysis/var_use_def_analysis.h" namespace tvm { namespace arith { @@ -97,46 +96,6 @@ struct SimplifyConfigNode : public AttrsNodeReflAdapter { } }; -/* \brief Utility function to collect vars that should be retained */ -std::unordered_set CollectVarsUsedInBufferDefinition(const Stmt& stmt) { - struct Visitor : StmtExprVisitor { - using StmtExprVisitor::VisitExpr_; - using StmtExprVisitor::VisitStmt_; - - void VisitExpr_(const BufferLoadNode* op) override { - VisitBuffer(op->buffer); - StmtExprVisitor::VisitExpr_(op); - } - void VisitStmt_(const BufferStoreNode* op) override { - VisitBuffer(op->buffer); - StmtExprVisitor::VisitStmt_(op); - } - - void VisitBuffer(const Buffer& buf) { - // Collect variables that should remain defined - VarUseDefAnalyzer usage(ffi::Array{}); - usage(buf->data); - for (const auto& dim : buf->shape) { - usage(dim); - } - for (const auto& dim : buf->strides) { - usage(dim); - } - usage(buf->elem_offset); - - // Track for use in LetStmtNode mutator - for (const auto& var : usage.undefined_) { - used_in_buffer_def_.insert(var.get()); - } - } - std::unordered_set used_in_buffer_def_; - }; - - Visitor visitor; - visitor(stmt); - return visitor.used_in_buffer_def_; -} - class SimplifyConfig : public Attrs { public: TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(SimplifyConfig, Attrs, SimplifyConfigNode); @@ -159,10 +118,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { touch_pattern = ControlFlowGraph(func->body); } - std::unordered_set used_in_buffer_def = - CollectVarsUsedInBufferDefinition(func->body); - StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern), - std::move(used_in_buffer_def)); + StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern)); simplifier.MarkBufferMapShapes(func); func.CopyOnWrite()->body = simplifier(func->body); return func; @@ -170,12 +126,8 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { private: explicit StmtSimplifier(Analyzer* analyzer, SimplifyConfig config, - std::optional touch_pattern, - std::unordered_set used_in_buffer_def) - : IRMutatorWithAnalyzer(analyzer), - config_(config), - touch_pattern_(touch_pattern), - used_in_buffer_def_(used_in_buffer_def) {} + std::optional touch_pattern) + : IRMutatorWithAnalyzer(analyzer), config_(config), touch_pattern_(touch_pattern) {} using Parent = IRMutatorWithAnalyzer; using Parent::VisitExpr_; @@ -220,52 +172,28 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { return Parent::VisitStmt_(op); } - bool CanInlineLetStmt(const LetStmtNode* op) { - if (is_const_number(op->value)) return true; - if (op->value.as()) return true; - // Won't face the deep expression explosion problem as in Let expression. - // attempt to inline as much as possible if the value integer type(can be index). - if (!op->value.dtype().is_int()) return false; - return SideEffect(op->value) <= CallEffectKind::kPure; - } - - Stmt VisitStmt_(const LetStmtNode* op) override { + Stmt VisitStmt_(const BindNode* op) override { PrimExpr value = this->VisitExpr(op->value); - bool can_inline = CanInlineLetStmt(op); - if (can_inline) { - // It is usually fine to discard the let binding because the - // call to simplify will always inline the var. - // - // The exception is when the variable is used in a Buffer's - // definition, as these are not updated by the simplification. - // After DeclBuffer is required prior to use of a buffer, - // simplifying can update the buffer definition as well. The - // buffer can only be updated at its point of definition, - // because the points of use may occur within contexts that - // allow for additional simplifications (e.g. a buffer of shape - // [i,j] whose first use occurs within "if i==1" should not have - // its shape simplified to [1,j]). + // Bind in analyzer for constraint proving and simplification of + // subsequent expressions. Don't remove the Bind statement -- + // with flat Bind there's no body to inspect for usage patterns, + // so we always keep the Bind. + if (SideEffect(value) <= CallEffectKind::kPure) { analyzer_->Bind(op->var, value); - } else if (SideEffect(op->value) <= CallEffectKind::kPure) { - // Even if we aren't replacing all occurrences, they may be - // necessary for proving conditional statements. + // Record the binding so we can substitute it into assert conditions + // (see VisitStmt_(const AssertStmtNode*)). Under SSA each var is + // bound exactly once, so the map grows monotonically without key + // conflicts. No scope-based cleanup is needed because vars bound + // in inner scopes are only referenced within those scopes; stale + // entries are harmless and never consulted again. non_inlined_bindings_.Set(op->var, value); } - Stmt body = this->VisitStmt(op->body); - - // TODO(Lunderberg): Update the Buffer object as part of - // DeclBuffer updates, which will first require - // https://github.com/apache/tvm/pull/14778. - bool used_in_buffer_def = used_in_buffer_def_.count(op->var.get()); - if (can_inline && !used_in_buffer_def) { - return body; - } else if (value.same_as(op->value) && body.same_as(op->body)) { + if (value.same_as(op->value)) { return ffi::GetRef(op); } else { auto n = this->CopyOnWrite(op); n->value = std::move(value); - n->body = std::move(body); return Stmt(n); } } @@ -350,9 +278,10 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { SimplifyConfig config_; std::optional touch_pattern_; + // Pure Bind values kept for substitution into assert conditions. + // Grows monotonically under SSA — no scope-based cleanup required. ffi::Map non_inlined_bindings_; ffi::Optional current_stmt_{std::nullopt}; - std::unordered_set used_in_buffer_def_; }; } // namespace arith diff --git a/src/tir/transform/split_host_device.cc b/src/tir/transform/split_host_device.cc index e832e9d1caea..0bdcde861c63 100644 --- a/src/tir/transform/split_host_device.cc +++ b/src/tir/transform/split_host_device.cc @@ -105,9 +105,7 @@ class HostDeviceSplitter : public StmtMutator { Call kernel_call(success->dtype, kernel_symbol_global, args); AssertStmt assert_success(kernel_error_code == success, StringImm("RuntimeError"), {StringImm("Error executing compute kernel")}); - LetStmt let_check(kernel_error_code, kernel_call, assert_success); - - return let_check; + return SeqStmt({Bind(kernel_error_code, kernel_call), assert_success}); } else { return Evaluate(Call(DataType::Void(), kernel_symbol_global, args)); diff --git a/src/tir/transform/storage_rewrite.cc b/src/tir/transform/storage_rewrite.cc index e6391f08c3d7..d315bc5d8194 100644 --- a/src/tir/transform/storage_rewrite.cc +++ b/src/tir/transform/storage_rewrite.cc @@ -213,7 +213,17 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { void VisitStmt_(const AssertStmtNode* op) final { VisitNewScope(op); } - void VisitStmt_(const LetStmtNode* op) final { VisitNewScope(op); } + void VisitStmt_(const BindNode* op) final { + scope_.push_back(StmtEntry()); + // visit subexpr (the value may contain BufferLoad) + StmtExprVisitor::VisitStmt_(op); + StmtEntry e = scope_.back(); + scope_.pop_back(); + if (e.touched.size() != 0) { + e.stmt = op; + linear_seq_.push_back(e); + } + } // linearized access sequence. std::vector linear_seq_; @@ -1205,7 +1215,7 @@ class VectorTypeAccessChecker : public StmtExprVisitor { StmtExprVisitor::VisitExpr_(op); } - void VisitStmt_(const LetStmtNode* op) final { + void VisitStmt_(const BindNode* op) final { HandleLetNode(op->var); StmtExprVisitor::VisitStmt_(op); } @@ -1516,15 +1526,14 @@ class VectorTypeRewriter : public StmtExprMutator { return modified; } - Stmt VisitStmt_(const LetStmtNode* op) final { + Stmt VisitStmt_(const BindNode* op) final { auto it = rewrite_map_.find(op->var.get()); PrimExpr value = this->VisitExpr(op->value); - Stmt body = this->VisitStmt(op->body); Var var = (it == rewrite_map_.end()) ? op->var : it->second.new_buffer_var; - if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { + if (var.same_as(op->var) && value.same_as(op->value)) { return ffi::GetRef(op); } - return LetStmt(var, value, body); + return Bind(var, value); } Stmt VisitStmt_(const AllocBufferNode* op) final { diff --git a/src/tir/transform/tvm_ffi_binder.cc b/src/tir/transform/tvm_ffi_binder.cc index 02058970e182..9b37f886c24f 100644 --- a/src/tir/transform/tvm_ffi_binder.cc +++ b/src/tir/transform/tvm_ffi_binder.cc @@ -169,7 +169,7 @@ bool TVMFFIABIBuilder::BindScalar(const PrimExpr& arg, const PrimExpr& value, // First bind: define the variable if (with_lets) { var_defs_.emplace(v_arg.get(), VarDefInfo{arg, path}); - init_nest_.emplace_back(LetStmt(v_arg, value, Evaluate(0))); + init_nest_.emplace_back(Bind(v_arg, value)); } else { var_defs_.emplace(v_arg.get(), VarDefInfo{value, path}); } @@ -488,17 +488,15 @@ PrimExpr TVMFFIABIBuilder::DecodeParamFloat(int param_index, const Var& type_ind // ============================================================ void TVMFFIABIBuilder::DecodeParam(int param_index) { - const Stmt nop = Evaluate(0); Var param = params_[param_index]; DataType dtype = param.dtype(); // Extract type_index from packed_args Var type_index(param->name_hint + ".type_index", DataType::Int(32)); - init_nest_.push_back(LetStmt(type_index, - tir::Call(DataType::Int(32), builtin::tvm_struct_get(), - {v_packed_args_, IntImm(DataType::Int(32), param_index), - IntImm(DataType::Int(32), builtin::kTVMFFIAnyTypeIndex)}), - nop)); + init_nest_.push_back( + Bind(type_index, tir::Call(DataType::Int(32), builtin::tvm_struct_get(), + {v_packed_args_, IntImm(DataType::Int(32), param_index), + IntImm(DataType::Int(32), builtin::kTVMFFIAnyTypeIndex)}))); // Type-check and load value via per-dtype dispatch PrimExpr arg_value; @@ -553,10 +551,8 @@ Var TVMFFIABIBuilder::DLTensorGetFieldPtr(const Var& handle, int field_kind, const std::string& var_name) { Var ptr(var_name, DataType::Handle()); init_nest_.emplace_back( - LetStmt(ptr, - TVMStructGet(DataType::Handle(), handle, 0, - static_cast(field_kind)), - Evaluate(0))); + Bind(ptr, TVMStructGet(DataType::Handle(), handle, 0, + static_cast(field_kind)))); return ptr; } diff --git a/src/tir/transform/tvm_ffi_binder.h b/src/tir/transform/tvm_ffi_binder.h index 2daa3200874c..0ae8338a9b38 100644 --- a/src/tir/transform/tvm_ffi_binder.h +++ b/src/tir/transform/tvm_ffi_binder.h @@ -59,7 +59,7 @@ namespace tir { * by a later buffer's shape (batch_size). Separating definitions from * checks guarantees all variables are in scope when assertions reference them. * - * - init_nest: LetStmts, DeclBuffers for shape/strides arrays, AttrStmts — + * - init_nest: Binds, DeclBuffers for shape/strides arrays, AttrStmts — * all value-loading code that defines variables. * - asserts: AssertStmts — all validation checks. * - decl_buffers: DeclBuffer for buffer_map entries — buffer declarations. @@ -95,7 +95,7 @@ class TVMFFIABIBuilder { struct Result { /*! \brief Var -> VarDefInfo map for defined variables. */ std::unordered_map var_defs; - /*! \brief Variable definitions (LetStmts, shape/strides DeclBuffers, AttrStmts). */ + /*! \brief Variable definitions (Binds, shape/strides DeclBuffers, AttrStmts). */ std::vector init_nest; /*! \brief Validation checks (all AssertStmts). */ std::vector asserts; @@ -227,7 +227,7 @@ class TVMFFIABIBuilder { * \brief Internal scalar bind with AccessPath tracking and rich error messages. * * Binds \p arg to \p value. If arg is a Var not yet in var_defs_, creates a - * new definition (LetStmt to init_nest_); otherwise emits a rich assertion + * new definition (Bind to init_nest_); otherwise emits a rich assertion * (to asserts_) that the existing value matches the new one. * * When arg is a non-Var expression (e.g. batch_size + 1), the assertion is @@ -236,7 +236,7 @@ class TVMFFIABIBuilder { * * \param arg The argument expression to bind (typically a Var or constant). * \param value The value expression to bind to the argument. - * \param with_lets If true, emit LetStmt bindings into init_nest_. + * \param with_lets If true, emit Bind bindings into init_nest_. * \param path AccessPath for rich error message rendering. * \return True if this was the first bind (definition created), false otherwise. */ @@ -284,7 +284,7 @@ class TVMFFIABIBuilder { // ── DLTensor sub-helpers ─────────────────────────────────────── /*! - * \brief Get a DLTensor field pointer (shape or strides) and store it in a LetStmt. + * \brief Get a DLTensor field pointer (shape or strides) and store it in a Bind. * * \param handle The DLTensor handle variable. * \param field_kind kDLTensorShape or kDLTensorStrides. @@ -390,7 +390,7 @@ class TVMFFIABIBuilder { /*! \brief The definition map: VarNode* -> VarDefInfo (value + first_def_path). */ std::unordered_map var_defs_; - /*! \brief Variable definitions: LetStmts, shape/strides DeclBuffers, AttrStmts. */ + /*! \brief Variable definitions: Binds, shape/strides DeclBuffers, AttrStmts. */ std::vector init_nest_; /*! \brief Validation checks: all AssertStmts. */ std::vector asserts_; diff --git a/src/tir/transform/unsupported_dtype_legalize.cc b/src/tir/transform/unsupported_dtype_legalize.cc index cca435d0de78..3c4fbc66241b 100644 --- a/src/tir/transform/unsupported_dtype_legalize.cc +++ b/src/tir/transform/unsupported_dtype_legalize.cc @@ -293,19 +293,18 @@ class ComputeLegalizer : public StmtExprMutator { DEFINE_BIOP_EXPR_LEGALIZE(EQNode, operator==); DEFINE_BIOP_EXPR_LEGALIZE(NENode, operator!=); - Stmt VisitStmt_(const LetStmtNode* op) final { + Stmt VisitStmt_(const BindNode* op) final { PrimExpr value = PromoteToTarget(op->value); Var var = op->var; if (value.dtype() != op->value.dtype()) { var = op->var.copy_with_dtype(op->value.dtype()); var_remap_[op->var] = var; } - Stmt body = VisitStmt(op->body); - if (value.same_as(op->value) && var.same_as(op->var) && body.same_as(op->body)) { + if (value.same_as(op->value) && var.same_as(op->var)) { return ffi::GetRef(op); } else { - return LetStmt(var, value, body); + return Bind(var, value); } } @@ -583,15 +582,14 @@ class StorageLegalizer : public StmtExprMutator { } } - Stmt VisitStmt_(const LetStmtNode* op) final { + Stmt VisitStmt_(const BindNode* op) final { PrimExpr value = VisitExpr(op->value); Var var = RemapVarDef(op->var); - Stmt body = VisitStmt(op->body); - if (value.same_as(op->value) && var.same_as(op->var) && body.same_as(op->body)) { + if (value.same_as(op->value) && var.same_as(op->var)) { return ffi::GetRef(op); } else { - return LetStmt(var, value, body); + return Bind(var, value); } } diff --git a/src/tir/transform/vectorize_loop.cc b/src/tir/transform/vectorize_loop.cc index 8bcf6078caf6..1862ceb1d480 100644 --- a/src/tir/transform/vectorize_loop.cc +++ b/src/tir/transform/vectorize_loop.cc @@ -815,9 +815,10 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->value); // if visit of value triggers need scalarize // we need to scalarize the let @@ -832,14 +833,13 @@ class Vectorizer : public StmtMutator, public ExprFunctorvalue.dtype().get_lanes_or_vscale_factor()) { Var new_var(op->var->name_hint, value.dtype()); let_binding_[op->var] = new_var; - return LetStmt(new_var, value, this->VisitStmt(op->body)); + return Bind(new_var, value); } else { let_binding_[op->var] = op->var; - Stmt body = this->VisitStmt(op->body); - if (value.same_as(op->value) && body.same_as(op->body)) { + if (value.same_as(op->value)) { return ffi::GetRef(op); } else { - return LetStmt(op->var, value, body); + return Bind(op->var, value); } } } diff --git a/tests/python/s_tir/schedule/test_tir_schedule_rfactor.py b/tests/python/s_tir/schedule/test_tir_schedule_rfactor.py index 499650345547..195d361a9a5e 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_rfactor.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_rfactor.py @@ -823,7 +823,7 @@ def argmax_split_init_buffer_duplicate( @T.prim_func -def argmax_split_letstmt_fewer_than_init( +def argmax_split_bind_fewer_than_init( idx: T.Buffer((128, 128), "int32"), val: T.Buffer((128, 128), "float32"), argmax_v0: T.Buffer((128,), "int32"), @@ -844,7 +844,7 @@ def argmax_split_letstmt_fewer_than_init( @T.prim_func -def argmax_split_letstmt_more_than_init( +def argmax_split_bind_more_than_init( idx: T.Buffer((128, 128), "int32"), val: T.Buffer((128, 128), "float32"), argmax_v0: T.Buffer((128,), "int32"), @@ -1544,16 +1544,16 @@ def test_reduction_rfactor_argmax_init_buffer_duplicate(): s.rfactor(ki, 1) -def test_reduction_rfactor_argmax_letstmt_fewer_than_init(): - s = tvm.s_tir.Schedule(argmax_split_letstmt_fewer_than_init, debug_mask="all") +def test_reduction_rfactor_argmax_bind_fewer_than_init(): + s = tvm.s_tir.Schedule(argmax_split_bind_fewer_than_init, debug_mask="all") argmax = s.get_sblock("argmax") _, _, ki = s.get_loops(argmax) with pytest.raises(tvm.s_tir.ScheduleError): s.rfactor(ki, 1) -def test_reduction_rfactor_argmax_letstmt_more_than_init(): - s = tvm.s_tir.Schedule(argmax_split_letstmt_more_than_init, debug_mask="all") +def test_reduction_rfactor_argmax_bind_more_than_init(): + s = tvm.s_tir.Schedule(argmax_split_bind_more_than_init, debug_mask="all") argmax = s.get_sblock("argmax") _, _, ki = s.get_loops(argmax) with pytest.raises(tvm.s_tir.ScheduleError): diff --git a/tests/python/s_tir/transform/test_s_tir_transform_hoist_expression.py b/tests/python/s_tir/transform/test_s_tir_transform_hoist_expression.py index ed82f3fecbb5..8bca1da57793 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_hoist_expression.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_hoist_expression.py @@ -228,20 +228,22 @@ def before(A: T.Buffer((4, 4), "float32")): @T.prim_func(private=True) def expected(A: T.Buffer((4, 4), "float32")): for i in T.serial(4): - condition = i < 3 - if condition: + condition: T.bool = i < 3 # noqa: F841 + if i < 3: for j in T.serial(4): - A[i, j] = 0.0 + A[i, j] = T.float32(0.0) after = _run_transform(before, HoistedConditionals.All, HoistedLetBindings.All) tvm.ir.assert_structural_equal(after, expected) def test_hoist_disable_let(): - """As test_hoist_with_let, but forbid hoisting of LetStmt + """As test_hoist_with_let, but forbid hoisting of Bind Because the condition depends on the let binding, it should no - longer be hoisted. + longer be hoisted. With Bind lifecycle management, the condition + var is erased at sequence boundaries, so the if-condition uses + the raw expression. """ @T.prim_func(private=True) @@ -252,7 +254,12 @@ def before(A: T.Buffer((4, 4), "float32")): if condition: A[i, j] = 0.0 - expected = before + @T.prim_func(private=True) + def expected(A: T.Buffer((4, 4), "float32")): + for i, j in T.grid(4, 4): + condition: T.bool = i < 3 # noqa: F841 + if i < 3: + A[i, j] = T.float32(0.0) after = _run_transform(before, HoistedConditionals.All, HoistedLetBindings.Never) tvm.ir.assert_structural_equal(after, expected) @@ -512,9 +519,9 @@ def before(A: T.Buffer((4, 4), "float32")): @T.prim_func(private=True) def expected(A: T.Buffer((4, 4), "float32")): for i in T.serial(4): - x = T.cast(i + 1, "float32") + x: T.float32 = T.cast(i + 1, "float32") # noqa: F841 for j in T.serial(4): - A[i, j] = 5.0 * x + T.cast(j, "float32") + A[i, j] = T.float32(5.0) * T.cast(i + 1, "float32") + T.cast(j, "float32") after = _run_transform(before, HoistedConditionals.All, HoistedLetBindings.All) tvm.ir.assert_structural_equal(after, expected) diff --git a/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py b/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py index 974d4c25c8e6..f266ee539aae 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py @@ -243,6 +243,10 @@ def test_inject_async_copy_barrier(): tvm.testing.assert_allclose(B_nd.numpy(), A_np) +# Note: the expected output contains a dead CSE variable `cse_v1 = (i < 12)`. +# CSE extracts (i < 12) before inject_ptx_async_copy runs, but the latter +# replaces the original IfThenElse guards with new cast(int32, i < 12) +# expressions for predicated async copies, leaving cse_v1 unused. expected_cuda_script = r"""#include __forceinline__ __device__ unsigned int cast_smem_ptr_to_int(const void* const smem_ptr) @@ -335,7 +339,7 @@ def test_inject_async_copy_barrier(): { unsigned int addr = cast_smem_ptr_to_int(A_shared + ((((i + 3) & 3) * 16) + ((int)threadIdx.x))); - int pred_guard = (int)cse_v1; + int pred_guard = (int)(i < 12); __asm__ __volatile__( "{ .reg .pred p;" " setp.ne.b32 p, %0, 0;" @@ -358,7 +362,7 @@ def test_inject_async_copy_barrier(): { unsigned int addr = cast_smem_ptr_to_int(B_shared + ((((i + 3) & 3) * 16) + ((int)threadIdx.x))); - int pred_guard = (int)cse_v1; + int pred_guard = (int)(i < 12); __asm__ __volatile__( "{ .reg .pred p;" " setp.ne.b32 p, %0, 0;" diff --git a/tests/python/s_tir/transform/test_s_tir_transform_plan_update_buffer_allocation_location.py b/tests/python/s_tir/transform/test_s_tir_transform_plan_update_buffer_allocation_location.py index 945f3d38f57e..c92528e5a2da 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_plan_update_buffer_allocation_location.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_plan_update_buffer_allocation_location.py @@ -373,7 +373,7 @@ def before(A: T.handle("float32")): def test_dltensor_buffer_is_unlowered(): - """Buffers allocated with a LetStmt are unmodified + """Buffers allocated with a Bind are unmodified Confirm that the `tir.PlanAndUpdateBufferAllocationLocation` pass leaves (Buffer nodes corresponding to PrimFunc DLTensor arguments) diff --git a/tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py b/tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py index 0b8c314a83ed..f1cdde039740 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py @@ -100,7 +100,7 @@ def expected(A: T.Buffer((4, 4), "float32"), E: T.Buffer((4, 4), "float32")): @tvm.testing.requires_cuda -def test_sync_let_stmt(): +def test_sync_bind(): @T.prim_func(private=True) def func(A: T.Buffer((16 * 512), "float32")): blockIdx_x = T.launch_thread("blockIdx.x", 16) @@ -113,14 +113,14 @@ def func(A: T.Buffer((16 * 512), "float32")): A_shared_1[ax0] = A[blockIdx_x * 512 + ax0] in_thread_A_temp_1 = T.decl_buffer((1,), data=in_thread_A_temp, scope="local") in_thread_A_temp_1[0] = T.float32(0) - with T.LetStmt(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x]) as A_temp: - in_thread_A_temp_1[0] = A_temp - with T.LetStmt(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 128]) as A_temp: - in_thread_A_temp_1[0] = A_temp - with T.LetStmt(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 256]) as A_temp: - in_thread_A_temp_1[0] = A_temp - with T.LetStmt(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 384]) as A_temp: - in_thread_A_temp_1[0] = A_temp + A_temp_1 = T.Bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x]) + in_thread_A_temp_1[0] = A_temp_1 + A_temp_2 = T.Bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 128]) + in_thread_A_temp_1[0] = A_temp_2 + A_temp_3 = T.Bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 256]) + in_thread_A_temp_1[0] = A_temp_3 + A_temp_4 = T.Bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 384]) + in_thread_A_temp_1[0] = A_temp_4 cross_thread_A_temp_1 = T.decl_buffer((1,), data=cross_thread_A_temp, scope="local") with T.attr( T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), @@ -148,14 +148,14 @@ def expected(A: T.Buffer((8192,), "float32")): in_thread_A_temp_1_1 = T.decl_buffer((1,), data=in_thread_A_temp_1, scope="local") in_thread_A_temp_1_1[0] = T.float32(0) T.tvm_storage_sync("shared") - with T.LetStmt(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x]) as A_temp: - in_thread_A_temp_1_1[0] = A_temp - with T.LetStmt(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 128]) as A_temp: - in_thread_A_temp_1_1[0] = A_temp - with T.LetStmt(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 256]) as A_temp: - in_thread_A_temp_1_1[0] = A_temp - with T.LetStmt(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 384]) as A_temp: - in_thread_A_temp_1_1[0] = A_temp + A_temp_1 = T.Bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x]) + in_thread_A_temp_1_1[0] = A_temp_1 + A_temp_2 = T.Bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 128]) + in_thread_A_temp_1_1[0] = A_temp_2 + A_temp_3 = T.Bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 256]) + in_thread_A_temp_1_1[0] = A_temp_3 + A_temp_4 = T.Bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 384]) + in_thread_A_temp_1_1[0] = A_temp_4 cross_thread_A_temp_1_1 = T.decl_buffer((1,), data=cross_thread_A_temp_1, scope="local") T.attr( T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), @@ -180,4 +180,4 @@ def expected(A: T.Buffer((8192,), "float32")): test_sync_else_branch() test_sync_read_thread_id_independent_location() test_sync_shared_dyn() - test_sync_let_stmt() + test_sync_bind() diff --git a/tests/python/tir-analysis/test_tir_analysis_verify_ssa.py b/tests/python/tir-analysis/test_tir_analysis_verify_ssa.py index 6d559a81d2c6..07611e55ace7 100644 --- a/tests/python/tir-analysis/test_tir_analysis_verify_ssa.py +++ b/tests/python/tir-analysis/test_tir_analysis_verify_ssa.py @@ -23,7 +23,9 @@ def test_verify_ssa(): z = tvm.tir.Evaluate(x + y) assert tvm.tir.analysis.verify_ssa(tvm.tir.PrimFunc([x, y], z)) - assert not tvm.tir.analysis.verify_ssa(tvm.tir.PrimFunc([x, y], tvm.tir.LetStmt(x, 1, z))) + assert not tvm.tir.analysis.verify_ssa( + tvm.tir.PrimFunc([x, y], tvm.tir.SeqStmt([tvm.tir.Bind(x, 1), z])) + ) def test_verify_weak_let_ssa(): diff --git a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py index 9c1bcc545b7e..73347c0728dc 100644 --- a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py +++ b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py @@ -60,14 +60,25 @@ def element_wise( def test_error_for_out_of_scope_usage(): - """A variable may not be used after its scope ends""" + """A variable may not be used after its scope ends. - @T.prim_func(check_well_formed=False) - def func(): - i = T.int32() - with T.LetStmt(42, var=i): - T.evaluate(i) - T.evaluate(i) + With flat Bind semantics, Bind vars are visible to all subsequent + siblings in the same SeqStmt. True out-of-scope usage occurs when + the Bind is inside a child scope (e.g., ForNode body) and the + variable is used outside that scope. + """ + i = tvm.tir.Var("i", "int32") + # Bind i inside a For loop body + for_stmt = tvm.tir.For( + tvm.tir.Var("j", "int32"), + 0, + 1, + tvm.tir.ForKind.SERIAL, + tvm.tir.SeqStmt([tvm.tir.Bind(i, 42), tvm.tir.Evaluate(i)]), + ) + # Use i outside the For loop — this is out of scope + body = tvm.tir.SeqStmt([for_stmt, tvm.tir.Evaluate(i)]) + func = tvm.tir.PrimFunc([], body) with pytest.raises( ValueError, match="Invalid use of undefined variable i at .* no longer in-scope." @@ -81,9 +92,9 @@ def test_error_for_nested_rebind_usage(): @T.prim_func(check_well_formed=False) def func(): i = T.int32() - with T.LetStmt(42, var=i): - with T.LetStmt(42, var=i): - T.evaluate(i) + T.Bind(42, var=i) + T.Bind(42, var=i) + T.evaluate(i) with pytest.raises( ValueError, match="ill-formed, due to multiple nested definitions of variable i" @@ -92,17 +103,22 @@ def func(): def test_error_for_repeated_binding(): - """A variable may not be re-defined after the scope ends""" + """A variable may not be re-defined in the same flat scope. + + With flat Bind semantics, sequential Bind of the same variable in the + same SeqStmt is treated as a nested redefinition (since the first Bind's + scope extends to all subsequent siblings). + """ @T.prim_func(check_well_formed=False) def func(): i = T.int32() - with T.LetStmt(42, var=i): - T.evaluate(i) - with T.LetStmt(17, var=i): - T.evaluate(i) + T.Bind(42, var=i) + T.evaluate(i) + T.Bind(17, var=i) + T.evaluate(i) - with pytest.raises(ValueError, match="multiple definitions of variable i"): + with pytest.raises(ValueError, match="multiple nested definitions of variable i"): tvm.tir.analysis.verify_well_formed(func) @@ -115,13 +131,13 @@ def test_error_for_cross_function_reuse(): class mod: @T.prim_func def func1(): - with T.LetStmt(42, var=i): - T.evaluate(i) + T.Bind(42, var=i) + T.evaluate(i) @T.prim_func def func2(): - with T.LetStmt(42, var=i): - T.evaluate(i) + T.Bind(42, var=i) + T.evaluate(i) with pytest.raises(ValueError, match="multiple definitions of variable i"): tvm.tir.analysis.verify_well_formed(mod) @@ -269,17 +285,21 @@ def test_error_message_without_previous_definition_location(): This tests the scenario where it == end(), so the error message should contain 'TIR is ill-formed, due to multiple definitions of variable' but should NOT contain 'It was first defined at' since the iterator is invalid. + + With flat Bind semantics, sequential redefinitions in the same SeqStmt + are treated as nested definitions, and the first definition location + IS known, so the message includes location info. """ @T.prim_func(check_well_formed=False) def func(): x = T.int32() - with T.LetStmt(42, var=x): - T.evaluate(x) + T.Bind(42, var=x) + T.evaluate(x) - with T.LetStmt(99, var=x): # This should trigger the error - T.evaluate(x) + T.Bind(99, var=x) # This should trigger the error + T.evaluate(x) with pytest.raises(ValueError) as exc_info: tvm.tir.analysis.verify_well_formed(func, assert_mode=True) @@ -287,7 +307,7 @@ def func(): error_msg = str(exc_info.value) assert "TIR is ill-formed" in error_msg - assert "multiple definitions of variable" in error_msg + assert "multiple nested definitions of variable" in error_msg def test_error_message_with_previous_definition_location(): @@ -302,9 +322,9 @@ def test_error_message_with_previous_definition_location(): def func(): x = T.int32() - with T.LetStmt(42, var=x): - with T.LetStmt(99, var=x): # This should trigger the error - T.evaluate(x) + T.Bind(42, var=x) + T.Bind(99, var=x) # This should trigger the error + T.evaluate(x) with pytest.raises(ValueError) as exc_info: tvm.tir.analysis.verify_well_formed(func, assert_mode=True) @@ -322,18 +342,20 @@ def func(): def test_sequential_redefinition_with_location(): """Test case 2b: Sequential redefinition that includes location info - This tests the previously_defined_ path where it != end() + This tests the previously_defined_ path where it != end(). + With flat Bind semantics, sequential redefinitions in the same SeqStmt + are treated as nested definitions with location info. """ @T.prim_func(check_well_formed=False) def func(): x = T.int32() - with T.LetStmt(1, var=x): - T.evaluate(x) + T.Bind(1, var=x) + T.evaluate(x) - with T.LetStmt(2, var=x): # This should trigger the error - T.evaluate(x) + T.Bind(2, var=x) # This should trigger the error + T.evaluate(x) with pytest.raises(ValueError) as exc_info: tvm.tir.analysis.verify_well_formed(func, assert_mode=True) @@ -341,9 +363,9 @@ def func(): error_msg = str(exc_info.value) assert "TIR is ill-formed" in error_msg - assert "multiple definitions of variable" in error_msg + assert "multiple nested definitions of variable" in error_msg assert "It was first defined at" in error_msg - assert "later re-defined at" in error_msg + assert "was re-defined at" in error_msg def test_buffer_in_buffer_map_is_well_formed(): diff --git a/tests/python/tir-base/test_tir_constructor.py b/tests/python/tir-base/test_tir_constructor.py index 5a7d516afcbe..1ebb80c038a4 100644 --- a/tests/python/tir-base/test_tir_constructor.py +++ b/tests/python/tir-base/test_tir_constructor.py @@ -131,11 +131,10 @@ def test_expr_constructor(): def test_stmt_constructor(): v = tvm.tir.Var("aa", "int32") nop = tvm.tir.Evaluate(1) - x = tvm.tir.LetStmt(v, 1, tvm.tir.Evaluate(1)) - assert isinstance(x, tvm.tir.LetStmt) + x = tvm.tir.Bind(v, 1) + assert isinstance(x, tvm.tir.Bind) assert x.var == v assert x.value.value == 1 - assert isinstance(x.body, tvm.tir.Evaluate) x = tvm.tir.AttrStmt(v == 1, "xx", 1, tvm.tir.Evaluate(1)) assert isinstance(x, tvm.tir.AttrStmt) diff --git a/tests/python/tir-base/test_tir_nodes.py b/tests/python/tir-base/test_tir_nodes.py index 4d66b929967a..a20e96b5668f 100644 --- a/tests/python/tir-base/test_tir_nodes.py +++ b/tests/python/tir-base/test_tir_nodes.py @@ -91,7 +91,7 @@ def test_ir2(): def test_let(): x = tvm.tir.Var("x", "int32") y = tvm.tir.Var("y", "int32") - stmt = tvm.tir.LetStmt(x, 10, tvm.tir.Evaluate(x + 1)) + stmt = tvm.tir.Bind(x, 10) def test_cast(): @@ -306,7 +306,7 @@ def test_prim_func(): x = tvm.tir.Var("x", "int32") y = tvm.tir.Var("y", "int32") b = tvm.tir.decl_buffer((x,), "float32") - stmt = tvm.tir.LetStmt(x, 10, tvm.tir.Evaluate(x + 1)) + stmt = tvm.tir.SeqStmt([tvm.tir.Bind(x, 10), tvm.tir.Evaluate(x + 1)]) func = tvm.tir.PrimFunc([x, y, b], stmt) # make sure we can print diff --git a/tests/python/tir-base/test_tir_structural_equal_hash.py b/tests/python/tir-base/test_tir_structural_equal_hash.py index bd0eb573b4b7..6f44138300b8 100644 --- a/tests/python/tir-base/test_tir_structural_equal_hash.py +++ b/tests/python/tir-base/test_tir_structural_equal_hash.py @@ -109,7 +109,7 @@ def test_prim_func(): # new cases b = tvm.tir.decl_buffer((x,), "float32") - stmt = tvm.tir.LetStmt(x, 10, tvm.tir.Evaluate(x + 1)) + stmt = tvm.tir.SeqStmt([tvm.tir.Bind(x, 10), tvm.tir.Evaluate(x + 1)]) func0 = tvm.tir.PrimFunc([x, y, b], stmt) # easiest way to deep copy is via save/load func1 = tvm.ir.load_json(tvm.ir.save_json(func0)) diff --git a/tests/python/tir-transform/test_tir_inline_private_functions.py b/tests/python/tir-transform/test_tir_inline_private_functions.py index 41edd410a008..e681073fa6f4 100644 --- a/tests/python/tir-transform/test_tir_inline_private_functions.py +++ b/tests/python/tir-transform/test_tir_inline_private_functions.py @@ -150,21 +150,21 @@ def subroutine(A_data: T.handle("float32"), B_data: T.handle("float32")): class Expected: @T.prim_func def main(A: T.Buffer([80, 16], "float32"), B: T.Buffer([64, 16], "float32")): - with T.LetStmt(T.address_of(A[0, 0]), var=T.handle("float32")) as A_data_1: - A_1 = T.decl_buffer(16, "float32", data=A_data_1) - B_data_1: T.handle("float32") = T.address_of(B[0, 0]) - B_1 = T.decl_buffer(16, "float32", data=B_data_1) - for i in range(16): - with T.sblock("scalar_mul_1"): - B_1[i] = A_1[i] * 2.0 - - with T.LetStmt(T.address_of(A[1, 0]), var=T.handle("float32")) as A_data_2: - A_2 = T.decl_buffer(16, "float32", data=A_data_2) - B_data_2: T.handle("float32") = T.address_of(B[1, 0]) - B_2 = T.decl_buffer(16, "float32", data=B_data_2) - for i in range(16): - with T.sblock("scalar_mul_2"): - B_2[i] = A_2[i] * 2.0 + A_data_1 = T.Bind(T.address_of(A[0, 0]), T.handle("float32")) + A_1 = T.decl_buffer(16, "float32", data=A_data_1) + B_data_1: T.handle("float32") = T.address_of(B[0, 0]) + B_1 = T.decl_buffer(16, "float32", data=B_data_1) + for i in range(16): + with T.sblock("scalar_mul_1"): + B_1[i] = A_1[i] * 2.0 + + A_data_2 = T.Bind(T.address_of(A[1, 0]), T.handle("float32")) + A_2 = T.decl_buffer(16, "float32", data=A_data_2) + B_data_2: T.handle("float32") = T.address_of(B[1, 0]) + B_2 = T.decl_buffer(16, "float32", data=B_data_2) + for i in range(16): + with T.sblock("scalar_mul_2"): + B_2[i] = A_2[i] * 2.0 class TestInlineCallOccurringInExpression(BaseTestCase): diff --git a/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py b/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py index 79cbdb91950f..c9c3703da210 100644 --- a/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py +++ b/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py @@ -40,44 +40,27 @@ def test_cse(): b = tvm.tir.Var("b", "int32") dtype = "int32" buffer = tvm.tir.decl_buffer((50,), dtype) - # Test prog : - # let z1=1 in let z2=2 in - # Mem[i1] = z1+z2; - # let x = 1 in let y = 1 in - # let a = (x+y) + (z1+z2) in - # let b = (x+y) + z3 in - # Mem[i2] = a+b; - body = tvm.tir.LetStmt( - z1, - 1, - tvm.tir.LetStmt( - z2, - 2, - tvm.tir.SeqStmt( - [ - tvm.tir.BufferStore(buffer, z1 + z2, [i1]), - tvm.tir.LetStmt( - x, - 1, - tvm.tir.LetStmt( - y, - 1, - tvm.tir.LetStmt( - a, - (x + y) + (z1 + z2), - tvm.tir.LetStmt( - b, (x + y) + z3, tvm.tir.BufferStore(buffer, a + b, [i2]) - ), - ), - ), - ), - ] - ), - ), + # Test prog (flat Bind style): + # z1 = 1; z2 = 2; + # Mem[i1] = z1+z2; + # x = 1; y = 1; + # a = (x+y) + (z1+z2); + # b = (x+y) + z3; + # Mem[i2] = a+b; + body = tvm.tir.SeqStmt( + [ + tvm.tir.Bind(z1, 1), + tvm.tir.Bind(z2, 2), + tvm.tir.BufferStore(buffer, z1 + z2, [i1]), + tvm.tir.Bind(x, 1), + tvm.tir.Bind(y, 1), + tvm.tir.Bind(a, (x + y) + (z1 + z2)), + tvm.tir.Bind(b, (x + y) + z3), + tvm.tir.BufferStore(buffer, a + b, [i2]), + ] ) - # This test program gives the opportunity to introduce two new variables, at two different - # levels and to perform replacements in the value of "a" and "b", using these new variables. - # We will check all of that underneath and more, making also sure that nothing else has changed + # This test program gives the opportunity to introduce two new variables, + # and to perform replacements in the value of "a" and "b", using these new variables. mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([i1, i2, z3], body)) body = tvm.tir.transform.CommonSubexprElimTIR()(mod) @@ -86,61 +69,72 @@ def test_cse(): body = body["main"].body # Gets the body of the main, i.e. the full statement - assert body.var.name == "z1" - assert body.value == 1 - - body = body.body - - assert body.var.name == "z2" - assert body.value == 2 - # This is the let-in for the first variable generated cse_v1 - assert isinstance(body.body, tvm.tir.LetStmt) - - body = body.body - - # And this is the name and value of this variable - cse_v1 = body.var # Keep the variable accessible for later checking the replacements - assert body.var.name == "cse_v1" - tvm.ir.assert_structural_equal(body.value, z1 + z2) - assert isinstance(body.body, tvm.tir.SeqStmt) - - body = body.body - - assert isinstance(body[0], tvm.tir.BufferStore) - assert isinstance(body[1], tvm.tir.LetStmt) - - body = body[1] - - assert body.var.name == "x" - assert body.value == 1 - - body = body.body - - assert body.var.name == "y" - assert body.value == 1 - # This is the let-in for the second variable generated cse_v2 - assert isinstance(body.body, tvm.tir.LetStmt) - - body = body.body - - # And this is the name and value of this variable - cse_v2 = body.var # Keep the variable accessible for later checking the replacements - assert body.var.name == "cse_v2" - tvm.ir.assert_structural_equal(body.value, x + y) - - body = body.body - - body.var.name == "a" - # Check that the replacement has been done correctly! - tvm.ir.assert_structural_equal(body.value, cse_v2 + cse_v1) - - body = body.body - - body.var.name == "b" - # Check that the replacement has been done correctly! - tvm.ir.assert_structural_equal(body.value, cse_v2 + z3) + # The result should be a flat SeqStmt with Bind nodes for z1, z2, cse_v1 (z1+z2), + # the store, x, y, cse_v2 (x+y), a (using cse vars), b (using cse vars), store + assert isinstance(body, tvm.tir.SeqStmt) - assert isinstance(body.body, tvm.tir.BufferStore) + # Walk through the flat sequence and check the CSE-introduced bindings + stmts = list(body) + idx = 0 + + # z1 = 1 + assert isinstance(stmts[idx], tvm.tir.Bind) + assert stmts[idx].var.name == "z1" + assert stmts[idx].value == 1 + idx += 1 + + # z2 = 2 + assert isinstance(stmts[idx], tvm.tir.Bind) + assert stmts[idx].var.name == "z2" + assert stmts[idx].value == 2 + idx += 1 + + # CSE should introduce cse_v1 = z1 + z2 here + assert isinstance(stmts[idx], tvm.tir.Bind) + cse_v1 = stmts[idx].var + assert stmts[idx].var.name == "cse_v1" + tvm.ir.assert_structural_equal(stmts[idx].value, z1 + z2) + idx += 1 + + # Mem[i1] = cse_v1 (was z1+z2, now replaced) + assert isinstance(stmts[idx], tvm.tir.BufferStore) + tvm.ir.assert_structural_equal(stmts[idx].value, cse_v1) + idx += 1 + + # x = 1 + assert isinstance(stmts[idx], tvm.tir.Bind) + assert stmts[idx].var.name == "x" + assert stmts[idx].value == 1 + idx += 1 + + # y = 1 + assert isinstance(stmts[idx], tvm.tir.Bind) + assert stmts[idx].var.name == "y" + assert stmts[idx].value == 1 + idx += 1 + + # CSE should introduce cse_v2 = x + y here + assert isinstance(stmts[idx], tvm.tir.Bind) + cse_v2 = stmts[idx].var + assert stmts[idx].var.name == "cse_v2" + tvm.ir.assert_structural_equal(stmts[idx].value, x + y) + idx += 1 + + # a = cse_v2 + cse_v1 (was (x+y) + (z1+z2), now replaced) + assert isinstance(stmts[idx], tvm.tir.Bind) + assert stmts[idx].var.name == "a" + tvm.ir.assert_structural_equal(stmts[idx].value, cse_v2 + cse_v1) + idx += 1 + + # b = cse_v2 + z3 (was (x+y) + z3, now replaced) + assert isinstance(stmts[idx], tvm.tir.Bind) + assert stmts[idx].var.name == "b" + tvm.ir.assert_structural_equal(stmts[idx].value, cse_v2 + z3) + idx += 1 + + # Mem[i2] = a + b + assert isinstance(stmts[idx], tvm.tir.BufferStore) + idx += 1 # ----------------------------------------------------- @@ -160,25 +154,24 @@ def test_cse_ifNode_1(): z = tvm.tir.Var("z", "int32") dtype = "int32" buffer = tvm.tir.decl_buffer((50,), dtype) - # Test prog : - # let b=1 in - # if(b) { - # Mem[i1] = y+z - # Mem[i2] = y+z - # } - # else { - # Mem[i3] = y - # } - body = tvm.tir.LetStmt( - b, - 1, - tvm.tir.IfThenElse( - b, - tvm.tir.SeqStmt( - [tvm.tir.BufferStore(buffer, y + z, [i1]), tvm.tir.BufferStore(buffer, y + z, [i2])] + # Test prog: + # b = 1; + # if(b) { Mem[i1] = y+z; Mem[i2] = y+z } + # else { Mem[i3] = y } + body = tvm.tir.SeqStmt( + [ + tvm.tir.Bind(b, 1), + tvm.tir.IfThenElse( + b, + tvm.tir.SeqStmt( + [ + tvm.tir.BufferStore(buffer, y + z, [i1]), + tvm.tir.BufferStore(buffer, y + z, [i2]), + ] + ), + tvm.tir.BufferStore(buffer, y, [i3]), ), - tvm.tir.BufferStore(buffer, y, [i3]), - ), + ] ) mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([i1, i2, i3, y, z], body)) @@ -188,20 +181,23 @@ def test_cse_ifNode_1(): body = body["main"].body # Gets the body of the main, i.e. the full statement - assert body.var.name == "b" - assert body.value == 1 - assert isinstance(body.body, tvm.tir.IfThenElse) - - body = body.body + assert isinstance(body, tvm.tir.SeqStmt) + stmts = list(body) - assert isinstance(body.then_case, tvm.tir.LetStmt) + # b = 1 + assert isinstance(stmts[0], tvm.tir.Bind) + assert stmts[0].var.name == "b" + assert stmts[0].value == 1 - body = body.then_case + # The If node + assert isinstance(stmts[1], tvm.tir.IfThenElse) + if_node = stmts[1] - # The let-in introduced by the CSE should appear now, inside the Then branch of the If node - assert body.var.name == "cse_v1" - # and it should contain the expression (y+z) that was redundant - tvm.ir.assert_structural_equal(body.value, y + z) + # The CSE variable should be inside the Then branch + then_stmts = list(if_node.then_case) + assert isinstance(then_stmts[0], tvm.tir.Bind) + assert then_stmts[0].var.name == "cse_v1" + tvm.ir.assert_structural_equal(then_stmts[0].value, y + z) # Second test for if nodes : Some duplicated computations appear in both the Then and Else branch. @@ -216,28 +212,24 @@ def test_cse_ifNode_2(): z = tvm.tir.Var("z", "int32") dtype = "int32" buffer = tvm.tir.decl_buffer((50,), dtype) - # Test prog : - # let b=1 in - # if(b) { - # Mem[i1] = y+z - # Mem[i2] = y - # } - # else { - # Mem[i3] = y+z - # } - body = tvm.tir.LetStmt( - b, - 1, - tvm.tir.IfThenElse( - b, - tvm.tir.SeqStmt( - [ - tvm.tir.BufferStore(buffer, y + z, [i1]), # (y+z) is present in Then branch - tvm.tir.BufferStore(buffer, y, [i2]), - ] + # Test prog: + # b = 1; + # if(b) { Mem[i1] = y+z; Mem[i2] = y } + # else { Mem[i3] = y+z } + body = tvm.tir.SeqStmt( + [ + tvm.tir.Bind(b, 1), + tvm.tir.IfThenElse( + b, + tvm.tir.SeqStmt( + [ + tvm.tir.BufferStore(buffer, y + z, [i1]), + tvm.tir.BufferStore(buffer, y, [i2]), + ] + ), + tvm.tir.BufferStore(buffer, y + z, [i3]), ), - tvm.tir.BufferStore(buffer, y + z, [i3]), # and also present in the Else branch - ), + ] ) mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([i1, i2, i3, y, z], body)) @@ -247,12 +239,18 @@ def test_cse_ifNode_2(): body = body["main"].body # Gets the body of the main, i.e. the full statement - assert isinstance(body, tvm.tir.LetStmt) + assert isinstance(body, tvm.tir.SeqStmt) + stmts = list(body) - # The let-in introduced by the CSE should appear now, at the toplevel (i.e. before the If) - assert body.var.name == "cse_v1" - # and it should contain the expression (y+z) that was redundant - tvm.ir.assert_structural_equal(body.value, y + z) + # CSE should introduce cse_v1 = y + z before the If + # Find the cse_v1 binding + found_cse = False + for s in stmts: + if isinstance(s, tvm.tir.Bind) and s.var.name == "cse_v1": + tvm.ir.assert_structural_equal(s.value, y + z) + found_cse = True + break + assert found_cse # ------------------------------------------------------------------------------------------------- @@ -288,38 +286,29 @@ def test_cse_cascade(): body = body["main"].body # Gets the body of the main, i.e. the full statement - assert isinstance(body, tvm.tir.LetStmt) - - # The second let-in (by order introduced) introduced by the CSE should appear first - cse_v2 = body.var # Keep the variable accessible for later checking the replacements - assert body.var.name == "cse_v2" - # and it should contain the expression (x+y) - tvm.ir.assert_structural_equal(body.value, (x + y)) - - body = body.body - - assert isinstance(body, tvm.tir.LetStmt) + assert isinstance(body, tvm.tir.SeqStmt) + stmts = list(body) - # The first let-in (by order introduced) introduced by the CSE should appear now, after the 2nd - cse_v1 = body.var # Keep the variable accessible for later checking the replacements - assert body.var.name == "cse_v1" - # and it should contain the expression cse_v2+z - tvm.ir.assert_structural_equal(body.value, cse_v2 + z) + # cse_v2 = x + y + assert isinstance(stmts[0], tvm.tir.Bind) + cse_v2 = stmts[0].var + assert stmts[0].var.name == "cse_v2" + tvm.ir.assert_structural_equal(stmts[0].value, (x + y)) - body = body.body + # cse_v1 = cse_v2 + z + assert isinstance(stmts[1], tvm.tir.Bind) + cse_v1 = stmts[1].var + assert stmts[1].var.name == "cse_v1" + tvm.ir.assert_structural_equal(stmts[1].value, cse_v2 + z) - assert isinstance(body, tvm.tir.SeqStmt) - assert isinstance(body[0], tvm.tir.BufferStore) - assert isinstance(body[1], tvm.tir.BufferStore) - assert isinstance(body[2], tvm.tir.BufferStore) + # Three stores + assert isinstance(stmts[2], tvm.tir.BufferStore) + assert isinstance(stmts[3], tvm.tir.BufferStore) + assert isinstance(stmts[4], tvm.tir.BufferStore) - store1 = body[0] - store2 = body[1] - store3 = body[2] - - tvm.ir.assert_structural_equal(store1.value, cse_v1) - tvm.ir.assert_structural_equal(store2.value, cse_v1) - tvm.ir.assert_structural_equal(store3.value, cse_v2) + tvm.ir.assert_structural_equal(stmts[2].value, cse_v1) + tvm.ir.assert_structural_equal(stmts[3].value, cse_v1) + tvm.ir.assert_structural_equal(stmts[4].value, cse_v2) # ----------------------------------------------------------------------------------------- @@ -331,8 +320,8 @@ def test_no_normalization_without_commoning(): z = tvm.tir.Var("z", "int32") a = tvm.tir.Var("a", "int32") # Test prog : - # let a = x + (y + z) in a - body = tvm.tir.LetStmt(a, x + (y + z), tvm.tir.Evaluate(a)) + # a = x + (y + z); evaluate(a) + body = tvm.tir.SeqStmt([tvm.tir.Bind(a, x + (y + z)), tvm.tir.Evaluate(a)]) mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([x, y, z], body)) body = tvm.tir.transform.CommonSubexprElimTIR(identify_equiv_terms=True)(mod) @@ -341,8 +330,11 @@ def test_no_normalization_without_commoning(): body = body["main"].body # Gets the body of the main, i.e. the full statement - assert body.var.name == "a" - tvm.ir.assert_structural_equal(body.value, x + (y + z)) + assert isinstance(body, tvm.tir.SeqStmt) + stmts = list(body) + assert isinstance(stmts[0], tvm.tir.Bind) + assert stmts[0].var.name == "a" + tvm.ir.assert_structural_equal(stmts[0].value, x + (y + z)) # ------------------------------------------------- @@ -360,9 +352,9 @@ def func_distributivity( def func_distributivity_expected( B: T.Buffer((50,), "int32"), i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32 ) -> None: - with T.LetStmt((y + z) * x) as cse_v1: - B[i1] = cse_v1 - B[i2] = cse_v1 + cse_v1 = T.Bind((y + z) * x) + B[i1] = cse_v1 + B[i2] = cse_v1 @T.prim_func @@ -377,9 +369,9 @@ def func_associativity( def func_associativity_expected( B: T.Buffer((50,), "int32"), i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32 ) -> None: - with T.LetStmt(x + y + z) as cse_v1: - B[i1] = cse_v1 - B[i2] = cse_v1 + cse_v1 = T.Bind(x + y + z) + B[i1] = cse_v1 + B[i2] = cse_v1 def _check(original, transformed): @@ -428,8 +420,8 @@ def test_deterministic_cse(): expression = x for add in inc1 + inc2: expression = expression + add - let_stmt = tvm.tir.LetStmt(result, expression, tvm.tir.Evaluate(result)) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([x], let_stmt)) + body = tvm.tir.SeqStmt([tvm.tir.Bind(result, expression), tvm.tir.Evaluate(result)]) + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([x], body)) initial_hash = None for _ in range(REPEATS): diff --git a/tests/python/tir-transform/test_tir_transform_convert_ssa.py b/tests/python/tir-transform/test_tir_transform_convert_ssa.py index 0d54d4fb048a..b37d337c28b0 100644 --- a/tests/python/tir-transform/test_tir_transform_convert_ssa.py +++ b/tests/python/tir-transform/test_tir_transform_convert_ssa.py @@ -23,7 +23,7 @@ from tvm.script import tir as T -def test_reuse_in_sequential_let_stmt(): +def test_reuse_in_sequential_bind(): """De-dup sequential variable bindings""" # Manually construct the PrimFunc body, as SSA violations are @@ -32,56 +32,71 @@ def test_reuse_in_sequential_let_stmt(): var = tir.Var("var", "int32") sequential_bindings = tir.SeqStmt( [ - tir.LetStmt(var, 16, tir.Evaluate(var)), - tir.LetStmt(var, 32, tir.Evaluate(var)), + tir.Bind(var, 16), + tir.Evaluate(var), + tir.Bind(var, 32), + tir.Evaluate(var), ] ) before = tir.PrimFunc([], sequential_bindings) @T.prim_func(private=True) def expected(): - with T.LetStmt(T.int32(16)) as var1: - T.evaluate(var1) - with T.LetStmt(T.int32(32)) as var2: - T.evaluate(var2) + var1 = T.Bind(T.int32(16)) + T.evaluate(var1) + var2 = T.Bind(T.int32(32)) + T.evaluate(var2) mod = tvm.IRModule.from_expr(before) mod = tvm.tir.transform.ConvertSSA()(mod) tvm.ir.assert_structural_equal(mod["main"], expected) -def test_reuse_in_nested_let_stmt(): - """De-dup nested bindings +def test_reuse_in_nested_bind(): + """De-dup sequential bindings of the same variable. - Use of a variable with nested bindings is de-duplicated to refer - to the inner-most binding that contains the use site. + In the flat Bind model, all Binds are siblings in a SeqStmt. A second + Bind of the same variable redefines it for all subsequent siblings. + ConvertSSA should create a new variable for the second binding and + update all subsequent uses to refer to the new variable. """ # Manually construct the PrimFunc body, as SSA violations are # not valid TIR, and may not be expressible in future versions - # of TVMSCript. + # of TVMScript. var = tir.Var("var", "int32") - inner_let = tir.LetStmt(var, 16, tir.Evaluate(var)) - outer_let = tir.LetStmt( - var, - 32, - tir.SeqStmt( - [ - tir.Evaluate(var), - inner_let, - tir.Evaluate(var), - ] - ), + # Note: nested SeqStmt is flattened by the IR builder, so the input + # is actually a flat SeqStmt with 5 elements. + inner_seq = tir.SeqStmt( + [ + tir.Bind(var, 16), + tir.Evaluate(var), + ] + ) + outer_seq = tir.SeqStmt( + [ + tir.Bind(var, 32), + tir.Evaluate(var), + inner_seq, + tir.Evaluate(var), + ] ) - before = tir.PrimFunc([], outer_let) + before = tir.PrimFunc([], outer_seq) - @T.prim_func(private=True) - def expected(): - with T.LetStmt(T.int32(32)) as outer: - T.evaluate(outer) - with T.LetStmt(T.int32(16)) as inner: - T.evaluate(inner) - T.evaluate(outer) + # In the flat model, the second Bind(var, 16) redefines var for + # ALL subsequent siblings including the last Evaluate. + var1 = tir.Var("var", "int32") + var2 = tir.Var("var", "int32") + expected_body = tir.SeqStmt( + [ + tir.Bind(var1, 32), + tir.Evaluate(var1), + tir.Bind(var2, 16), + tir.Evaluate(var2), + tir.Evaluate(var2), + ] + ) + expected = tir.PrimFunc([], expected_body) mod = tvm.IRModule.from_expr(before) mod = tvm.tir.transform.ConvertSSA()(mod) @@ -93,8 +108,8 @@ def test_reused_var_across_module(): @T.prim_func(private=True) def func(): - with T.LetStmt(10) as var: - T.evaluate(var) + var = T.Bind(10) + T.evaluate(var) before = tvm.IRModule( { diff --git a/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py b/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py index f0f431b8ec44..0a64d8a176ab 100644 --- a/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py +++ b/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py @@ -135,10 +135,10 @@ def build_tir(): # 2. Let binding: Aptr_dup = packed_echo(Ab.data), then store const into Ab[1] Aptr_dup = tvm.tir.Var("Aptr_dup", "handle") store1 = tvm.tir.BufferStore(Ab, tvm.tir.const(expected_value[1], "float32"), [1]) - let_stmt = tvm.tir.LetStmt(Aptr_dup, packed_echo(Ab.data), store1) + bind_stmt = tvm.tir.Bind(Aptr_dup, packed_echo(Ab.data)) # Combine into sequence - stmt = tvm.tir.SeqStmt([store0, let_stmt]) + stmt = tvm.tir.SeqStmt([store0, bind_stmt, store1]) return tvm.IRModule.from_expr( tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "packed_test") diff --git a/tests/python/tir-transform/test_tir_transform_prim_func_pass.py b/tests/python/tir-transform/test_tir_transform_prim_func_pass.py index 8bee82f02c52..72b168414263 100644 --- a/tests/python/tir-transform/test_tir_transform_prim_func_pass.py +++ b/tests/python/tir-transform/test_tir_transform_prim_func_pass.py @@ -32,7 +32,7 @@ def transform_function(self, func, mod, ctx): x = tvm.tir.Var("x", "int32") y = tvm.tir.Var("y", "int32") b = tvm.tir.decl_buffer((x,), "float32") - stmt = tvm.tir.LetStmt(x, 10, tvm.tir.Evaluate(x + 1)) + stmt = tvm.tir.SeqStmt([tvm.tir.Bind(x, 10), tvm.tir.Evaluate(x + 1)]) func = tvm.tir.PrimFunc([x, y, b], stmt) diff --git a/tests/python/tir-transform/test_tir_transform_remove_no_op.py b/tests/python/tir-transform/test_tir_transform_remove_no_op.py index 55574e838172..668b652d7087 100644 --- a/tests/python/tir-transform/test_tir_transform_remove_no_op.py +++ b/tests/python/tir-transform/test_tir_transform_remove_no_op.py @@ -132,7 +132,11 @@ def expected(A: T.Buffer(16, "int32")): def test_remove_unused_let(): - """A let statement that is never used is a no-op.""" + """With flat Bind, unused bindings are preserved by remove_no_op. + + Unused Bind elimination requires a separate two-pass approach + and is not handled by the current remove_no_op pass. + """ @T.prim_func(private=True) def before(A: T.Buffer(16, "int32")): @@ -142,6 +146,7 @@ def before(A: T.Buffer(16, "int32")): @T.prim_func(private=True) def expected(A: T.Buffer(16, "int32")): + x = 5 for i in T.serial(16): A[i] = 0 @@ -151,10 +156,10 @@ def expected(A: T.Buffer(16, "int32")): def test_remove_let_used_only_in_no_op(): - """A let statement that is never used is a no-op. + """With flat Bind, unused bindings are preserved by remove_no_op. - Similar to test_remove_unused_let, but the usage of the let binding - may have been removed by an earlier removal of another no-op. + The zero-extent for loop is removed, but the Bind for x remains + since unused Bind elimination is not handled by remove_no_op. """ @T.prim_func(private=True) @@ -165,6 +170,7 @@ def before(A: T.Buffer(16, "int32")): @T.prim_func(private=True) def expected(A: T.Buffer(16, "int32")): + x = 5 T.evaluate(0) mod = tvm.IRModule.from_expr(before) @@ -173,7 +179,7 @@ def expected(A: T.Buffer(16, "int32")): def test_keep_side_effects_of_let(): - """The side effects of a no-op let must be kept.""" + """Side-effect Bind is preserved as-is by remove_no_op.""" @T.prim_func(private=True) def before(): @@ -182,7 +188,8 @@ def before(): @T.prim_func(private=True) def expected(): - T.evaluate(T.call_extern("extern_func", dtype="int32")) + x = T.call_extern("extern_func", dtype="int32") + T.evaluate(0) mod = tvm.IRModule.from_expr(before) mod = _apply_remove_no_op(mod) @@ -549,8 +556,9 @@ def test_remove_read_write_same_index_different_expression(): """Writing a value to the same location as the read is a no-op. If the value of the index can be proven to be the same, then the - no-op can be removed, even if they have different forms of the - expression. + store can be removed. With flat Bind, the Bind for i and the + enclosing loops remain since unused Bind elimination is not + handled by remove_no_op. """ @T.prim_func(private=True) @@ -561,7 +569,9 @@ def before(A: T.Buffer(16, "int32")): @T.prim_func(private=True) def expected(A: T.Buffer(16, "int32")): - T.evaluate(0) + for io in range(4): + for ii in range(4): + i: T.int32 = 4 * io + ii mod = tvm.IRModule.from_expr(before) mod = _apply_remove_no_op(mod) diff --git a/tests/python/tir-transform/test_tir_transform_simplify.py b/tests/python/tir-transform/test_tir_transform_simplify.py index 46e094acfa20..b99e41b569e2 100644 --- a/tests/python/tir-transform/test_tir_transform_simplify.py +++ b/tests/python/tir-transform/test_tir_transform_simplify.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# ruff: noqa: E501 import tvm import tvm.testing from tvm.script import ir as I @@ -36,8 +35,14 @@ def func(A: T.handle("float32"), C: T.handle("float32"), n: T.int32): # Navigate through DeclBuffer nodes to reach the inner body while isinstance(body, tvm.tir.DeclBuffer): body = body.body - # After simplification, LetStmt -> For -> BufferStore (if is eliminated since i < 12 is always true for i in 0..10) - assert isinstance(body.body, tvm.tir.BufferStore) + # After simplification, Bind is kept (not inlined) but the if is eliminated + # since i < 12 is always true for i in 0..10. + # Body is SeqStmt(Bind(n_val, 10), For(i, ...)) + stmts = body if isinstance(body, tvm.tir.SeqStmt) else [body] + # Find the For loop in the sequence + for_stmt = [s for s in stmts if isinstance(s, tvm.tir.For)] + assert len(for_stmt) == 1, f"Expected one For loop, got {len(for_stmt)}" + assert isinstance(for_stmt[0].body, tvm.tir.BufferStore) def test_thread_extent_simplify(): @@ -56,11 +61,16 @@ def func(A: T.handle("float32"), C: T.handle("float32"), n: T.int32): # Navigate through DeclBuffer nodes to reach the inner body while isinstance(body, tvm.tir.DeclBuffer): body = body.body - # After simplification: For(tx) -> For(ty) -> BufferStore - # The LetStmt and if are eliminated since tx + ty < 12 is always true for tx in 0..10 and ty = 0 - assert isinstance(body, tvm.tir.For) # tx loop - assert isinstance(body.body, tvm.tir.For) # ty loop - assert isinstance(body.body.body, tvm.tir.BufferStore) # The if was eliminated + # After simplification: Bind is kept but the if is eliminated + # since tx + ty < 12 is always true for tx in 0..10 and ty = 0. + stmts = list(body) if isinstance(body, tvm.tir.SeqStmt) else [body] + for_stmts = [s for s in stmts if isinstance(s, tvm.tir.For)] + assert len(for_stmts) >= 1, f"Expected For loop, got stmts: {[type(s).__name__ for s in stmts]}" + # The outermost For is the tx loop + tx_loop = for_stmts[0] + assert isinstance(tx_loop, tvm.tir.For) # tx loop + assert isinstance(tx_loop.body, tvm.tir.For) # ty loop + assert isinstance(tx_loop.body.body, tvm.tir.BufferStore) # The if was eliminated def test_if_likely(): @@ -385,6 +395,10 @@ def test_prove_condition_using_let(): Not all let bindings are inlined when they occur in later expressions. However, even if they are not inlined, they may be used to prove the value of a condition. + + With flat Bind, the analyzer binds the variable to its value for + constraint proving, which also substitutes the variable in later + expressions. """ @T.prim_func(private=True) @@ -397,8 +411,8 @@ def before(A: T.Buffer(4, "bool")): @T.prim_func(private=True) def expected(A: T.Buffer(4, "bool")): for i in T.serial(4): - condition = i < 3 - A[i] = condition + condition: T.bool = i < 3 # noqa: F841 + A[i] = i < 3 after = _apply_simplify(before) tvm.ir.assert_structural_equal(after, expected) @@ -407,9 +421,8 @@ def expected(A: T.Buffer(4, "bool")): def test_prove_let_condition(): """Simplify conditions using non-inlined let bindings - Not all let bindings are inlined when they occur in later - expressions. However, even if they are not inlined, they may be - used to prove the value of a condition. + With flat Bind, analyzer binds variable to value, which also + substitutes the variable in later expressions. """ @T.prim_func(private=True) @@ -423,9 +436,9 @@ def before(A: T.Buffer(4, "bool")): @T.prim_func(private=True) def expected(A: T.Buffer(4, "bool")): for i in T.serial(4): - condition = i < 3 + condition: T.bool = i < 3 # noqa: F841 if i < 3: - A[i] = condition + A[i] = T.bool(True) after = _apply_simplify(before) tvm.ir.assert_structural_equal(after, expected) @@ -434,8 +447,9 @@ def expected(A: T.Buffer(4, "bool")): def test_prove_repeated_let_condition(): """Simplify conditions using non-inlined let bindings - A variable may be used as a literal constraint, and be recognized - as being True within the context of the constraint. + With analyzer Bind, the variable is substituted with its value, + so `if condition` becomes `if i < 3`, and within that context + the inner `if condition` simplifies to True and is eliminated. """ @T.prim_func(private=True) @@ -449,9 +463,9 @@ def before(A: T.Buffer(4, "bool")): @T.prim_func(private=True) def expected(A: T.Buffer(4, "bool")): for i in T.serial(4): - condition = i < 3 - if condition: - A[i] = True + condition: T.bool = i < 3 # noqa: F841 + if i < 3: + A[i] = T.bool(True) after = _apply_simplify(before) tvm.ir.assert_structural_equal(after, expected) @@ -492,7 +506,11 @@ def expected(A: T.Buffer(1, "int32")): def test_left_ceil_log2_lower_bound(): - """Integer bounds are propagated through topi.math.ceil_log2""" + """Integer bounds are propagated through topi.math.ceil_log2 + + With flat Bind, the Bind is kept even when the variable is unused + after simplification. The if condition is still eliminated. + """ @T.prim_func(private=True) def before(A: T.Buffer(16, "float32")): @@ -507,7 +525,11 @@ def before(A: T.Buffer(16, "float32")): @T.prim_func(private=True) def expected(A: T.Buffer(16, "float32")): for i in T.serial(16): - A[i] = 0.0 + x: T.int32 = T.Cast( # noqa: F841 + "int32", + T.ceil(T.log2(T.Cast("float64", i + 1025))), + ) + A[i] = T.float32(0) after = _apply_simplify(before) tvm.ir.assert_structural_equal(after, expected) @@ -1804,7 +1826,7 @@ def expected(A: T.Buffer(1, "int32")): def test_simplify_trivial_let_buffer_var(): - """A LetStmt used in a buffer definition should be retained""" + """A Bind used in a buffer definition should be retained""" @T.prim_func(private=True) def before(A_ptr: T.handle("float32")): @@ -1819,7 +1841,7 @@ def before(A_ptr: T.handle("float32")): def test_simplify_trivial_let_elem_offset(): - """A LetStmt used in a buffer definition should be retained, buffer fields unchanged""" + """A Bind used in a buffer definition should be retained""" @T.prim_func(private=True) def before(A_ptr: T.handle("float32"), A_offset: T.int32): @@ -1838,7 +1860,7 @@ def expected(A_ptr: T.handle("float32"), A_offset: T.int32): def test_simplify_trivial_let_shape(): - """A LetStmt used in a buffer definition should be retained, buffer fields unchanged""" + """A Bind used in a buffer definition should be retained""" @T.prim_func(private=True) def before(A_ptr: T.handle("float32"), A_size: T.int32): @@ -1857,7 +1879,7 @@ def expected(A_ptr: T.handle("float32"), A_size: T.int32): def test_simplify_trivial_let_stride(): - """A LetStmt used in a buffer definition should be retained, buffer fields unchanged""" + """A Bind used in a buffer definition should be retained""" @T.prim_func(private=True) def before(A_ptr: T.handle("float32"), A_stride: T.int32): diff --git a/tests/python/tir-transform/test_tir_transform_storage_rewrite.py b/tests/python/tir-transform/test_tir_transform_storage_rewrite.py index 5e509956df8d..5d713e6c7fd5 100644 --- a/tests/python/tir-transform/test_tir_transform_storage_rewrite.py +++ b/tests/python/tir-transform/test_tir_transform_storage_rewrite.py @@ -410,9 +410,9 @@ def test_let_buffer_rewrite(): If StorageRewrite replaces the backing variable of an array, such as when vectorizing the storage type, the variable must be - replaced in the LetStmt that defines it. Currently, StmtMutator + replaced in the Bind that defines it. Currently, StmtMutator only visits usage of variables, and does not visit definitions of - variables, so the definition in a LetStmt must be explicitly + variables, so the definition in a Bind must be explicitly handled. """ diff --git a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py index 624f428e2a93..797217b8d328 100644 --- a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py +++ b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py @@ -314,18 +314,20 @@ def test_ir_builder_tir_assert(): assert_structural_equal(assert_actual, assert_expected, map_free_vars=True) -def test_ir_builder_tir_let(): +def test_ir_builder_tir_bind(): + # Test that T.Bind emits a flat Bind statement and returns the Var. with IRBuilder() as ib: - with T.LetStmt(tir.IntImm("int32", 2)) as v: - T.evaluate(0) + v = T.Bind(tir.IntImm("int32", 2)) # the let binding generated by IRBuilder let_actual = ib.get() - # the expected Let statement - let_expected = tir.LetStmt(T.int32(), tir.IntImm("int32", 2), tir.Evaluate(0)) + # Bind is now flat (no body), so a single Bind stmt is emitted. + let_expected = tir.Bind(T.int32(), tir.IntImm("int32", 2)) # Check if the generated ir is expected assert_structural_equal(let_actual, let_expected, map_free_vars=True) + # Check that the returned value is a Var + assert isinstance(v, tir.Var) def test_ir_builder_tir_thread(): diff --git a/tests/python/tvmscript/test_tvmscript_printer_annotation.py b/tests/python/tvmscript/test_tvmscript_printer_annotation.py index 08565ce074f7..13ace54ff7c2 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_annotation.py +++ b/tests/python/tvmscript/test_tvmscript_printer_annotation.py @@ -95,9 +95,11 @@ def _func(): y = x + 1 T.evaluate(y - 1) + # With flat Bind, the body is SeqStmt([Bind(x,1), Bind(y,x+1), Evaluate(y-1)]). + # Annotate the second Bind (y = x + 1). result = _func.with_attr("global_symbol", "main").script( obj_to_annotate={ - _func.body.body: "annotation 1", + _func.body.seq[1]: "annotation 1", } ) assert ( @@ -107,7 +109,6 @@ def _func(): @T.prim_func def main(): x: T.int32 = 1 - # annotation 1 - with T.LetStmt(x + 1) as y: - T.evaluate(y - 1)""" + y: T.int32 = x + 1 # annotation 1 + T.evaluate(y - 1)""" ) diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py b/tests/python/tvmscript/test_tvmscript_printer_tir.py index 3409b731b364..8a814089661a 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_tir.py +++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py @@ -252,17 +252,22 @@ def test_for(): ) -def test_let_stmt(): +def test_bind(): with IRBuilder() as ib: - with T.LetStmt(T.float32(10)) as v: + with T.prim_func(): + v = T.Bind(T.float32(10)) ib.name("v", v) - T.evaluate(0) + T.evaluate(1) obj = ib.get() _assert_print( obj, """ -with T.LetStmt(T.float32(10.0)) as v: - T.evaluate(0) +# from tvm.script import tir as T + +@T.prim_func(private=True) +def main(): + v: T.float32 = T.float32(10.0) + T.evaluate(1) """, ) diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index f1c27518c0b9..26cce98bf9d3 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -127,301 +127,6 @@ def main(inputs: T.Buffer((64, 2, 4), "float32")) -> None: return main -def opt_gemm_mod_host(): - @tvm.script.ir_module(check_well_formed=False) - class Module: - # packedB is treated as undefined - @T.prim_func - def mmult( - args: T.handle, - arg_type_ids: T.handle, - num_args: T.int32, - out_ret_value: T.handle, - out_ret_tcode: T.handle, - ) -> T.int32: - # function attr dict - T.func_attr( - { - "tir.noalias": True, - "tir.is_entry_func": True, - "calling_conv": 1, - } - ) - # buffer definition - buf_type_ids = T.match_buffer(arg_type_ids, [3], dtype="int32") - packedB = T.decl_buffer([32768], dtype="float32") - C_global = T.decl_buffer([1024], dtype="float32") - # body - assert num_args == 3, "mmult: num_args should be 3" - arg0: T.handle = T.tvm_struct_get(args, 0, 12, dtype="handle") - arg0_code: T.int32 = buf_type_ids[0] - arg1: T.handle = T.tvm_struct_get(args, 1, 12, dtype="handle") - arg1_code: T.int32 = buf_type_ids[1] - arg2: T.handle = T.tvm_struct_get(args, 2, 12, dtype="handle") - arg2_code: T.int32 = buf_type_ids[2] - - A_data: T.handle("int32") = T.tvm_struct_get(arg0, 0, 1, dtype="handle") - T.attr(A_data, "storage_alignment", 128) - A = T.decl_buffer([1024 * 1024], dtype="int32", data=A_data) - buf0_shape_data: T.handle("int32") = T.tvm_struct_get(arg0, 0, 2, dtype="handle") - buf0_shape = T.decl_buffer([2], dtype="int32", data=buf0_shape_data) - buf0_strides_data: T.handle("int32") = T.tvm_struct_get(arg0, 0, 3, dtype="handle") - buf0_strides = T.decl_buffer([2], dtype="int32", data=buf0_strides_data) - - dev_id: T.int32 = T.tvm_struct_get(arg0, 0, 9, dtype="int32") - - B_data: T.handle("int32") = T.tvm_struct_get(arg1, 0, 1, dtype="handle") - T.attr(B_data, "storage_alignment", 128) - B = T.decl_buffer([1024 * 1024], dtype="int32", data=B_data) - buf1_shape_data: T.handle("int32") = T.tvm_struct_get(arg1, 0, 2, dtype="handle") - buf1_shape = T.decl_buffer([2], dtype="int32", data=buf1_shape_data) - buf1_strides_data: T.handle("int32") = T.tvm_struct_get(arg1, 0, 3, dtype="handle") - buf1_strides = T.decl_buffer([2], dtype="int32", data=buf1_strides_data) - - C_data: T.handle("int32") = T.tvm_struct_get(arg2, 0, 1, dtype="handle") - T.attr(C_data, "storage_alignment", 128) - C = T.decl_buffer([1024 * 1024], dtype="int32", data=C_data) - buf2_shape_data: T.handle("int32") = T.tvm_struct_get(arg2, 0, 2, dtype="handle") - buf2_shape = T.decl_buffer([2], dtype="int32", data=buf2_shape_data) - buf2_strides_data: T.handle("int32") = T.tvm_struct_get(arg2, 0, 3, dtype="handle") - buf2_strides = T.decl_buffer([2], dtype="int32", data=buf2_strides_data) - - assert (((arg0_code == 3) or (arg0_code == 13)) or (arg0_code == 7)) or ( - arg0_code == 4 - ), "mmult: Expect arg[0] to be pointer" - assert (((arg1_code == 3) or (arg1_code == 13)) or (arg1_code == 7)) or ( - arg1_code == 4 - ), "mmult: Expect arg[1] to be pointer" - assert (((arg2_code == 3) or (arg2_code == 13)) or (arg2_code == 7)) or ( - arg2_code == 4 - ), "mmult: Expect arg[2] to be pointer" - assert 2 == T.tvm_struct_get(arg0, 0, 4, dtype="int32"), ( - "arg0.ndim is expected to equal 2" - ) - assert 2 == T.tvm_struct_get(arg0, 0, 4, dtype="int32"), ( - "arg0.ndim is expected to equal 2" - ) - assert ( - (T.tvm_struct_get(arg0, 0, 5, dtype="uint8") == T.uint8(2)) - and (T.tvm_struct_get(arg0, 0, 6, dtype="uint8") == T.uint8(32)) - ) and (T.tvm_struct_get(arg0, 0, 7, dtype="uint16") == T.uint16(1)), ( - "arg0.dtype is expected to be float32" - ) - assert 1024 == T.cast(buf0_shape[0], "int32"), ( - "Argument arg0.shape[0] has an unsatisfied constraint" - ) - assert 1024 == T.cast(buf0_shape[1], "int32"), ( - "Argument arg0.shape[1] has an unsatisfied constraint" - ) - if not (T.isnullptr(buf0_strides.data, dtype="bool")): - assert (1 == T.cast(buf0_strides[1], "int32")) and ( - 1024 == T.cast(buf0_strides[0], "int32") - ), "arg0.strides: expected to be compact array" - T.evaluate(0) - assert T.uint64(0) == T.tvm_struct_get(arg0, 0, 8, dtype="uint64"), ( - "Argument arg0.byte_offset has an unsatisfied constraint" - ) - assert 1 == T.tvm_struct_get(arg0, 0, 10, dtype="int32"), ( - "Argument arg0.device_type has an unsatisfied constraint" - ) - assert 2 == T.tvm_struct_get(arg1, 0, 4, dtype="int32"), ( - "arg1.ndim is expected to equal 2" - ) - assert 2 == T.tvm_struct_get(arg1, 0, 4, dtype="int32"), ( - "arg1.ndim is expected to equal 2" - ) - assert ( - (T.tvm_struct_get(arg1, 0, 5, dtype="uint8") == T.uint8(2)) - and (T.tvm_struct_get(arg1, 0, 6, dtype="uint8") == T.uint8(32)) - ) and (T.tvm_struct_get(arg1, 0, 7, dtype="uint16") == T.uint16(1)), ( - "arg1.dtype is expected to be float32" - ) - assert 1024 == T.cast(buf1_shape[0], "int32"), ( - "Argument arg1.shape[0] has an unsatisfied constraint" - ) - assert 1024 == T.cast(buf1_shape[1], "int32"), ( - "Argument arg1.shape[1] has an unsatisfied constraint" - ) - if not (T.isnullptr(buf1_strides.data, dtype="bool")): - assert (1 == T.cast(buf1_strides[1], "int32")) and ( - 1024 == T.cast(buf1_strides[0], "int32") - ), "arg1.strides: expected to be compact array" - T.evaluate(0) - assert T.uint64(0) == T.tvm_struct_get(arg1, 0, 8, dtype="uint64"), ( - "Argument arg1.byte_offset has an unsatisfied constraint" - ) - assert 1 == T.tvm_struct_get(arg1, 0, 10, dtype="int32"), ( - "Argument arg1.device_type has an unsatisfied constraint" - ) - assert dev_id == T.tvm_struct_get(arg1, 0, 9, dtype="int32"), ( - "Argument arg1.device_id has an unsatisfied constraint" - ) - assert 2 == T.tvm_struct_get(arg2, 0, 4, dtype="int32"), ( - "arg2.ndim is expected to equal 2" - ) - assert 2 == T.tvm_struct_get(arg2, 0, 4, dtype="int32"), ( - "arg2.ndim is expected to equal 2" - ) - assert ( - (T.tvm_struct_get(arg2, 0, 5, dtype="uint8") == T.uint8(2)) - and (T.tvm_struct_get(arg2, 0, 6, dtype="uint8") == T.uint8(32)) - ) and (T.tvm_struct_get(arg2, 0, 7, dtype="uint16") == T.uint16(1)), ( - "arg2.dtype is expected to be float32" - ) - assert 1024 == T.cast(buf2_shape[0], "int32"), ( - "Argument arg2.shape[0] has an unsatisfied constraint" - ) - assert 1024 == T.cast(buf2_shape[1], "int32"), ( - "Argument arg2.shape[1] has an unsatisfied constraint" - ) - if not (T.isnullptr(buf2_strides.data, dtype="bool")): - assert (1 == T.cast(buf2_strides[1], "int32")) and ( - 1024 == T.cast(buf2_strides[0], "int32") - ), "arg2.strides: expected to be compact array" - T.evaluate(0) - assert T.uint64(0) == T.tvm_struct_get(arg2, 0, 8, dtype="uint64"), ( - "Argument arg2.byte_offset has an unsatisfied constraint" - ) - assert 1 == T.tvm_struct_get(arg2, 0, 10, dtype="int32"), ( - "Argument arg2.device_type has an unsatisfied constraint" - ) - assert dev_id == T.tvm_struct_get(arg2, 0, 9, dtype="int32"), ( - "Argument arg2.device_id has an unsatisfied constraint" - ) - T.attr(0, "compute_scope", "mmult_compute_") - T.attr(packedB.data, "storage_scope", "global") - T.attr(packedB.data, "storage_alignment", 128) - with T.LetStmt( - T.TVMBackendAllocWorkspace(1, dev_id, T.uint64(4194304), 2, 32, dtype="handle"), - var=packedB.data, - ): - if T.isnullptr(packedB.data, dtype="bool"): - T.evaluate(T.tvm_throw_last_error(dtype="int32")) - for x in T.parallel(0, 32): - for y in T.serial(0, 1024): - packedB[T.ramp(((x * 32768) + (y * 32)), 1, 32)] = B[ - T.ramp(((y * 1024) + (x * 32)), 1, 32) - ] - for x_outer in T.parallel(0, 32): - T.attr(C_global.data, "storage_scope", "global") - T.attr(C_global.data, "storage_alignment", 128) - with T.LetStmt( - T.TVMBackendAllocWorkspace( - 1, dev_id, T.uint64(4096), 2, 32, dtype="handle" - ), - var=C_global.data, - ): - if T.isnullptr(C_global.data, dtype="bool"): - T.evaluate(T.tvm_throw_last_error(dtype="int32")) - for y_outer in T.serial(0, 32): - for x_c_init in T.serial(0, 32): - C_global[T.ramp((x_c_init * 32), 1, 32)] = T.broadcast( - T.float32(0), 32 - ) - for k_outer in T.serial(0, 256): - for x_c in T.serial(0, 32): - C_global[T.ramp((x_c * 32), 1, 32)] = T.call_llvm_pure_intrin( - T.uint32(97), - T.broadcast( - A[ - ( - ((x_outer * 32768) + (x_c * 1024)) - + (k_outer * 4) - ), - ], - 32, - ), - packedB[ - T.ramp(((y_outer * 32768) + (k_outer * 128)), 1, 32) - ], - C_global[T.ramp((x_c * 32), 1, 32)], - dtype="float32x32", - ) - C_global[T.ramp((x_c * 32), 1, 32)] = T.call_llvm_pure_intrin( - T.uint32(97), - T.broadcast( - A[ - ( - ( - ((x_outer * 32768) + (x_c * 1024)) - + (k_outer * 4) - ) - + 1 - ), - ], - 32, - ), - packedB[ - T.ramp( - (((y_outer * 32768) + (k_outer * 128)) + 32), 1, 32 - ) - ], - C_global[T.ramp((x_c * 32), 1, 32)], - dtype="float32x32", - ) - C_global[T.ramp((x_c * 32), 1, 32)] = T.call_llvm_pure_intrin( - T.uint32(97), - T.broadcast( - A[ - ( - ( - ((x_outer * 32768) + (x_c * 1024)) - + (k_outer * 4) - ) - + 2 - ), - ], - 32, - ), - packedB[ - T.ramp( - (((y_outer * 32768) + (k_outer * 128)) + 64), 1, 32 - ) - ], - C_global[T.ramp((x_c * 32), 1, 32)], - dtype="float32x32", - ) - C_global[T.ramp((x_c * 32), 1, 32)] = T.call_llvm_pure_intrin( - T.uint32(97), - T.broadcast( - A[ - ( - ( - ((x_outer * 32768) + (x_c * 1024)) - + (k_outer * 4) - ) - + 3 - ), - ], - 32, - ), - packedB[ - T.ramp( - (((y_outer * 32768) + (k_outer * 128)) + 96), 1, 32 - ) - ], - C_global[T.ramp((x_c * 32), 1, 32)], - dtype="float32x32", - ) - for x_inner in T.serial(0, 32): - for y_inner in T.serial(0, 32): - C[ - ( - ( - ((x_outer * 32768) + (x_inner * 1024)) - + (y_outer * 32) - ) - + y_inner - ) - ] = C_global[((x_inner * 32) + y_inner)] - if T.TVMBackendFreeWorkspace(1, dev_id, C_global.data, dtype="int32") != 0: - T.evaluate(T.tvm_throw_last_error(dtype="int32")) - if T.TVMBackendFreeWorkspace(1, dev_id, packedB.data, dtype="int32") != 0: - T.evaluate(T.tvm_throw_last_error(dtype="int32")) - - return Module - - def opt_conv_tensorcore_lower(): @T.prim_func def func( @@ -3049,25 +2754,12 @@ def func(): return func -def let_stmt_var(): +def bind_var(): @T.prim_func def func(): - with T.LetStmt(0) as x: - with T.LetStmt(0) as y: - T.evaluate(0) + x = T.Bind(0) + y = T.Bind(0) T.evaluate(0) - - return func - - -def let_stmt_value(): - # uninitialized var - @T.prim_func(check_well_formed=False) - def func(): - y = T.int32() - with T.LetStmt(y) as x: - with T.LetStmt(0, var=y): - T.evaluate(0) T.evaluate(0) return func @@ -3612,7 +3304,6 @@ def func(A: R.Tensor(["N"], "float16"), _: R.Prim(value="threshold")): ir_generator = tvm.testing.parameter( launch_env_thread, opt_gemm_lower, - opt_gemm_mod_host, opt_conv_tensorcore_lower, opt_conv_tensorcore_mod_host, vthread_func, @@ -3669,8 +3360,7 @@ def func(A: R.Tensor(["N"], "float16"), _: R.Prim(value="threshold")): *nested_boolean_expressions(), multi_env_threads, intrinsic_pow, - let_stmt_var, - let_stmt_value, + bind_var, string_stride, string_stride_int64, merge_shape_var_def, diff --git a/tests/python/tvmscript/test_tvmscript_syntax_sugar.py b/tests/python/tvmscript/test_tvmscript_syntax_sugar.py index 08bf90123b13..f3d19f8ebad2 100644 --- a/tests/python/tvmscript/test_tvmscript_syntax_sugar.py +++ b/tests/python/tvmscript/test_tvmscript_syntax_sugar.py @@ -241,7 +241,7 @@ def func(a: T.handle): T.evaluate(0) -def test_letstmt_bufferload_without_type_annotation(): +def test_bind_bufferload_without_type_annotation(): # Variable assignment of PrimExpr types uses the dtype of the # PrimExpr to determine the variable's dtype. Parsing of # buf[indices] is done by generating a BufferSlice object, which @@ -255,7 +255,7 @@ def func_without_type_annotation(A: T.Buffer((1,), "int32")): T.evaluate(x) -def test_letstmt_bind_with_constant(): +def test_bind_with_constant(): @T.prim_func def constant_binds(): x = T.meta_var(1) @@ -410,7 +410,7 @@ def test_preserve_trivial_let_binding(): @T.prim_func def explicit(i: T.int32): j = T.int32() - T.LetStmt(i, var=j) + T.Bind(i, var=j) T.evaluate(j) @T.prim_func @@ -425,7 +425,7 @@ def test_preserve_trivial_let_binding_of_value(): @T.prim_func def explicit(i: T.int32): j = T.int32() - T.LetStmt(42, var=j) + T.Bind(42, var=j) T.evaluate(j) @T.prim_func @@ -447,7 +447,7 @@ def func(i: T.int32): def test_preserve_variable_name(): - """Use variable name when generating tir::LetStmt""" + """Use variable name when generating tir::Bind""" @T.prim_func def func(): @@ -455,7 +455,8 @@ def func(): j = i // 4 T.evaluate(j) - var_name = func.body.body.var.name + # With flat Bind, the for body is SeqStmt([Bind(j, i//4), Evaluate(j)]) + var_name = func.body.body.seq[0].var.name assert var_name == "j"