Skip to content

Commit 83f3b1c

Browse files
[mlir][sparse] Add verification for explicit/implicit value (llvm#90111)
1. Verify that the type of explicit/implicit values should be the same as the tensor element type. 2. Verify that implicit value could only be zero. 3. Verify that explicit/implicit values should be numeric. 4. Fix the type change issue caused by SparseTensorType(enc).
1 parent d4cf20c commit 83f3b1c

File tree

5 files changed

+169
-47
lines changed

5 files changed

+169
-47
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h

+13
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,19 @@ using Level = uint64_t;
4141
/// including the value `ShapedType::kDynamic` (for shapes).
4242
using Size = int64_t;
4343

44+
/// A simple structure that encodes a range of levels in the sparse tensors
45+
/// that forms a COO segment.
46+
struct COOSegment {
47+
std::pair<Level, Level> lvlRange; // [low, high)
48+
bool isSoA;
49+
50+
bool isAoS() const { return !isSoA; }
51+
bool isSegmentStart(Level l) const { return l == lvlRange.first; }
52+
bool inSegment(Level l) const {
53+
return l >= lvlRange.first && l < lvlRange.second;
54+
}
55+
};
56+
4457
} // namespace sparse_tensor
4558
} // namespace mlir
4659

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td

+15
Original file line numberDiff line numberDiff line change
@@ -502,9 +502,24 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
502502
//
503503
// Helper function to translate between level/dimension space.
504504
//
505+
505506
SmallVector<int64_t> translateShape(::mlir::ArrayRef<int64_t> srcShape, ::mlir::sparse_tensor::CrdTransDirectionKind) const;
506507
ValueRange translateCrds(::mlir::OpBuilder &builder, ::mlir::Location loc, ::mlir::ValueRange crds, ::mlir::sparse_tensor::CrdTransDirectionKind) const;
507508

509+
//
510+
// COO methods.
511+
//
512+
513+
/// Returns the starting level of this sparse tensor type for a
514+
/// trailing COO region that spans **at least** two levels. If
515+
/// no such COO region is found, then returns the level-rank.
516+
///
517+
/// DEPRECATED: use getCOOSegment instead;
518+
Level getAoSCOOStart() const;
519+
520+
/// Returns a list of COO segments in the sparse tensor types.
521+
SmallVector<COOSegment> getCOOSegments() const;
522+
508523
//
509524
// Printing methods.
510525
//

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h

