xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- LowerVectorMask.cpp - Lower 'vector.mask' operation ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.mask' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/Vector/IR/VectorOps.h"
17 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
18 #include "mlir/Dialect/Vector/Transforms/Passes.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 
22 #define DEBUG_TYPE "lower-vector-mask"
23 
24 namespace mlir {
25 namespace vector {
26 #define GEN_PASS_DEF_LOWERVECTORMASKPASS
27 #include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
28 } // namespace vector
29 } // namespace mlir
30 
31 using namespace mlir;
32 using namespace mlir::vector;
33 
34 //===----------------------------------------------------------------------===//
35 // populateVectorMaskOpLoweringPatterns
36 //===----------------------------------------------------------------------===//
37 
38 namespace {
39 /// Progressive lowering of CreateMaskOp.
40 /// One:
41 ///   %x = vector.create_mask %a, ... : vector<dx...>
42 /// is replaced by:
43 ///   %l = vector.create_mask ... : vector<...>  ; one lower rank
44 ///   %0 = arith.cmpi "slt", %ci, %a       |
45 ///   %1 = select %0, %l, %zeroes    |
46 ///   %r = vector.insert %1, %pr [i] | d-times
47 ///   %x = ....
48 /// until a one-dimensional vector is reached.
49 class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
50 public:
51   using OpRewritePattern::OpRewritePattern;
52 
53   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
54                                 PatternRewriter &rewriter) const override {
55     auto dstType = cast<VectorType>(op.getResult().getType());
56     int64_t rank = dstType.getRank();
57     if (rank <= 1)
58       return rewriter.notifyMatchFailure(
59           op, "0-D and 1-D vectors are handled separately");
60 
61     if (dstType.getScalableDims().front())
62       return rewriter.notifyMatchFailure(
63           op, "Cannot unroll leading scalable dim in dstType");
64 
65     auto loc = op.getLoc();
66     int64_t dim = dstType.getDimSize(0);
67     Value idx = op.getOperand(0);
68 
69     VectorType lowType = VectorType::Builder(dstType).dropDim(0);
70     Value trueVal = rewriter.create<vector::CreateMaskOp>(
71         loc, lowType, op.getOperands().drop_front());
72     Value falseVal = rewriter.create<arith::ConstantOp>(
73         loc, lowType, rewriter.getZeroAttr(lowType));
74     Value result = rewriter.create<arith::ConstantOp>(
75         loc, dstType, rewriter.getZeroAttr(dstType));
76     for (int64_t d = 0; d < dim; d++) {
77       Value bnd =
78           rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d));
79       Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
80                                                  bnd, idx);
81       Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal);
82       result = rewriter.create<vector::InsertOp>(loc, sel, result, d);
83     }
84     rewriter.replaceOp(op, result);
85     return success();
86   }
87 };
88 
89 /// Progressive lowering of ConstantMaskOp.
90 /// One:
91 ///   %x = vector.constant_mask [a,b]
92 /// is replaced by:
93 ///   %z = zero-result
94 ///   %l = vector.constant_mask [b]
95 ///   %4 = vector.insert %l, %z[0]
96 ///   ..
97 ///   %x = vector.insert %l, %..[a-1]
98 /// until a one-dimensional vector is reached. All these operations
99 /// will be folded at LLVM IR level.
100 class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
101 public:
102   using OpRewritePattern::OpRewritePattern;
103 
104   LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
105                                 PatternRewriter &rewriter) const override {
106     auto loc = op.getLoc();
107     auto dstType = op.getType();
108     auto dimSizes = op.getMaskDimSizes();
109     int64_t rank = dstType.getRank();
110 
111     if (rank == 0) {
112       assert(dimSizes.size() == 1 &&
113              "Expected exactly one dim size for a 0-D vector");
114       bool value = dimSizes.front() == 1;
115       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
116           op, dstType,
117           DenseIntElementsAttr::get(VectorType::get({}, rewriter.getI1Type()),
118                                     value));
119       return success();
120     }
121 
122     int64_t trueDimSize = dimSizes.front();
123 
124     if (rank == 1) {
125       if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) {
126         // Use constant splat for 'all set' or 'none set' dims.
127         // This produces correct code for scalable dimensions (it will lower to
128         // a constant splat).
129         rewriter.replaceOpWithNewOp<arith::ConstantOp>(
130             op, DenseElementsAttr::get(dstType, trueDimSize != 0));
131       } else {
132         // Express constant 1-D case in explicit vector form:
133         //   [T,..,T,F,..,F].
134         // Note: The verifier would reject this case for scalable vectors.
135         SmallVector<bool> values(dstType.getDimSize(0), false);
136         for (int64_t d = 0; d < trueDimSize; d++)
137           values[d] = true;
138         rewriter.replaceOpWithNewOp<arith::ConstantOp>(
139             op, dstType, rewriter.getBoolVectorAttr(values));
140       }
141       return success();
142     }
143 
144     if (dstType.getScalableDims().front())
145       return rewriter.notifyMatchFailure(
146           op, "Cannot unroll leading scalable dim in dstType");
147 
148     VectorType lowType = VectorType::Builder(dstType).dropDim(0);
149     Value trueVal = rewriter.create<vector::ConstantMaskOp>(
150         loc, lowType, dimSizes.drop_front());
151     Value result = rewriter.create<arith::ConstantOp>(
152         loc, dstType, rewriter.getZeroAttr(dstType));
153     for (int64_t d = 0; d < trueDimSize; d++)
154       result = rewriter.create<vector::InsertOp>(loc, trueVal, result, d);
155 
156     rewriter.replaceOp(op, result);
157     return success();
158   }
159 };
160 } // namespace
161 
162 void mlir::vector::populateVectorMaskOpLoweringPatterns(
163     RewritePatternSet &patterns, PatternBenefit benefit) {
164   patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
165       patterns.getContext(), benefit);
166 }
167 
168 //===----------------------------------------------------------------------===//
169 // populateVectorMaskLoweringPatternsForSideEffectingOps
170 //===----------------------------------------------------------------------===//
171 
172 namespace {
173 
174 /// The `MaskOpRewritePattern` implements a pattern that follows a two-fold
175 /// matching:
176 ///   1. It matches a `vector.mask` operation.
177 ///   2. It invokes `matchAndRewriteMaskableOp` on `MaskableOpInterface` nested
178 ///      in the matched `vector.mask` operation.
179 ///
180 /// It is required that the replacement op in the pattern replaces the
181 /// `vector.mask` operation and not the nested `MaskableOpInterface`. This
182 /// approach allows having patterns that "stop" at every `vector.mask` operation
183 /// and actually match the traits of its the nested `MaskableOpInterface`.
184 template <class SourceOp>
185 struct MaskOpRewritePattern : OpRewritePattern<MaskOp> {
186   using OpRewritePattern<MaskOp>::OpRewritePattern;
187 
188 private:
189   LogicalResult matchAndRewrite(MaskOp maskOp,
190                                 PatternRewriter &rewriter) const final {
191     auto maskableOp = cast_or_null<MaskableOpInterface>(maskOp.getMaskableOp());
192     if (!maskableOp)
193       return failure();
194     SourceOp sourceOp = dyn_cast<SourceOp>(maskableOp.getOperation());
195     if (!sourceOp)
196       return failure();
197 
198     return matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
199   }
200 
201 protected:
202   virtual LogicalResult
203   matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp,
204                             PatternRewriter &rewriter) const = 0;
205 };
206 
207 /// Lowers a masked `vector.transfer_read` operation.
208 struct MaskedTransferReadOpPattern
209     : public MaskOpRewritePattern<TransferReadOp> {
210 public:
211   using MaskOpRewritePattern<TransferReadOp>::MaskOpRewritePattern;
212 
213   LogicalResult
214   matchAndRewriteMaskableOp(TransferReadOp readOp, MaskingOpInterface maskingOp,
215                             PatternRewriter &rewriter) const override {
216     // TODO: The 'vector.mask' passthru is a vector and 'vector.transfer_read'
217     // expects a scalar. We could only lower one to the other for cases where
218     // the passthru is a broadcast of a scalar.
219     if (maskingOp.hasPassthru())
220       return rewriter.notifyMatchFailure(
221           maskingOp, "Can't lower passthru to vector.transfer_read");
222 
223     // Replace the `vector.mask` operation.
224     rewriter.replaceOpWithNewOp<TransferReadOp>(
225         maskingOp.getOperation(), readOp.getVectorType(), readOp.getSource(),
226         readOp.getIndices(), readOp.getPermutationMap(), readOp.getPadding(),
227         maskingOp.getMask(), readOp.getInBounds());
228     return success();
229   }
230 };
231 
232 /// Lowers a masked `vector.transfer_write` operation.
233 struct MaskedTransferWriteOpPattern
234     : public MaskOpRewritePattern<TransferWriteOp> {
235 public:
236   using MaskOpRewritePattern<TransferWriteOp>::MaskOpRewritePattern;
237 
238   LogicalResult
239   matchAndRewriteMaskableOp(TransferWriteOp writeOp,
240                             MaskingOpInterface maskingOp,
241                             PatternRewriter &rewriter) const override {
242     Type resultType =
243         writeOp.getResult() ? writeOp.getResult().getType() : Type();
244 
245     // Replace the `vector.mask` operation.
246     rewriter.replaceOpWithNewOp<TransferWriteOp>(
247         maskingOp.getOperation(), resultType, writeOp.getVector(),
248         writeOp.getSource(), writeOp.getIndices(), writeOp.getPermutationMap(),
249         maskingOp.getMask(), writeOp.getInBounds());
250     return success();
251   }
252 };
253 
254 /// Lowers a masked `vector.gather` operation.
255 struct MaskedGatherOpPattern : public MaskOpRewritePattern<GatherOp> {
256 public:
257   using MaskOpRewritePattern<GatherOp>::MaskOpRewritePattern;
258 
259   LogicalResult
260   matchAndRewriteMaskableOp(GatherOp gatherOp, MaskingOpInterface maskingOp,
261                             PatternRewriter &rewriter) const override {
262     Value passthru = maskingOp.hasPassthru()
263                          ? maskingOp.getPassthru()
264                          : rewriter.create<arith::ConstantOp>(
265                                gatherOp.getLoc(),
266                                rewriter.getZeroAttr(gatherOp.getVectorType()));
267 
268     // Replace the `vector.mask` operation.
269     rewriter.replaceOpWithNewOp<GatherOp>(
270         maskingOp.getOperation(), gatherOp.getVectorType(), gatherOp.getBase(),
271         gatherOp.getIndices(), gatherOp.getIndexVec(), maskingOp.getMask(),
272         passthru);
273     return success();
274   }
275 };
276 
277 struct LowerVectorMaskPass
278     : public vector::impl::LowerVectorMaskPassBase<LowerVectorMaskPass> {
279   using Base::Base;
280 
281   void runOnOperation() override {
282     Operation *op = getOperation();
283     MLIRContext *context = op->getContext();
284 
285     RewritePatternSet loweringPatterns(context);
286     populateVectorMaskLoweringPatternsForSideEffectingOps(loweringPatterns);
287     MaskOp::getCanonicalizationPatterns(loweringPatterns, context);
288 
289     if (failed(applyPatternsGreedily(op, std::move(loweringPatterns))))
290       signalPassFailure();
291   }
292 
293   void getDependentDialects(DialectRegistry &registry) const override {
294     registry.insert<vector::VectorDialect>();
295   }
296 };
297 
298 } // namespace
299 
300 /// Populates instances of `MaskOpRewritePattern` to lower masked operations
301 /// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
302 /// not its nested `MaskableOpInterface`.
303 void vector::populateVectorMaskLoweringPatternsForSideEffectingOps(
304     RewritePatternSet &patterns) {
305   patterns.add<MaskedTransferReadOpPattern, MaskedTransferWriteOpPattern,
306                MaskedGatherOpPattern>(patterns.getContext());
307 }
308 
309 std::unique_ptr<Pass> mlir::vector::createLowerVectorMaskPass() {
310   return std::make_unique<LowerVectorMaskPass>();
311 }
312