[ET-VK][q8ta] Fix addmm arg indexing in QuantizedLinearMatch#17567
[ET-VK][q8ta] Fix addmm arg indexing in QuantizedLinearMatch#17567SS-JIA wants to merge 5 commits intogh/SS-JIA/441/basefrom
Conversation
QuantizedLinearMatch always used args[1] for the weight and args[0] for the input, which is correct for mm(input, weight) and linear(input, weight, bias?) but wrong for addmm(bias, input, weight) where the weight is at args[2] and the input is at args[1]. This was exposed by a torchao change (D69887498) that added Linear+BatchNorm fusion to prepare_pt2e(). The fusion adds a bias to Linear nodes that previously had none, causing them to decompose to addmm instead of mm in the edge dialect. The pattern matcher then read the input's per-tensor dequantize scale (a float literal) as if it were the weight's per-channel scale (a Node), causing an assertion failure. The fix determines the correct arg indices based on whether the anchor node is addmm. The bias handling at args[0] for addmm was already correct. Authored-by: Claude Differential Revision: [D93768640](https://our.internmc.facebook.com/intern/diff/D93768640/) [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/17567
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New Failures, 59 Pending, 1 Unrelated FailureAs of commit 2a99758 with merge base 1056c34 ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
QuantizedLinearMatch always used args[1] for the weight and args[0] for the input, which is correct for mm(input, weight) and linear(input, weight, bias?) but wrong for addmm(bias, input, weight) where the weight is at args[2] and the input is at args[1]. This was exposed by a torchao change (D69887498) that added Linear+BatchNorm fusion to prepare_pt2e(). The fusion adds a bias to Linear nodes that previously had none, causing them to decompose to addmm instead of mm in the edge dialect. The pattern matcher then read the input's per-tensor dequantize scale (a float literal) as if it were the weight's per-channel scale (a Node), causing an assertion failure. The fix determines the correct arg indices based on whether the anchor node is addmm. The bias handling at args[0] for addmm was already correct. Authored-by: Claude Differential Revision: [D93768640](https://our.internmc.facebook.com/intern/diff/D93768640/) [ghstack-poisoned]
QuantizedLinearMatch always used args[1] for the weight and args[0] for the input, which is correct for mm(input, weight) and linear(input, weight, bias?) but wrong for addmm(bias, input, weight) where the weight is at args[2] and the input is at args[1]. This was exposed by a torchao change (D69887498) that added Linear+BatchNorm fusion to prepare_pt2e(). The fusion adds a bias to Linear nodes that previously had none, causing them to decompose to addmm instead of mm in the edge dialect. The pattern matcher then read the input's per-tensor dequantize scale (a float literal) as if it were the weight's per-channel scale (a Node), causing an assertion failure. The fix determines the correct arg indices based on whether the anchor node is addmm. The bias handling at args[0] for addmm was already correct. Authored-by: Claude Differential Revision: [D93768640](https://our.internmc.facebook.com/intern/diff/D93768640/) [ghstack-poisoned]
QuantizedLinearMatch always used args[1] for the weight and args[0] for the input, which is correct for mm(input, weight) and linear(input, weight, bias?) but wrong for addmm(bias, input, weight) where the weight is at args[2] and the input is at args[1]. This was exposed by a torchao change (D69887498) that added Linear+BatchNorm fusion to prepare_pt2e(). The fusion adds a bias to Linear nodes that previously had none, causing them to decompose to addmm instead of mm in the edge dialect. The pattern matcher then read the input's per-tensor dequantize scale (a float literal) as if it were the weight's per-channel scale (a Node), causing an assertion failure. The fix determines the correct arg indices based on whether the anchor node is addmm. The bias handling at args[0] for addmm was already correct. Authored-by: Claude Differential Revision: [D93768640](https://our.internmc.facebook.com/intern/diff/D93768640/) [ghstack-poisoned]
QuantizedLinearMatch always used args[1] for the weight and args[0] for the input, which is correct for mm(input, weight) and linear(input, weight, bias?) but wrong for addmm(bias, input, weight) where the weight is at args[2] and the input is at args[1]. This was exposed by a torchao change (D69887498) that added Linear+BatchNorm fusion to prepare_pt2e(). The fusion adds a bias to Linear nodes that previously had none, causing them to decompose to addmm instead of mm in the edge dialect. The pattern matcher then read the input's per-tensor dequantize scale (a float literal) as if it were the weight's per-channel scale (a Node), causing an assertion failure. The fix determines the correct arg indices based on whether the anchor node is addmm. The bias handling at args[0] for addmm was already correct. Authored-by: Claude Differential Revision: [D93768640](https://our.internmc.facebook.com/intern/diff/D93768640/) [ghstack-poisoned]
Stack from ghstack (oldest at bottom):
QuantizedLinearMatch always used args[1] for the weight and args[0] for the
input, which is correct for mm(input, weight) and linear(input, weight, bias?)
but wrong for addmm(bias, input, weight) where the weight is at args[2] and the
input is at args[1].
This was exposed by a torchao change (D69887498) that added Linear+BatchNorm
fusion to prepare_pt2e(). The fusion adds a bias to Linear nodes that previously
had none, causing them to decompose to addmm instead of mm in the edge dialect.
The pattern matcher then read the input's per-tensor dequantize scale (a float
literal) as if it were the weight's per-channel scale (a Node), causing an
assertion failure.
The fix determines the correct arg indices based on whether the anchor node is
addmm. The bias handling at args[0] for addmm was already correct.
Authored-by: Claude
Differential Revision: D93768640