1 //===- LowerVectorMask.cpp - Lower 'vector.mask' operation ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements target-independent rewrites and utilities to lower the 10 // 'vector.mask' operation. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Arith/IR/Arith.h" 15 #include "mlir/Dialect/Func/IR/FuncOps.h" 16 #include "mlir/Dialect/Vector/IR/VectorOps.h" 17 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 18 #include "mlir/Dialect/Vector/Transforms/Passes.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21 22 #define DEBUG_TYPE "lower-vector-mask" 23 24 namespace mlir { 25 namespace vector { 26 #define GEN_PASS_DEF_LOWERVECTORMASKPASS 27 #include "mlir/Dialect/Vector/Transforms/Passes.h.inc" 28 } // namespace vector 29 } // namespace mlir 30 31 using namespace mlir; 32 using namespace mlir::vector; 33 34 //===----------------------------------------------------------------------===// 35 // populateVectorMaskOpLoweringPatterns 36 //===----------------------------------------------------------------------===// 37 38 namespace { 39 /// Progressive lowering of CreateMaskOp. 40 /// One: 41 /// %x = vector.create_mask %a, ... : vector<dx...> 42 /// is replaced by: 43 /// %l = vector.create_mask ... : vector<...> ; one lower rank 44 /// %0 = arith.cmpi "slt", %ci, %a | 45 /// %1 = select %0, %l, %zeroes | 46 /// %r = vector.insert %1, %pr [i] | d-times 47 /// %x = .... 48 /// until a one-dimensional vector is reached. 49 class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> { 50 public: 51 using OpRewritePattern::OpRewritePattern; 52 53 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 54 PatternRewriter &rewriter) const override { 55 auto dstType = cast<VectorType>(op.getResult().getType()); 56 int64_t rank = dstType.getRank(); 57 if (rank <= 1) 58 return rewriter.notifyMatchFailure( 59 op, "0-D and 1-D vectors are handled separately"); 60 61 if (dstType.getScalableDims().front()) 62 return rewriter.notifyMatchFailure( 63 op, "Cannot unroll leading scalable dim in dstType"); 64 65 auto loc = op.getLoc(); 66 int64_t dim = dstType.getDimSize(0); 67 Value idx = op.getOperand(0); 68 69 VectorType lowType = VectorType::Builder(dstType).dropDim(0); 70 Value trueVal = rewriter.create<vector::CreateMaskOp>( 71 loc, lowType, op.getOperands().drop_front()); 72 Value falseVal = rewriter.create<arith::ConstantOp>( 73 loc, lowType, rewriter.getZeroAttr(lowType)); 74 Value result = rewriter.create<arith::ConstantOp>( 75 loc, dstType, rewriter.getZeroAttr(dstType)); 76 for (int64_t d = 0; d < dim; d++) { 77 Value bnd = 78 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d)); 79 Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, 80 bnd, idx); 81 Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal); 82 result = rewriter.create<vector::InsertOp>(loc, sel, result, d); 83 } 84 rewriter.replaceOp(op, result); 85 return success(); 86 } 87 }; 88 89 /// Progressive lowering of ConstantMaskOp. 90 /// One: 91 /// %x = vector.constant_mask [a,b] 92 /// is replaced by: 93 /// %z = zero-result 94 /// %l = vector.constant_mask [b] 95 /// %4 = vector.insert %l, %z[0] 96 /// .. 97 /// %x = vector.insert %l, %..[a-1] 98 /// until a one-dimensional vector is reached. All these operations 99 /// will be folded at LLVM IR level. 100 class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> { 101 public: 102 using OpRewritePattern::OpRewritePattern; 103 104 LogicalResult matchAndRewrite(vector::ConstantMaskOp op, 105 PatternRewriter &rewriter) const override { 106 auto loc = op.getLoc(); 107 auto dstType = op.getType(); 108 auto dimSizes = op.getMaskDimSizes(); 109 int64_t rank = dstType.getRank(); 110 111 if (rank == 0) { 112 assert(dimSizes.size() == 1 && 113 "Expected exactly one dim size for a 0-D vector"); 114 bool value = dimSizes.front() == 1; 115 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 116 op, dstType, 117 DenseIntElementsAttr::get(VectorType::get({}, rewriter.getI1Type()), 118 value)); 119 return success(); 120 } 121 122 int64_t trueDimSize = dimSizes.front(); 123 124 if (rank == 1) { 125 if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) { 126 // Use constant splat for 'all set' or 'none set' dims. 127 // This produces correct code for scalable dimensions (it will lower to 128 // a constant splat). 129 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 130 op, DenseElementsAttr::get(dstType, trueDimSize != 0)); 131 } else { 132 // Express constant 1-D case in explicit vector form: 133 // [T,..,T,F,..,F]. 134 // Note: The verifier would reject this case for scalable vectors. 135 SmallVector<bool> values(dstType.getDimSize(0), false); 136 for (int64_t d = 0; d < trueDimSize; d++) 137 values[d] = true; 138 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 139 op, dstType, rewriter.getBoolVectorAttr(values)); 140 } 141 return success(); 142 } 143 144 if (dstType.getScalableDims().front()) 145 return rewriter.notifyMatchFailure( 146 op, "Cannot unroll leading scalable dim in dstType"); 147 148 VectorType lowType = VectorType::Builder(dstType).dropDim(0); 149 Value trueVal = rewriter.create<vector::ConstantMaskOp>( 150 loc, lowType, dimSizes.drop_front()); 151 Value result = rewriter.create<arith::ConstantOp>( 152 loc, dstType, rewriter.getZeroAttr(dstType)); 153 for (int64_t d = 0; d < trueDimSize; d++) 154 result = rewriter.create<vector::InsertOp>(loc, trueVal, result, d); 155 156 rewriter.replaceOp(op, result); 157 return success(); 158 } 159 }; 160 } // namespace 161 162 void mlir::vector::populateVectorMaskOpLoweringPatterns( 163 RewritePatternSet &patterns, PatternBenefit benefit) { 164 patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>( 165 patterns.getContext(), benefit); 166 } 167 168 //===----------------------------------------------------------------------===// 169 // populateVectorMaskLoweringPatternsForSideEffectingOps 170 //===----------------------------------------------------------------------===// 171 172 namespace { 173 174 /// The `MaskOpRewritePattern` implements a pattern that follows a two-fold 175 /// matching: 176 /// 1. It matches a `vector.mask` operation. 177 /// 2. It invokes `matchAndRewriteMaskableOp` on `MaskableOpInterface` nested 178 /// in the matched `vector.mask` operation. 179 /// 180 /// It is required that the replacement op in the pattern replaces the 181 /// `vector.mask` operation and not the nested `MaskableOpInterface`. This 182 /// approach allows having patterns that "stop" at every `vector.mask` operation 183 /// and actually match the traits of its the nested `MaskableOpInterface`. 184 template <class SourceOp> 185 struct MaskOpRewritePattern : OpRewritePattern<MaskOp> { 186 using OpRewritePattern<MaskOp>::OpRewritePattern; 187 188 private: 189 LogicalResult matchAndRewrite(MaskOp maskOp, 190 PatternRewriter &rewriter) const final { 191 auto maskableOp = cast_or_null<MaskableOpInterface>(maskOp.getMaskableOp()); 192 if (!maskableOp) 193 return failure(); 194 SourceOp sourceOp = dyn_cast<SourceOp>(maskableOp.getOperation()); 195 if (!sourceOp) 196 return failure(); 197 198 return matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter); 199 } 200 201 protected: 202 virtual LogicalResult 203 matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp, 204 PatternRewriter &rewriter) const = 0; 205 }; 206 207 /// Lowers a masked `vector.transfer_read` operation. 208 struct MaskedTransferReadOpPattern 209 : public MaskOpRewritePattern<TransferReadOp> { 210 public: 211 using MaskOpRewritePattern<TransferReadOp>::MaskOpRewritePattern; 212 213 LogicalResult 214 matchAndRewriteMaskableOp(TransferReadOp readOp, MaskingOpInterface maskingOp, 215 PatternRewriter &rewriter) const override { 216 // TODO: The 'vector.mask' passthru is a vector and 'vector.transfer_read' 217 // expects a scalar. We could only lower one to the other for cases where 218 // the passthru is a broadcast of a scalar. 219 if (maskingOp.hasPassthru()) 220 return rewriter.notifyMatchFailure( 221 maskingOp, "Can't lower passthru to vector.transfer_read"); 222 223 // Replace the `vector.mask` operation. 224 rewriter.replaceOpWithNewOp<TransferReadOp>( 225 maskingOp.getOperation(), readOp.getVectorType(), readOp.getSource(), 226 readOp.getIndices(), readOp.getPermutationMap(), readOp.getPadding(), 227 maskingOp.getMask(), readOp.getInBounds()); 228 return success(); 229 } 230 }; 231 232 /// Lowers a masked `vector.transfer_write` operation. 233 struct MaskedTransferWriteOpPattern 234 : public MaskOpRewritePattern<TransferWriteOp> { 235 public: 236 using MaskOpRewritePattern<TransferWriteOp>::MaskOpRewritePattern; 237 238 LogicalResult 239 matchAndRewriteMaskableOp(TransferWriteOp writeOp, 240 MaskingOpInterface maskingOp, 241 PatternRewriter &rewriter) const override { 242 Type resultType = 243 writeOp.getResult() ? writeOp.getResult().getType() : Type(); 244 245 // Replace the `vector.mask` operation. 246 rewriter.replaceOpWithNewOp<TransferWriteOp>( 247 maskingOp.getOperation(), resultType, writeOp.getVector(), 248 writeOp.getSource(), writeOp.getIndices(), writeOp.getPermutationMap(), 249 maskingOp.getMask(), writeOp.getInBounds()); 250 return success(); 251 } 252 }; 253 254 /// Lowers a masked `vector.gather` operation. 255 struct MaskedGatherOpPattern : public MaskOpRewritePattern<GatherOp> { 256 public: 257 using MaskOpRewritePattern<GatherOp>::MaskOpRewritePattern; 258 259 LogicalResult 260 matchAndRewriteMaskableOp(GatherOp gatherOp, MaskingOpInterface maskingOp, 261 PatternRewriter &rewriter) const override { 262 Value passthru = maskingOp.hasPassthru() 263 ? maskingOp.getPassthru() 264 : rewriter.create<arith::ConstantOp>( 265 gatherOp.getLoc(), 266 rewriter.getZeroAttr(gatherOp.getVectorType())); 267 268 // Replace the `vector.mask` operation. 269 rewriter.replaceOpWithNewOp<GatherOp>( 270 maskingOp.getOperation(), gatherOp.getVectorType(), gatherOp.getBase(), 271 gatherOp.getIndices(), gatherOp.getIndexVec(), maskingOp.getMask(), 272 passthru); 273 return success(); 274 } 275 }; 276 277 struct LowerVectorMaskPass 278 : public vector::impl::LowerVectorMaskPassBase<LowerVectorMaskPass> { 279 using Base::Base; 280 281 void runOnOperation() override { 282 Operation *op = getOperation(); 283 MLIRContext *context = op->getContext(); 284 285 RewritePatternSet loweringPatterns(context); 286 populateVectorMaskLoweringPatternsForSideEffectingOps(loweringPatterns); 287 MaskOp::getCanonicalizationPatterns(loweringPatterns, context); 288 289 if (failed(applyPatternsGreedily(op, std::move(loweringPatterns)))) 290 signalPassFailure(); 291 } 292 293 void getDependentDialects(DialectRegistry ®istry) const override { 294 registry.insert<vector::VectorDialect>(); 295 } 296 }; 297 298 } // namespace 299 300 /// Populates instances of `MaskOpRewritePattern` to lower masked operations 301 /// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and 302 /// not its nested `MaskableOpInterface`. 303 void vector::populateVectorMaskLoweringPatternsForSideEffectingOps( 304 RewritePatternSet &patterns) { 305 patterns.add<MaskedTransferReadOpPattern, MaskedTransferWriteOpPattern, 306 MaskedGatherOpPattern>(patterns.getContext()); 307 } 308 309 std::unique_ptr<Pass> mlir::vector::createLowerVectorMaskPass() { 310 return std::make_unique<LowerVectorMaskPass>(); 311 } 312