1//===- SparseTensorTypes.td - Sparse tensor dialect types ------*- tablegen -*-===// 2// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6// 7//===----------------------------------------------------------------------===// 8 9#ifndef SPARSETENSOR_TYPES 10#define SPARSETENSOR_TYPES 11 12include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td" 13include "mlir/Dialect/SparseTensor/IR/SparseTensorBase.td" 14 15//===----------------------------------------------------------------------===// 16// Base class. 17//===----------------------------------------------------------------------===// 18 19// Base class for Builtin dialect types. 20class SparseTensor_Type<string name, list<Trait> traits = [], 21 string baseCppClass = "::mlir::Type"> 22 : TypeDef<SparseTensor_Dialect, name, traits, baseCppClass> {} 23 24//===----------------------------------------------------------------------===// 25// Sparse Tensor Dialect Types. 26//===----------------------------------------------------------------------===// 27 28def SparseTensor_StorageSpecifier : SparseTensor_Type<"StorageSpecifier"> { 29 let mnemonic = "storage_specifier"; 30 let summary = "Structured metadata for sparse tensor low-level storage scheme"; 31 32 let description = [{ 33 Values with storage_specifier types represent aggregated storage scheme 34 metadata for the given sparse tensor encoding. It currently holds 35 a set of values for level-sizes, coordinate arrays, position arrays, 36 and value array. Note that the type is not yet stable and subject to 37 change in the near future. 38 39 Examples: 40 41 ```mlir 42 // A storage specifier that can be used to store storage scheme metadata from CSR matrix. 43 !storage_specifier<#CSR> 44 ``` 45 }]; 46 47 let parameters = (ins SparseTensorEncodingAttr : $encoding); 48 let builders = [ 49 TypeBuilder<(ins "SparseTensorEncodingAttr":$encoding)>, 50 TypeBuilderWithInferredContext<(ins "SparseTensorEncodingAttr":$encoding), [{ 51 return get(encoding.getContext(), encoding); 52 }]>, 53 TypeBuilderWithInferredContext<(ins "Type":$type), [{ 54 return get(getSparseTensorEncoding(type)); 55 }]>, 56 TypeBuilderWithInferredContext<(ins "Value":$tensor), [{ 57 return get(tensor.getType()); 58 }]> 59 ]; 60 61 // We skipped the default builder that simply takes the input sparse tensor encoding 62 // attribute since we need to normalize the dimension level type and remove unrelated 63 // fields that are irrelavant to sparse tensor storage scheme. 64 let skipDefaultBuilders = 1; 65 let assemblyFormat="`<` qualified($encoding) `>`"; 66} 67 68def IsSparseTensorStorageSpecifierTypePred 69 : CPred<"::llvm::isa<::mlir::sparse_tensor::StorageSpecifierType>($_self)">; 70 71def SparseTensorStorageSpecifier 72 : Type<CPred<"::llvm::isa<::mlir::sparse_tensor::StorageSpecifierType>($_self)">, "metadata", 73 "::mlir::sparse_tensor::StorageSpecifierType">; 74 75//===----------------------------------------------------------------------===// 76// Sparse Tensor Iteration Types. 77//===----------------------------------------------------------------------===// 78 79def SparseTensor_IterSpace : SparseTensor_Type<"IterSpace"> { 80 let mnemonic = "iter_space"; 81 82 let description = [{ 83 A sparse iteration space that represents an abstract N-D (sparse) iteration space 84 extracted from a sparse tensor, i.e., a set of (crd_0, crd_1, ..., crd_N) for 85 every stored element (usually nonzeros) in a sparse tensor between the specified 86 [$loLvl, $hiLvl) levels. 87 88 Examples: 89 90 ```mlir 91 // An iteration space extracted from a CSR tensor between levels [0, 2). 92 !iter_space<#CSR, lvls = 0 to 2> 93 ``` 94 }]; 95 96 let parameters = (ins 97 SparseTensorEncodingAttr : $encoding, 98 "Level" : $loLvl, 99 "Level" : $hiLvl 100 ); 101 102 let extraClassDeclaration = [{ 103 /// The the dimension of the iteration space. 104 unsigned getSpaceDim() const { 105 return getHiLvl() - getLoLvl(); 106 } 107 108 /// Get the level types for the iteration space. 109 ArrayRef<LevelType> getLvlTypes() const { 110 return getEncoding().getLvlTypes().slice(getLoLvl(), getSpaceDim()); 111 } 112 113 /// Whether the iteration space is unique (i.e., no duplicated coordinate). 114 bool isUnique() { 115 return !getLvlTypes().back().isa<LevelPropNonDefault::Nonunique>(); 116 } 117 118 /// Get the corresponding iterator type. 119 ::mlir::sparse_tensor::IteratorType getIteratorType() const; 120 }]; 121 122 let assemblyFormat="`<` $encoding `,` `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) `>`"; 123} 124 125def SparseTensor_Iterator : SparseTensor_Type<"Iterator"> { 126 let mnemonic = "iterator"; 127 128 let description = [{ 129 An iterator that points to the current element in the corresponding iteration space. 130 131 Examples: 132 133 ```mlir 134 // An iterator that iterates over a iteration space of type `!iter_space<#CSR, lvls = 0 to 2>` 135 !iterator<#CSR, lvls = 0 to 2> 136 ``` 137 }]; 138 139 let parameters = (ins 140 SparseTensorEncodingAttr : $encoding, 141 "Level" : $loLvl, 142 "Level" : $hiLvl 143 ); 144 145 let extraClassDeclaration = [{ 146 /// Get the corresponding iteration space type. 147 ::mlir::sparse_tensor::IterSpaceType getIterSpaceType() const; 148 149 unsigned getSpaceDim() const { return getIterSpaceType().getSpaceDim(); } 150 ArrayRef<LevelType> getLvlTypes() const { return getIterSpaceType().getLvlTypes(); } 151 bool isUnique() { return getIterSpaceType().isUnique(); } 152 }]; 153 154 let assemblyFormat="`<` $encoding `,` `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) `>`"; 155} 156 157def IsSparseSparseIterSpaceTypePred 158 : CPred<"::llvm::isa<::mlir::sparse_tensor::IterSpaceType>($_self)">; 159 160def IsSparseSparseIteratorTypePred 161 : CPred<"::llvm::isa<::mlir::sparse_tensor::IteratorType>($_self)">; 162 163def AnySparseIterSpace 164 : Type<IsSparseSparseIterSpaceTypePred, "sparse iteration space", 165 "::mlir::sparse_tensor::IterSpaceType">; 166 167def AnySparseIterator 168 : Type<IsSparseSparseIteratorTypePred, "sparse iterator", 169 "::mlir::sparse_tensor::IteratorType">; 170 171 172#endif // SPARSETENSOR_TYPES 173