Skip to content

Commit 2c3ca3b

Browse files
committedAug 12, 2022
[MLIR] Add utility function to create values for all dimensions of a tensor value
This is a variant of the already provided `createDynamicDimValues` helper. Differential Revision: https://reviews.llvm.org/D131798
1 parent 6826682 commit 2c3ca3b

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed
 

‎mlir/include/mlir/Dialect/Tensor/Utils/Utils.h

+6-1
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,16 @@ PadOp createPadScalarOp(Type type, Value source, Value pad,
2828
ArrayRef<OpFoldResult> low, ArrayRef<OpFoldResult> high,
2929
bool nofold, Location loc, OpBuilder &builder);
3030

31-
// Creates dim ops for each dynamic dimension of the raked tensor argument and
31+
// Creates dim ops for each dynamic dimension of the ranked tensor argument and
3232
// returns these as values.
3333
SmallVector<Value> createDynamicDimValues(OpBuilder &b, Location loc,
3434
Value rankedTensor);
3535

36+
// Creates dim ops or constant ops for each dimension of the ranked tensor
37+
// argument and returns these as values.
38+
SmallVector<Value> createDimValues(OpBuilder &b, Location loc,
39+
Value rankedTensor);
40+
3641
} // namespace tensor
3742
} // namespace mlir
3843

‎mlir/lib/Dialect/Tensor/Utils/Utils.cpp

+11
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,14 @@ SmallVector<Value> mlir::tensor::createDynamicDimValues(OpBuilder &b,
6868
}
6969
return dynamicDims;
7070
}
71+
72+
SmallVector<Value> mlir::tensor::createDimValues(OpBuilder &b, Location loc,
73+
Value rankedTensor) {
74+
auto tensorTy = rankedTensor.getType().cast<RankedTensorType>();
75+
SmallVector<Value> dims;
76+
for (const auto &en : llvm::enumerate(tensorTy.getShape())) {
77+
dims.push_back(
78+
b.createOrFold<tensor::DimOp>(loc, rankedTensor, en.index()));
79+
}
80+
return dims;
81+
}

0 commit comments

Comments
 (0)
Please sign in to comment.