@@ -104,7 +104,7 @@ void StorageLayout::foreachField(
104
104
callback) const {
105
105
const auto lvlTypes = enc.getLvlTypes ();
106
106
const Level lvlRank = enc.getLvlRank ();
107
- SmallVector<COOSegment> cooSegs = SparseTensorType ( enc) .getCOOSegments ();
107
+ SmallVector<COOSegment> cooSegs = enc.getCOOSegments ();
108
108
FieldIndex fieldIdx = kDataFieldStartingIdx ;
109
109
110
110
ArrayRef cooSegsRef = cooSegs;
@@ -211,7 +211,7 @@ StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind,
211
211
unsigned stride = 1 ;
212
212
if (kind == SparseTensorFieldKind::CrdMemRef) {
213
213
assert (lvl.has_value ());
214
- const Level cooStart = SparseTensorType ( enc) .getAoSCOOStart ();
214
+ const Level cooStart = enc.getAoSCOOStart ();
215
215
const Level lvlRank = enc.getLvlRank ();
216
216
if (lvl.value () >= cooStart && lvl.value () < lvlRank) {
217
217
lvl = cooStart;
@@ -912,46 +912,53 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
912
912
return emitError ()
913
913
<< " dimension-rank mismatch between encoding and tensor shape: "
914
914
<< getDimRank () << " != " << dimRank;
915
+ if (auto expVal = getExplicitVal ()) {
916
+ Type attrType = llvm::dyn_cast<TypedAttr>(expVal).getType ();
917
+ if (attrType != elementType) {
918
+ return emitError () << " explicit value type mismatch between encoding and "
919
+ << " tensor element type: " << attrType
920
+ << " != " << elementType;
921
+ }
922
+ }
923
+ if (auto impVal = getImplicitVal ()) {
924
+ Type attrType = llvm::dyn_cast<TypedAttr>(impVal).getType ();
925
+ if (attrType != elementType) {
926
+ return emitError () << " implicit value type mismatch between encoding and "
927
+ << " tensor element type: " << attrType
928
+ << " != " << elementType;
929
+ }
930
+ // Currently, we only support zero as the implicit value.
931
+ auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);
932
+ auto impIntVal = llvm::dyn_cast<IntegerAttr>(impVal);
933
+ auto impComplexVal = llvm::dyn_cast<complex::NumberAttr>(impVal);
934
+ if ((impFVal && impFVal.getValue ().isNonZero ()) ||
935
+ (impIntVal && !impIntVal.getValue ().isZero ()) ||
936
+ (impComplexVal && (impComplexVal.getImag ().isNonZero () ||
937
+ impComplexVal.getReal ().isNonZero ()))) {
938
+ return emitError () << " implicit value must be zero" ;
939
+ }
940
+ }
915
941
return success ();
916
942
}
917
943
918
- // ===----------------------------------------------------------------------===//
919
- // SparseTensorType Methods.
920
- // ===----------------------------------------------------------------------===//
921
-
922
- bool mlir::sparse_tensor::SparseTensorType::isCOOType (Level startLvl,
923
- bool isUnique) const {
924
- if (!hasEncoding ())
925
- return false ;
926
- if (!isCompressedLvl (startLvl) && !isLooseCompressedLvl (startLvl))
927
- return false ;
928
- for (Level l = startLvl + 1 ; l < lvlRank; ++l)
929
- if (!isSingletonLvl (l))
930
- return false ;
931
- // If isUnique is true, then make sure that the last level is unique,
932
- // that is, when lvlRank == 1, the only compressed level is unique,
933
- // and when lvlRank > 1, the last singleton is unique.
934
- return !isUnique || isUniqueLvl (lvlRank - 1 );
935
- }
936
-
937
- Level mlir::sparse_tensor::SparseTensorType::getAoSCOOStart () const {
944
+ Level mlir::sparse_tensor::SparseTensorEncodingAttr::getAoSCOOStart () const {
938
945
SmallVector<COOSegment> coo = getCOOSegments ();
939
946
assert (coo.size () == 1 || coo.empty ());
940
947
if (!coo.empty () && coo.front ().isAoS ()) {
941
948
return coo.front ().lvlRange .first ;
942
949
}
943
- return lvlRank ;
950
+ return getLvlRank () ;
944
951
}
945
952
946
953
SmallVector<COOSegment>
947
- mlir::sparse_tensor::SparseTensorType ::getCOOSegments () const {
954
+ mlir::sparse_tensor::SparseTensorEncodingAttr ::getCOOSegments () const {
948
955
SmallVector<COOSegment> ret;
949
- if (! hasEncoding () || lvlRank <= 1 )
956
+ if (getLvlRank () <= 1 )
950
957
return ret;
951
958
952
959
ArrayRef<LevelType> lts = getLvlTypes ();
953
960
Level l = 0 ;
954
- while (l < lvlRank ) {
961
+ while (l < getLvlRank () ) {
955
962
auto lt = lts[l];
956
963
if (lt.isa <LevelFormat::Compressed, LevelFormat::LooseCompressed>()) {
957
964
auto cur = lts.begin () + l;
@@ -975,6 +982,25 @@ mlir::sparse_tensor::SparseTensorType::getCOOSegments() const {
975
982
return ret;
976
983
}
977
984
985
+ // ===----------------------------------------------------------------------===//
986
+ // SparseTensorType Methods.
987
+ // ===----------------------------------------------------------------------===//
988
+
989
+ bool mlir::sparse_tensor::SparseTensorType::isCOOType (Level startLvl,
990
+ bool isUnique) const {
991
+ if (!hasEncoding ())
992
+ return false ;
993
+ if (!isCompressedLvl (startLvl) && !isLooseCompressedLvl (startLvl))
994
+ return false ;
995
+ for (Level l = startLvl + 1 ; l < lvlRank; ++l)
996
+ if (!isSingletonLvl (l))
997
+ return false ;
998
+ // If isUnique is true, then make sure that the last level is unique,
999
+ // that is, when lvlRank == 1, the only compressed level is unique,
1000
+ // and when lvlRank > 1, the last singleton is unique.
1001
+ return !isUnique || isUniqueLvl (lvlRank - 1 );
1002
+ }
1003
+
978
1004
RankedTensorType
979
1005
mlir::sparse_tensor::SparseTensorType::getCOOType (bool ordered) const {
980
1006
SmallVector<LevelType> lvlTypes;
0 commit comments