@@ -881,25 +881,27 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
881
881
PatternRewriter &rewriter) const override {
882
882
Location loc = op.getLoc ();
883
883
Value srcTensor = op.getSource ();
884
- const auto srcTp = getSparseTensorType (srcTensor);
885
- const auto dstTp = getSparseTensorType (op.getResult ());
884
+ const auto srcTp = tryGetSparseTensorType (srcTensor);
885
+ const auto dstTp = tryGetSparseTensorType (op.getResult ());
886
+ if (!srcTp || !dstTp)
887
+ return failure ();
886
888
887
- if (!srcTp. hasEncoding () || !dstTp. hasEncoding () ||
888
- !dstTp. hasStaticDimShape ())
889
+ if (!srcTp-> hasEncoding () || !dstTp-> hasEncoding () ||
890
+ !dstTp-> hasStaticDimShape ())
889
891
return failure ();
890
892
891
893
SmallVector<Value> srcSizes;
892
- sizesForTensor (rewriter, srcSizes, loc, srcTp, srcTensor);
894
+ sizesForTensor (rewriter, srcSizes, loc, * srcTp, srcTensor);
893
895
SmallVector<Value> dstSizes;
894
- for (Dimension d : dstTp. getDimShape ())
896
+ for (Dimension d : dstTp-> getDimShape ())
895
897
dstSizes.push_back (constantIndex (rewriter, loc, d));
896
898
897
899
Value nnz = rewriter.create <NumberOfEntriesOp>(loc, srcTensor);
898
900
// Only need an unordered COO buffer if input and output are not sorted
899
901
// in the same way.
900
902
Type bufferTp = getBufferType (
901
- dstTp. withoutDimToLvl (),
902
- !srcTp. isAllOrdered () || !srcTp. isIdentity () || !dstTp. isIdentity ());
903
+ dstTp-> withoutDimToLvl (),
904
+ !srcTp-> isAllOrdered () || !srcTp-> isIdentity () || !dstTp-> isIdentity ());
903
905
SmallVector<Value> dynSizes;
904
906
Value buffer = rewriter
905
907
.create <AllocTensorOp>(loc, bufferTp, dynSizes, Value (),
@@ -917,12 +919,12 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
917
919
// followed by an optional
918
920
// %t = sparse_tensor.cast %tmp
919
921
// depending on whether the input/output are sorted in the same way.
920
- const auto encSrc = srcTp. getEncoding ();
922
+ const auto encSrc = srcTp-> getEncoding ();
921
923
ForeachOp foreachOp = rewriter.create <ForeachOp>(
922
924
loc, srcTensor, buffer,
923
925
[&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
924
926
ValueRange reduc) {
925
- const Dimension srcRank = srcTp. getDimRank ();
927
+ const Dimension srcRank = srcTp-> getDimRank ();
926
928
SmallVector<Value> srcDcvs;
927
929
srcDcvs.reserve (srcRank);
928
930
for (Dimension d = 0 ; d < srcRank; d++) {
@@ -945,7 +947,7 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
945
947
collapsedSizes, collapsedDcvs);
946
948
947
949
ReassociationIndices expandIdx;
948
- for (Dimension i = 0 ; i < dstTp. getDimRank (); i++)
950
+ for (Dimension i = 0 ; i < dstTp-> getDimRank (); i++)
949
951
expandIdx.push_back (i);
950
952
SmallVector<ReassociationIndices, 1 > expandReass = {expandIdx};
951
953
SmallVector<Value> dstDcvs;
@@ -958,8 +960,8 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
958
960
});
959
961
960
962
Value t = rewriter.create <LoadOp>(loc, foreachOp.getResult (0 ), true );
961
- if (bufferTp != dstTp) {
962
- auto dstRTT = dstTp. getRankedTensorType ();
963
+ if (bufferTp != * dstTp) {
964
+ auto dstRTT = dstTp-> getRankedTensorType ();
963
965
Value converted = rewriter.create <ConvertOp>(loc, dstRTT, t).getResult ();
964
966
rewriter.create <DeallocTensorOp>(loc, t);
965
967
t = converted;
@@ -1139,13 +1141,13 @@ struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
1139
1141
LogicalResult matchAndRewrite (tensor::DimOp op,
1140
1142
PatternRewriter &rewriter) const override {
1141
1143
std::optional<int64_t > dim = op.getConstantIndex ();
1142
- auto stt = getSparseTensorType (op.getSource ());
1143
- if (!dim || !stt. hasEncoding ())
1144
+ auto stt = tryGetSparseTensorType (op.getSource ());
1145
+ if (!dim || !stt || !stt-> hasEncoding ())
1144
1146
return failure ();
1145
1147
1146
- if (stt. isPermutation ()) {
1148
+ if (stt-> isPermutation ()) {
1147
1149
rewriter.replaceOpWithNewOp <LvlOp>(op, op.getSource (),
1148
- toLvl (stt. getEncoding (), *dim));
1150
+ toLvl (stt-> getEncoding (), *dim));
1149
1151
return success ();
1150
1152
}
1151
1153
@@ -1157,16 +1159,16 @@ struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
1157
1159
// computed simply by lvl_size * block_size.
1158
1160
Location loc = op.getLoc ();
1159
1161
SmallVector<Value> maxLvlCrds;
1160
- for (Level l = 0 ; l < stt. getLvlRank (); l++) {
1162
+ for (Level l = 0 ; l < stt-> getLvlRank (); l++) {
1161
1163
Value lvlSz = rewriter.create <LvlOp>(loc, op.getSource (), l);
1162
1164
Value maxLvlCrd = rewriter.create <arith::SubIOp>(
1163
1165
loc, lvlSz, constantOne (rewriter, loc, rewriter.getIndexType ()));
1164
1166
maxLvlCrds.push_back (maxLvlCrd);
1165
1167
}
1166
1168
1167
- AffineExpr lvl2DimExp = stt. getLvlToDim ().getResult (*dim);
1169
+ AffineExpr lvl2DimExp = stt-> getLvlToDim ().getResult (*dim);
1168
1170
Value maxDimCrd = rewriter.create <affine::AffineApplyOp>(
1169
- op.getLoc (), AffineMap::get (stt. getLvlRank (), 0 , lvl2DimExp),
1171
+ op.getLoc (), AffineMap::get (stt-> getLvlRank (), 0 , lvl2DimExp),
1170
1172
maxLvlCrds);
1171
1173
1172
1174
Value dimSz = rewriter.create <arith::AddIOp>(
0 commit comments