[REFACTOR][TIR] Rename LetStmt to Bind and flatten to sequential semantics#18874
[REFACTOR][TIR] Rename LetStmt to Bind and flatten to sequential semantics#18874tqchen wants to merge 34 commits intoapache:mainfrom
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a significant refactoring of the TIR (Tensor Intermediate Representation) Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request is a large-scale refactoring that renames LetStmt to Bind and changes its semantics to be a sequential binding rather than a scoped one, aiming to flatten the IR and simplify many transformations. While the refactoring is generally well-executed and improves the TIR infrastructure, a significant security vulnerability was identified. Specifically, in the LowerTVMBuiltin pass, the use of a global pending_nd_frees_ list for nd_mem_alloc_with_scope deallocations can lead to incorrect hoisting of free calls out of conditional branches. This issue could cause memory corruption or crashes on the target device and requires immediate attention.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request refactors the nested LetStmt with a flattened Bind statement in TIR, aiming to simplify the IR structure for easier analysis and transformation. While this refactoring is extensive and largely consistent with its goal, a security audit identified critical regressions in variable scoping and memory management within transformation passes. Specifically, multiple passes now leak variable bindings across scopes, and the lower_tvm_builtin pass fails to correctly scope memory deallocations, leading to memory leaks and potential Denial of Service (DoS) vulnerabilities in the generated code. Furthermore, there is a potential performance issue in the common subexpression elimination pass and concerns regarding reduced test coverage due to the removal of a large test case.
I am having trouble creating individual review comments. Click here to see my feedback.
src/tir/transform/simplify.cc (162-180)
The StmtSimplifier now leaks variable bindings across scopes. In the previous implementation, LetStmt bindings were naturally scoped by the recursion into the body. In the new flat IR, VisitStmt_(const BindNode* op) adds bindings to the non_inlined_bindings_ map, but these bindings are never removed when the enclosing scope (e.g., a For loop or IfThenElse branch) exits. This is because StmtSimplifier does not override scope-defining nodes to manage the state of this map. This can lead to incorrect simplifications based on out-of-scope variables, potentially resulting in incorrect code generation and bypassing safety checks.
src/tir/transform/common_subexpr_elim.cc (575-620)
The implementation of VisitStmt_(const SeqStmtNode* op) appears to have a time complexity that is quadratic in the number of statements in the sequence. For each BindNode encountered, it recursively calls VisitStmt on the remainder of the sequence. While this approach seems correct for enabling common subexpression elimination across sibling statements, it could become a performance bottleneck for functions with very long sequences of Bind nodes. A more performant single-pass approach should be considered if this becomes a bottleneck.
tests/python/tvmscript/test_tvmscript_roundtrip.py (130-423)
The removal of the large test case opt_gemm_mod_host is concerning as it might lead to a gap in test coverage. Please either restore the test and adapt it to the new IR structure, or add new tests that provide equivalent coverage for the host-side code generation and packed function ABI handling that opt_gemm_mod_host was testing.
…ctor + analysis infrastructure)
Introduces `BindNode`/`Bind`, a new TIR statement node that binds a variable
to a value with flat (no-body) scope semantics, as the first step of the
LetStmt-to-Bind refactor. Unlike `LetStmtNode`, `BindNode` has no body field;
the bound variable is visible in subsequent siblings of the enclosing SeqStmt.
PR 1 — Core IR node + functor infrastructure:
- `include/tvm/tir/stmt.h`: Define BindNode (var, value, no body) and Bind ref class
- `src/tir/ir/stmt.cc`: Implement Bind constructor, RegisterReflection, GlobalDef
- `include/tvm/tir/stmt_functor.h`: Add VisitStmt_(BindNode*) to StmtFunctor vtable,
StmtVisitor, StmtMutator
- `src/tir/ir/stmt_functor.cc`: Implement StmtVisitor and StmtMutator for BindNode
- `src/tir/ir/py_functor.cc`: Add BindNode dispatch entries for Python functors
- `src/tir/ir/tir_visitor_with_path.{h,cc}`: Add BindNode visitor (visits value only)
PR 2 — Base visitors/mutators + arithmetic + analysis:
- `src/arith/ir_mutator_with_analyzer.{h,cc}`: BindNode handler binds in analyzer
- `src/arith/ir_visitor_with_analyzer.{h,cc}`: BindNode handler visits value + binds
- `src/tir/ir/data_type_rewriter.{h,cc}`: BindNode support in DataTypeLegalizer and
IndexDataTypeRewriter
- `src/tir/analysis/var_use_def_analysis.{h,cc}`: BindNode registers HandleDef
- `src/tir/analysis/verify_ssa.cc`: BindNode calls MarkDef
- `src/tir/analysis/verify_memory.cc`: BindNode books defs_ map
- `src/tir/analysis/control_flow_graph.cc`: BindNode checks UsesLoopVar
LetStmtNode/LetStmt are kept intact (deprecated aliases come in PR 11).
Tests: TIR base (269 passed, 2 skipped), all-platform-minimal (75 passed, 77 skipped)
This commit completes the migration from tree-nested LetStmtNode
(with body field) to flat BindNode (no body) across the entire TVM
codebase. BindNode binds a variable visible to subsequent siblings
in the enclosing SeqStmt scope, replacing the old nested scoping.
Key changes:
- LetStmtNode is now a `using` alias for BindNode
- All C++ LetStmt(var,val,body) constructions -> SeqStmt({Bind(var,val), body})
- All VisitStmt_(LetStmtNode*) handlers -> VisitStmt_(BindNode*)
with op->body access removed (parent SeqStmt handles traversal)
- TIRVisitorWithPath::SeqStmt handler tracks Bind-defined vars
for well-formed verification
- CSE pass: new SeqStmtNode handler + VisitSeqStmtSlice to process
flat Bind sequences (mirrors old nested LetStmt CSE behavior)
- Python: added Bind class, LetStmt = Bind alias
- Updated ~65 C++ files and test files
Update TIR passes and tests to work correctly with the flat BindNode model (no body field) where variable scoping is managed via SeqStmt siblings instead of nested tree structure. Pass fixes: - RemoveNoOp: Add SeqStmt handler for back-to-front unused Bind scan - ConvertSSA: Add SeqStmt handler to maintain ScopedRedefine across siblings - StorageRewrite: Push/pop scope entry in BindNode handler - HoistExpression: Merge Bind lifecycle management into SeqStmt handler; only set reached_sequential_node for truly sequential (non-Bind) stmts - SBlockAccessRegionDetector: Defer let_bindings_ erasure to SeqStmt end - TVMScript printer: Add AsDocBodySeqSlice for scoped T.LetStmt form when printing already-defined-var Binds - TVMScript parser: Support doc.Attribute in _duplicate_lhs_check Test updates: - verify_well_formed: Adjust for flat scope semantics - convert_ssa: Update for flattened SeqStmt behavior - tvmscript printer/annotation/syntax_sugar: Update access paths - loop_partition: Fix pre-existing test with incorrect expected output
- Update comments in var.h and Python functor.py to reference BindNode instead of LetStmtNode - Apply clang-format fixes to files modified by the BindNode migration - Remove unused Bind import in functor.py (LetStmt alias is used instead) - Remove extra blank lines left over from migration in analysis/rewriter files
…FI_UNREACHABLE - Replace all 8 `__builtin_unreachable()` calls with `TVM_FFI_UNREACHABLE()`: src/s_tir/transform/inject_virtual_thread.cc, src/script/printer/relax/distributed.cc, src/script/printer/tir/stmt.cc, src/target/source/codegen_c.cc, src/tir/transform/vectorize_loop.cc, src/tir/transform/lower_tvm_builtin.cc, include/tvm/script/printer/ir_docsifier_functor.h, include/tvm/tir/stmt.h - Rename `kLetStmt` enum value → `kBind` in HoistedLetBindings (C++ and Python) (src/s_tir/transform/hoist_expression.cc, python/tvm/tir/transform/transform.py) - Rename `LetStmt()` → `Bind()` in script/ir_builder/tir: - C++ function in ir.h and ir.cc; keep `LetStmt` as a deprecated inline alias - Register `"script.ir_builder.tir.Bind"` as primary; keep `LetStmt` as alias - Python ir.py: add `Bind()` as primary function; `LetStmt()` delegates to it - Update stale `LetStmt` mentions in comments and docstrings to `Bind`: src/s_tir/schedule/analysis/reducer.cc, src/s_tir/schedule/primitive/reduction.cc, src/s_tir/transform/hoist_expression.cc, src/tir/ir/specialize.cc, src/tir/transform/common_subexpr_elim.cc, src/tir/transform/tvm_ffi_binder.h, src/tir/transform/ir_utils.cc, src/te/operation/create_primfunc.cc, include/tvm/tir/stmt.h, python/tvm/tir/stmt.py, python/tvm/tir/functor.py - Clean up `src/script/printer/tir/utils.h`: remove `AsDocBodySeqSlice` helper that used `TIR(d, "LetStmt")` scoped form; inline loop directly in `AsDocBody` (Bind is flat-assignment, no scoped form needed)
… semantics Replace the recursive VisitSeqStmtSlice helper with an iterative SeqStmt handler that processes children directly: Bind nodes augment the context and trigger cross-sibling CSE on remaining siblings, while non-Bind nodes are processed individually.
…Bind Remove the custom SeqStmt handler that maintained ScopedRedefine entries for Bind nodes. Instead, the simplified BindNode handler adds persistent remappings via function_scope_var_remap_ directly, which don't need scoped cleanup. The default StmtMutator processes SeqStmt children sequentially, so remappings from Bind nodes are automatically visible to subsequent siblings.
… flat Bind Simplify the SeqStmt handler to only perform sequential detection (counting non-Bind statements) and delegate visitation to the parent. Remove the Bind-var lifecycle management (tracking and erasing from let_var_to_loop_vars/let_var_to_let_vars maps at sequence boundaries). Bind vars now persist in the tracking maps for the duration of the HoistInfoCollector instance.
…t handler for flat Bind Remove the custom SeqStmt handler that tracked and erased Bind-defined let_bindings_ at sequence boundaries. The BindNode handler now just adds to let_bindings_ and relies on the BlockReadWriteDetector instance scope for cleanup. The default StmtVisitor processes SeqStmt children sequentially, so bindings are visible to subsequent siblings.
…eqStmt Remove the custom SeqStmt handler and dead-Bind-variable backward scan from remove_no_op. The VisitStmt_(BindNode*) handler now simply mutates the value and returns. Unused Bind elimination can be added back later via a separate two-pass approach.
…hScope Remove the custom SeqStmt handler that captured remaining siblings as body for nd_mem_alloc_with_scope processing. MakeNdMemAllocWithScope now rewrites the Bind value inline (lowering to tvm_call_packed) and adds a null check, without body capture.
…roundtrip tests Remove opt_gemm_mod_host and let_stmt_value test functions from test_tvmscript_roundtrip.py. Both use non-SSA re-binds (with T.LetStmt var= pattern) that cannot roundtrip with flat Bind semantics.
With flat Bind there is no body to inspect for usage patterns, so Bind inlining (removing the Bind and substituting its value) is disabled. The analyzer still records variable bindings for constraint proving, but the Bind statement is always kept. Remove the CollectVarsUsedInBufferDefinition utility and used_in_buffer_def_ tracking which were only needed for the inlining codepath. Update tests to reflect that Binds are no longer eliminated.
…ls, and tests Update hoist_expression to manage Bind lifecycle in SeqStmt, fix IRConvertSSA to handle Bind redefinitions across SeqStmt siblings, and update test expectations for flat Bind semantics.
…defs Use ScopeStack to manage Bind variable definitions. Body-carrying statements (For, IfThenElse, Allocate, DeclBuffer, AttrStmt, While, SBlock) push a new scope; BindNode pushes its WithDef into the current scope. When the scope exits all Bind defs are cleaned up automatically, removing the need for custom SeqStmt handling.
…flat Bind - remove_store_undef: eagerly check buffer indices for undef in the locator phase (flat Bind means the undef Bind is a sibling, not an ancestor, so post-validation alone cannot catch it). Also remove Bind nodes whose value contains undef alongside the removed stores. - inject_ptx_async_copy test: update expected CUDA to reflect that analyzer->Bind substitutes the variable with its value.
…nagement Restore the original SeqStmt handler logic that tracks Bind vars defined in a sequence and erases them from let_var_to_loop_vars/let_var_to_let_vars maps when the sequence ends. Keep the refactor simple per user feedback.
The migration from LetStmt to Bind is complete. Remove all backward- compatibility aliases and deprecated wrappers: - Remove `using LetStmtNode = BindNode` and `using LetStmt = Bind` from include/tvm/tir/stmt.h - Remove `LetStmt()` wrapper and `LegacyLetStmt()` from C++ and Python script ir_builder - Remove `tir.LetStmt` FFI factory from stmt.cc - Remove `LetStmt = Bind` alias from python/tvm/tir/stmt.py - Rename `visit_let_stmt_` to `visit_bind_` in Python functor metadata and method names, matching the C++ `f_visit_bind` field - Rename `f_visit_let_stmt` parameters in py_functor.cc to `f_visit_bind` - Update all test files: T.LetStmt -> T.Bind, comments, function names
…anagement Replace the ScopedRedefine RAII struct and custom SeqStmt handler with ScopeStack<ScopeLevel> for cleaner scope management: - Body-carrying statements (For, IfThenElse, AttrStmt, DeclBuffer, While, Allocate, SBlock) push a new scope via scope_.WithNewScope() - Bind pushes var remaps to the current scope level, persisting across SeqStmt siblings - Scope exit automatically pops all remaps via ScopeLevel destructor - Remove the custom VisitStmt_(SeqStmtNode*) -- default sequential iteration works because Bind remaps persist in the enclosing scope - Add IfThenElse handler with separate scopes per branch to prevent remap leakage between then/else cases
…ement Replace manual save/restore of context_ in the Common Subexpression Elimination pass with ScopeStack-based automatic scope management. Key changes: - Add ScopeStack<ContextScopeLevel> where each scope level records the context size on entry and truncates it back on exit via destructor - ForNode, LetNode: WithNewScope replaces manual context save/restore - New scope-boundary overrides for IfThenElse, AttrStmt, Allocate, DeclBuffer, While to prevent context leaks across scope boundaries - SeqStmtNode: remove manual context save/restore (enclosing scope handles cleanup), retain wrap-remaining-siblings pattern for cross-sibling CSE after Bind nodes - BindNode: entries persist across SeqStmt siblings, cleaned up automatically when enclosing body-carrying statement's scope exits
The LetStmt-to-Bind migration dropped the free_nd call that was previously wrapped after the LetStmt body, causing a memory leak for nd allocations (Hexagon VTCM, Adreno textures). With flat Bind semantics, the free is pushed to a pending_nd_frees_ vector and appended at the end of the enclosing SeqStmt by a new VisitStmt_(SeqStmtNode*) override.
The LetStmt-to-Bind refactor accidentally duplicated the LOG(WARNING) call in IRDocsifierFunctor::operator(). Remove the extra one.
Rename test fixtures and functions that still use "letstmt" to "bind"
to match the LetStmt-to-Bind refactor:
- argmax_split_letstmt_{fewer,more}_than_init -> argmax_split_bind_*
- test_letstmt_bufferload_without_type_annotation -> test_bind_*
- test_letstmt_bind_with_constant -> test_bind_with_constant
Rename LetFrameNode/LetFrame to BindFrameNode/BindFrame across C++ headers, implementation, and Python bindings to align with the LetStmt-to-Bind refactor. Updates FFI type key from "script.ir_builder.tir.LetFrame" to "script.ir_builder.tir.BindFrame".
…base Resolve remaining AllocateNode references that should be AllocBufferNode after rebasing onto the AllocBuffer commit. Also add TVM_FFI_UNREACHABLE after throw in blockize_tensorize.
The pending_nd_frees_ approach hoisted free_nd calls to the nearest SeqStmt boundary, which could incorrectly escape conditional branches. Use ScopeStack instead: register free_nd in the current scope when Bind allocates via nd_mem_alloc_with_scope, and emit frees on scope exit. This matches the old LetStmt body semantics structurally.
The old handler wrapped remaining siblings after each individual Bind node and re-ran VisitStmt, causing O(n^2) complexity for sequences with many consecutive Bind nodes. The new hybrid approach batches consecutive trivial Binds (constant or variable values) and defers the cross-sibling CSE until the batch ends, reducing the common case to O(n). Non-trivial Binds (whose values may contain eligible computations) still use the per-Bind wrap pattern to preserve full CSE effectiveness.
…ph and layout_transformation BindLetVar and BindVariableDefinition RAII guards erased map entries on destruction, but flat BindNode has no body -- the guard is destroyed when the handler returns, making the binding invisible to subsequent sibling statements. Under SSA each variable is bound exactly once, so the maps grow monotonically and cleanup is unnecessary. Remove the cleanup from both destructors to fix the bug. Also add a comment explaining the dead cse_v1 variable in test_s_tir_transform_inject_ptx_async_copy: CSE extracts (i < 12) before inject_ptx_async_copy replaces IfThenElse guards with new cast(int32, ...) expressions for predicated copies, leaving the CSE variable unused.
Bug 1 (inject_virtual_thread.cc): When a Bind in a SeqStmt touches vt_var, the VT loop must wrap the Bind together with all remaining siblings (which may reference the bound variable). Previously, the Bind handler wrapped only itself, breaking semantics. Rewrite the SeqStmt handler to pre-check Bind children and group them with remaining siblings before wrapping with InjectVTLoop. Bug 2 (lower_tvm_builtin.cc): MakeNdMemAllocWithScope was returning without re-visiting via StmtExprMutator::VisitStmt, leaving tvm_call_packed builtins in both the Bind value and the free_stmt unlowered. Re-wrap with VisitStmt and visit free_stmt before pushing to pending_frees. Bug 3 (frame.cc): BindFrameNode::ExitWithScope used SeqStmt constructor (which does not flatten) instead of SeqStmt::Flatten, creating nested SeqStmts. Also, when stmts is empty, emit just the Bind without wrapping in a SeqStmt with a spurious Evaluate(0).
Bind is now a direct statement like Evaluate -- it emits a Bind stmt to the current frame and returns the Var, with no context manager or RAII scope needed. Changes: - C++ ir_builder: Bind() creates var, calls AddToParent(tir::Bind(...)), returns var instead of BindFrame - Remove BindFrameNode/BindFrame classes from frame.h and frame.cc - Python ir_builder: Bind() returns Var instead of BindFrame - Parser: bind_assign_value and visit_ann_assign simplified to call T.Bind() directly without frame lifecycle management - Parser: visit_expr_stmt skips standalone Var results (from T.Bind()) instead of wrapping them in T.evaluate() - Remove BindFrame Python class from frame.py - Update all tests from `with T.Bind(...) as v:` to `v = T.Bind(...)`
…ataTypeRewriter Base StmtMutator now handles these via VisitBufferDef (from upstream PR apache#18873).
4259c3d to
a89b7dc
Compare
Summary
Rename
LetStmtNode/LetStmttoBindNode/Bindand remove thebodyfield.The variable defined by
Bind(var, value)is now visible in all subsequentstatements within the same enclosing scope, rather than being scoped to a nested body.
This flattens deeply nested let-chains into sequential
SeqStmt([Bind(...), Bind(...), ...]),making the IR easier to read, transform, and analyze.
Key Changes
BindNode:{var, value}— no body field. Variable scope is the enclosingstatement's body (For, IfThenElse, AllocBuffer, etc.)
tir_visitor_with_path) use
ScopeStackinstead of manual save/restore or RAII wrappersS-TIR transforms, analyses, TVMScript printer/parser/ir_builder