xref: /llvm-project/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td (revision 481bd5d416df7a1d24e18cc81ae782e8701de965)
1//===- SparseTensorTypes.td - Sparse tensor dialect types ------*- tablegen -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef SPARSETENSOR_TYPES
10#define SPARSETENSOR_TYPES
11
12include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td"
13include "mlir/Dialect/SparseTensor/IR/SparseTensorBase.td"
14
15//===----------------------------------------------------------------------===//
16// Base class.
17//===----------------------------------------------------------------------===//
18
19// Base class for Builtin dialect types.
20class SparseTensor_Type<string name, list<Trait> traits = [],
21                   string baseCppClass = "::mlir::Type">
22    : TypeDef<SparseTensor_Dialect, name, traits, baseCppClass> {}
23
24//===----------------------------------------------------------------------===//
25// Sparse Tensor Dialect Types.
26//===----------------------------------------------------------------------===//
27
28def SparseTensor_StorageSpecifier : SparseTensor_Type<"StorageSpecifier"> {
29  let mnemonic = "storage_specifier";
30  let summary = "Structured metadata for sparse tensor low-level storage scheme";
31
32  let description = [{
33    Values with storage_specifier types represent aggregated storage scheme
34    metadata for the given sparse tensor encoding.  It currently holds
35    a set of values for level-sizes, coordinate arrays, position arrays,
36    and value array.  Note that the type is not yet stable and subject to
37    change in the near future.
38
39    Examples:
40
41    ```mlir
42    // A storage specifier that can be used to store storage scheme metadata from CSR matrix.
43    !storage_specifier<#CSR>
44    ```
45  }];
46
47  let parameters = (ins SparseTensorEncodingAttr : $encoding);
48  let builders = [
49    TypeBuilder<(ins "SparseTensorEncodingAttr":$encoding)>,
50    TypeBuilderWithInferredContext<(ins "SparseTensorEncodingAttr":$encoding), [{
51      return get(encoding.getContext(), encoding);
52    }]>,
53    TypeBuilderWithInferredContext<(ins "Type":$type), [{
54      return get(getSparseTensorEncoding(type));
55    }]>,
56    TypeBuilderWithInferredContext<(ins "Value":$tensor), [{
57      return get(tensor.getType());
58    }]>
59  ];
60
61  // We skipped the default builder that simply takes the input sparse tensor encoding
62  // attribute since we need to normalize the dimension level type and remove unrelated
63  // fields that are irrelavant to sparse tensor storage scheme.
64  let skipDefaultBuilders = 1;
65  let assemblyFormat="`<` qualified($encoding) `>`";
66}
67
68def IsSparseTensorStorageSpecifierTypePred
69    : CPred<"::llvm::isa<::mlir::sparse_tensor::StorageSpecifierType>($_self)">;
70
71def SparseTensorStorageSpecifier
72    : Type<CPred<"::llvm::isa<::mlir::sparse_tensor::StorageSpecifierType>($_self)">, "metadata",
73          "::mlir::sparse_tensor::StorageSpecifierType">;
74
75//===----------------------------------------------------------------------===//
76// Sparse Tensor Iteration Types.
77//===----------------------------------------------------------------------===//
78
79def SparseTensor_IterSpace : SparseTensor_Type<"IterSpace"> {
80  let mnemonic = "iter_space";
81
82  let description = [{
83    A sparse iteration space that represents an abstract N-D (sparse) iteration space
84    extracted from a sparse tensor, i.e., a set of (crd_0, crd_1, ..., crd_N) for
85    every stored element (usually nonzeros) in a sparse tensor between the specified
86    [$loLvl, $hiLvl) levels.
87
88    Examples:
89
90    ```mlir
91    // An iteration space extracted from a CSR tensor between levels [0, 2).
92    !iter_space<#CSR, lvls = 0 to 2>
93    ```
94  }];
95
96  let parameters = (ins
97     SparseTensorEncodingAttr : $encoding,
98     "Level" : $loLvl,
99     "Level" : $hiLvl
100  );
101
102  let extraClassDeclaration = [{
103     /// The the dimension of the iteration space.
104     unsigned getSpaceDim() const {
105       return getHiLvl() - getLoLvl();
106     }
107
108     /// Get the level types for the iteration space.
109     ArrayRef<LevelType> getLvlTypes() const {
110       return getEncoding().getLvlTypes().slice(getLoLvl(), getSpaceDim());
111     }
112
113     /// Whether the iteration space is unique (i.e., no duplicated coordinate).
114     bool isUnique() {
115       return !getLvlTypes().back().isa<LevelPropNonDefault::Nonunique>();
116     }
117
118     /// Get the corresponding iterator type.
119     ::mlir::sparse_tensor::IteratorType getIteratorType() const;
120  }];
121
122  let assemblyFormat="`<` $encoding `,` `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) `>`";
123}
124
125def SparseTensor_Iterator : SparseTensor_Type<"Iterator"> {
126  let mnemonic = "iterator";
127
128  let description = [{
129    An iterator that points to the current element in the corresponding iteration space.
130
131    Examples:
132
133    ```mlir
134    // An iterator that iterates over a iteration space of type `!iter_space<#CSR, lvls = 0 to 2>`
135    !iterator<#CSR, lvls = 0 to 2>
136    ```
137  }];
138
139  let parameters = (ins
140     SparseTensorEncodingAttr : $encoding,
141     "Level" : $loLvl,
142     "Level" : $hiLvl
143  );
144
145  let extraClassDeclaration = [{
146     /// Get the corresponding iteration space type.
147     ::mlir::sparse_tensor::IterSpaceType getIterSpaceType() const;
148
149     unsigned getSpaceDim() const { return getIterSpaceType().getSpaceDim(); }
150     ArrayRef<LevelType> getLvlTypes() const { return getIterSpaceType().getLvlTypes(); }
151     bool isUnique() { return getIterSpaceType().isUnique(); }
152  }];
153
154  let assemblyFormat="`<` $encoding `,` `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) `>`";
155}
156
157def IsSparseSparseIterSpaceTypePred
158    : CPred<"::llvm::isa<::mlir::sparse_tensor::IterSpaceType>($_self)">;
159
160def IsSparseSparseIteratorTypePred
161    : CPred<"::llvm::isa<::mlir::sparse_tensor::IteratorType>($_self)">;
162
163def AnySparseIterSpace
164    : Type<IsSparseSparseIterSpaceTypePred, "sparse iteration space",
165          "::mlir::sparse_tensor::IterSpaceType">;
166
167def AnySparseIterator
168    : Type<IsSparseSparseIteratorTypePred, "sparse iterator",
169          "::mlir::sparse_tensor::IteratorType">;
170
171
172#endif // SPARSETENSOR_TYPES
173