xref: /llvm-project/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h (revision fe8a62c46365f5ef0c15df2265bbf0026d0a4047)
1 //===- StructuredOpsUtils.h - Utilities used by structured ops --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file define utilities that operate on builtin types and are
10 // useful across multiple dialects that use structured ops abstractions. These
11 // abstractions consist of define custom operations that encode and transport
12 // information about their semantics (e.g. type of iterators like parallel,
13 // reduction, etc..) as attributes.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #ifndef MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
18 #define MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
19 
20 #include "mlir/IR/AffineMap.h"
21 #include "mlir/IR/BuiltinAttributes.h"
22 #include "mlir/IR/Location.h"
23 #include "mlir/IR/TypeRange.h"
24 #include "mlir/Support/LLVM.h"
25 
26 // Pull in all enum type definitions and utility function declarations.
27 #include "mlir/Dialect/Utils/DialectUtilsEnums.h.inc"
28 
29 namespace mlir {
30 
31 class OpBuilder;
32 class RewriterBase;
33 
34 /// Tests whether the given maps describe a row major matmul. The test is
35 /// permutation-invariant. Note that this only checks the affine maps from an
36 /// operation, so does not perform any checks on the math being performed within
37 /// the reduction.
38 bool isRowMajorMatmul(ArrayAttr indexingMaps);
39 
40 /// Tests whether the given maps describe a column major matmul. The test is
41 /// permutation-invariant. Note that this only checks the affine maps from an
42 /// operation, so does not perform any checks on the math being performed within
43 /// the reduction.
44 bool isColumnMajorMatmul(ArrayAttr indexingMaps);
45 
46 /// Tests whether the given maps describe a row major batch matmul. The test is
47 /// permutation-invariant. Note that this only checks the affine maps from an
48 /// operation, so does not perform any checks on the math being performed within
49 /// the reduction.
50 bool isRowMajorBatchMatmul(ArrayAttr indexingMaps);
51 
52 /// Tests whether the given maps describe a vector matrix multiplication. The
53 /// test is permutation-invariant. Note that this only checks the affine maps
54 /// from an operation, so does not perform any checks on the math being
55 /// performed within the reduction.
56 bool isVecmat(ArrayAttr indexingMaps);
57 
58 /// Tests whether the given maps describe a batch vector matrix multiplication.
59 /// The test is permutation-invariant. Note that this only checks the affine
60 /// maps from an operation, so does not perform any checks on the math being
61 /// performed within the reduction.
62 bool isBatchVecmat(ArrayAttr indexingMaps);
63 
64 /// Tests whether the given maps describe a matrix vector multiplication. The
65 /// test is permutation-invariant. Note that this only checks the affine maps
66 /// from an operation, so does not perform any checks on the math being
67 /// performed within the reduction.
68 bool isMatvec(ArrayAttr indexingMaps);
69 
70 /// Tests whether the given maps describe a batch matrix vector multiplication.
71 /// The test is permutation-invariant. Note that this only checks the affine
72 /// maps from an operation, so does not perform any checks on the math being
73 /// performed within the reduction.
74 bool isBatchMatvec(ArrayAttr indexingMaps);
75 
76 /// Return positions in `iteratorTypes` that match `iteratorTypeName`.
findPositionsOfType(ArrayRef<utils::IteratorType> iteratorTypes,utils::IteratorType iteratorTypeName,SmallVectorImpl<unsigned> & res)77 inline void findPositionsOfType(ArrayRef<utils::IteratorType> iteratorTypes,
78                                 utils::IteratorType iteratorTypeName,
79                                 SmallVectorImpl<unsigned> &res) {
80   for (const auto &en : llvm::enumerate(iteratorTypes)) {
81     if (en.value() == iteratorTypeName)
82       res.push_back(en.index());
83   }
84 }
85 
86 /// Helper StructuredGenerator class to manipulate and rewrite ops with
87 /// `StructuredOpInterface`. This is templated for now because VectorOps do not
88 /// yet implement the StructuredOpInterface itself.
89 template <typename StructuredOpInterface, typename IteratorTypeT>
90 class StructuredGenerator {
91 public:
92   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
93 
94   struct IteratorType {
IteratorTypeIteratorType95     IteratorType(IteratorTypeT iter) : iter(iter) {}
isOfTypeIteratorType96     bool isOfType(IteratorTypeT expectedIter) const {
97       return expectedIter == iter;
98     }
99     IteratorTypeT iter;
100   };
101   struct Par : public IteratorType {
ParPar102     Par() : IteratorType(IteratorTypeT::parallel) {}
103   };
104   struct Red : public IteratorType {
RedRed105     Red() : IteratorType(IteratorTypeT::reduction) {}
106   };
107 
StructuredGenerator(RewriterBase & rewriter,StructuredOpInterface op)108   StructuredGenerator(RewriterBase &rewriter, StructuredOpInterface op)
109       : rewriter(rewriter), ctx(op.getContext()), loc(op.getLoc()),
110         iterators(op.getIteratorTypesArray()), maps(op.getIndexingMapsArray()),
111         op(op) {}
112 
iters(ArrayRef<IteratorType> its)113   bool iters(ArrayRef<IteratorType> its) {
114     if (its.size() != iterators.size())
115       return false;
116     for (int i = 0, e = its.size(); i != e; ++i) {
117       if (!its[i].isOfType(iterators[i]))
118         return false;
119     }
120     return true;
121   }
122 
layout(MapList l)123   bool layout(MapList l) {
124     auto infer = [&](MapList m) {
125       return AffineMap::inferFromExprList(m, ctx);
126     };
127     return maps == infer(l);
128   }
129 
130 protected:
131   RewriterBase &rewriter;
132   MLIRContext *ctx;
133   Location loc;
134   SmallVector<IteratorTypeT> iterators;
135   SmallVector<AffineMap, 4> maps;
136   Operation *op;
137 };
138 
139 // Clone the current operation with the operands. This is used to abstract away
140 // the optional underlying region creation.
141 // Note: this is a true builder that notifies the OpBuilder listener.
142 Operation *clone(OpBuilder &b, Operation *op, TypeRange newResultTypes,
143                  ValueRange newOperands);
144 template <typename OpT>
clone(OpBuilder & b,OpT op,TypeRange newResultTypes,ValueRange newOperands)145 OpT clone(OpBuilder &b, OpT op, TypeRange newResultTypes,
146           ValueRange newOperands) {
147   return cast<OpT>(clone(b, op.getOperation(), newResultTypes, newOperands));
148 }
149 
150 // Clone the current operation with the operands but leave the regions empty.
151 // Note: this is a true builder that notifies the OpBuilder listener.
152 Operation *cloneWithoutRegions(OpBuilder &b, Operation *op,
153                                TypeRange newResultTypes,
154                                ValueRange newOperands);
155 
156 // Get the list of attributes associated with the op, ignoring
157 // those with the provided name.
158 SmallVector<NamedAttribute>
159 getPrunedAttributeList(Operation *op, ArrayRef<StringRef> elidedAttrs);
160 
161 } // namespace mlir
162 
163 #endif // MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
164