+4-21
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,6 @@
1818
namespace mlir {
1919
namespace sparse_tensor {
2020

21-
/// A simple structure that encodes a range of levels in the sparse tensors that
22-
/// forms a COO segment.
23-
struct COOSegment {
24-
std::pair<Level, Level> lvlRange; // [low, high)
25-
bool isSoA;
26-
27-
bool isAoS() const { return !isSoA; }
28-
bool isSegmentStart(Level l) const { return l == lvlRange.first; }
29-
bool inSegment(Level l) const {
30-
return l >= lvlRange.first && l < lvlRange.second;
31-
}
32-
};
33-
3421
//===----------------------------------------------------------------------===//
3522
/// A wrapper around `RankedTensorType`, which has three goals:
3623
///
@@ -73,12 +60,6 @@ class SparseTensorType {
7360
: SparseTensorType(
7461
RankedTensorType::get(stp.getShape(), stp.getElementType(), enc)) {}
7562

76-
// TODO: remove?
77-
SparseTensorType(SparseTensorEncodingAttr enc)
78-
: SparseTensorType(RankedTensorType::get(
79-
SmallVector<Size>(enc.getDimRank(), ShapedType::kDynamic),
80-
Float32Type::get(enc.getContext()), enc)) {}
81-
8263
SparseTensorType &operator=(const SparseTensorType &) = delete;
8364
SparseTensorType(const SparseTensorType &) = default;
8465

@@ -369,13 +350,15 @@ class SparseTensorType {
369350
/// no such COO region is found, then returns the level-rank.
370351
///
371352
/// DEPRECATED: use getCOOSegment instead;
372-
Level getAoSCOOStart() const;
353+
Level getAoSCOOStart() const { return getEncoding().getAoSCOOStart(); };
373354

374355
/// Returns [un]ordered COO type for this sparse tensor type.
375356
RankedTensorType getCOOType(bool ordered) const;
376357

377358
/// Returns a list of COO segments in the sparse tensor types.
378-
SmallVector<COOSegment> getCOOSegments() const;
359+
SmallVector<COOSegment> getCOOSegments() const {
360+
return getEncoding().getCOOSegments();
361+
}
379362

380363
private:
381364
// These two must be const, to ensure coherence of the memoized fields.

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

+52-26
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ void StorageLayout::foreachField(
104104
callback) const {
105105
const auto lvlTypes = enc.getLvlTypes();
106106
const Level lvlRank = enc.getLvlRank();
107-
SmallVector<COOSegment> cooSegs = SparseTensorType(enc).getCOOSegments();
107+
SmallVector<COOSegment> cooSegs = enc.getCOOSegments();
108108
FieldIndex fieldIdx = kDataFieldStartingIdx;
109109

110110
ArrayRef cooSegsRef = cooSegs;
@@ -211,7 +211,7 @@ StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind,
211211
unsigned stride = 1;
212212
if (kind == SparseTensorFieldKind::CrdMemRef) {
213213
assert(lvl.has_value());
214-
const Level cooStart = SparseTensorType(enc).getAoSCOOStart();
214+
const Level cooStart = enc.getAoSCOOStart();
215215
const Level lvlRank = enc.getLvlRank();
216216
if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
217217
lvl = cooStart;
@@ -912,46 +912,53 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
912912
return emitError()
913913
<< "dimension-rank mismatch between encoding and tensor shape: "
914914
<< getDimRank() << " != " << dimRank;
915+
if (auto expVal = getExplicitVal()) {
916+
Type attrType = llvm::dyn_cast<TypedAttr>(expVal).getType();
917+
if (attrType != elementType) {
918+
return emitError() << "explicit value type mismatch between encoding and "
919+
<< "tensor element type: " << attrType
920+
<< " != " << elementType;
921+
}
922+
}
923+
if (auto impVal = getImplicitVal()) {
924+
Type attrType = llvm::dyn_cast<TypedAttr>(impVal).getType();
925+
if (attrType != elementType) {
926+
return emitError() << "implicit value type mismatch between encoding and "
927+
<< "tensor element type: " << attrType
928+
<< " != " << elementType;
929+
}
930+
// Currently, we only support zero as the implicit value.
931+
auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);
932+
auto impIntVal = llvm::dyn_cast<IntegerAttr>(impVal);
933+
auto impComplexVal = llvm::dyn_cast<complex::NumberAttr>(impVal);
934+
if ((impFVal && impFVal.getValue().isNonZero()) ||
935+
(impIntVal && !impIntVal.getValue().isZero()) ||
936+
(impComplexVal && (impComplexVal.getImag().isNonZero() ||
937+
impComplexVal.getReal().isNonZero()))) {
938+
return emitError() << "implicit value must be zero";
939+
}
940+
}
915941
return success();
916942
}
917943

918-
//===----------------------------------------------------------------------===//
919-
// SparseTensorType Methods.
920-
//===----------------------------------------------------------------------===//
921-
922-
bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl,
923-
bool isUnique) const {
924-
if (!hasEncoding())
925-
return false;
926-
if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl))
927-
return false;
928-
for (Level l = startLvl + 1; l < lvlRank; ++l)
929-
if (!isSingletonLvl(l))
930-
return false;
931-
// If isUnique is true, then make sure that the last level is unique,
932-
// that is, when lvlRank == 1, the only compressed level is unique,
933-
// and when lvlRank > 1, the last singleton is unique.
934-
return !isUnique || isUniqueLvl(lvlRank - 1);
935-
}
936-
937-
Level mlir::sparse_tensor::SparseTensorType::getAoSCOOStart() const {
944+
Level mlir::sparse_tensor::SparseTensorEncodingAttr::getAoSCOOStart() const {
938945
SmallVector<COOSegment> coo = getCOOSegments();
939946
assert(coo.size() == 1 || coo.empty());
940947
if (!coo.empty() && coo.front().isAoS()) {
941948
return coo.front().lvlRange.first;
942949
}
943-
return lvlRank;
950+
return getLvlRank();
944951
}
945952

946953
SmallVector<COOSegment>
947-
mlir::sparse_tensor::SparseTensorType::getCOOSegments() const {
954+
mlir::sparse_tensor::SparseTensorEncodingAttr::getCOOSegments() const {
948955
SmallVector<COOSegment> ret;
949-
if (!hasEncoding() || lvlRank <= 1)
956+
if (getLvlRank() <= 1)
950957
return ret;
951958

952959
ArrayRef<LevelType> lts = getLvlTypes();
953960
Level l = 0;
954-
while (l < lvlRank) {
961+
while (l < getLvlRank()) {
955962
auto lt = lts[l];
956963
if (lt.isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>()) {
957964
auto cur = lts.begin() + l;
@@ -975,6 +982,25 @@ mlir::sparse_tensor::SparseTensorType::getCOOSegments() const {
975982
return ret;
976983
}
977984

985+
//===----------------------------------------------------------------------===//
986+
// SparseTensorType Methods.
987+
//===----------------------------------------------------------------------===//
988+
989+
bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl,
990+
bool isUnique) const {
991+
if (!hasEncoding())
992+
return false;
993+
if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl))
994+
return false;
995+
for (Level l = startLvl + 1; l < lvlRank; ++l)
996+
if (!isSingletonLvl(l))
997+
return false;
998+
// If isUnique is true, then make sure that the last level is unique,
999+
// that is, when lvlRank == 1, the only compressed level is unique,
1000+
// and when lvlRank > 1, the last singleton is unique.
1001+
return !isUnique || isUniqueLvl(lvlRank - 1);
1002+
}
1003+
9781004
RankedTensorType
9791005
mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const {
9801006
SmallVector<LevelType> lvlTypes;

mlir/test/Dialect/SparseTensor/invalid_encoding.mlir

+85
Original file line numberDiff line numberDiff line change
@@ -443,3 +443,88 @@ func.func private @NOutOfM(%arg0: tensor<?x?x?xf64, #NOutOfM>) {
443443
func.func private @NOutOfM(%arg0: tensor<?x?x?xf64, #NOutOfM>) {
444444
return
445445
}
446+
447+
// -----
448+
449+
#CSR_ExpType = #sparse_tensor.encoding<{
450+
map = (d0, d1) -> (d0 : dense, d1 : compressed),
451+
posWidth = 32,
452+
crdWidth = 32,
453+
explicitVal = 1 : i32,
454+
implicitVal = 0.0 : f32
455+
}>
456+
457+
// expected-error@+1 {{explicit value type mismatch between encoding and tensor element type: 'i32' != 'f32'}}
458+
func.func private @sparse_csr(tensor<?x?xf32, #CSR_ExpType>)
459+
460+
// -----
461+
462+
#CSR_ImpType = #sparse_tensor.encoding<{
463+
map = (d0, d1) -> (d0 : dense, d1 : compressed),
464+
posWidth = 32,
465+
crdWidth = 32,
466+
explicitVal = 1 : i32,
467+
implicitVal = 0.0 : f32
468+
}>
469+
470+
// expected-error@+1 {{implicit value type mismatch between encoding and tensor element type: 'f32' != 'i32'}}
471+
func.func private @sparse_csr(tensor<?x?xi32, #CSR_ImpType>)
472+
473+
// -----
474+
475+
// expected-error@+1 {{expected a numeric value for explicitVal}}
476+
#CSR_ExpType = #sparse_tensor.encoding<{
477+
map = (d0, d1) -> (d0 : dense, d1 : compressed),
478+
posWidth = 32,
479+
crdWidth = 32,
480+
explicitVal = "str"
481+
}>
482+
func.func private @sparse_csr(tensor<?x?xi32, #CSR_ExpType>)
483+
484+
// -----
485+
486+
// expected-error@+1 {{expected a numeric value for implicitVal}}
487+
#CSR_ImpType = #sparse_tensor.encoding<{
488+
map = (d0, d1) -> (d0 : dense, d1 : compressed),
489+
posWidth = 32,
490+
crdWidth = 32,
491+
implicitVal = "str"
492+
}>
493+
func.func private @sparse_csr(tensor<?x?xi32, #CSR_ImpType>)
494+
495+
// -----
496+
497+
#CSR_ImpVal = #sparse_tensor.encoding<{
498+
map = (d0, d1) -> (d0 : dense, d1 : compressed),
499+
posWidth = 32,
500+
crdWidth = 32,
501+
implicitVal = 1 : i32
502+
}>
503+
504+
// expected-error@+1 {{implicit value must be zero}}
505+
func.func private @sparse_csr(tensor<?x?xi32, #CSR_ImpVal>)
506+
507+
// -----
508+
509+
#CSR_ImpVal = #sparse_tensor.encoding<{
510+
map = (d0, d1) -> (d0 : dense, d1 : compressed),
511+
posWidth = 32,
512+
crdWidth = 32,
513+
implicitVal = 1.0 : f32
514+
}>
515+
516+
// expected-error@+1 {{implicit value must be zero}}
517+
func.func private @sparse_csr(tensor<?x?xf32, #CSR_ImpVal>)
518+
519+
// -----
520+
521+
#CSR_OnlyOnes = #sparse_tensor.encoding<{
522+
map = (d0, d1) -> (d0 : dense, d1 : compressed),
523+
posWidth = 64,
524+
crdWidth = 64,
525+
explicitVal = #complex.number<:f32 1.0, 0.0>,
526+
implicitVal = #complex.number<:f32 1.0, 0.0>
527+
}>
528+
529+
// expected-error@+1 {{implicit value must be zero}}
530+
func.func private @sparse_csr(tensor<?x?xcomplex<f32>, #CSR_OnlyOnes>)

0 commit comments

Comments
 (0)