Skip to content

Commit e8a51bf

Browse files
CoTinkerpuja2196
authored andcommitted
[mlir][sparse] Replace getSparseTensorType with tryGetSparseTensorType (#109435)
This PR fixes a bug in `SparseTensorDimOpRewriter` when `tensor.dim` has an unranked tensor type. To prevent crashes, we now use `tryGetSparseTensorType` instead of `getSparseTensorType`. Fixes #107807.
1 parent 77b728d commit e8a51bf

File tree

2 files changed

+38
-20
lines changed

2 files changed

+38
-20
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

+22-20
Original file line numberDiff line numberDiff line change
@@ -881,25 +881,27 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
881881
PatternRewriter &rewriter) const override {
882882
Location loc = op.getLoc();
883883
Value srcTensor = op.getSource();
884-
const auto srcTp = getSparseTensorType(srcTensor);
885-
const auto dstTp = getSparseTensorType(op.getResult());
884+
const auto srcTp = tryGetSparseTensorType(srcTensor);
885+
const auto dstTp = tryGetSparseTensorType(op.getResult());
886+
if (!srcTp || !dstTp)
887+
return failure();
886888

887-
if (!srcTp.hasEncoding() || !dstTp.hasEncoding() ||
888-
!dstTp.hasStaticDimShape())
889+
if (!srcTp->hasEncoding() || !dstTp->hasEncoding() ||
890+
!dstTp->hasStaticDimShape())
889891
return failure();
890892

891893
SmallVector<Value> srcSizes;
892-
sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
894+
sizesForTensor(rewriter, srcSizes, loc, *srcTp, srcTensor);
893895
SmallVector<Value> dstSizes;
894-
for (Dimension d : dstTp.getDimShape())
896+
for (Dimension d : dstTp->getDimShape())
895897
dstSizes.push_back(constantIndex(rewriter, loc, d));
896898

897899
Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor);
898900
// Only need an unordered COO buffer if input and output are not sorted
899901
// in the same way.
900902
Type bufferTp = getBufferType(
901-
dstTp.withoutDimToLvl(),
902-
!srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity());
903+
dstTp->withoutDimToLvl(),
904+
!srcTp->isAllOrdered() || !srcTp->isIdentity() || !dstTp->isIdentity());
903905
SmallVector<Value> dynSizes;
904906
Value buffer = rewriter
905907
.create<AllocTensorOp>(loc, bufferTp, dynSizes, Value(),
@@ -917,12 +919,12 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
917919
// followed by an optional
918920
// %t = sparse_tensor.cast %tmp
919921
// depending on whether the input/output are sorted in the same way.
920-
const auto encSrc = srcTp.getEncoding();
922+
const auto encSrc = srcTp->getEncoding();
921923
ForeachOp foreachOp = rewriter.create<ForeachOp>(
922924
loc, srcTensor, buffer,
923925
[&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
924926
ValueRange reduc) {
925-
const Dimension srcRank = srcTp.getDimRank();
927+
const Dimension srcRank = srcTp->getDimRank();
926928
SmallVector<Value> srcDcvs;
927929
srcDcvs.reserve(srcRank);
928930
for (Dimension d = 0; d < srcRank; d++) {
@@ -945,7 +947,7 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
945947
collapsedSizes, collapsedDcvs);
946948

947949
ReassociationIndices expandIdx;
948-
for (Dimension i = 0; i < dstTp.getDimRank(); i++)
950+
for (Dimension i = 0; i < dstTp->getDimRank(); i++)
949951
expandIdx.push_back(i);
950952
SmallVector<ReassociationIndices, 1> expandReass = {expandIdx};
951953
SmallVector<Value> dstDcvs;
@@ -958,8 +960,8 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
958960
});
959961

960962
Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
961-
if (bufferTp != dstTp) {
962-
auto dstRTT = dstTp.getRankedTensorType();
963+
if (bufferTp != *dstTp) {
964+
auto dstRTT = dstTp->getRankedTensorType();
963965
Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult();
964966
rewriter.create<DeallocTensorOp>(loc, t);
965967
t = converted;
@@ -1139,13 +1141,13 @@ struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
11391141
LogicalResult matchAndRewrite(tensor::DimOp op,
11401142
PatternRewriter &rewriter) const override {
11411143
std::optional<int64_t> dim = op.getConstantIndex();
1142-
auto stt = getSparseTensorType(op.getSource());
1143-
if (!dim || !stt.hasEncoding())
1144+
auto stt = tryGetSparseTensorType(op.getSource());
1145+
if (!dim || !stt || !stt->hasEncoding())
11441146
return failure();
11451147

1146-
if (stt.isPermutation()) {
1148+
if (stt->isPermutation()) {
11471149
rewriter.replaceOpWithNewOp<LvlOp>(op, op.getSource(),
1148-
toLvl(stt.getEncoding(), *dim));
1150+
toLvl(stt->getEncoding(), *dim));
11491151
return success();
11501152
}
11511153

@@ -1157,16 +1159,16 @@ struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
11571159
// computed simply by lvl_size * block_size.
11581160
Location loc = op.getLoc();
11591161
SmallVector<Value> maxLvlCrds;
1160-
for (Level l = 0; l < stt.getLvlRank(); l++) {
1162+
for (Level l = 0; l < stt->getLvlRank(); l++) {
11611163
Value lvlSz = rewriter.create<LvlOp>(loc, op.getSource(), l);
11621164
Value maxLvlCrd = rewriter.create<arith::SubIOp>(
11631165
loc, lvlSz, constantOne(rewriter, loc, rewriter.getIndexType()));
11641166
maxLvlCrds.push_back(maxLvlCrd);
11651167
}
11661168

1167-
AffineExpr lvl2DimExp = stt.getLvlToDim().getResult(*dim);
1169+
AffineExpr lvl2DimExp = stt->getLvlToDim().getResult(*dim);
11681170
Value maxDimCrd = rewriter.create<affine::AffineApplyOp>(
1169-
op.getLoc(), AffineMap::get(stt.getLvlRank(), 0, lvl2DimExp),
1171+
op.getLoc(), AffineMap::get(stt->getLvlRank(), 0, lvl2DimExp),
11701172
maxLvlCrds);
11711173

11721174
Value dimSz = rewriter.create<arith::AddIOp>(

mlir/test/Dialect/SparseTensor/codegen.mlir

+16
Original file line numberDiff line numberDiff line change
@@ -826,3 +826,19 @@ func.func @sparse_new_coo_permute_no(%arg0: !llvm.ptr) -> tensor<?x?xf32, #CooPN
826826
%0 = sparse_tensor.new %arg0 : !llvm.ptr to tensor<?x?xf32, #CooPNo>
827827
return %0 : tensor<?x?xf32, #CooPNo>
828828
}
829+
830+
// CHECK-LABEL: func.func @test_tensor_dim_unranked
831+
// CHECK: tensor.dim
832+
func.func @test_tensor_dim_unranked(%arg0: tensor<*xf32>) -> index {
833+
%c = arith.constant 0 : index
834+
%0 = tensor.dim %arg0, %c : tensor<*xf32>
835+
return %0 : index
836+
}
837+
838+
// CHECK-LABEL: func.func @test_tensor_reshape_unranked
839+
// CHECK: tensor.reshape
840+
func.func @test_tensor_reshape_unranked(%src: tensor<*xf32>, %shape: tensor<1xi32>) -> tensor<?xf32> {
841+
%dst = tensor.reshape %src(%shape)
842+
: (tensor<*xf32>, tensor<1xi32>) -> tensor<?xf32>
843+
return %dst : tensor<?xf32>
844+
}

0 commit comments

Comments
 (0)