Skip to content

Commit ac99d63

Browse files
zhxchen17facebook-github-bot
authored andcommitted
[jit] Make operation call accept Stack& instead Stack* (pytorch#63414)
Summary: Pull Request resolved: pytorch#63414 Misuse of raw pointer in here where stack is never nullable. ghstack-source-id: 136938318 Test Plan: compiles. Imported from OSS Reviewed By: ejguan Differential Revision: D30375410 fbshipit-source-id: 9d65b620bb76d90d886c800f54308520095d58ee
1 parent 93d2e50 commit ac99d63

34 files changed

+451
-409
lines changed

aten/src/ATen/core/dispatch/Dispatcher.h

+4
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,10 @@ class TORCH_API OperatorHandle {
344344
c10::Dispatcher::singleton().callBoxed(*this, stack);
345345
}
346346

347+
void callBoxed(Stack& stack) const {
348+
callBoxed(&stack);
349+
}
350+
347351
void redispatchBoxed(DispatchKeySet ks, Stack* stack) const {
348352
c10::Dispatcher::singleton().redispatchBoxed(*this, ks, stack);
349353
}

aten/src/ATen/core/stack.h

+39-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
#pragma once
22

3+
#include <type_traits>
4+
35
#include <ATen/core/ivalue.h>
6+
#include <c10/util/Deprecated.h>
47

58
// TODO move this to c10 namespace
69

@@ -9,7 +12,42 @@ namespace jit {
912

1013
using c10::IValue;
1114
using Stack = std::vector<IValue>;
12-
using Operation = std::function<void(Stack*)>;
15+
16+
class Operation {
17+
template <typename F, typename Arg>
18+
using accepts = std::is_constructible<std::function<void(Arg)>, F&&>;
19+
20+
public:
21+
template <typename F,
22+
std::enable_if_t<accepts<F, Stack*>::value, int> = 0>
23+
C10_DEPRECATED_MESSAGE("Please use void(Stack&) to register operator instead.")
24+
Operation(F&& raw): op_([raw = std::forward<F>(raw)](Stack& stack) {
25+
raw(&stack);
26+
}) {}
27+
28+
template <typename F,
29+
std::enable_if_t<accepts<F, Stack&>::value &&
30+
!std::is_same<std::decay_t<F>, Operation>::value, int> = 0>
31+
Operation(F&& op): op_(std::forward<F>(op)) {}
32+
33+
Operation(std::nullptr_t) noexcept {}
34+
35+
explicit operator bool() const noexcept {
36+
return op_ ? true : false;
37+
}
38+
39+
void operator()(Stack& stack) {
40+
op_(stack);
41+
}
42+
43+
template <typename T>
44+
T* target() noexcept {
45+
return op_.target<T>();
46+
}
47+
48+
private:
49+
std::function<void(Stack&)> op_;
50+
};
1351

1452
// An operation with N inputs and M outputs pops the last N inputs off
1553
// the stack and pushes its M inputs onto the stack

test/cpp/jit/test_alias_analysis.cpp

+7-7
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
#include <gtest/gtest.h>
22

33
#include <torch/csrc/autograd/generated/variable_factories.h>
4+
#include <torch/csrc/jit/frontend/ir_emitter.h>
5+
#include <torch/csrc/jit/ir/alias_analysis.h>
46
#include <torch/csrc/jit/ir/irparser.h>
5-
#include "torch/csrc/jit/frontend/ir_emitter.h"
6-
#include "torch/csrc/jit/ir/alias_analysis.h"
7-
#include "torch/csrc/jit/runtime/custom_operator.h"
8-
#include "torch/csrc/utils/memory.h"
7+
#include <torch/csrc/jit/runtime/custom_operator.h>
8+
#include <torch/csrc/utils/memory.h>
99

1010
namespace torch {
1111
namespace jit {
@@ -484,7 +484,7 @@ TEST(AliasAnalysisTest, SafeToChangeAliasingRelationship) {
484484
TEST(WriteTrackingTest, Basic) {
485485
RegisterOperators reg({Operator(
486486
"prim::creates_alias(Tensor(a) x) -> Tensor(a)",
487-
[](Stack* s) {},
487+
[](Stack&) {},
488488
aliasAnalysisFromSchema())});
489489
const auto creates_alias = Symbol::fromQualString("prim::creates_alias");
490490
auto graph = std::make_shared<Graph>();
@@ -949,11 +949,11 @@ TEST(WildcardsTest, Basic) {
949949
RegisterOperators reg(
950950
{Operator(
951951
"prim::returns_wildcard(Tensor a) -> Tensor(*)",
952-
[](Stack* stack) {},
952+
[](Stack&) {},
953953
aliasAnalysisFromSchema()),
954954
Operator(
955955
"prim::writes(Tensor(z!) a) -> Tensor(a)",
956-
[](Stack* stack) {},
956+
[](Stack&) {},
957957
aliasAnalysisFromSchema())});
958958
const auto returns_wildcard =
959959
Symbol::fromQualString("prim::returns_wildcard");

test/cpp/jit/test_custom_operators.cpp

+7-7
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ TEST(CustomOperatorTest, InferredSchema) {
3131

3232
Stack stack;
3333
push(stack, 2.0f, at::ones(5));
34-
op->getOperation()(&stack);
34+
op->getOperation()(stack);
3535
at::Tensor output;
3636
pop(stack, output);
3737

@@ -61,7 +61,7 @@ TEST(CustomOperatorTest, ExplicitSchema) {
6161

6262
Stack stack;
6363
push(stack, 2.0f, at::ones(5));
64-
op->getOperation()(&stack);
64+
op->getOperation()(stack);
6565
at::Tensor output;
6666
pop(stack, output);
6767

@@ -109,7 +109,7 @@ TEST(CustomOperatorTest, ListParameters) {
109109
c10::List<c10::complex<double>>(
110110
{c10::complex<double>(2.4, -5.5), c10::complex<double>(-1.3, 2)}));
111111
push(stack, c10::List<at::Tensor>({at::ones(5)}));
112-
op->getOperation()(&stack);
112+
op->getOperation()(stack);
113113
c10::List<double> output;
114114
pop(stack, output);
115115

@@ -140,7 +140,7 @@ TEST(CustomOperatorTest, ListParameters2) {
140140

141141
Stack stack;
142142
push(stack, c10::List<at::Tensor>({at::ones(5)}));
143-
op->getOperation()(&stack);
143+
op->getOperation()(stack);
144144
c10::List<at::Tensor> output;
145145
pop(stack, output);
146146

@@ -204,7 +204,7 @@ TEST(TestCustomOperator, OperatorGeneratorUndeclared) {
204204
torch::jit::RegisterOperators reg({OperatorGenerator(
205205
TORCH_SELECTIVE_NAME_IN_SCHEMA(
206206
op_list, "foofoo::not_exist(float a, Tensor b) -> Tensor"),
207-
[](Stack* stack) {
207+
[](Stack& stack) {
208208
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
209209
double a;
210210
at::Tensor b;
@@ -223,7 +223,7 @@ TEST(TestCustomOperator, OperatorGeneratorBasic) {
223223
torch::jit::RegisterOperators reg({OperatorGenerator(
224224
TORCH_SELECTIVE_NAME_IN_SCHEMA(
225225
op_list, "foofoo::bar.template(float a, Tensor b) -> Tensor"),
226-
[](Stack* stack) {
226+
[](Stack& stack) {
227227
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
228228
double a;
229229
at::Tensor b;
@@ -249,7 +249,7 @@ TEST(TestCustomOperator, OperatorGeneratorBasic) {
249249

250250
Stack stack;
251251
push(stack, 2.0f, at::ones(5));
252-
op->getOperation()(&stack);
252+
op->getOperation()(stack);
253253
at::Tensor output;
254254
pop(stack, output);
255255

test/cpp/jit/test_misc.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -1493,11 +1493,11 @@ TEST(NoneSchemaMatchTest, Basic) {
14931493
RegisterOperators reg({
14941494
Operator(
14951495
"prim::test_none() -> int?",
1496-
[](Stack* stack) { push(stack, IValue()); },
1496+
[](Stack& stack) { push(stack, IValue()); },
14971497
aliasAnalysisFromSchema()),
14981498
Operator(
14991499
"prim::is_none(int? a) -> bool",
1500-
[](Stack* stack) {
1500+
[](Stack& stack) {
15011501
IValue a = pop(stack);
15021502
if (a.isNone()) {
15031503
push(stack, true);

test/cpp/jit/test_schema_matching.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ TEST(SchemaMatchingTest, VarType) {
1515
RegisterOperators reg({
1616
Operator(
1717
"aten::test_vartype(t[] a, t b) -> (t)",
18-
[](Stack* stack) {
18+
[](Stack& stack) {
1919
c10::List<double> list;
2020
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
2121
double a;
@@ -54,7 +54,7 @@ TEST(SchemaMatchingTest, VarType2) {
5454
RegisterOperators reg({
5555
Operator(
5656
"aten::test_vartype2(t a, t[] b) -> (t[])",
57-
[](Stack* stack) {
57+
[](Stack& stack) {
5858
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
5959
double a;
6060
c10::List<double> list;

test/cpp/jit/test_utils.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ RegisterOperators reg({
273273
// because it always produces empty Tensors.
274274
Operator(
275275
"prim::MakeTestTensor() -> Tensor",
276-
[](Stack* stack) { push(stack, at::Tensor()); },
276+
[](Stack& stack) { push(stack, at::Tensor()); },
277277
aliasAnalysisFromSchema()),
278278
});
279279
} // namespace

test/custom_operator/test_custom_ops.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ Result get_operator_from_registry_and_execute(const char* op_name, Args&&... arg
3030

3131
torch::jit::Stack stack;
3232
torch::jit::push(stack, std::forward<Args>(args)...);
33-
op->getOperation()(&stack);
33+
op->getOperation()(stack);
3434

3535
TORCH_INTERNAL_ASSERT(1 == stack.size());
3636
return torch::jit::pop(stack).to<Result>();

torch/csrc/autograd/record_function_ops.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ c10::AliasAnalysisKind aliasAnalysisFromSchema() {
7979
jit::RegisterOperators reg_fut_ops({
8080
jit::Operator(
8181
"profiler::_call_end_callbacks_on_jit_fut(Tensor x, Future(t) y) -> Future(t)",
82-
[](jit::Stack* stack) {
82+
[](jit::Stack& stack) {
8383
// Pop inputs, which should be a future and a tensor
8484
auto fut = jit::pop(stack).toFuture();
8585
auto tensor = jit::pop(stack).toTensor();

torch/csrc/distributed/rpc/request_callback_no_python.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,7 @@ c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::runJitOperator(
582582
std::vector<c10::Stream> streams) const {
583583
c10::MultiStreamGuard guard(streams);
584584
try {
585-
op.getOperation()(&stack);
585+
op.getOperation()(stack);
586586
} catch (const std::exception&) {
587587
return asFuture(std::current_exception());
588588
}

torch/csrc/jit/codegen/cuda/interface.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,8 @@ RegisterOperators reg_fusion({
182182
Operator(
183183
prim::CudaFusionGroup,
184184
[](const Node* node) -> Operation {
185-
return [node](Stack* stack) {
186-
fuser::cuda::runFusionGroup(node, *stack);
185+
return [node](Stack& stack) {
186+
fuser::cuda::runFusionGroup(node, stack);
187187
};
188188
},
189189
aliasAnalysisSpecialCase()),
@@ -196,7 +196,7 @@ RegisterOperators reg_guard({
196196
// if we would ever return refined tensor, which would change aliasing
197197
// analysis, we should update aliasdb pass.
198198
[](const Node* node) -> Operation {
199-
return [node](Stack* stack) {
199+
return [node](Stack& stack) {
200200
// TODO: check latency here!!!!
201201
std::vector<TypePtr> types = node->tys(attr::types);
202202
const auto num_inputs = types.size();

torch/csrc/jit/codegen/fuser/fallback.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ RegisterOperators reg_fused_operators({Operator(
2626
[](const Node* node) -> Operation {
2727
int64_t dim = node->i(attr::dim);
2828
int64_t num_inputs = node->inputs().size();
29-
return [dim, num_inputs](Stack* stack) {
29+
return [dim, num_inputs](Stack& stack) {
3030
auto result = at::cat(
3131
fmap(
3232
last(stack, num_inputs),

torch/csrc/jit/mobile/function.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ bool Function::append_operator(
6767
auto jit_op = findOperatorFor(opname);
6868
std::vector<c10::Argument> args;
6969
if (jit_op) {
70-
fn = [jit_op](Stack& stack) { jit_op->getOperation()(&stack); };
70+
fn = [jit_op](Stack& stack) { jit_op->getOperation()(stack); };
7171
args = jit_op->schema().arguments();
7272
} else {
7373
auto op = c10::Dispatcher::singleton().findSchema(opname_c10);

torch/csrc/jit/passes/batch_mm.cpp

+8-8
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,11 @@ bool shape_is_fast_for_reduce(const at::Tensor& lhs, const at::Tensor& rhs) {
109109

110110
RegisterOperators mm_tree_reduction_reg({Operator(
111111
"prim::MMTreeReduce(...) -> Tensor",
112-
[](Stack* stack) {
112+
[](Stack& stack) {
113113
auto num_inputs = pop(stack).toInt();
114114
std::vector<at::Tensor> inputs;
115115
inputs.reserve(num_inputs);
116-
for (auto it = stack->end() - num_inputs; it != stack->end(); ++it) {
116+
for (auto it = stack.end() - num_inputs; it != stack.end(); ++it) {
117117
inputs.push_back(std::move(*it).toTensor());
118118
}
119119
drop(stack, num_inputs);
@@ -320,11 +320,11 @@ RegisterOperators mm_batch_side_reg({Operator(
320320
[](const Node* node) -> Operation {
321321
size_t num_other_side_inputs = node->inputs().size() - 1;
322322
Side single_side = static_cast<Side>(node->i(Symbol::attr("side")));
323-
return [num_other_side_inputs, single_side](Stack* stack) {
323+
return [num_other_side_inputs, single_side](Stack& stack) {
324324
at::Tensor side_input;
325325
std::vector<at::Tensor> other_side_inputs;
326326
other_side_inputs.reserve(num_other_side_inputs);
327-
for (auto it = stack->end() - num_other_side_inputs; it != stack->end();
327+
for (auto it = stack.end() - num_other_side_inputs; it != stack.end();
328328
++it) {
329329
other_side_inputs.push_back(std::move(*it).toTensor());
330330
}
@@ -343,18 +343,18 @@ RegisterOperators mm_batch_side_reg({Operator(
343343
mm_out,
344344
num_other_side_inputs,
345345
/*dim=*/single_side == Side::LHS ? 1 : 0);
346-
stack->insert(
347-
stack->end(),
346+
stack.insert(
347+
stack.end(),
348348
std::make_move_iterator(outputs.begin()),
349349
std::make_move_iterator(outputs.end()));
350350
} else {
351351
if (single_side == Side::LHS) {
352352
for (at::Tensor& other : other_side_inputs) {
353-
stack->emplace_back(side_input.mm(other));
353+
stack.emplace_back(side_input.mm(other));
354354
}
355355
} else {
356356
for (at::Tensor& other : other_side_inputs) {
357-
stack->emplace_back(other.mm(side_input));
357+
stack.emplace_back(other.mm(side_input));
358358
}
359359
}
360360
}

torch/csrc/jit/passes/constant_propagation.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ c10::optional<std::vector<IValue>> runNodeIfInputsAreConstant(
7878

7979
try {
8080
auto op = n->getOperation();
81-
op(&stack);
81+
op(stack);
8282
} catch (...) {
8383
return c10::nullopt;
8484
}

torch/csrc/jit/passes/decompose_ops.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ bool isDecomposableNorm(Node* normalize_op) {
5959
RegisterOperators reg_ops(
6060
{Operator(
6161
"aten::_ncf_unsqueeze(Tensor(a) self, int ndim) -> Tensor(a)",
62-
[](Stack* stack) {
62+
[](Stack& stack) {
6363
const int64_t ndim = pop(stack).toInt();
6464
auto self = pop(stack).toTensor();
6565
c10::SmallVector<int64_t, 8> sizes(ndim, 1);
@@ -70,7 +70,7 @@ RegisterOperators reg_ops(
7070
aliasAnalysisFromSchema()),
7171
Operator(
7272
"aten::_ncf_view(Tensor(a) self, int[] input_shape, int normalized_ndim) -> Tensor(a)",
73-
[](Stack* stack) {
73+
[](Stack& stack) {
7474
const int64_t normalized_ndim = pop(stack).toInt();
7575
auto input_shape = pop(stack).toIntList();
7676
auto self = pop(stack).toTensor();

0 commit comments

Comments
 (0)