xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp (revision ccef726d09b1ffadfae6b1d1d986ae2f6d25a6a6)
1 //===- LowerVectorMask.cpp - Lower 'vector.mask' operation ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.mask' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/Vector/IR/VectorOps.h"
17 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
18 #include "mlir/Dialect/Vector/Transforms/Passes.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 
22 #define DEBUG_TYPE "lower-vector-mask"
23 
24 namespace mlir {
25 namespace vector {
26 #define GEN_PASS_DEF_LOWERVECTORMASKPASS
27 #include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
28 } // namespace vector
29 } // namespace mlir
30 
31 using namespace mlir;
32 using namespace mlir::vector;
33 
34 //===----------------------------------------------------------------------===//
35 // populateVectorMaskOpLoweringPatterns
36 //===----------------------------------------------------------------------===//
37 
38 namespace {
39 /// Progressive lowering of CreateMaskOp.
40 /// One:
41 ///   %x = vector.create_mask %a, ... : vector<dx...>
42 /// is replaced by:
43 ///   %l = vector.create_mask ... : vector<...>  ; one lower rank
44 ///   %0 = arith.cmpi "slt", %ci, %a       |
45 ///   %1 = select %0, %l, %zeroes    |
46 ///   %r = vector.insert %1, %pr [i] | d-times
47 ///   %x = ....
48 /// until a one-dimensional vector is reached.
49 class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
50 public:
51   using OpRewritePattern::OpRewritePattern;
52 
53   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
54                                 PatternRewriter &rewriter) const override {
55     auto dstType = cast<VectorType>(op.getResult().getType());
56     int64_t rank = dstType.getRank();
57     if (rank <= 1)
58       return rewriter.notifyMatchFailure(
59           op, "0-D and 1-D vectors are handled separately");
60 
61     if (dstType.getScalableDims().front())
62       return rewriter.notifyMatchFailure(
63           op, "Cannot unroll leading scalable dim in dstType");
64 
65     auto loc = op.getLoc();
66     int64_t dim = dstType.getDimSize(0);
67     Value idx = op.getOperand(0);
68 
69     VectorType lowType = VectorType::Builder(dstType).dropDim(0);
70     Value trueVal = rewriter.create<vector::CreateMaskOp>(
71         loc, lowType, op.getOperands().drop_front());
72     Value falseVal = rewriter.create<arith::ConstantOp>(
73         loc, lowType, rewriter.getZeroAttr(lowType));
74     Value result = rewriter.create<arith::ConstantOp>(
75         loc, dstType, rewriter.getZeroAttr(dstType));
76     for (int64_t d = 0; d < dim; d++) {
77       Value bnd =
78           rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d));
79       Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
80                                                  bnd, idx);
81       Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal);
82       result = rewriter.create<vector::InsertOp>(loc, dstType, sel, result, d);
83     }
84     rewriter.replaceOp(op, result);
85     return success();
86   }
87 };
88 
89 /// Progressive lowering of ConstantMaskOp.
90 /// One:
91 ///   %x = vector.constant_mask [a,b]
92 /// is replaced by:
93 ///   %z = zero-result
94 ///   %l = vector.constant_mask [b]
95 ///   %4 = vector.insert %l, %z[0]
96 ///   ..
97 ///   %x = vector.insert %l, %..[a-1]
98 /// until a one-dimensional vector is reached. All these operations
99 /// will be folded at LLVM IR level.
100 class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
101 public:
102   using OpRewritePattern::OpRewritePattern;
103 
104   LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
105                                 PatternRewriter &rewriter) const override {
106     auto loc = op.getLoc();
107     auto dstType = op.getType();
108     auto eltType = dstType.getElementType();
109     auto dimSizes = op.getMaskDimSizes();
110     int64_t rank = dstType.getRank();
111 
112     if (rank == 0) {
113       assert(dimSizes.size() == 1 &&
114              "Expected exactly one dim size for a 0-D vector");
115       bool value = cast<IntegerAttr>(dimSizes[0]).getInt() == 1;
116       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
117           op, dstType,
118           DenseIntElementsAttr::get(
119               VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
120               ArrayRef<bool>{value}));
121       return success();
122     }
123 
124     // Scalable constant masks can only be lowered for the "none set" case.
125     if (cast<VectorType>(dstType).isScalable()) {
126       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
127           op, DenseElementsAttr::get(dstType, false));
128       return success();
129     }
130 
131     int64_t trueDim = std::min(dstType.getDimSize(0),
132                                cast<IntegerAttr>(dimSizes[0]).getInt());
133 
134     if (rank == 1) {
135       // Express constant 1-D case in explicit vector form:
136       //   [T,..,T,F,..,F].
137       SmallVector<bool> values(dstType.getDimSize(0));
138       for (int64_t d = 0; d < trueDim; d++)
139         values[d] = true;
140       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
141           op, dstType, rewriter.getBoolVectorAttr(values));
142       return success();
143     }
144 
145     VectorType lowType =
146         VectorType::get(dstType.getShape().drop_front(), eltType);
147     SmallVector<int64_t> newDimSizes;
148     for (int64_t r = 1; r < rank; r++)
149       newDimSizes.push_back(cast<IntegerAttr>(dimSizes[r]).getInt());
150     Value trueVal = rewriter.create<vector::ConstantMaskOp>(
151         loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
152     Value result = rewriter.create<arith::ConstantOp>(
153         loc, dstType, rewriter.getZeroAttr(dstType));
154     for (int64_t d = 0; d < trueDim; d++)
155       result =
156           rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, d);
157     rewriter.replaceOp(op, result);
158     return success();
159   }
160 };
161 } // namespace
162 
163 void mlir::vector::populateVectorMaskOpLoweringPatterns(
164     RewritePatternSet &patterns, PatternBenefit benefit) {
165   patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
166       patterns.getContext(), benefit);
167 }
168 
169 //===----------------------------------------------------------------------===//
170 // populateVectorMaskLoweringPatternsForSideEffectingOps
171 //===----------------------------------------------------------------------===//
172 
173 namespace {
174 
175 /// The `MaskOpRewritePattern` implements a pattern that follows a two-fold
176 /// matching:
177 ///   1. It matches a `vector.mask` operation.
178 ///   2. It invokes `matchAndRewriteMaskableOp` on `MaskableOpInterface` nested
179 ///      in the matched `vector.mask` operation.
180 ///
181 /// It is required that the replacement op in the pattern replaces the
182 /// `vector.mask` operation and not the nested `MaskableOpInterface`. This
183 /// approach allows having patterns that "stop" at every `vector.mask` operation
184 /// and actually match the traits of its the nested `MaskableOpInterface`.
185 template <class SourceOp>
186 struct MaskOpRewritePattern : OpRewritePattern<MaskOp> {
187   using OpRewritePattern<MaskOp>::OpRewritePattern;
188 
189 private:
190   LogicalResult matchAndRewrite(MaskOp maskOp,
191                                 PatternRewriter &rewriter) const final {
192     auto maskableOp = cast<MaskableOpInterface>(maskOp.getMaskableOp());
193     SourceOp sourceOp = dyn_cast<SourceOp>(maskableOp.getOperation());
194     if (!sourceOp)
195       return failure();
196 
197     return matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
198   }
199 
200 protected:
201   virtual LogicalResult
202   matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp,
203                             PatternRewriter &rewriter) const = 0;
204 };
205 
206 /// Lowers a masked `vector.transfer_read` operation.
207 struct MaskedTransferReadOpPattern
208     : public MaskOpRewritePattern<TransferReadOp> {
209 public:
210   using MaskOpRewritePattern<TransferReadOp>::MaskOpRewritePattern;
211 
212   LogicalResult
213   matchAndRewriteMaskableOp(TransferReadOp readOp, MaskingOpInterface maskingOp,
214                             PatternRewriter &rewriter) const override {
215     // TODO: The 'vector.mask' passthru is a vector and 'vector.transfer_read'
216     // expects a scalar. We could only lower one to the other for cases where
217     // the passthru is a broadcast of a scalar.
218     if (maskingOp.hasPassthru())
219       return rewriter.notifyMatchFailure(
220           maskingOp, "Can't lower passthru to vector.transfer_read");
221 
222     // Replace the `vector.mask` operation.
223     rewriter.replaceOpWithNewOp<TransferReadOp>(
224         maskingOp.getOperation(), readOp.getVectorType(), readOp.getSource(),
225         readOp.getIndices(), readOp.getPermutationMap(), readOp.getPadding(),
226         maskingOp.getMask(), readOp.getInBounds().value_or(ArrayAttr()));
227     return success();
228   }
229 };
230 
231 /// Lowers a masked `vector.transfer_write` operation.
232 struct MaskedTransferWriteOpPattern
233     : public MaskOpRewritePattern<TransferWriteOp> {
234 public:
235   using MaskOpRewritePattern<TransferWriteOp>::MaskOpRewritePattern;
236 
237   LogicalResult
238   matchAndRewriteMaskableOp(TransferWriteOp writeOp,
239                             MaskingOpInterface maskingOp,
240                             PatternRewriter &rewriter) const override {
241     Type resultType =
242         writeOp.getResult() ? writeOp.getResult().getType() : Type();
243 
244     // Replace the `vector.mask` operation.
245     rewriter.replaceOpWithNewOp<TransferWriteOp>(
246         maskingOp.getOperation(), resultType, writeOp.getVector(),
247         writeOp.getSource(), writeOp.getIndices(), writeOp.getPermutationMap(),
248         maskingOp.getMask(), writeOp.getInBounds().value_or(ArrayAttr()));
249     return success();
250   }
251 };
252 
253 /// Lowers a masked `vector.gather` operation.
254 struct MaskedGatherOpPattern : public MaskOpRewritePattern<GatherOp> {
255 public:
256   using MaskOpRewritePattern<GatherOp>::MaskOpRewritePattern;
257 
258   LogicalResult
259   matchAndRewriteMaskableOp(GatherOp gatherOp, MaskingOpInterface maskingOp,
260                             PatternRewriter &rewriter) const override {
261     Value passthru = maskingOp.hasPassthru()
262                          ? maskingOp.getPassthru()
263                          : rewriter.create<arith::ConstantOp>(
264                                gatherOp.getLoc(),
265                                rewriter.getZeroAttr(gatherOp.getVectorType()));
266 
267     // Replace the `vector.mask` operation.
268     rewriter.replaceOpWithNewOp<GatherOp>(
269         maskingOp.getOperation(), gatherOp.getVectorType(), gatherOp.getBase(),
270         gatherOp.getIndices(), gatherOp.getIndexVec(), maskingOp.getMask(),
271         passthru);
272     return success();
273   }
274 };
275 
276 struct LowerVectorMaskPass
277     : public vector::impl::LowerVectorMaskPassBase<LowerVectorMaskPass> {
278   using Base::Base;
279 
280   void runOnOperation() override {
281     Operation *op = getOperation();
282     MLIRContext *context = op->getContext();
283 
284     RewritePatternSet loweringPatterns(context);
285     populateVectorMaskLoweringPatternsForSideEffectingOps(loweringPatterns);
286 
287     if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns))))
288       signalPassFailure();
289   }
290 
291   void getDependentDialects(DialectRegistry &registry) const override {
292     registry.insert<vector::VectorDialect>();
293   }
294 };
295 
296 } // namespace
297 
298 /// Populates instances of `MaskOpRewritePattern` to lower masked operations
299 /// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
300 /// not its nested `MaskableOpInterface`.
301 void vector::populateVectorMaskLoweringPatternsForSideEffectingOps(
302     RewritePatternSet &patterns) {
303   patterns.add<MaskedTransferReadOpPattern, MaskedTransferWriteOpPattern,
304                MaskedGatherOpPattern>(patterns.getContext());
305 }
306 
307 std::unique_ptr<Pass> mlir::vector::createLowerVectorMaskPass() {
308   return std::make_unique<LowerVectorMaskPass>();
309 }
310