1 //===- LowerVectorMask.cpp - Lower 'vector.mask' operation ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements target-independent rewrites and utilities to lower the 10 // 'vector.mask' operation. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Arith/IR/Arith.h" 15 #include "mlir/Dialect/Func/IR/FuncOps.h" 16 #include "mlir/Dialect/Vector/IR/VectorOps.h" 17 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 18 #include "mlir/Dialect/Vector/Transforms/Passes.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21 22 #define DEBUG_TYPE "lower-vector-mask" 23 24 namespace mlir { 25 namespace vector { 26 #define GEN_PASS_DEF_LOWERVECTORMASKPASS 27 #include "mlir/Dialect/Vector/Transforms/Passes.h.inc" 28 } // namespace vector 29 } // namespace mlir 30 31 using namespace mlir; 32 using namespace mlir::vector; 33 34 //===----------------------------------------------------------------------===// 35 // populateVectorMaskOpLoweringPatterns 36 //===----------------------------------------------------------------------===// 37 38 namespace { 39 /// Progressive lowering of CreateMaskOp. 40 /// One: 41 /// %x = vector.create_mask %a, ... : vector<dx...> 42 /// is replaced by: 43 /// %l = vector.create_mask ... : vector<...> ; one lower rank 44 /// %0 = arith.cmpi "slt", %ci, %a | 45 /// %1 = select %0, %l, %zeroes | 46 /// %r = vector.insert %1, %pr [i] | d-times 47 /// %x = .... 48 /// until a one-dimensional vector is reached. 49 class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> { 50 public: 51 using OpRewritePattern::OpRewritePattern; 52 53 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 54 PatternRewriter &rewriter) const override { 55 auto dstType = cast<VectorType>(op.getResult().getType()); 56 int64_t rank = dstType.getRank(); 57 if (rank <= 1) 58 return rewriter.notifyMatchFailure( 59 op, "0-D and 1-D vectors are handled separately"); 60 61 if (dstType.getScalableDims().front()) 62 return rewriter.notifyMatchFailure( 63 op, "Cannot unroll leading scalable dim in dstType"); 64 65 auto loc = op.getLoc(); 66 int64_t dim = dstType.getDimSize(0); 67 Value idx = op.getOperand(0); 68 69 VectorType lowType = VectorType::Builder(dstType).dropDim(0); 70 Value trueVal = rewriter.create<vector::CreateMaskOp>( 71 loc, lowType, op.getOperands().drop_front()); 72 Value falseVal = rewriter.create<arith::ConstantOp>( 73 loc, lowType, rewriter.getZeroAttr(lowType)); 74 Value result = rewriter.create<arith::ConstantOp>( 75 loc, dstType, rewriter.getZeroAttr(dstType)); 76 for (int64_t d = 0; d < dim; d++) { 77 Value bnd = 78 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d)); 79 Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, 80 bnd, idx); 81 Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal); 82 result = rewriter.create<vector::InsertOp>(loc, dstType, sel, result, d); 83 } 84 rewriter.replaceOp(op, result); 85 return success(); 86 } 87 }; 88 89 /// Progressive lowering of ConstantMaskOp. 90 /// One: 91 /// %x = vector.constant_mask [a,b] 92 /// is replaced by: 93 /// %z = zero-result 94 /// %l = vector.constant_mask [b] 95 /// %4 = vector.insert %l, %z[0] 96 /// .. 97 /// %x = vector.insert %l, %..[a-1] 98 /// until a one-dimensional vector is reached. All these operations 99 /// will be folded at LLVM IR level. 100 class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> { 101 public: 102 using OpRewritePattern::OpRewritePattern; 103 104 LogicalResult matchAndRewrite(vector::ConstantMaskOp op, 105 PatternRewriter &rewriter) const override { 106 auto loc = op.getLoc(); 107 auto dstType = op.getType(); 108 auto eltType = dstType.getElementType(); 109 auto dimSizes = op.getMaskDimSizes(); 110 int64_t rank = dstType.getRank(); 111 112 if (rank == 0) { 113 assert(dimSizes.size() == 1 && 114 "Expected exactly one dim size for a 0-D vector"); 115 bool value = cast<IntegerAttr>(dimSizes[0]).getInt() == 1; 116 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 117 op, dstType, 118 DenseIntElementsAttr::get( 119 VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()), 120 ArrayRef<bool>{value})); 121 return success(); 122 } 123 124 // Scalable constant masks can only be lowered for the "none set" case. 125 if (cast<VectorType>(dstType).isScalable()) { 126 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 127 op, DenseElementsAttr::get(dstType, false)); 128 return success(); 129 } 130 131 int64_t trueDim = std::min(dstType.getDimSize(0), 132 cast<IntegerAttr>(dimSizes[0]).getInt()); 133 134 if (rank == 1) { 135 // Express constant 1-D case in explicit vector form: 136 // [T,..,T,F,..,F]. 137 SmallVector<bool> values(dstType.getDimSize(0)); 138 for (int64_t d = 0; d < trueDim; d++) 139 values[d] = true; 140 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 141 op, dstType, rewriter.getBoolVectorAttr(values)); 142 return success(); 143 } 144 145 VectorType lowType = 146 VectorType::get(dstType.getShape().drop_front(), eltType); 147 SmallVector<int64_t> newDimSizes; 148 for (int64_t r = 1; r < rank; r++) 149 newDimSizes.push_back(cast<IntegerAttr>(dimSizes[r]).getInt()); 150 Value trueVal = rewriter.create<vector::ConstantMaskOp>( 151 loc, lowType, rewriter.getI64ArrayAttr(newDimSizes)); 152 Value result = rewriter.create<arith::ConstantOp>( 153 loc, dstType, rewriter.getZeroAttr(dstType)); 154 for (int64_t d = 0; d < trueDim; d++) 155 result = 156 rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, d); 157 rewriter.replaceOp(op, result); 158 return success(); 159 } 160 }; 161 } // namespace 162 163 void mlir::vector::populateVectorMaskOpLoweringPatterns( 164 RewritePatternSet &patterns, PatternBenefit benefit) { 165 patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>( 166 patterns.getContext(), benefit); 167 } 168 169 //===----------------------------------------------------------------------===// 170 // populateVectorMaskLoweringPatternsForSideEffectingOps 171 //===----------------------------------------------------------------------===// 172 173 namespace { 174 175 /// The `MaskOpRewritePattern` implements a pattern that follows a two-fold 176 /// matching: 177 /// 1. It matches a `vector.mask` operation. 178 /// 2. It invokes `matchAndRewriteMaskableOp` on `MaskableOpInterface` nested 179 /// in the matched `vector.mask` operation. 180 /// 181 /// It is required that the replacement op in the pattern replaces the 182 /// `vector.mask` operation and not the nested `MaskableOpInterface`. This 183 /// approach allows having patterns that "stop" at every `vector.mask` operation 184 /// and actually match the traits of its the nested `MaskableOpInterface`. 185 template <class SourceOp> 186 struct MaskOpRewritePattern : OpRewritePattern<MaskOp> { 187 using OpRewritePattern<MaskOp>::OpRewritePattern; 188 189 private: 190 LogicalResult matchAndRewrite(MaskOp maskOp, 191 PatternRewriter &rewriter) const final { 192 auto maskableOp = cast<MaskableOpInterface>(maskOp.getMaskableOp()); 193 SourceOp sourceOp = dyn_cast<SourceOp>(maskableOp.getOperation()); 194 if (!sourceOp) 195 return failure(); 196 197 return matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter); 198 } 199 200 protected: 201 virtual LogicalResult 202 matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp, 203 PatternRewriter &rewriter) const = 0; 204 }; 205 206 /// Lowers a masked `vector.transfer_read` operation. 207 struct MaskedTransferReadOpPattern 208 : public MaskOpRewritePattern<TransferReadOp> { 209 public: 210 using MaskOpRewritePattern<TransferReadOp>::MaskOpRewritePattern; 211 212 LogicalResult 213 matchAndRewriteMaskableOp(TransferReadOp readOp, MaskingOpInterface maskingOp, 214 PatternRewriter &rewriter) const override { 215 // TODO: The 'vector.mask' passthru is a vector and 'vector.transfer_read' 216 // expects a scalar. We could only lower one to the other for cases where 217 // the passthru is a broadcast of a scalar. 218 if (maskingOp.hasPassthru()) 219 return rewriter.notifyMatchFailure( 220 maskingOp, "Can't lower passthru to vector.transfer_read"); 221 222 // Replace the `vector.mask` operation. 223 rewriter.replaceOpWithNewOp<TransferReadOp>( 224 maskingOp.getOperation(), readOp.getVectorType(), readOp.getSource(), 225 readOp.getIndices(), readOp.getPermutationMap(), readOp.getPadding(), 226 maskingOp.getMask(), readOp.getInBounds().value_or(ArrayAttr())); 227 return success(); 228 } 229 }; 230 231 /// Lowers a masked `vector.transfer_write` operation. 232 struct MaskedTransferWriteOpPattern 233 : public MaskOpRewritePattern<TransferWriteOp> { 234 public: 235 using MaskOpRewritePattern<TransferWriteOp>::MaskOpRewritePattern; 236 237 LogicalResult 238 matchAndRewriteMaskableOp(TransferWriteOp writeOp, 239 MaskingOpInterface maskingOp, 240 PatternRewriter &rewriter) const override { 241 Type resultType = 242 writeOp.getResult() ? writeOp.getResult().getType() : Type(); 243 244 // Replace the `vector.mask` operation. 245 rewriter.replaceOpWithNewOp<TransferWriteOp>( 246 maskingOp.getOperation(), resultType, writeOp.getVector(), 247 writeOp.getSource(), writeOp.getIndices(), writeOp.getPermutationMap(), 248 maskingOp.getMask(), writeOp.getInBounds().value_or(ArrayAttr())); 249 return success(); 250 } 251 }; 252 253 /// Lowers a masked `vector.gather` operation. 254 struct MaskedGatherOpPattern : public MaskOpRewritePattern<GatherOp> { 255 public: 256 using MaskOpRewritePattern<GatherOp>::MaskOpRewritePattern; 257 258 LogicalResult 259 matchAndRewriteMaskableOp(GatherOp gatherOp, MaskingOpInterface maskingOp, 260 PatternRewriter &rewriter) const override { 261 Value passthru = maskingOp.hasPassthru() 262 ? maskingOp.getPassthru() 263 : rewriter.create<arith::ConstantOp>( 264 gatherOp.getLoc(), 265 rewriter.getZeroAttr(gatherOp.getVectorType())); 266 267 // Replace the `vector.mask` operation. 268 rewriter.replaceOpWithNewOp<GatherOp>( 269 maskingOp.getOperation(), gatherOp.getVectorType(), gatherOp.getBase(), 270 gatherOp.getIndices(), gatherOp.getIndexVec(), maskingOp.getMask(), 271 passthru); 272 return success(); 273 } 274 }; 275 276 struct LowerVectorMaskPass 277 : public vector::impl::LowerVectorMaskPassBase<LowerVectorMaskPass> { 278 using Base::Base; 279 280 void runOnOperation() override { 281 Operation *op = getOperation(); 282 MLIRContext *context = op->getContext(); 283 284 RewritePatternSet loweringPatterns(context); 285 populateVectorMaskLoweringPatternsForSideEffectingOps(loweringPatterns); 286 287 if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns)))) 288 signalPassFailure(); 289 } 290 291 void getDependentDialects(DialectRegistry ®istry) const override { 292 registry.insert<vector::VectorDialect>(); 293 } 294 }; 295 296 } // namespace 297 298 /// Populates instances of `MaskOpRewritePattern` to lower masked operations 299 /// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and 300 /// not its nested `MaskableOpInterface`. 301 void vector::populateVectorMaskLoweringPatternsForSideEffectingOps( 302 RewritePatternSet &patterns) { 303 patterns.add<MaskedTransferReadOpPattern, MaskedTransferWriteOpPattern, 304 MaskedGatherOpPattern>(patterns.getContext()); 305 } 306 307 std::unique_ptr<Pass> mlir::vector::createLowerVectorMaskPass() { 308 return std::make_unique<LowerVectorMaskPass>(); 309 } 310