Skip to content

Commit 14b594f

Browse files
committed
Move arm.passes to arm._passes (pytorch#5918)
Summary: Changing arm.passes to arm._passes to indicate that these passes are not covered under the API stability guarantee. Pull Request resolved: pytorch#5918 Reviewed By: malfet, helunwencser Differential Revision: D63926055 fbshipit-source-id: 141a5be9f3a81e75784825357bacbab91904620c (cherry picked from commit 83c95df)
1 parent 40358fa commit 14b594f

17 files changed

+226
-9
lines changed

backends/arm/TARGETS

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ python_library(
88
typing = True,
99
deps = [
1010
":arm_backend",
11-
"//executorch/backends/arm/passes:passes",
11+
"//executorch/backends/arm/_passes:passes",
1212
"//executorch/exir:lib",
1313
],
1414
)
@@ -27,7 +27,7 @@ python_library(
2727
":arm_vela",
2828
"//executorch/backends/arm/operators:lib",
2929
"//executorch/backends/arm/operators:node_visitor",
30-
"//executorch/backends/arm/passes:passes",
30+
"//executorch/backends/arm/_passes:passes",
3131
],
3232
)
3333

File renamed without changes.

backends/arm/passes/arm_pass_manager.py renamed to backends/arm/_passes/arm_pass_manager.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
# pyre-unsafe
99

