Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
48bb763
[TIR][Refactor] Introduce BindNode and Bind (PR 1+2/11: core IR + fun…
tqchen Mar 2, 2026
b5a545f
[TIR] Phase out LetStmtNode body field: migrate to flat BindNode
tqchen Mar 2, 2026
1c9f460
[TIR] Fix passes and tests for flat BindNode semantics
tqchen Mar 2, 2026
a596311
[REFACTOR][TIR] Complete LetStmt-to-Bind migration: fix remaining issues
tqchen Mar 2, 2026
72c4043
[REFACTOR][TIR] Cleanup: rename LetStmt references to Bind, use TVM_F…
tqchen Mar 2, 2026
9de2a72
[REFACTOR][TIR] Simplify CSE: remove VisitSeqStmtSlice, use flat Bind…
tqchen Mar 2, 2026
a978486
[REFACTOR][TIR] Simplify ConvertSSA: remove SeqStmt handler for flat …
tqchen Mar 2, 2026
adfcee8
[REFACTOR][TIR] Simplify hoist_expression: remove SeqStmt handler for…
tqchen Mar 2, 2026
9757900
[REFACTOR][TIR] Simplify sblock_access_region_detector: remove SeqStm…
tqchen Mar 2, 2026
8763d3c
[REFACTOR][TIR] Simplify remove_no_op: remove Bind elimination from S…
tqchen Mar 2, 2026
5cfc855
[REFACTOR][TIR] Simplify lower_tvm_builtin: flatten MakeNdMemAllocWit…
tqchen Mar 2, 2026
af7a730
[REFACTOR][TIR] Remove obsolete opt_gemm_mod_host and let_stmt_value …
tqchen Mar 2, 2026
e76fc81
[REFACTOR][TIR] Disable CanInlineLetStmt for flat Bind
tqchen Mar 2, 2026
f7689d7
[REFACTOR][TIR] Fix Bind scope management in hoist_expression, ir_uti…
tqchen Mar 2, 2026
703813b
[REFACTOR][TIR] Simplify tir_visitor_with_path: use scope-based Bind …
tqchen Mar 2, 2026
8575eaf
[REFACTOR][TIR] Fix remove_store_undef and inject_ptx_async_copy for …
tqchen Mar 2, 2026
6a52630
[REFACTOR][TIR] Restore hoist_expression SeqStmt handler lifecycle ma…
tqchen Mar 2, 2026
ebf5d76
[REFACTOR][TIR] Remove LetStmt/LetStmtNode backward-compat aliases
tqchen Mar 2, 2026
7f8886f
[REFACTOR][TIR] Refactor ConvertSSA to use ScopeStack for var remap m…
tqchen Mar 2, 2026
d5374ee
[REFACTOR][TIR] Refactor CSE pass to use ScopeStack for context manag…
tqchen Mar 2, 2026
db3c5da
[REFACTOR][TIR] Restore free_nd in MakeNdMemAllocWithScope for flat Bind
tqchen Mar 2, 2026
21ed144
[REFACTOR][TIR] Remove duplicate LOG(WARNING) in ir_docsifier_functor.h
tqchen Mar 2, 2026
05ed2d4
[REFACTOR][TIR] Rename stale letstmt references in test function names
tqchen Mar 2, 2026
ca589b0
[REFACTOR][TIR] Rename LetFrame to BindFrame in ir_builder
tqchen Mar 2, 2026
48e2f7a
[REFACTOR][TIR] Fix AllocateNode->AllocBufferNode references after re…
tqchen Mar 4, 2026
fe104fc
[REFACTOR][TIR] clang-format fixes
tqchen Mar 4, 2026
fcf21ff
[TIR] Replace pending_nd_frees_ with ScopeStack in lower_tvm_builtin
tqchen Mar 4, 2026
f5f28f6
[TIR] Add SSA invariant comment for non_inlined_bindings_ in simplify.cc
tqchen Mar 4, 2026
2d72320
[TIR] Optimize CSE SeqStmt handler to batch trivial Binds
tqchen Mar 4, 2026
c9fe4fe
[TIR] Fix RAII scope guard bugs for flat BindNode in control_flow_gra…
tqchen Mar 4, 2026
01f04cd
[TIR] Fix three LetStmt-to-Bind refactor bugs from second review
tqchen Mar 4, 2026
9877fba
[REFACTOR][TIR] Remove BindFrame, make Bind a flat non-frame statement
tqchen Mar 4, 2026
98b8bf1
[refactor] inject_virtual_thread: replace goto with break, improve co…
tqchen Mar 4, 2026
a89b7dc
[fixup] Remove redundant AllocBuffer/DeclBuffer overrides from IndexD…
tqchen Mar 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 0 additions & 42 deletions include/tvm/script/ir_builder/tir/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -341,48 +341,6 @@ class AssertFrame : public TIRFrame {
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(AssertFrame, TIRFrame, AssertFrameNode);
};

/*!
* \brief A frame represents the let binding expression, which binds a var.
*
* \sa LetFrameNode
*/
class LetFrameNode : public TIRFrameNode {
public:
/*! \brief The variable we bind to */
tvm::tir::Var var;
/*! \brief The value we bind var to */
PrimExpr value;

static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<LetFrameNode>()
.def_ro("var", &LetFrameNode::var)
.def_ro("value", &LetFrameNode::value);
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.LetFrame", LetFrameNode, TIRFrameNode);

public:
/*!
* \brief The method called when exiting RAII scope.
* \sa tvm::support::With
*/
void ExitWithScope() final;
};

/*!
* \brief Managed reference to LetFrameNode.
*
* \sa LetFrameNode
*/
class LetFrame : public TIRFrame {
public:
explicit LetFrame(ObjectPtr<LetFrameNode> data) : TIRFrame(ffi::UnsafeInit{}) {
TVM_FFI_ICHECK(data != nullptr);
data_ = std::move(data);
}
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(LetFrame, TIRFrame, LetFrameNode);
};

/*!
* \brief The LaunchThreadFrameNode.
* \note It is used only inside a PrimFunc.
Expand Down
13 changes: 8 additions & 5 deletions include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -294,16 +294,19 @@ AssertFrame Assert(PrimExpr condition, ffi::String error_kind,
ffi::Array<ffi::String> message_parts);

/*!
* \brief The let binding.
* \brief Create a Bind (variable binding).
*
* Emits a flat Bind statement to the current frame and returns the bound variable.
*
* \param value The value to be bound.
* \param type_annotation The type annotation of the let binding.
* \param type_annotation The type annotation of the binding.
* Usually it is used for fine-grained var typing,
* particularly, PointerType.
* \param var The variable to be bound. If not specified, a new variable will be created.
* \return The created LetFrame.
* \return The bound Var.
*/
LetFrame LetStmt(PrimExpr value, ffi::Optional<Type> type_annotation = std::nullopt,
ffi::Optional<Var> var = std::nullopt);
Var Bind(PrimExpr value, ffi::Optional<Type> type_annotation = std::nullopt,
ffi::Optional<Var> var = std::nullopt);

/*!
* \brief The allocate node.
Expand Down
1 change: 1 addition & 0 deletions include/tvm/script/printer/ir_docsifier_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class IRDocsifierFunctor {
<< runtime::Object::TypeIndex2Key(type_index) << " (token: " << token
<< ")"
<< ". ObjectType: " << obj->GetTypeKey() << ". Object: " << obj;
TVM_FFI_UNREACHABLE();
}

/*!
Expand Down
36 changes: 19 additions & 17 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,37 +68,38 @@ class Stmt : public ObjectRef {
};

/*!
* \brief Let binding, bind var to value, then run body.
* \brief Bind a variable to a value in the enclosing scope.
*
* BindNode has no body field. The bound variable is visible
* in all subsequent statements within the same enclosing scope (SeqStmt,
* ForNode.body, etc.). This enables flat (non-nested) IR sequences.
*/
class LetStmtNode : public StmtNode {
class BindNode : public StmtNode {
public:
/*! \brief The variable. */
/*! \brief The variable being bound. */
Var var;
/*! \brief The value to be bound. */
/*! \brief The value to bind to the variable. */
PrimExpr value;
/*! \brief The body block. */
Stmt body;

static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<LetStmtNode>()
.def_ro("var", &LetStmtNode::var, refl::AttachFieldFlag::SEqHashDef())
.def_ro("value", &LetStmtNode::value)
.def_ro("body", &LetStmtNode::body);
refl::ObjectDef<BindNode>()
.def_ro("var", &BindNode::var, refl::AttachFieldFlag::SEqHashDef())
.def_ro("value", &BindNode::value);
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.LetStmt", LetStmtNode, StmtNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Bind", BindNode, StmtNode);
};

/*!
* \brief Managed reference to LetStmtNode.
* \sa LetStmtNode
* \brief Managed reference to BindNode.
* \sa BindNode
*/
class LetStmt : public Stmt {
class Bind : public Stmt {
public:
TVM_DLL LetStmt(Var var, PrimExpr value, Stmt body, Span span = Span());
TVM_DLL Bind(Var var, PrimExpr value, Span span = Span());

TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(LetStmt, Stmt, LetStmtNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(LetStmtNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Bind, Stmt, BindNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BindNode);
};

/*!
Expand Down Expand Up @@ -978,6 +979,7 @@ inline const char* ForKind2String(ForKind t) {
return "thread_binding";
}
TVM_FFI_THROW(InternalError) << "Unknown ForKind" << t;
TVM_FFI_UNREACHABLE();
}

} // namespace tir
Expand Down
8 changes: 4 additions & 4 deletions include/tvm/tir/stmt_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
return vtable(n, this, std::forward<Args>(args)...);
}
// Functions that can be overriden by subclass
virtual R VisitStmt_(const LetStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const BindNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const AttrStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const IfThenElseNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const ForNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
Expand All @@ -106,7 +106,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
// initialize the vtable.
static FType InitVTable() {
FType vtable;
IR_STMT_FUNCTOR_DISPATCH(LetStmtNode);
IR_STMT_FUNCTOR_DISPATCH(BindNode);
IR_STMT_FUNCTOR_DISPATCH(AttrStmtNode);
IR_STMT_FUNCTOR_DISPATCH(IfThenElseNode);
IR_STMT_FUNCTOR_DISPATCH(ForNode);
Expand Down Expand Up @@ -159,9 +159,9 @@ class TVM_DLL StmtVisitor : protected StmtFunctor<void(const Stmt&)> {
*/
virtual void VisitBufferUse(const Buffer& buffer);
// statement visitor
void VisitStmt_(const BindNode* op) override;
void VisitStmt_(const AttrStmtNode* op) override;
void VisitStmt_(const IfThenElseNode* op) override;
void VisitStmt_(const LetStmtNode* op) override;
void VisitStmt_(const ForNode* op) override;
void VisitStmt_(const WhileNode* op) override;
void VisitStmt_(const AllocBufferNode* op) override;
Expand Down Expand Up @@ -273,9 +273,9 @@ class TVM_DLL StmtMutator : protected StmtFunctor<Stmt(const Stmt&)> {
*/
virtual Buffer VisitBufferUse(const Buffer& buffer);
// statement visitor
Stmt VisitStmt_(const BindNode* op) override;
Stmt VisitStmt_(const AttrStmtNode* op) override;
Stmt VisitStmt_(const IfThenElseNode* op) override;
Stmt VisitStmt_(const LetStmtNode* op) override;
Stmt VisitStmt_(const ForNode* op) override;
Stmt VisitStmt_(const WhileNode* op) override;
Stmt VisitStmt_(const AllocBufferNode* op) override;
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/tir/var.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ namespace tir {
* - Allocate
* - For
* - Let
* - LetStmt
* - Bind
*/
class VarNode : public PrimExprNode {
public:
Expand Down
7 changes: 0 additions & 7 deletions python/tvm/script/ir_builder/tir/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,6 @@ def __enter__(self) -> Var | list[Var]: # type: ignore[override]
class AssertFrame(TIRFrame): ...


@_register_object("script.ir_builder.tir.LetFrame")
class LetFrame(TIRFrame):
def __enter__(self) -> Var:
super().__enter__()
return self.var


@_register_object("script.ir_builder.tir.AllocateFrame")
class AllocateFrame(TIRFrame):
def __enter__(self) -> Buffer:
Expand Down
30 changes: 16 additions & 14 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,35 +983,37 @@ def Assert(condition: PrimExpr, message, error_kind: str = "RuntimeError") -> fr
return _ffi_api.Assert(condition, error_kind, message) # type: ignore[attr-defined] # pylint: disable=no-member


def LetStmt( # pylint: disable=invalid-name
def Bind( # pylint: disable=invalid-name
value: PrimExpr,
type_annotation: Type | None = None, # pylint: disable=redefined-outer-name
*,
var: Var | None = None, # pylint: disable=redefined-outer-name
) -> frame.LetFrame:
"""Create a LetStmt binding
) -> Var:
"""Create a Bind (variable binding).

Emits a flat Bind statement to the current frame and returns the bound variable.

Parameters
----------
value : PrimExpr
The value to be bound.
type_annotation : Optional[Type] = None
The type annotation of the let binding. Usually it is used for fine-grained var typing,
The type annotation of the binding. Usually it is used for fine-grained var typing,
particularly, PointerType.
var : Optional[Var] = None
The variable to bind. If not specified, a new variable will be created.

Returns
-------
let_frame : frame.LetFrame
The result LetFrame.
var : Var
The bound variable.
"""
if type_annotation is not None:
if callable(type_annotation):
type_annotation = type_annotation()
if isinstance(type_annotation, Var):
type_annotation = type_annotation.type_annotation
return _ffi_api.LetStmt(value, type_annotation, var) # type: ignore[attr-defined] # pylint: disable=no-member
return _ffi_api.Bind(value, type_annotation, var) # type: ignore[attr-defined] # pylint: disable=no-member


def Let( # pylint: disable=invalid-name
Expand All @@ -1028,7 +1030,7 @@ def let(
v: Var,
value: PrimExpr,
body: PrimExpr = None,
) -> frame.LetFrame:
) -> Var:
"""Create a new let binding.

Parameters
Expand All @@ -1044,17 +1046,17 @@ def let(

Returns
-------
res : frame.LetFrame
The result LetFrame.
res : Var
The bound variable.
"""

@deprecated("T.let", "T.Let")
def let_expr(v: Var, value: PrimExpr, body: PrimExpr) -> PrimExpr:
return tir.Let(v, value, body)

@deprecated("T.let", "T.LetStmt")
def let_stmt(v: Var, value: PrimExpr) -> frame.LetFrame:
return _ffi_api.LegacyLetStmt(v, value) # type: ignore[attr-defined] # pylint: disable=no-member
@deprecated("T.let", "T.Bind")
def let_stmt(v: Var, value: PrimExpr) -> Var:
return Bind(value, var=v)

if body is None:
return let_stmt(v, value)
Expand Down Expand Up @@ -2343,7 +2345,7 @@ def wrapped(*args, **kwargs):
"Call",
"CallEffectKind",
"let",
"LetStmt",
"Bind",
"Let",
"IterVar",
"CommReducer",
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/script/parser/core/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,9 @@ def _duplicate_lhs_check(self, target: doc.expr) -> bool | set[str]:
return {target.id}
elif isinstance(target, doc.Starred):
return self._duplicate_lhs_check(target.value)
elif isinstance(target, doc.Attribute):
# Attribute assignment like packedB.data = ..., treated as rebinding.
return {target.attr}
else:
self.report_error(target, "Invalid type in assign statement")
raise NotImplementedError
Expand Down
14 changes: 7 additions & 7 deletions python/tvm/script/parser/tir/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,8 @@ def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) -
return value
else:
value = tvm.runtime.convert(value)
frame = T.LetStmt(value)
var = frame.var
var = T.Bind(value)
IRBuilder.name(var_name, var)
frame.add_callback(partial(frame.__exit__, None, None, None))
frame.__enter__()
return var


Expand Down Expand Up @@ -352,9 +349,7 @@ def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None:
if not isinstance(ann_var, Var):
self.report_error(node.annotation, "Annotation should be Var")
self.eval_assign(target=lhs, source=ann_var, bind_value=bind_assign_value)
frame = T.LetStmt(rhs, var=ann_var)
frame.add_callback(partial(frame.__exit__, None, None, None))
frame.__enter__()
T.Bind(rhs, var=ann_var)


@dispatch.register(token="tir", type_name="With")
Expand Down Expand Up @@ -471,6 +466,11 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
elif isinstance(res, Frame):
res.add_callback(partial(res.__exit__, None, None, None))
res.__enter__()
elif isinstance(res, Var):
# Standalone Var expression (e.g. from T.Bind(value, var=v)) --
# the Bind statement was already emitted to the parent frame by the FFI call,
# so just discard the returned Var.
pass
elif isinstance(res, PrimExpr):
T.evaluate(res)
elif isinstance(res, int | bool):
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from .expr import Select, BufferLoad, ProducerLoad, Ramp, Broadcast, Shuffle
from .expr import Call, CallEffectKind, Let, IterVar, CommReducer

from .stmt import Stmt, LetStmt, AssertStmt, ForKind, For, While
from .stmt import Stmt, Bind, AssertStmt, ForKind, For, While
from .stmt import (
BufferStore,
AllocBuffer,
Expand Down
Loading
Loading