From 9d1bd7cab009cdbd5f57a7c0fbb8fb5a4ad1ebd8 Mon Sep 17 00:00:00 2001 From: Chapaman <14204271+Chapaman@users.noreply.github.com> Date: Fri, 30 Jan 2026 16:35:04 -0300 Subject: [PATCH 01/16] Basic gall_gather implementation --- exla/lib/exla/defn.ex | 18 ++++++++++++++++++ exla/lib/exla/mlir/value.ex | 27 +++++++++++++++++++++++++++ nx/lib/nx/defn/evaluator.ex | 5 +++++ nx/lib/nx/defn/expr.ex | 27 +++++++++++++++++++++++++++ nx/lib/nx/defn/kernel.ex | 26 ++++++++++++++++++++++++++ 5 files changed, 103 insertions(+) diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 592a8279b4..4cd674ab97 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -1471,6 +1471,24 @@ defmodule EXLA.Defn do EXLA.Lib.argsort(state.builder, tensor, dimension, stable, comp, ans.type) end +## to_operator collective ops + + defp to_operator(:all_gather, [%Value{} = tensor, opts], ans, _state) do + all_gather_dim = Keyword.fetch!(opts, :all_gather_dim) + replica_groups = Keyword.fetch!(opts, :replica_groups) + use_global_device_ids = Keyword.get(opts, :use_global_device_ids, false) + + Value.all_gather( + [tensor], + expr_to_typespec(ans), + all_gather_dim, + replica_groups, + use_global_device_ids, + Keyword.take(opts, [:channel_id]) + ) + |> hd() + end + defp fft(exla_op, [%Value{} = tensor, opts], %{type: type} = ans, state) do n = opts[:length] axis = opts[:axis] diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index 393b6d57a8..9b6822c6dd 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -64,6 +64,33 @@ defmodule EXLA.MLIR.Value do end end + def all_gather([%Value{function: func} | _] = operands, typespec, all_gather_dim, replica_groups, use_global_device_ids, opts \\ []) do + result_types = typespecs_to_mlir_types([typespec]) + + opts = + Keyword.validate!(opts, [ + channel_id: nil, + ]) + + num_groups = length(replica_groups) + group_size = if num_groups > 0, do: length(hd(replica_groups)), else: 0 + flat_groups = List.flatten(replica_groups) + + attributes = [ + all_gather_dim: attr_i64(all_gather_dim), + replica_groups: attr_dense_elements(flat_groups, {:s, 64}, {num_groups, group_size}), + use_global_device_ids: attr_boolean(use_global_device_ids) + ] + + attributes = + if opts[:channel_id] do + attributes ++ [channel_id: attr_i64(opts[:channel_id])] + else + attributes end + + op(func, "stablehlo.all_gather", operands, result_types, attributes: attributes) + end + defp compare_and_return_bool(func, lhs, rhs, typespec, direction, total_order? \\ false) do %{type: lhs_type} = get_typespec(lhs) %{type: rhs_type} = get_typespec(rhs) diff --git a/nx/lib/nx/defn/evaluator.ex b/nx/lib/nx/defn/evaluator.ex index 601f750942..5e33925625 100644 --- a/nx/lib/nx/defn/evaluator.ex +++ b/nx/lib/nx/defn/evaluator.ex @@ -478,6 +478,11 @@ defmodule Nx.Defn.Evaluator do {Nx.Shared.list_impl!(args), [ans | args]} end + if op == :all_gather and not function_exported?(mod, :all_gather, 3) do + raise ArgumentError, + "all_gather/3 is not supported by backend #{inspect(mod)}." + end + {apply(mod, op, args), caches} end diff --git a/nx/lib/nx/defn/expr.ex b/nx/lib/nx/defn/expr.ex index 899b430da4..ab11c46af1 100644 --- a/nx/lib/nx/defn/expr.ex +++ b/nx/lib/nx/defn/expr.ex @@ -1166,6 +1166,33 @@ defmodule Nx.Defn.Expr do expr(out, context, :gather, [tensor, indices, opts]) end + def all_gather(tensor, opts) do + {[tensor], context} = to_exprs([tensor]) + + _all_gather_dim = opts[:all_gather_dim] + replica_groups = opts[:replica_groups] + + # Calculate group size (number of replicas per group) + _group_size = + case replica_groups do + [first_group | _] -> length(first_group) + [] -> 1 + end + + # Calculate output shape by multiplying the gather dimension by group_size + input_shape = tensor.shape + output_shape = + input_shape +# |> Tuple.to_list() +# |> List.update_at(all_gather_dim, &(&1 * group_size)) +# |> List.to_tuple() + + # Create output tensor with the new shape + out = %{tensor | shape: output_shape} + + expr(out, context, :all_gather, [tensor, opts]) + end + @impl true def reverse(out, tensor, axes) do tensor = to_expr(tensor) diff --git a/nx/lib/nx/defn/kernel.ex b/nx/lib/nx/defn/kernel.ex index ab913ab61f..a0cf4f4493 100644 --- a/nx/lib/nx/defn/kernel.ex +++ b/nx/lib/nx/defn/kernel.ex @@ -1669,6 +1669,32 @@ defmodule Nx.Defn.Kernel do end end + @doc """ + Gathers tensors from all replicas along a specified dimension. + + This operation concatenates tensors from multiple replicas/devices along + the specified dimension. Requires a backend that supports multi-device operations. + + ## Parameters + + * `tensor` - The input tensor to gather + * `all_gather_dim` - The dimension along which to gather + * `replica_groups` - 2D list defining how replicas are grouped (required) + * `opts` - Optional keyword list: + * `:use_global_device_ids` - Whether to use global device IDs (default: false) + * `:channel_id` - Channel ID for communication (optional) + + ## Examples + + all_gather(tensor, 0, [[0, 1, 2, 3]]) + all_gather(tensor, 1, [[0, 1], [2, 3]], use_global_device_ids: true) + """ + def all_gather(tensor, all_gather_dim, replica_groups, opts \\ []) do + opts = Keyword.put(opts, :all_gather_dim, all_gather_dim) + opts = Keyword.put(opts, :replica_groups, replica_groups) + Nx.Defn.Expr.all_gather(tensor, opts) + end + @definitions (Module.definitions_in(__MODULE__, :def) ++ Module.definitions_in(__MODULE__, :defmacro)) -- [ From cc8761d4fa10441bfbe658dd30610a9d1d1c74c1 Mon Sep 17 00:00:00 2001 From: Chapaman <14204271+Chapaman@users.noreply.github.com> Date: Mon, 2 Feb 2026 19:33:10 -0300 Subject: [PATCH 02/16] changes due to code review by @polvalente --- exla/lib/exla.ex | 15 ++++++ exla/lib/exla/defn.ex | 21 ++++---- exla/lib/exla/mlir/value.ex | 30 +++++------ exla/test/exla/defn/sharding_test.exs | 73 ++++++++++++++++++++++++++- nx/lib/nx/defn/evaluator.ex | 5 -- nx/lib/nx/defn/kernel.ex | 14 ++--- nx/test/nx/defn_test.exs | 13 +++++ 7 files changed, 128 insertions(+), 43 deletions(-) diff --git a/exla/lib/exla.ex b/exla/lib/exla.ex index 78c9016361..403c6fbe76 100644 --- a/exla/lib/exla.ex +++ b/exla/lib/exla.ex @@ -215,6 +215,21 @@ defmodule EXLA do The metadata is: * `:key` - the compilation key for debugging + + ## Sharding + + EXLA supports sharding, which is a way to partition a computation across multiple devices. + There are a number of collective operations that are supported by sharding. + + ### [`all_gather`](https://openxla.org/stablehlo/spec#all_gather) + + #### Options + + * `:all_gather_dim` - the dimension along which to gather + * `:replica_groups` - 2D list defining how replicas are grouped + * `:use_global_device_ids` - Whether to use global device IDs (default: `false`) + * `:channel_id` - Channel ID for communication (optional) + """ @behaviour Nx.Defn.Compiler diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 4cd674ab97..65239373ba 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -1478,15 +1478,18 @@ defmodule EXLA.Defn do replica_groups = Keyword.fetch!(opts, :replica_groups) use_global_device_ids = Keyword.get(opts, :use_global_device_ids, false) - Value.all_gather( - [tensor], - expr_to_typespec(ans), - all_gather_dim, - replica_groups, - use_global_device_ids, - Keyword.take(opts, [:channel_id]) - ) - |> hd() + # We might want to surface all_gather as an operation that takes a container of operands instead of a single one. + [result] = + Value.all_gather( + [tensor], + expr_to_typespec(ans), + all_gather_dim, + replica_groups, + use_global_device_ids, + opts[:channel_id] + ) + + result end defp fft(exla_op, [%Value{} = tensor, opts], %{type: type} = ans, state) do diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index 9b6822c6dd..e548693497 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -64,29 +64,25 @@ defmodule EXLA.MLIR.Value do end end - def all_gather([%Value{function: func} | _] = operands, typespec, all_gather_dim, replica_groups, use_global_device_ids, opts \\ []) do + def all_gather([%Value{function: func} | _] = operands, typespec, all_gather_dim, replica_groups, use_global_device_ids, channel_id \\ nil) do result_types = typespecs_to_mlir_types([typespec]) - opts = - Keyword.validate!(opts, [ - channel_id: nil, - ]) + num_groups = length(replica_groups) + group_size = if num_groups > 0, do: length(hd(replica_groups)), else: 0 + flat_groups = List.flatten(replica_groups) - num_groups = length(replica_groups) - group_size = if num_groups > 0, do: length(hd(replica_groups)), else: 0 - flat_groups = List.flatten(replica_groups) - - attributes = [ - all_gather_dim: attr_i64(all_gather_dim), - replica_groups: attr_dense_elements(flat_groups, {:s, 64}, {num_groups, group_size}), - use_global_device_ids: attr_boolean(use_global_device_ids) - ] + attributes = [ + all_gather_dim: attr_i64(all_gather_dim), + replica_groups: attr_dense_elements(flat_groups, {:s, 64}, {num_groups, group_size}), + use_global_device_ids: attr_boolean(use_global_device_ids) + ] attributes = - if opts[:channel_id] do - attributes ++ [channel_id: attr_i64(opts[:channel_id])] + if channel_id do + Keyword.put(attributes, :channel_id, attr_i64(channel_id)) else - attributes end + attributes + end op(func, "stablehlo.all_gather", operands, result_types, attributes: attributes) end diff --git a/exla/test/exla/defn/sharding_test.exs b/exla/test/exla/defn/sharding_test.exs index ed46f76b6a..058e2683b6 100644 --- a/exla/test/exla/defn/sharding_test.exs +++ b/exla/test/exla/defn/sharding_test.exs @@ -6,7 +6,8 @@ defmodule EXLA.Defn.ShardingTest do describe "MLIR module generation with sharding" do @moduletag :multi_device test "generates correct MLIR with simple 2D mesh and sharding" do - fun = fn x, y -> Nx.add(x, y) end + fun = fn x, y -> Nx.add(x, y) + end mesh = %Mesh{name: "mesh", shape: {2, 2}} # First arg: shard dim 0 on mesh axis 0, dim 1 on mesh axis 1 @@ -737,5 +738,75 @@ defmodule EXLA.Defn.ShardingTest do assert result.mlir_module =~ ~r/"axis_0"/ assert result.mlir_module =~ ~r/"axis_1"/ end + + @moduletag :multi_device + test "generates correct MLIR with all_gather" do + fun = fn x, y -> Nx.add(x, y) + |> Nx.Defn.Kernel.all_gather(all_gather_dim: 0, replica_groups: [[0]]) + |> Nx.Defn.Kernel.all_gather(all_gather_dim: 1, replica_groups: [[0]]) + end + + mesh = %Mesh{name: "mesh", shape: {2, 2}} + # First arg: shard dim 0 on mesh axis 0, dim 1 on mesh axis 1 + # Second arg: shard dim 0 on mesh axis 0, dim 1 not sharded + input_shardings = [%{0 => [0], 1 => [1]}, %{0 => [0]}] + + # For mesh {2, 2}, we have 4 partitions + # Each partition gets a shard of the inputs + # First input: shape {8, 2} sharded as [[0], [1]] -> each partition gets {4, 1} + # Second input: shape {8, 1} sharded as [[0], []] -> each partition gets {4, 1} + args = [ + # partition 0 + [Nx.iota({4, 1}), Nx.iota({4, 1})], + # partition 1 + [Nx.iota({4, 1}), Nx.iota({4, 1})], + # partition 2 + [Nx.iota({4, 1}), Nx.iota({4, 1})], + # partition 3 + [Nx.iota({4, 1}), Nx.iota({4, 1})] + ] + + result = EXLA.to_mlir_module(fun, args, mesh: mesh, input_shardings: input_shardings) + + expected_mlir = """ + module { + sdy.mesh @mesh = <["axis_0"=2, "axis_1"=2]> + func.func public @main(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"axis_0", ?}p0, {"axis_1", ?}p0]>}, %arg1: tensor<8x1xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"axis_0", ?}p0, {?}p0]>}) -> tensor<8x2xi32> { + %0 = stablehlo.broadcast_in_dim %arg1, dims = [0, 1] : (tensor<8x1xi32>) -> tensor<8x2xi32> + %1 = stablehlo.add %arg0, %0 : tensor<8x2xi32> + %2 = "stablehlo.all_gather"(%1) <{all_gather_dim = 0 : i64, replica_groups = dense<0> : tensor<1x1xi64>}> : (tensor<8x2xi32>) -> tensor<8x2xi32> + %3 = "stablehlo.all_gather"(%2) <{all_gather_dim = 1 : i64, replica_groups = dense<0> : tensor<1x1xi64>}> : (tensor<8x2xi32>) -> tensor<8x2xi32> + return %3 : tensor<8x2xi32> + } + } + """ + + assert expected_mlir == result.mlir_module + + results = EXLA.shard_jit(fun, mesh, input_shardings: input_shardings).(args) + + assert length(results) == 4 + + # After all_gather on both dims, each partition has the full tensor: add(iota, iota) -> 2*iota + # Each shard had iota({4,1}) = [[0],[1],[2],[3]], so add gives [[0],[2],[4],[6]] + # After gathering: replicated 8x2 with pattern [[0,0],[2,2],[4,4],[6,6],[0,0],[2,2],[4,4],[6,6]] + expected_result = + Nx.tensor([ + [0, 0], + [2, 2], + [4, 4], + [6, 6], + [0, 0], + [2, 2], + [4, 4], + [6, 6] + ]) + + for r <- results do + assert_equal(r, expected_result) + end + end + + end end diff --git a/nx/lib/nx/defn/evaluator.ex b/nx/lib/nx/defn/evaluator.ex index 5e33925625..601f750942 100644 --- a/nx/lib/nx/defn/evaluator.ex +++ b/nx/lib/nx/defn/evaluator.ex @@ -478,11 +478,6 @@ defmodule Nx.Defn.Evaluator do {Nx.Shared.list_impl!(args), [ans | args]} end - if op == :all_gather and not function_exported?(mod, :all_gather, 3) do - raise ArgumentError, - "all_gather/3 is not supported by backend #{inspect(mod)}." - end - {apply(mod, op, args), caches} end diff --git a/nx/lib/nx/defn/kernel.ex b/nx/lib/nx/defn/kernel.ex index a0cf4f4493..809a66b480 100644 --- a/nx/lib/nx/defn/kernel.ex +++ b/nx/lib/nx/defn/kernel.ex @@ -1678,20 +1678,12 @@ defmodule Nx.Defn.Kernel do ## Parameters * `tensor` - The input tensor to gather - * `all_gather_dim` - The dimension along which to gather - * `replica_groups` - 2D list defining how replicas are grouped (required) - * `opts` - Optional keyword list: - * `:use_global_device_ids` - Whether to use global device IDs (default: false) - * `:channel_id` - Channel ID for communication (optional) - ## Examples + * `opts` - Optional keyword list. These are backend- and compiler-specific; + see your backend or compiler docs for supported options. - all_gather(tensor, 0, [[0, 1, 2, 3]]) - all_gather(tensor, 1, [[0, 1], [2, 3]], use_global_device_ids: true) """ - def all_gather(tensor, all_gather_dim, replica_groups, opts \\ []) do - opts = Keyword.put(opts, :all_gather_dim, all_gather_dim) - opts = Keyword.put(opts, :replica_groups, replica_groups) + def all_gather(tensor, opts \\ []) do Nx.Defn.Expr.all_gather(tensor, opts) end diff --git a/nx/test/nx/defn_test.exs b/nx/test/nx/defn_test.exs index 62993b07a3..621b4f4e77 100644 --- a/nx/test/nx/defn_test.exs +++ b/nx/test/nx/defn_test.exs @@ -2952,4 +2952,17 @@ defmodule Nx.DefnTest do assert vectorized_metadata_tuple(x, z) == vec_nonvec_result end end + + describe "sharding" do + defn all_gather_test(tensor) do + Nx.Defn.Kernel.all_gather(tensor, all_gather_dim: 0, replica_groups: [[0]]) + end + + @tag compiler: Evaluator + test "all_gather works" do + assert_raise UndefinedFunctionError, fn -> + all_gather_test(Nx.tensor([1, 2, 3, 4])) + end + end + end end From 0a50a3e5fb457b8fb84a60d553b1e1ebccb325e9 Mon Sep 17 00:00:00 2001 From: Chapaman <14204271+Chapaman@users.noreply.github.com> Date: Mon, 9 Feb 2026 18:26:32 -0300 Subject: [PATCH 03/16] added test in defn to guarantee output format --- nx/test/nx/defn_test.exs | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/nx/test/nx/defn_test.exs b/nx/test/nx/defn_test.exs index 621b4f4e77..d15951ce71 100644 --- a/nx/test/nx/defn_test.exs +++ b/nx/test/nx/defn_test.exs @@ -2958,6 +2958,19 @@ defmodule Nx.DefnTest do Nx.Defn.Kernel.all_gather(tensor, all_gather_dim: 0, replica_groups: [[0]]) end + test "all_gather produces correct expr format for compiler" do + # Uses debug_expr to inspect the expression without compiling. + # Guarantees the format passed to compilers (e.g. EXLA) stays stable. + assert %T{data: %Expr{op: :all_gather, args: [tensor, opts]}} = + Nx.Defn.debug_expr(&all_gather_test/1).(Nx.tensor([1, 2, 3, 4])) + + assert %T{data: %Expr{op: :parameter, args: [0]}} = tensor + + # Compilers expect opts with :all_gather_dim and :replica_groups + assert opts[:all_gather_dim] == 0 + assert opts[:replica_groups] == [[0]] + end + @tag compiler: Evaluator test "all_gather works" do assert_raise UndefinedFunctionError, fn -> From 2daaad8dc3f64918cae50f040a031e35a373be78 Mon Sep 17 00:00:00 2001 From: Chapaman <14204271+Chapaman@users.noreply.github.com> Date: Mon, 9 Feb 2026 18:37:28 -0300 Subject: [PATCH 04/16] added sharding confirmation to test --- exla/test/exla/defn/sharding_test.exs | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/exla/test/exla/defn/sharding_test.exs b/exla/test/exla/defn/sharding_test.exs index 058e2683b6..7d1eb993e4 100644 --- a/exla/test/exla/defn/sharding_test.exs +++ b/exla/test/exla/defn/sharding_test.exs @@ -802,11 +802,13 @@ defmodule EXLA.Defn.ShardingTest do [6, 6] ]) - for r <- results do - assert_equal(r, expected_result) - end - end - + device_ids = + for r <- results do + assert_equal(r, expected_result) + r.data.buffer.device_id + end + assert Enum.sort(device_ids) == [0, 1, 2, 3] + end end end From 1be394a1e7dec11666e14728e5cd4f734485f74e Mon Sep 17 00:00:00 2001 From: Chapaman <14204271+Chapaman@users.noreply.github.com> Date: Wed, 11 Feb 2026 18:47:15 -0300 Subject: [PATCH 05/16] updated the test so the result on the test is clearer --- exla/test/exla/defn/sharding_test.exs | 60 +++++++++++++-------------- 1 file changed, 28 insertions(+), 32 deletions(-) diff --git a/exla/test/exla/defn/sharding_test.exs b/exla/test/exla/defn/sharding_test.exs index 7d1eb993e4..ccf3ce2583 100644 --- a/exla/test/exla/defn/sharding_test.exs +++ b/exla/test/exla/defn/sharding_test.exs @@ -747,23 +747,22 @@ defmodule EXLA.Defn.ShardingTest do end mesh = %Mesh{name: "mesh", shape: {2, 2}} - # First arg: shard dim 0 on mesh axis 0, dim 1 on mesh axis 1 - # Second arg: shard dim 0 on mesh axis 0, dim 1 not sharded - input_shardings = [%{0 => [0], 1 => [1]}, %{0 => [0]}] + # First arg: 0..15 (8x2), shard dim 0 on mesh axis 0, dim 1 on mesh axis 1 + # Second arg: 100..115 (8x2), same sharding — makes sharded results easy to read + input_shardings = [%{0 => [0], 1 => [1]}, %{0 => [0], 1 => [1]}] - # For mesh {2, 2}, we have 4 partitions - # Each partition gets a shard of the inputs - # First input: shape {8, 2} sharded as [[0], [1]] -> each partition gets {4, 1} - # Second input: shape {8, 1} sharded as [[0], []] -> each partition gets {4, 1} + # For mesh {2, 2}, 4 partitions. Each gets {4, 1}. Full 8x2 row-major: [[0,1],[2,3],...,[14,15]]. + # Partition (axis_0, axis_1): (0,0)=rows 0-3 col 0, (0,1)=rows 0-3 col 1, (1,0)=rows 4-7 col 0, (1,1)=rows 4-7 col 1. + # So partition 0 gets (0,0),(1,0),(2,0),(3,0) = 0,2,4,6; partition 1 gets (0,1),(1,1),... = 1,3,5,7; etc. args = [ - # partition 0 - [Nx.iota({4, 1}), Nx.iota({4, 1})], - # partition 1 - [Nx.iota({4, 1}), Nx.iota({4, 1})], - # partition 2 - [Nx.iota({4, 1}), Nx.iota({4, 1})], - # partition 3 - [Nx.iota({4, 1}), Nx.iota({4, 1})] + # partition 0: rows 0–3 col 0 -> 0,2,4,6 and 100,102,104,106 + [Nx.tensor([[0], [2], [4], [6]]), Nx.tensor([[100], [102], [104], [106]])], + # partition 1: rows 0–3 col 1 -> 1,3,5,7 and 101,103,105,107 + [Nx.tensor([[1], [3], [5], [7]]), Nx.tensor([[101], [103], [105], [107]])], + # partition 2: rows 4–7 col 0 -> 8,10,12,14 and 108,110,112,114 + [Nx.tensor([[8], [10], [12], [14]]), Nx.tensor([[108], [110], [112], [114]])], + # partition 3: rows 4–7 col 1 -> 9,11,13,15 and 109,111,113,115 + [Nx.tensor([[9], [11], [13], [15]]), Nx.tensor([[109], [111], [113], [115]])] ] result = EXLA.to_mlir_module(fun, args, mesh: mesh, input_shardings: input_shardings) @@ -771,12 +770,11 @@ defmodule EXLA.Defn.ShardingTest do expected_mlir = """ module { sdy.mesh @mesh = <["axis_0"=2, "axis_1"=2]> - func.func public @main(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"axis_0", ?}p0, {"axis_1", ?}p0]>}, %arg1: tensor<8x1xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"axis_0", ?}p0, {?}p0]>}) -> tensor<8x2xi32> { - %0 = stablehlo.broadcast_in_dim %arg1, dims = [0, 1] : (tensor<8x1xi32>) -> tensor<8x2xi32> - %1 = stablehlo.add %arg0, %0 : tensor<8x2xi32> - %2 = "stablehlo.all_gather"(%1) <{all_gather_dim = 0 : i64, replica_groups = dense<0> : tensor<1x1xi64>}> : (tensor<8x2xi32>) -> tensor<8x2xi32> - %3 = "stablehlo.all_gather"(%2) <{all_gather_dim = 1 : i64, replica_groups = dense<0> : tensor<1x1xi64>}> : (tensor<8x2xi32>) -> tensor<8x2xi32> - return %3 : tensor<8x2xi32> + func.func public @main(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"axis_0", ?}p0, {"axis_1", ?}p0]>}, %arg1: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"axis_0", ?}p0, {"axis_1", ?}p0]>}) -> tensor<8x2xi32> { + %0 = stablehlo.add %arg0, %arg1 : tensor<8x2xi32> + %1 = "stablehlo.all_gather"(%0) <{all_gather_dim = 0 : i64, replica_groups = dense<0> : tensor<1x1xi64>}> : (tensor<8x2xi32>) -> tensor<8x2xi32> + %2 = "stablehlo.all_gather"(%1) <{all_gather_dim = 1 : i64, replica_groups = dense<0> : tensor<1x1xi64>}> : (tensor<8x2xi32>) -> tensor<8x2xi32> + return %2 : tensor<8x2xi32> } } """ @@ -787,19 +785,17 @@ defmodule EXLA.Defn.ShardingTest do assert length(results) == 4 - # After all_gather on both dims, each partition has the full tensor: add(iota, iota) -> 2*iota - # Each shard had iota({4,1}) = [[0],[1],[2],[3]], so add gives [[0],[2],[4],[6]] - # After gathering: replicated 8x2 with pattern [[0,0],[2,2],[4,4],[6,6],[0,0],[2,2],[4,4],[6,6]] + # After all_gather: full first arg 0..15 + full second 100..115 -> 100,102,...,130 expected_result = Nx.tensor([ - [0, 0], - [2, 2], - [4, 4], - [6, 6], - [0, 0], - [2, 2], - [4, 4], - [6, 6] + [100, 102], + [104, 106], + [108, 110], + [112, 114], + [116, 118], + [120, 122], + [124, 126], + [128, 130] ]) device_ids = From eaef13671e3be66c6c7945e0204aa6992eea234c Mon Sep 17 00:00:00 2001 From: Chapaman <14204271+Chapaman@users.noreply.github.com> Date: Wed, 11 Feb 2026 19:46:48 -0300 Subject: [PATCH 06/16] intermediate commit --- exla/test/exla/defn/sharding_test.exs | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/exla/test/exla/defn/sharding_test.exs b/exla/test/exla/defn/sharding_test.exs index ccf3ce2583..3052fcc30d 100644 --- a/exla/test/exla/defn/sharding_test.exs +++ b/exla/test/exla/defn/sharding_test.exs @@ -798,13 +798,10 @@ defmodule EXLA.Defn.ShardingTest do [128, 130] ]) - device_ids = - for r <- results do - assert_equal(r, expected_result) - r.data.buffer.device_id - end - - assert Enum.sort(device_ids) == [0, 1, 2, 3] + for {r, partition_idx} <- Enum.with_index(results) do + assert_equal(r, expected_result) + assert r.data.buffer.device_id == partition_idx + end end end end From d1c683869e8964e4f36012af8c3d4ee4080bd106 Mon Sep 17 00:00:00 2001 From: Chapaman <14204271+Chapaman@users.noreply.github.com> Date: Wed, 11 Feb 2026 20:10:16 -0300 Subject: [PATCH 07/16] added partially sharded test + Enum.zip_with on previous test --- exla/test/exla/defn/sharding_test.exs | 50 ++++++++++++++++++++++++--- 1 file changed, 46 insertions(+), 4 deletions(-) diff --git a/exla/test/exla/defn/sharding_test.exs b/exla/test/exla/defn/sharding_test.exs index 4154d71808..b22e20a97a 100644 --- a/exla/test/exla/defn/sharding_test.exs +++ b/exla/test/exla/defn/sharding_test.exs @@ -950,10 +950,52 @@ defmodule EXLA.Defn.ShardingTest do [128, 130] ]) - for {r, partition_idx} <- Enum.with_index(results) do - assert_equal(r, expected_result) - assert r.data.buffer.device_id == partition_idx - end + Enum.zip_with([results, 0..3], fn [result, i] -> + assert_equal(result, expected_result) + assert result.data.buffer.device_id == i + end) + end + + @moduletag :multi_device + test "can return partially sharded results" do + fun = fn x, y -> Nx.add(x, y) end + + mesh = %Mesh{name: "mesh", shape: {2, 2}} + # Inputs sharded on both axes + input_shardings = [%{0 => [0], 1 => [1]}, %{0 => [0], 1 => [1]}] + # Output: sharded only on axis 0 (dim 1 replicated) -> each partition gets {4, 2} + output_shardings = [%{0 => [0]}] + + # Logical x: 8x2, y: 8x2. Each partition gets {4, 1} of each + args = [ + [Nx.tensor([[0], [1], [2], [3]]), Nx.tensor([[100], [101], [102], [103]])], + [Nx.tensor([[10], [11], [12], [13]]), Nx.tensor([[110], [111], [112], [113]])], + [Nx.tensor([[4], [5], [6], [7]]), Nx.tensor([[104], [105], [106], [107]])], + [Nx.tensor([[14], [15], [16], [17]]), Nx.tensor([[114], [115], [116], [117]])] + ] + + results = + EXLA.shard_jit(fun, mesh, + input_shardings: input_shardings, + output_shardings: output_shardings + ).(args) + + assert length(results) == 4 + + # Partially sharded output: dim 0 sharded on axis 0, dim 1 not in output spec + # Each device returns its local shard {4, 1} (x+y computed locally) + # Dev0: col0 rows 0-3, Dev1: col1 rows 0-3, Dev2: col0 rows 4-7, Dev3: col1 rows 4-7 + expected_results = [ + Nx.tensor([[100], [102], [104], [106]]), + Nx.tensor([[120], [122], [124], [126]]), + Nx.tensor([[108], [110], [112], [114]]), + Nx.tensor([[128], [130], [132], [134]]) + ] + + Enum.zip_with([results, expected_results, 0..3], fn [result, expected, i] -> + assert_equal(result, expected) + assert result.data.buffer.device_id == i + end) end end end From 0b300c46d1fe7bfeb9a4f61e70aa8f26dc3322cb Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 12 Feb 2026 22:50:33 -0300 Subject: [PATCH 08/16] tests: improve tests --- exla/test/exla/defn/sharding_test.exs | 123 +++++++++----------------- 1 file changed, 40 insertions(+), 83 deletions(-) diff --git a/exla/test/exla/defn/sharding_test.exs b/exla/test/exla/defn/sharding_test.exs index b22e20a97a..d9fdb5f113 100644 --- a/exla/test/exla/defn/sharding_test.exs +++ b/exla/test/exla/defn/sharding_test.exs @@ -290,9 +290,6 @@ defmodule EXLA.Defn.ShardingTest do %{0 => [0]} ] - # Output: shard dim 0 on axis 0, dim 1 on axis 1 (like x) - output_shardings = [%{0 => [0], 1 => [1]}] - # For mesh {2, 2}, we have 4 partitions # x: {8, 2} sharded [[0], [1]] -> each partition gets {4, 1} # y: {8, 1} sharded [[0], []] -> each partition gets {4, 1} @@ -333,8 +330,7 @@ defmodule EXLA.Defn.ShardingTest do result = EXLA.to_mlir_module(fun, args, mesh: mesh, - input_shardings: input_shardings, - output_shardings: output_shardings + input_shardings: input_shardings ) expected_mlir = """ @@ -354,8 +350,7 @@ defmodule EXLA.Defn.ShardingTest do results = EXLA.shard_jit(fun, mesh, - input_shardings: input_shardings, - output_shardings: output_shardings + input_shardings: input_shardings ).(args) assert length(results) == 4 @@ -427,7 +422,6 @@ defmodule EXLA.Defn.ShardingTest do mesh = %Mesh{name: "mesh", shape: {2, 2}} # Only one sharding spec for two arguments input_shardings = [%{0 => [0]}] - output_shardings = [%{}] # For mesh {2, 2}, we have 4 partitions # Each partition has 2 inputs @@ -438,8 +432,7 @@ defmodule EXLA.Defn.ShardingTest do fn -> EXLA.to_mlir_module(fun, args, mesh: mesh, - input_shardings: input_shardings, - output_shardings: output_shardings + input_shardings: input_shardings ) end end @@ -452,15 +445,13 @@ defmodule EXLA.Defn.ShardingTest do mesh = %Mesh{name: "mesh", shape: {2, 2}} # Mesh has 2 axes (0 and 1), but we reference axis 2 input_shardings = [%{0 => [2]}] - output_shardings = [%{}] # For mesh {2, 2}, we have 4 partitions args = List.duplicate([Nx.iota({4, 2})], 4) assert_raise ArgumentError, fn -> EXLA.shard_jit(fun, mesh, - input_shardings: input_shardings, - output_shardings: output_shardings + input_shardings: input_shardings ).(args) end end @@ -471,7 +462,6 @@ defmodule EXLA.Defn.ShardingTest do mesh = %Mesh{name: "mesh", shape: {2, 2}} # Axis 0 used for both dimensions input_shardings = [%{0 => [0], 1 => [0]}] - output_shardings = [%{}] # For mesh {2, 2}, we have 4 partitions args = List.duplicate([Nx.iota({4, 1})], 4) @@ -480,8 +470,7 @@ defmodule EXLA.Defn.ShardingTest do ~r/axis 0 was used twice in the same input sharding/, fn -> EXLA.shard_jit(fun, mesh, - input_shardings: input_shardings, - output_shardings: output_shardings + input_shardings: input_shardings ).(args) end end @@ -492,15 +481,13 @@ defmodule EXLA.Defn.ShardingTest do mesh = %Mesh{name: "mesh", shape: {2, 2}} # Tensor is rank 2, but sharding spec has 3 dimensions input_shardings = [%{0 => [0], 1 => [1], 2 => []}] - output_shardings = [%{}] # For mesh {2, 2}, we have 4 partitions args = List.duplicate([Nx.iota({4, 2})], 4) assert_raise ArgumentError, fn -> EXLA.shard_jit(fun, mesh, - input_shardings: input_shardings, - output_shardings: output_shardings + input_shardings: input_shardings ).(args) end end @@ -511,7 +498,6 @@ defmodule EXLA.Defn.ShardingTest do mesh = %Mesh{name: "mesh", shape: {2}} # Tensor is rank 2, but -3 is out of bounds (only -1 and -2 are valid) input_shardings = [%{-3 => [0]}] - output_shardings = [%{}] args = List.duplicate([Nx.iota({4, 2})], 2) @@ -519,8 +505,7 @@ defmodule EXLA.Defn.ShardingTest do ~r/given axis \(-3\) invalid for shape with rank 2/, fn -> EXLA.shard_jit(fun, mesh, - input_shardings: input_shardings, - output_shardings: output_shardings + input_shardings: input_shardings ).(args) end end @@ -531,7 +516,6 @@ defmodule EXLA.Defn.ShardingTest do mesh = %Mesh{name: "mesh", shape: {2}} # Tensor is rank 2, but axis 3 is out of bounds input_shardings = [%{3 => [0]}] - output_shardings = [%{}] args = List.duplicate([Nx.iota({4, 2})], 2) @@ -539,8 +523,7 @@ defmodule EXLA.Defn.ShardingTest do ~r/given axis \(3\) invalid for shape with rank 2/, fn -> EXLA.shard_jit(fun, mesh, - input_shardings: input_shardings, - output_shardings: output_shardings + input_shardings: input_shardings ).(args) end end @@ -582,8 +565,6 @@ defmodule EXLA.Defn.ShardingTest do # All dimensions replicated input_shardings = [%{}] - # Output: replicated (all-gathered) across all devices - output_shardings = [%{}] # For mesh {2, 2}, we have 4 partitions # Input fully replicated -> each partition gets full {8, 4} @@ -592,8 +573,7 @@ defmodule EXLA.Defn.ShardingTest do result = EXLA.to_mlir_module(fun, args, mesh: mesh, - input_shardings: input_shardings, - output_shardings: output_shardings + input_shardings: input_shardings ) assert is_binary(result.mlir_module) @@ -606,8 +586,6 @@ defmodule EXLA.Defn.ShardingTest do # Scalar has no dimensions to shard input_shardings = [%{}] - # Output: replicated (all-gathered) across all devices - output_shardings = [%{}] # For mesh {2}, we have 2 partitions # Scalar is replicated across all partitions @@ -616,8 +594,7 @@ defmodule EXLA.Defn.ShardingTest do result = EXLA.to_mlir_module(fun, args, mesh: mesh, - input_shardings: input_shardings, - output_shardings: output_shardings + input_shardings: input_shardings ) assert is_binary(result.mlir_module) @@ -637,9 +614,6 @@ defmodule EXLA.Defn.ShardingTest do %{} ] - # Outputs: both replicated (tuple with two elements) - output_shardings = [%{}, %{}] - # For mesh {2, 2}, we have 4 partitions # x: {8, 4} sharded [[0], [1]] -> each partition gets {4, 2} # y: {8, 4} sharded [[0], []] -> each partition gets {4, 4} @@ -658,8 +632,7 @@ defmodule EXLA.Defn.ShardingTest do result = EXLA.to_mlir_module(fun, args, mesh: mesh, - input_shardings: input_shardings, - output_shardings: output_shardings + input_shardings: input_shardings ) assert is_binary(result.mlir_module) @@ -673,8 +646,6 @@ defmodule EXLA.Defn.ShardingTest do mesh = %Mesh{name: "test_mesh", shape: {2, 2}} input_shardings = [%{0 => [0]}] - # Output: replicated (all-gathered) across all devices - output_shardings = [%{}] # For mesh {2, 2}, we have 4 partitions # Input sharded [[0], []] -> each partition gets {4, 2} @@ -683,8 +654,7 @@ defmodule EXLA.Defn.ShardingTest do result = EXLA.to_mlir_module(fun, args, mesh: mesh, - input_shardings: input_shardings, - output_shardings: output_shardings + input_shardings: input_shardings ) mlir = result.mlir_module @@ -702,8 +672,6 @@ defmodule EXLA.Defn.ShardingTest do mesh = %Mesh{name: "mesh", shape: {2, 2}} input_shardings = [%{0 => [0], 1 => [1]}] - # Output: replicated (all-gathered) across all devices - output_shardings = [%{}] # For mesh {2, 2}, we have 4 partitions # Input sharded [[0], [1]] -> each partition gets {4, 1} @@ -712,8 +680,7 @@ defmodule EXLA.Defn.ShardingTest do result = EXLA.to_mlir_module(fun, args, mesh: mesh, - input_shardings: input_shardings, - output_shardings: output_shardings + input_shardings: input_shardings ) mlir = result.mlir_module @@ -732,8 +699,6 @@ defmodule EXLA.Defn.ShardingTest do # Use named dimensions instead of indices input_shardings = [%{:batch => [0]}] - # Output: replicated (all-gathered) across all devices - output_shardings = [%{}] # For mesh {2}, we have 2 partitions # Named dimension :batch should map to first dimension (index 0) @@ -742,8 +707,7 @@ defmodule EXLA.Defn.ShardingTest do result = EXLA.to_mlir_module(fun, args, mesh: mesh, - input_shardings: input_shardings, - output_shardings: output_shardings + input_shardings: input_shardings ) # Should generate valid MLIR with sharding on first dimension @@ -890,9 +854,10 @@ defmodule EXLA.Defn.ShardingTest do end) end end - +end +describe "all_gather" do @moduletag :multi_device - test "generates correct MLIR with all_gather" do + test "in all dims results in the same tensor in all devices" do fun = fn x, y -> Nx.add(x, y) |> Nx.Defn.Kernel.all_gather(all_gather_dim: 0, replica_groups: [[0]]) |> Nx.Defn.Kernel.all_gather(all_gather_dim: 1, replica_groups: [[0]]) @@ -950,7 +915,7 @@ defmodule EXLA.Defn.ShardingTest do [128, 130] ]) - Enum.zip_with([results, 0..3], fn [result, i] -> + Enum.with_index(results, fn result, i -> assert_equal(result, expected_result) assert result.data.buffer.device_id == i end) @@ -958,44 +923,36 @@ defmodule EXLA.Defn.ShardingTest do @moduletag :multi_device test "can return partially sharded results" do - fun = fn x, y -> Nx.add(x, y) end + fun = fn x, y -> + x + |> Nx.Defn.Kernel.all_gather(all_gather_dim: 1, replica_groups: [[0]]) + |> Nx.add(y) + end mesh = %Mesh{name: "mesh", shape: {2, 2}} # Inputs sharded on both axes - input_shardings = [%{0 => [0], 1 => [1]}, %{0 => [0], 1 => [1]}] - # Output: sharded only on axis 0 (dim 1 replicated) -> each partition gets {4, 2} - output_shardings = [%{0 => [0]}] + input_shardings = [%{0 => [0], 1 => [1]}, %{0 => [0]}] - # Logical x: 8x2, y: 8x2. Each partition gets {4, 1} of each + # Logical x: 8x2, y: 8x2. Each partition gets {4, 1} of x and {4, 2} of y args = [ - [Nx.tensor([[0], [1], [2], [3]]), Nx.tensor([[100], [101], [102], [103]])], - [Nx.tensor([[10], [11], [12], [13]]), Nx.tensor([[110], [111], [112], [113]])], - [Nx.tensor([[4], [5], [6], [7]]), Nx.tensor([[104], [105], [106], [107]])], - [Nx.tensor([[14], [15], [16], [17]]), Nx.tensor([[114], [115], [116], [117]])] + [Nx.tensor([[0], [1], [2], [3]]), Nx.tensor([[100, 101], [102, 103], [104, 105], [106, 107]])], + [Nx.tensor([[4], [5], [6], [7]]), Nx.tensor([[100, 101], [102, 103], [104, 105], [106, 107]])], + [Nx.tensor([[8], [9], [10], [11]]), Nx.tensor([[110, 111], [112, 113], [114, 115], [116, 117]])], + [Nx.tensor([[12], [13], [14], [15]]), Nx.tensor([[110, 111], [112, 113], [114, 115], [116, 117]])] ] - results = - EXLA.shard_jit(fun, mesh, - input_shardings: input_shardings, - output_shardings: output_shardings - ).(args) - - assert length(results) == 4 - - # Partially sharded output: dim 0 sharded on axis 0, dim 1 not in output spec - # Each device returns its local shard {4, 1} (x+y computed locally) - # Dev0: col0 rows 0-3, Dev1: col1 rows 0-3, Dev2: col0 rows 4-7, Dev3: col1 rows 4-7 - expected_results = [ - Nx.tensor([[100], [102], [104], [106]]), - Nx.tensor([[120], [122], [124], [126]]), - Nx.tensor([[108], [110], [112], [114]]), - Nx.tensor([[128], [130], [132], [134]]) - ] + assert [result0, result1, result2, result3] = + EXLA.shard_jit(fun, mesh, input_shardings: input_shardings).(args) - Enum.zip_with([results, expected_results, 0..3], fn [result, expected, i] -> - assert_equal(result, expected) - assert result.data.buffer.device_id == i - end) + # After gathering, devices 0 and 1 have the same data as each other, likewise for devices 2 and 3 + assert_equal(result0, Nx.tensor([[100, 105], [103, 108], [106, 111], [109, 114]])) + assert result0.data.buffer.device_id == 0 + assert_equal(result0, Nx.tensor([[100, 105], [103, 108], [106, 111], [109, 114]])) + assert result1.data.buffer.device_id == 1 + assert_equal(result2, Nx.tensor([[118, 123], [121, 126], [124, 129], [127, 132]])) + assert result2.data.buffer.device_id == 2 + assert_equal(result3, Nx.tensor([[118, 123], [121, 126], [124, 129], [127, 132]])) + assert result3.data.buffer.device_id == 3 end end end From ebf4b9f3cbfb76adc723206898fdbd59b81ed3d9 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 12 Feb 2026 22:52:11 -0300 Subject: [PATCH 09/16] chore: format --- exla/lib/exla/defn.ex | 2 +- exla/lib/exla/mlir/value.ex | 9 +++- exla/test/exla/defn/sharding_test.exs | 65 +++++++++++++-------------- 3 files changed, 40 insertions(+), 36 deletions(-) diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index f4447bb1ee..48593c22aa 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -1474,7 +1474,7 @@ defmodule EXLA.Defn do EXLA.Lib.argsort(state.builder, tensor, dimension, stable, comp, ans.type) end -## to_operator collective ops + ## to_operator collective ops defp to_operator(:all_gather, [%Value{} = tensor, opts], ans, _state) do all_gather_dim = Keyword.fetch!(opts, :all_gather_dim) diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index e548693497..ec6527e091 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -64,7 +64,14 @@ defmodule EXLA.MLIR.Value do end end - def all_gather([%Value{function: func} | _] = operands, typespec, all_gather_dim, replica_groups, use_global_device_ids, channel_id \\ nil) do + def all_gather( + [%Value{function: func} | _] = operands, + typespec, + all_gather_dim, + replica_groups, + use_global_device_ids, + channel_id \\ nil + ) do result_types = typespecs_to_mlir_types([typespec]) num_groups = length(replica_groups) diff --git a/exla/test/exla/defn/sharding_test.exs b/exla/test/exla/defn/sharding_test.exs index d9fdb5f113..e11bab0981 100644 --- a/exla/test/exla/defn/sharding_test.exs +++ b/exla/test/exla/defn/sharding_test.exs @@ -349,9 +349,7 @@ defmodule EXLA.Defn.ShardingTest do assert expected_mlir == result.mlir_module results = - EXLA.shard_jit(fun, mesh, - input_shardings: input_shardings - ).(args) + EXLA.shard_jit(fun, mesh, input_shardings: input_shardings).(args) assert length(results) == 4 @@ -450,9 +448,7 @@ defmodule EXLA.Defn.ShardingTest do args = List.duplicate([Nx.iota({4, 2})], 4) assert_raise ArgumentError, fn -> - EXLA.shard_jit(fun, mesh, - input_shardings: input_shardings - ).(args) + EXLA.shard_jit(fun, mesh, input_shardings: input_shardings).(args) end end @@ -469,9 +465,7 @@ defmodule EXLA.Defn.ShardingTest do assert_raise ArgumentError, ~r/axis 0 was used twice in the same input sharding/, fn -> - EXLA.shard_jit(fun, mesh, - input_shardings: input_shardings - ).(args) + EXLA.shard_jit(fun, mesh, input_shardings: input_shardings).(args) end end @@ -486,9 +480,7 @@ defmodule EXLA.Defn.ShardingTest do args = List.duplicate([Nx.iota({4, 2})], 4) assert_raise ArgumentError, fn -> - EXLA.shard_jit(fun, mesh, - input_shardings: input_shardings - ).(args) + EXLA.shard_jit(fun, mesh, input_shardings: input_shardings).(args) end end @@ -504,9 +496,7 @@ defmodule EXLA.Defn.ShardingTest do assert_raise ArgumentError, ~r/given axis \(-3\) invalid for shape with rank 2/, fn -> - EXLA.shard_jit(fun, mesh, - input_shardings: input_shardings - ).(args) + EXLA.shard_jit(fun, mesh, input_shardings: input_shardings).(args) end end @@ -522,9 +512,7 @@ defmodule EXLA.Defn.ShardingTest do assert_raise ArgumentError, ~r/given axis \(3\) invalid for shape with rank 2/, fn -> - EXLA.shard_jit(fun, mesh, - input_shardings: input_shardings - ).(args) + EXLA.shard_jit(fun, mesh, input_shardings: input_shardings).(args) end end end @@ -565,7 +553,6 @@ defmodule EXLA.Defn.ShardingTest do # All dimensions replicated input_shardings = [%{}] - # For mesh {2, 2}, we have 4 partitions # Input fully replicated -> each partition gets full {8, 4} args = List.duplicate([Nx.iota({8, 4})], 4) @@ -586,7 +573,6 @@ defmodule EXLA.Defn.ShardingTest do # Scalar has no dimensions to shard input_shardings = [%{}] - # For mesh {2}, we have 2 partitions # Scalar is replicated across all partitions args = List.duplicate([Nx.tensor(5.0)], 2) @@ -646,7 +632,6 @@ defmodule EXLA.Defn.ShardingTest do mesh = %Mesh{name: "test_mesh", shape: {2, 2}} input_shardings = [%{0 => [0]}] - # For mesh {2, 2}, we have 4 partitions # Input sharded [[0], []] -> each partition gets {4, 2} args = List.duplicate([Nx.iota({4, 2})], 4) @@ -672,7 +657,6 @@ defmodule EXLA.Defn.ShardingTest do mesh = %Mesh{name: "mesh", shape: {2, 2}} input_shardings = [%{0 => [0], 1 => [1]}] - # For mesh {2, 2}, we have 4 partitions # Input sharded [[0], [1]] -> each partition gets {4, 1} args = List.duplicate([Nx.iota({4, 1})], 4) @@ -699,7 +683,6 @@ defmodule EXLA.Defn.ShardingTest do # Use named dimensions instead of indices input_shardings = [%{:batch => [0]}] - # For mesh {2}, we have 2 partitions # Named dimension :batch should map to first dimension (index 0) args = List.duplicate([Nx.iota({4, 2}, names: [:batch, :features])], 2) @@ -854,14 +837,16 @@ defmodule EXLA.Defn.ShardingTest do end) end end -end -describe "all_gather" do + end + + describe "all_gather" do @moduletag :multi_device test "in all dims results in the same tensor in all devices" do - fun = fn x, y -> Nx.add(x, y) - |> Nx.Defn.Kernel.all_gather(all_gather_dim: 0, replica_groups: [[0]]) - |> Nx.Defn.Kernel.all_gather(all_gather_dim: 1, replica_groups: [[0]]) - end + fun = fn x, y -> + Nx.add(x, y) + |> Nx.Defn.Kernel.all_gather(all_gather_dim: 0, replica_groups: [[0]]) + |> Nx.Defn.Kernel.all_gather(all_gather_dim: 1, replica_groups: [[0]]) + end mesh = %Mesh{name: "mesh", shape: {2, 2}} # First arg: 0..15 (8x2), shard dim 0 on mesh axis 0, dim 1 on mesh axis 1 @@ -935,14 +920,26 @@ describe "all_gather" do # Logical x: 8x2, y: 8x2. Each partition gets {4, 1} of x and {4, 2} of y args = [ - [Nx.tensor([[0], [1], [2], [3]]), Nx.tensor([[100, 101], [102, 103], [104, 105], [106, 107]])], - [Nx.tensor([[4], [5], [6], [7]]), Nx.tensor([[100, 101], [102, 103], [104, 105], [106, 107]])], - [Nx.tensor([[8], [9], [10], [11]]), Nx.tensor([[110, 111], [112, 113], [114, 115], [116, 117]])], - [Nx.tensor([[12], [13], [14], [15]]), Nx.tensor([[110, 111], [112, 113], [114, 115], [116, 117]])] + [ + Nx.tensor([[0], [1], [2], [3]]), + Nx.tensor([[100, 101], [102, 103], [104, 105], [106, 107]]) + ], + [ + Nx.tensor([[4], [5], [6], [7]]), + Nx.tensor([[100, 101], [102, 103], [104, 105], [106, 107]]) + ], + [ + Nx.tensor([[8], [9], [10], [11]]), + Nx.tensor([[110, 111], [112, 113], [114, 115], [116, 117]]) + ], + [ + Nx.tensor([[12], [13], [14], [15]]), + Nx.tensor([[110, 111], [112, 113], [114, 115], [116, 117]]) + ] ] assert [result0, result1, result2, result3] = - EXLA.shard_jit(fun, mesh, input_shardings: input_shardings).(args) + EXLA.shard_jit(fun, mesh, input_shardings: input_shardings).(args) # After gathering, devices 0 and 1 have the same data as each other, likewise for devices 2 and 3 assert_equal(result0, Nx.tensor([[100, 105], [103, 108], [106, 111], [109, 114]])) From 023b1a8e09956aac393ee6c6f47f176e1d0d8c30 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 12 Feb 2026 22:55:07 -0300 Subject: [PATCH 10/16] simplify Expr.all_gather --- nx/lib/nx/defn/expr.ex | 26 ++------------------------ 1 file changed, 2 insertions(+), 24 deletions(-) diff --git a/nx/lib/nx/defn/expr.ex b/nx/lib/nx/defn/expr.ex index ab11c46af1..e753b9aebb 100644 --- a/nx/lib/nx/defn/expr.ex +++ b/nx/lib/nx/defn/expr.ex @@ -1167,30 +1167,8 @@ defmodule Nx.Defn.Expr do end def all_gather(tensor, opts) do - {[tensor], context} = to_exprs([tensor]) - - _all_gather_dim = opts[:all_gather_dim] - replica_groups = opts[:replica_groups] - - # Calculate group size (number of replicas per group) - _group_size = - case replica_groups do - [first_group | _] -> length(first_group) - [] -> 1 - end - - # Calculate output shape by multiplying the gather dimension by group_size - input_shape = tensor.shape - output_shape = - input_shape -# |> Tuple.to_list() -# |> List.update_at(all_gather_dim, &(&1 * group_size)) -# |> List.to_tuple() - - # Create output tensor with the new shape - out = %{tensor | shape: output_shape} - - expr(out, context, :all_gather, [tensor, opts]) + {[expr], context} = to_exprs([tensor]) + expr(expr, context, :all_gather, [expr, opts]) end @impl true From 74ae783d4b80ef9cd215cbe7d113d9e2d8a13982 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 12 Feb 2026 22:58:25 -0300 Subject: [PATCH 11/16] docs --- nx/lib/nx/defn/kernel.ex | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/nx/lib/nx/defn/kernel.ex b/nx/lib/nx/defn/kernel.ex index 809a66b480..4530d79e16 100644 --- a/nx/lib/nx/defn/kernel.ex +++ b/nx/lib/nx/defn/kernel.ex @@ -1670,20 +1670,15 @@ defmodule Nx.Defn.Kernel do end @doc """ - Gathers tensors from all replicas along a specified dimension. + Gathers tensors along a specified axis across an `Nx.Mesh`. - This operation concatenates tensors from multiple replicas/devices along - the specified dimension. Requires a backend that supports multi-device operations. + Requires a backend that supports collective operations. - ## Parameters - - * `tensor` - The input tensor to gather - - * `opts` - Optional keyword list. These are backend- and compiler-specific; - see your backend or compiler docs for supported options. + ## Options + Refer to the chosen backend/compiler documentation for supported options. """ - def all_gather(tensor, opts \\ []) do + def all_gather(tensor, opts) do Nx.Defn.Expr.all_gather(tensor, opts) end From eae07df6f64d47812740e232d0c5da9214ef49c4 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 12 Feb 2026 23:06:10 -0300 Subject: [PATCH 12/16] fix: floating --- nx/lib/nx/floating.ex | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/nx/lib/nx/floating.ex b/nx/lib/nx/floating.ex index 4ded248c91..54cec81378 100644 --- a/nx/lib/nx/floating.ex +++ b/nx/lib/nx/floating.ex @@ -209,8 +209,16 @@ defmodule Nx.Floating do # E4M3FN: 1 sign, 4 exponent (bias 7), 3 mantissa # Max value: 448.0 (0x7E), Min value: -448.0 (0xFE) def dump_f8_e4m3fn(0), do: <<0b0000_0000>> - def dump_f8_e4m3fn(+0.0), do: <<0b0000_0000>> - def dump_f8_e4m3fn(-0.0), do: <<0b1000_0000>> + + if +0.0 === -0.0 do + # OTP versions <= 28.0 have a bug where +0.0 === -0.0, + # so we need to special-case it to avoid compiler errors + # related to the +0.0 clause shadowing the -0.0 clause + def dump_f8_e4m3fn(x) when x == 0.0, do: <<0b0000_0000>> + else + def dump_f8_e4m3fn(+0.0), do: <<0b0000_0000>> + def dump_f8_e4m3fn(-0.0), do: <<0b1000_0000>> + end def dump_f8_e4m3fn(x) when is_number(x) do # Clamp to E4M3FN range and convert From 187b81db045b4afd6be5386fd2c06c529dacd0b4 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 12 Feb 2026 23:07:53 -0300 Subject: [PATCH 13/16] fix: floating --- nx/test/nx/floating_test.exs | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/nx/test/nx/floating_test.exs b/nx/test/nx/floating_test.exs index 7e0a088fe0..5b6f050e32 100644 --- a/nx/test/nx/floating_test.exs +++ b/nx/test/nx/floating_test.exs @@ -349,7 +349,12 @@ defmodule Nx.FloatingTest do test "pretty printing" do # Zeroes assert Nx.tensor([0.0], type: :f8) |> inspect() =~ "[0.0]" - assert Nx.tensor([-0.0], type: :f8) |> inspect() =~ "[-0.0]" + + if +0.0 === -0.0 do + assert Nx.tensor([-0.0], type: :f8) |> inspect() =~ "[0.0]" + else + assert Nx.tensor([-0.0], type: :f8) |> inspect() =~ "[-0.0]" + end # Infinity assert Nx.tensor([:infinity], type: :f8) |> inspect() =~ "[Inf]" @@ -399,7 +404,12 @@ defmodule Nx.FloatingTest do test "pretty printing" do # Zeroes assert Nx.tensor([0.0], type: :bf16) |> inspect() =~ "[0.0]" - assert Nx.tensor([-0.0], type: :bf16) |> inspect() =~ "[-0.0]" + + if +0.0 === -0.0 do + assert Nx.tensor([-0.0], type: :bf16) |> inspect() =~ "[0.0]" + else + assert Nx.tensor([-0.0], type: :bf16) |> inspect() =~ "[-0.0]" + end # Infinity assert Nx.tensor([:infinity], type: :bf16) |> inspect() =~ "[Inf]" From 92825971bb0dfe18d6d9821567a07bb2c1d6442a Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 12 Feb 2026 23:08:36 -0300 Subject: [PATCH 14/16] format --- nx/lib/nx/binary_backend.ex | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/nx/lib/nx/binary_backend.ex b/nx/lib/nx/binary_backend.ex index cae7c94997..974d558b0d 100644 --- a/nx/lib/nx/binary_backend.ex +++ b/nx/lib/nx/binary_backend.ex @@ -522,11 +522,13 @@ defmodule Nx.BinaryBackend do right_batch_item_bits = right_batch_item_length * right_size <<_::bitstring-size(^left_offset_bits), - left_batch_item_binary::bitstring-size(^left_batch_item_bits), _::bitstring>> = + left_batch_item_binary::bitstring-size(^left_batch_item_bits), + _::bitstring>> = left_binary <<_::bitstring-size(^right_offset_bits), - right_batch_item_binary::bitstring-size(^right_batch_item_bits), _::bitstring>> = + right_batch_item_binary::bitstring-size(^right_batch_item_bits), + _::bitstring>> = right_binary bin_dot( @@ -1756,7 +1758,8 @@ defmodule Nx.BinaryBackend do before_slice_size = current - previous <> = + current_bitstring::bitstring-size(^target_chunk), + to_traverse::bitstring>> = to_traverse updated_elements = From 2dca531a831f7f9e7a8a342e6eb13e4b46e2c63b Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 12 Feb 2026 23:12:12 -0300 Subject: [PATCH 15/16] format --- nx/test/nx/defn/composite_test.exs | 3 ++- nx/test/nx/floating_test.exs | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/nx/test/nx/defn/composite_test.exs b/nx/test/nx/defn/composite_test.exs index 2618e29170..6dc8fc1d52 100644 --- a/nx/test/nx/defn/composite_test.exs +++ b/nx/test/nx/defn/composite_test.exs @@ -24,7 +24,8 @@ defmodule Nx.Defn.CompositeTest do Nx.tensor(1), Nx.tensor(3, type: {:c, 64}), Nx.tensor(4, type: {:c, 64}) - }, Nx.tensor(2, type: {:c, 64})} == + }, + Nx.tensor(2, type: {:c, 64})} == Composite.traverse( {1, Complex.new(2), Nx.tensor(3)}, 0, diff --git a/nx/test/nx/floating_test.exs b/nx/test/nx/floating_test.exs index 5b6f050e32..a9399ff244 100644 --- a/nx/test/nx/floating_test.exs +++ b/nx/test/nx/floating_test.exs @@ -156,7 +156,7 @@ defmodule Nx.FloatingTest do {0x7F, :nan}, # Negative values (sign bit = 1) # Denormalized (exponent = 0): value = -mantissa/8 * 2^-6 - {0x80, -0.0}, + if(+0.0 === -0.0, do: {0x80, 0.0}, else: {0x80, -0.0}), {0x81, -0.001953125}, {0x82, -0.00390625}, {0x83, -0.005859375}, From 90410bbeea740dd38edf58539daeead9d6429652 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 12 Feb 2026 23:14:34 -0300 Subject: [PATCH 16/16] fix test --- nx/test/nx/floating_test.exs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nx/test/nx/floating_test.exs b/nx/test/nx/floating_test.exs index a9399ff244..4720c94d0d 100644 --- a/nx/test/nx/floating_test.exs +++ b/nx/test/nx/floating_test.exs @@ -156,7 +156,7 @@ defmodule Nx.FloatingTest do {0x7F, :nan}, # Negative values (sign bit = 1) # Denormalized (exponent = 0): value = -mantissa/8 * 2^-6 - if(+0.0 === -0.0, do: {0x80, 0.0}, else: {0x80, -0.0}), + if(+0.0 === -0.0, do: {0x00, -0.0}, else: {0x80, -0.0}), {0x81, -0.001953125}, {0x82, -0.00390625}, {0x83, -0.005859375},