1010
import torch
11-
from executorch.backends.arm.passes.annotate_channels_last_dim_order_pass import (
11+
from executorch.backends.arm._passes.annotate_channels_last_dim_order_pass import (
1212
AnnotateChannelsLastDimOrder,
1313
)
14-
from executorch.backends.arm.passes.convert_expand_copy_to_repeat import (
14+
from executorch.backends.arm._passes.convert_expand_copy_to_repeat import (
1515
ConvertExpandCopyToRepeatPass,
1616
)
17-
from executorch.backends.arm.passes.convert_split_to_slice import (
17+
from executorch.backends.arm._passes.convert_split_to_slice import (
1818
ConvertSplitToSlicePass,
1919
)
2020
from executorch.backends.arm.passes.meandim_to_averagepool_pass import (
+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Optional
8+
9+
import torch
10+
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from torch._ops import OpOverload
13+
14+
15+
def create_node(
16+
graph: torch.fx.Graph,
17+
op_target: OpOverload,
18+
args: tuple = (),
19+
kwargs: Optional[dict] = None,
20+
quantize: bool = False,
21+
q_params: Optional[tuple] = None,
22+
):
23+
"""
24+
Adds a node to 'graph'. graph.inserting_before/after() should be used before the call to decide where to insert the node.
25+
If quantize is true and q_params is not None, a q dq pair is inserted after the newly created node.
26+
"""
27+
28+
node = graph.create_node(
29+
"call_function",
30+
op_target,
31+
args=args,
32+
kwargs=kwargs or {},
33+
)
34+
if quantize and q_params:
35+
return insert_q_dq_pair(graph, node, q_params)
36+
return node
37+
38+
39+
def insert_q_dq_pair(
40+
graph: torch.fx.Graph,
41+
anchor: torch.fx.Node,
42+
q_params: tuple,
43+
):
44+
"""
45+
Inserts a q dq node pair after the node 'anchor'.
46+
"""
47+
48+
with graph.inserting_after(anchor):
49+
q = create_node(
50+
graph=graph,
51+
op_target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
52+
args=(), # We add the argument last
53+
)
54+
q.meta = anchor.meta
55+
with graph.inserting_after(q):
56+
dq = create_node(
57+
graph=graph,
58+
op_target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
59+
args=(q,) + q_params,
60+
)
61+
dq.meta = q.meta
62+
anchor.replace_all_uses_with(dq)
63+
# We add this last so the replace all uses above does not replace the quantized
64+
# node's first use
65+
q.args = (anchor,) + q_params
66+
return dq
+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.exir.pass_base import ExportPass, PassResult
8+
9+
10+
class CastInt64ToInt32Pass(ExportPass):
11+
def __init__(self, exported_program: torch.export.ExportedProgram):
12+
super(CastInt64ToInt32Pass, self).__init__()
13+
self.exported_program = exported_program
14+
15+
def _to_int32(self, graph_module: torch.fx.GraphModule):
16+
for node in graph_module.graph.nodes:
17+
fake_tensor = node.meta["val"]
18+
if isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor):
19+
if node.meta["val"].dtype == torch.int64:
20+
node.meta["val"] = node.meta["val"].to(torch.int32)
21+
buffer_name = (
22+
self.exported_program.graph_signature.inputs_to_buffers[
23+
node.name
24+
]
25+
)
26+
new_tensor = self.exported_program.state_dict[buffer_name].to(
27+
torch.int32
28+
)
29+
self.exported_program.state_dict[buffer_name] = new_tensor
30+
31+
def call(self, graph_module: torch.fx.GraphModule):
32+
self._to_int32(graph_module)
33+
graph_module.recompile()
34+
graph_module = super().call(graph_module).graph_module
35+
return PassResult(graph_module, True)
+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.pass_base import ExportPass
10+
11+
12+
def get_div_decomposition(op) -> tuple:
13+
"""
14+
Returns the the (reciprocal_op, mul_op), where the ops depends on if
15+
the div op is in exir_ops torch.ops.aten.
16+
"""
17+
if op == exir_ops.edge.aten.div.Tensor:
18+
return (exir_ops.edge.aten.reciprocal.default, exir_ops.edge.aten.mul.Tensor)
19+
if op == torch.ops.aten.div.Tensor:
20+
return (torch.ops.aten.reciprocal.default, torch.ops.aten.mul.Tensor)
21+
raise RuntimeError(f"Can't get div decomposition for op {op}")
22+
23+
24+
class DecomposeDivPass(ExportPass):
25+
"""
26+
This pass decomposes div into a mul and a reciprocal node.
27+
28+
Example:
29+
y = div(a,b)
30+
Becomes:
31+
x = reciprocal(b)
32+
y = mul(a,x)
33+
"""
34+
35+
def call_operator(self, op, args, kwargs, meta):
36+
if op not in (exir_ops.edge.aten.div.Tensor, torch.ops.aten.div.Tensor):
37+
return super().call_operator(op, args, kwargs, meta)
38+
39+
reciprocal_op, mul_op = get_div_decomposition(op)
40+
41+
numerator = args[0]
42+
denominator = args[1]
43+
reciprocal = super().call_operator(reciprocal_op, (denominator,), {}, meta)
44+
45+
return super().call_operator(mul_op, (numerator, reciprocal), {}, meta)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import cast, Union
8+
9+
import torch
10+
from executorch.backends.arm.tosa_mapping import extract_tensor_meta
11+
12+
from executorch.exir.pass_base import ExportPass, PassResult
13+
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
14+
from torch.fx import GraphModule, Node
15+
16+
17+
class ScalarsToAttributePass(ExportPass):
18+
"""
19+
For ops in 'targeted_ops', convert inputs that are scalar values
20+
to attribute Nodes that output the same value.
21+
"""
22+
23+
targeted_ops = [
24+
torch.ops.aten.add.Tensor,
25+
torch.ops.aten.sub.Tensor,
26+
torch.ops.aten.sub_.Tensor,
27+
torch.ops.aten.mul.Tensor,
28+
torch.ops.aten.div.Tensor,
29+
]
30+
31+
def call(self, graph_module: GraphModule) -> PassResult:
32+
for n in graph_module.graph.nodes:
33+
n = cast(Node, n)
34+
if n.op != "call_function" or n.target not in self.targeted_ops:
35+
continue
36+
37+
biggest_rank = 1
38+
for arg in n.args:
39+
if isinstance(arg, Node):
40+
_, shape, _ = extract_tensor_meta(arg.meta)
41+
biggest_rank = max(biggest_rank, len(shape))
42+
43+
new_args = []
44+
for arg in n.args:
45+
if isinstance(arg, Node):
46+
new_args.append(arg)
47+
continue
48+
49+
prefix = "_tensor_constant_"
50+
get_new_attr_name = get_new_attr_name_with_prefix(prefix)
51+
tensor_constant_name = get_new_attr_name(graph_module)
52+
float_tensor = torch.tensor(
53+
float(cast(Union[int, float], arg))
54+
).reshape((1,) * biggest_rank)
55+
graph_module.register_buffer(tensor_constant_name, float_tensor)
56+
fake_mode = n.meta["val"].fake_mode
57+
58+
with graph_module.graph.inserting_before(n):
59+
get_attr_node = graph_module.graph.create_node(
60+
"get_attr", tensor_constant_name, (), {}
61+
)
62+
get_attr_node.meta["val"] = fake_mode.from_tensor(
63+
float_tensor, static_shapes=True
64+
)
65+
new_args.append(get_attr_node)
66+
n.args = tuple(new_args)
67+
68+
graph_module.recompile()
69+
return PassResult(graph_module, True)

backends/arm/arm_backend.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
from executorch.backends.arm.operators.node_visitor import get_node_visitors
2121
from executorch.backends.arm.operators.op_output import process_output
2222
from executorch.backends.arm.operators.op_placeholder import process_placeholder
23-
from executorch.backends.arm.passes.arm_pass_manager import ArmPassManager
23+
from executorch.backends.arm._passes.arm_pass_manager import (
24+
ArmPassManager,
25+
) # usort: skip
2426
from executorch.backends.arm.tosa_utils import (
2527
dbg_fail,
2628
dbg_tosa_dump,

backends/arm/arm_partitioner.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from typing import final, List
1212

1313
import torch
14-
from executorch.backends.arm.arm_backend import ArmBackend
15-
from executorch.backends.arm.passes.tag_io_quant_pass import TagIOQuantPass
14+
from executorch.backends.arm.arm_backend import ArmBackend # usort: skip
15+
from executorch.backends.arm._passes.tag_io_quant_pass import TagIOQuantPass
1616
from executorch.exir.backend.compile_spec_schema import CompileSpec
1717
from executorch.exir.backend.partitioner import (
1818
DelegationSpec,

backends/arm/test/passes/test_meandim_to_averagepool2d.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import unittest
88

99
import torch
10-
from executorch.backends.arm.passes.meandim_to_averagepool_pass import (
10+
from executorch.backends.arm._passes.meandim_to_averagepool_pass import (
1111
ConvertMeanDimToAveragePool,
1212
)
1313

0 commit comments

Comments
 (0)