1c3e09036SDiego Caballero //===- LowerVectorMask.cpp - Lower 'vector.mask' operation ----------------===// 2c3e09036SDiego Caballero // 3c3e09036SDiego Caballero // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4c3e09036SDiego Caballero // See https://llvm.org/LICENSE.txt for license information. 5c3e09036SDiego Caballero // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6c3e09036SDiego Caballero // 7c3e09036SDiego Caballero //===----------------------------------------------------------------------===// 8c3e09036SDiego Caballero // 92bc4c3e9SNicolas Vasilache // This file implements target-independent rewrites and utilities to lower the 10c3e09036SDiego Caballero // 'vector.mask' operation. 11c3e09036SDiego Caballero // 12c3e09036SDiego Caballero //===----------------------------------------------------------------------===// 13c3e09036SDiego Caballero 141ac874c9SDiego Caballero #include "mlir/Dialect/Arith/IR/Arith.h" 15c3e09036SDiego Caballero #include "mlir/Dialect/Func/IR/FuncOps.h" 16c3e09036SDiego Caballero #include "mlir/Dialect/Vector/IR/VectorOps.h" 172bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 18c3e09036SDiego Caballero #include "mlir/Dialect/Vector/Transforms/Passes.h" 19c3e09036SDiego Caballero #include "mlir/IR/PatternMatch.h" 20c3e09036SDiego Caballero #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21c3e09036SDiego Caballero 22c3e09036SDiego Caballero #define DEBUG_TYPE "lower-vector-mask" 23c3e09036SDiego Caballero 24c3e09036SDiego Caballero namespace mlir { 25c3e09036SDiego Caballero namespace vector { 26c3e09036SDiego Caballero #define GEN_PASS_DEF_LOWERVECTORMASKPASS 27c3e09036SDiego Caballero #include "mlir/Dialect/Vector/Transforms/Passes.h.inc" 28c3e09036SDiego Caballero } // namespace vector 29c3e09036SDiego Caballero } // namespace mlir 30c3e09036SDiego Caballero 31c3e09036SDiego Caballero using namespace mlir; 32c3e09036SDiego Caballero using namespace mlir::vector; 33c3e09036SDiego Caballero 342bc4c3e9SNicolas Vasilache //===----------------------------------------------------------------------===// 352bc4c3e9SNicolas Vasilache // populateVectorMaskOpLoweringPatterns 362bc4c3e9SNicolas Vasilache //===----------------------------------------------------------------------===// 372bc4c3e9SNicolas Vasilache 382bc4c3e9SNicolas Vasilache namespace { 392bc4c3e9SNicolas Vasilache /// Progressive lowering of CreateMaskOp. 402bc4c3e9SNicolas Vasilache /// One: 412bc4c3e9SNicolas Vasilache /// %x = vector.create_mask %a, ... : vector<dx...> 422bc4c3e9SNicolas Vasilache /// is replaced by: 432bc4c3e9SNicolas Vasilache /// %l = vector.create_mask ... : vector<...> ; one lower rank 442bc4c3e9SNicolas Vasilache /// %0 = arith.cmpi "slt", %ci, %a | 452bc4c3e9SNicolas Vasilache /// %1 = select %0, %l, %zeroes | 462bc4c3e9SNicolas Vasilache /// %r = vector.insert %1, %pr [i] | d-times 472bc4c3e9SNicolas Vasilache /// %x = .... 482bc4c3e9SNicolas Vasilache /// until a one-dimensional vector is reached. 492bc4c3e9SNicolas Vasilache class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> { 502bc4c3e9SNicolas Vasilache public: 512bc4c3e9SNicolas Vasilache using OpRewritePattern::OpRewritePattern; 522bc4c3e9SNicolas Vasilache 532bc4c3e9SNicolas Vasilache LogicalResult matchAndRewrite(vector::CreateMaskOp op, 542bc4c3e9SNicolas Vasilache PatternRewriter &rewriter) const override { 555550c821STres Popp auto dstType = cast<VectorType>(op.getResult().getType()); 562bc4c3e9SNicolas Vasilache int64_t rank = dstType.getRank(); 572bc4c3e9SNicolas Vasilache if (rank <= 1) 582bc4c3e9SNicolas Vasilache return rewriter.notifyMatchFailure( 592bc4c3e9SNicolas Vasilache op, "0-D and 1-D vectors are handled separately"); 602bc4c3e9SNicolas Vasilache 61ccef726dSBenjamin Maxwell if (dstType.getScalableDims().front()) 62ccef726dSBenjamin Maxwell return rewriter.notifyMatchFailure( 63ccef726dSBenjamin Maxwell op, "Cannot unroll leading scalable dim in dstType"); 64ccef726dSBenjamin Maxwell 652bc4c3e9SNicolas Vasilache auto loc = op.getLoc(); 662bc4c3e9SNicolas Vasilache int64_t dim = dstType.getDimSize(0); 672bc4c3e9SNicolas Vasilache Value idx = op.getOperand(0); 682bc4c3e9SNicolas Vasilache 69ccef726dSBenjamin Maxwell VectorType lowType = VectorType::Builder(dstType).dropDim(0); 702bc4c3e9SNicolas Vasilache Value trueVal = rewriter.create<vector::CreateMaskOp>( 712bc4c3e9SNicolas Vasilache loc, lowType, op.getOperands().drop_front()); 722bc4c3e9SNicolas Vasilache Value falseVal = rewriter.create<arith::ConstantOp>( 732bc4c3e9SNicolas Vasilache loc, lowType, rewriter.getZeroAttr(lowType)); 742bc4c3e9SNicolas Vasilache Value result = rewriter.create<arith::ConstantOp>( 752bc4c3e9SNicolas Vasilache loc, dstType, rewriter.getZeroAttr(dstType)); 762bc4c3e9SNicolas Vasilache for (int64_t d = 0; d < dim; d++) { 772bc4c3e9SNicolas Vasilache Value bnd = 782bc4c3e9SNicolas Vasilache rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d)); 792bc4c3e9SNicolas Vasilache Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, 802bc4c3e9SNicolas Vasilache bnd, idx); 812bc4c3e9SNicolas Vasilache Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal); 8298f6289aSDiego Caballero result = rewriter.create<vector::InsertOp>(loc, sel, result, d); 832bc4c3e9SNicolas Vasilache } 842bc4c3e9SNicolas Vasilache rewriter.replaceOp(op, result); 852bc4c3e9SNicolas Vasilache return success(); 862bc4c3e9SNicolas Vasilache } 872bc4c3e9SNicolas Vasilache }; 882bc4c3e9SNicolas Vasilache 892bc4c3e9SNicolas Vasilache /// Progressive lowering of ConstantMaskOp. 902bc4c3e9SNicolas Vasilache /// One: 912bc4c3e9SNicolas Vasilache /// %x = vector.constant_mask [a,b] 922bc4c3e9SNicolas Vasilache /// is replaced by: 932bc4c3e9SNicolas Vasilache /// %z = zero-result 942bc4c3e9SNicolas Vasilache /// %l = vector.constant_mask [b] 952bc4c3e9SNicolas Vasilache /// %4 = vector.insert %l, %z[0] 962bc4c3e9SNicolas Vasilache /// .. 972bc4c3e9SNicolas Vasilache /// %x = vector.insert %l, %..[a-1] 982bc4c3e9SNicolas Vasilache /// until a one-dimensional vector is reached. All these operations 992bc4c3e9SNicolas Vasilache /// will be folded at LLVM IR level. 1002bc4c3e9SNicolas Vasilache class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> { 1012bc4c3e9SNicolas Vasilache public: 1022bc4c3e9SNicolas Vasilache using OpRewritePattern::OpRewritePattern; 1032bc4c3e9SNicolas Vasilache 1042bc4c3e9SNicolas Vasilache LogicalResult matchAndRewrite(vector::ConstantMaskOp op, 1052bc4c3e9SNicolas Vasilache PatternRewriter &rewriter) const override { 1062bc4c3e9SNicolas Vasilache auto loc = op.getLoc(); 1072bc4c3e9SNicolas Vasilache auto dstType = op.getType(); 1082bc4c3e9SNicolas Vasilache auto dimSizes = op.getMaskDimSizes(); 1092bc4c3e9SNicolas Vasilache int64_t rank = dstType.getRank(); 1102bc4c3e9SNicolas Vasilache 1112bc4c3e9SNicolas Vasilache if (rank == 0) { 1122bc4c3e9SNicolas Vasilache assert(dimSizes.size() == 1 && 1132bc4c3e9SNicolas Vasilache "Expected exactly one dim size for a 0-D vector"); 1140d9b4394SBenjamin Maxwell bool value = dimSizes.front() == 1; 1152bc4c3e9SNicolas Vasilache rewriter.replaceOpWithNewOp<arith::ConstantOp>( 1162bc4c3e9SNicolas Vasilache op, dstType, 1172f11ce55SBenjamin Maxwell DenseIntElementsAttr::get(VectorType::get({}, rewriter.getI1Type()), 1182f11ce55SBenjamin Maxwell value)); 1192bc4c3e9SNicolas Vasilache return success(); 1202bc4c3e9SNicolas Vasilache } 1212bc4c3e9SNicolas Vasilache 1220d9b4394SBenjamin Maxwell int64_t trueDimSize = dimSizes.front(); 1232bc4c3e9SNicolas Vasilache 1242bc4c3e9SNicolas Vasilache if (rank == 1) { 1252f11ce55SBenjamin Maxwell if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) { 1262f11ce55SBenjamin Maxwell // Use constant splat for 'all set' or 'none set' dims. 1272f11ce55SBenjamin Maxwell // This produces correct code for scalable dimensions (it will lower to 1282f11ce55SBenjamin Maxwell // a constant splat). 1292f11ce55SBenjamin Maxwell rewriter.replaceOpWithNewOp<arith::ConstantOp>( 1302f11ce55SBenjamin Maxwell op, DenseElementsAttr::get(dstType, trueDimSize != 0)); 1312f11ce55SBenjamin Maxwell } else { 1322bc4c3e9SNicolas Vasilache // Express constant 1-D case in explicit vector form: 1332bc4c3e9SNicolas Vasilache // [T,..,T,F,..,F]. 1342f11ce55SBenjamin Maxwell // Note: The verifier would reject this case for scalable vectors. 1352f11ce55SBenjamin Maxwell SmallVector<bool> values(dstType.getDimSize(0), false); 1362f11ce55SBenjamin Maxwell for (int64_t d = 0; d < trueDimSize; d++) 1372bc4c3e9SNicolas Vasilache values[d] = true; 1382bc4c3e9SNicolas Vasilache rewriter.replaceOpWithNewOp<arith::ConstantOp>( 1392bc4c3e9SNicolas Vasilache op, dstType, rewriter.getBoolVectorAttr(values)); 1402f11ce55SBenjamin Maxwell } 1412bc4c3e9SNicolas Vasilache return success(); 1422bc4c3e9SNicolas Vasilache } 1432bc4c3e9SNicolas Vasilache 1442f11ce55SBenjamin Maxwell if (dstType.getScalableDims().front()) 1452f11ce55SBenjamin Maxwell return rewriter.notifyMatchFailure( 1462f11ce55SBenjamin Maxwell op, "Cannot unroll leading scalable dim in dstType"); 1472f11ce55SBenjamin Maxwell 1482f11ce55SBenjamin Maxwell VectorType lowType = VectorType::Builder(dstType).dropDim(0); 1492bc4c3e9SNicolas Vasilache Value trueVal = rewriter.create<vector::ConstantMaskOp>( 1500d9b4394SBenjamin Maxwell loc, lowType, dimSizes.drop_front()); 1512bc4c3e9SNicolas Vasilache Value result = rewriter.create<arith::ConstantOp>( 1522bc4c3e9SNicolas Vasilache loc, dstType, rewriter.getZeroAttr(dstType)); 1532f11ce55SBenjamin Maxwell for (int64_t d = 0; d < trueDimSize; d++) 15498f6289aSDiego Caballero result = rewriter.create<vector::InsertOp>(loc, trueVal, result, d); 15598f6289aSDiego Caballero 1562bc4c3e9SNicolas Vasilache rewriter.replaceOp(op, result); 1572bc4c3e9SNicolas Vasilache return success(); 1582bc4c3e9SNicolas Vasilache } 1592bc4c3e9SNicolas Vasilache }; 1602bc4c3e9SNicolas Vasilache } // namespace 1612bc4c3e9SNicolas Vasilache 1622bc4c3e9SNicolas Vasilache void mlir::vector::populateVectorMaskOpLoweringPatterns( 1632bc4c3e9SNicolas Vasilache RewritePatternSet &patterns, PatternBenefit benefit) { 1642bc4c3e9SNicolas Vasilache patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>( 1652bc4c3e9SNicolas Vasilache patterns.getContext(), benefit); 1662bc4c3e9SNicolas Vasilache } 1672bc4c3e9SNicolas Vasilache 1682bc4c3e9SNicolas Vasilache //===----------------------------------------------------------------------===// 1692bc4c3e9SNicolas Vasilache // populateVectorMaskLoweringPatternsForSideEffectingOps 1702bc4c3e9SNicolas Vasilache //===----------------------------------------------------------------------===// 1712bc4c3e9SNicolas Vasilache 172c3e09036SDiego Caballero namespace { 173c3e09036SDiego Caballero 174c3e09036SDiego Caballero /// The `MaskOpRewritePattern` implements a pattern that follows a two-fold 175c3e09036SDiego Caballero /// matching: 176c3e09036SDiego Caballero /// 1. It matches a `vector.mask` operation. 177c3e09036SDiego Caballero /// 2. It invokes `matchAndRewriteMaskableOp` on `MaskableOpInterface` nested 178c3e09036SDiego Caballero /// in the matched `vector.mask` operation. 179c3e09036SDiego Caballero /// 180c3e09036SDiego Caballero /// It is required that the replacement op in the pattern replaces the 181c3e09036SDiego Caballero /// `vector.mask` operation and not the nested `MaskableOpInterface`. This 182c3e09036SDiego Caballero /// approach allows having patterns that "stop" at every `vector.mask` operation 183c3e09036SDiego Caballero /// and actually match the traits of its the nested `MaskableOpInterface`. 184c3e09036SDiego Caballero template <class SourceOp> 185c3e09036SDiego Caballero struct MaskOpRewritePattern : OpRewritePattern<MaskOp> { 186c3e09036SDiego Caballero using OpRewritePattern<MaskOp>::OpRewritePattern; 187c3e09036SDiego Caballero 188c3e09036SDiego Caballero private: 189abe2738bSAdrian Kuegel LogicalResult matchAndRewrite(MaskOp maskOp, 190abe2738bSAdrian Kuegel PatternRewriter &rewriter) const final { 191d5a0fb39SFelix Schneider auto maskableOp = cast_or_null<MaskableOpInterface>(maskOp.getMaskableOp()); 192d5a0fb39SFelix Schneider if (!maskableOp) 193d5a0fb39SFelix Schneider return failure(); 194c3e09036SDiego Caballero SourceOp sourceOp = dyn_cast<SourceOp>(maskableOp.getOperation()); 195c3e09036SDiego Caballero if (!sourceOp) 196c3e09036SDiego Caballero return failure(); 197c3e09036SDiego Caballero 198c3e09036SDiego Caballero return matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter); 199c3e09036SDiego Caballero } 200c3e09036SDiego Caballero 201c3e09036SDiego Caballero protected: 202c3e09036SDiego Caballero virtual LogicalResult 203c3e09036SDiego Caballero matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp, 204c3e09036SDiego Caballero PatternRewriter &rewriter) const = 0; 205c3e09036SDiego Caballero }; 206c3e09036SDiego Caballero 207c3e09036SDiego Caballero /// Lowers a masked `vector.transfer_read` operation. 208c3e09036SDiego Caballero struct MaskedTransferReadOpPattern 209c3e09036SDiego Caballero : public MaskOpRewritePattern<TransferReadOp> { 210c3e09036SDiego Caballero public: 211c3e09036SDiego Caballero using MaskOpRewritePattern<TransferReadOp>::MaskOpRewritePattern; 212c3e09036SDiego Caballero 213c3e09036SDiego Caballero LogicalResult 214c3e09036SDiego Caballero matchAndRewriteMaskableOp(TransferReadOp readOp, MaskingOpInterface maskingOp, 215c3e09036SDiego Caballero PatternRewriter &rewriter) const override { 216c3e09036SDiego Caballero // TODO: The 'vector.mask' passthru is a vector and 'vector.transfer_read' 217c3e09036SDiego Caballero // expects a scalar. We could only lower one to the other for cases where 218c3e09036SDiego Caballero // the passthru is a broadcast of a scalar. 219c3e09036SDiego Caballero if (maskingOp.hasPassthru()) 220c3e09036SDiego Caballero return rewriter.notifyMatchFailure( 221c3e09036SDiego Caballero maskingOp, "Can't lower passthru to vector.transfer_read"); 222c3e09036SDiego Caballero 223c3e09036SDiego Caballero // Replace the `vector.mask` operation. 224c3e09036SDiego Caballero rewriter.replaceOpWithNewOp<TransferReadOp>( 225c3e09036SDiego Caballero maskingOp.getOperation(), readOp.getVectorType(), readOp.getSource(), 226c3e09036SDiego Caballero readOp.getIndices(), readOp.getPermutationMap(), readOp.getPadding(), 2272ee5586aSAndrzej Warzyński maskingOp.getMask(), readOp.getInBounds()); 228c3e09036SDiego Caballero return success(); 229c3e09036SDiego Caballero } 230c3e09036SDiego Caballero }; 231c3e09036SDiego Caballero 232c3e09036SDiego Caballero /// Lowers a masked `vector.transfer_write` operation. 233c3e09036SDiego Caballero struct MaskedTransferWriteOpPattern 234c3e09036SDiego Caballero : public MaskOpRewritePattern<TransferWriteOp> { 235c3e09036SDiego Caballero public: 236c3e09036SDiego Caballero using MaskOpRewritePattern<TransferWriteOp>::MaskOpRewritePattern; 237c3e09036SDiego Caballero 238c3e09036SDiego Caballero LogicalResult 239c3e09036SDiego Caballero matchAndRewriteMaskableOp(TransferWriteOp writeOp, 240c3e09036SDiego Caballero MaskingOpInterface maskingOp, 241c3e09036SDiego Caballero PatternRewriter &rewriter) const override { 242c3e09036SDiego Caballero Type resultType = 243c3e09036SDiego Caballero writeOp.getResult() ? writeOp.getResult().getType() : Type(); 244c3e09036SDiego Caballero 245c3e09036SDiego Caballero // Replace the `vector.mask` operation. 246c3e09036SDiego Caballero rewriter.replaceOpWithNewOp<TransferWriteOp>( 247c3e09036SDiego Caballero maskingOp.getOperation(), resultType, writeOp.getVector(), 248c3e09036SDiego Caballero writeOp.getSource(), writeOp.getIndices(), writeOp.getPermutationMap(), 2492ee5586aSAndrzej Warzyński maskingOp.getMask(), writeOp.getInBounds()); 250c3e09036SDiego Caballero return success(); 251c3e09036SDiego Caballero } 252c3e09036SDiego Caballero }; 253c3e09036SDiego Caballero 2541ac874c9SDiego Caballero /// Lowers a masked `vector.gather` operation. 2551ac874c9SDiego Caballero struct MaskedGatherOpPattern : public MaskOpRewritePattern<GatherOp> { 2561ac874c9SDiego Caballero public: 2571ac874c9SDiego Caballero using MaskOpRewritePattern<GatherOp>::MaskOpRewritePattern; 2581ac874c9SDiego Caballero 2591ac874c9SDiego Caballero LogicalResult 2601ac874c9SDiego Caballero matchAndRewriteMaskableOp(GatherOp gatherOp, MaskingOpInterface maskingOp, 2611ac874c9SDiego Caballero PatternRewriter &rewriter) const override { 2621ac874c9SDiego Caballero Value passthru = maskingOp.hasPassthru() 2631ac874c9SDiego Caballero ? maskingOp.getPassthru() 2641ac874c9SDiego Caballero : rewriter.create<arith::ConstantOp>( 2651ac874c9SDiego Caballero gatherOp.getLoc(), 2661ac874c9SDiego Caballero rewriter.getZeroAttr(gatherOp.getVectorType())); 2671ac874c9SDiego Caballero 2681ac874c9SDiego Caballero // Replace the `vector.mask` operation. 2691ac874c9SDiego Caballero rewriter.replaceOpWithNewOp<GatherOp>( 2701ac874c9SDiego Caballero maskingOp.getOperation(), gatherOp.getVectorType(), gatherOp.getBase(), 2711ac874c9SDiego Caballero gatherOp.getIndices(), gatherOp.getIndexVec(), maskingOp.getMask(), 2721ac874c9SDiego Caballero passthru); 2731ac874c9SDiego Caballero return success(); 2741ac874c9SDiego Caballero } 2751ac874c9SDiego Caballero }; 2761ac874c9SDiego Caballero 277c3e09036SDiego Caballero struct LowerVectorMaskPass 278c3e09036SDiego Caballero : public vector::impl::LowerVectorMaskPassBase<LowerVectorMaskPass> { 279c3e09036SDiego Caballero using Base::Base; 280c3e09036SDiego Caballero 281c3e09036SDiego Caballero void runOnOperation() override { 282c3e09036SDiego Caballero Operation *op = getOperation(); 283c3e09036SDiego Caballero MLIRContext *context = op->getContext(); 284c3e09036SDiego Caballero 285c3e09036SDiego Caballero RewritePatternSet loweringPatterns(context); 286c3e09036SDiego Caballero populateVectorMaskLoweringPatternsForSideEffectingOps(loweringPatterns); 287d5a0fb39SFelix Schneider MaskOp::getCanonicalizationPatterns(loweringPatterns, context); 288c3e09036SDiego Caballero 289*09dfc571SJacques Pienaar if (failed(applyPatternsGreedily(op, std::move(loweringPatterns)))) 290c3e09036SDiego Caballero signalPassFailure(); 291c3e09036SDiego Caballero } 292c3e09036SDiego Caballero 293c3e09036SDiego Caballero void getDependentDialects(DialectRegistry ®istry) const override { 294c3e09036SDiego Caballero registry.insert<vector::VectorDialect>(); 295c3e09036SDiego Caballero } 296c3e09036SDiego Caballero }; 297c3e09036SDiego Caballero 298c3e09036SDiego Caballero } // namespace 299c3e09036SDiego Caballero 30072fd3644SDiego Caballero /// Populates instances of `MaskOpRewritePattern` to lower masked operations 30172fd3644SDiego Caballero /// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and 30272fd3644SDiego Caballero /// not its nested `MaskableOpInterface`. 30372fd3644SDiego Caballero void vector::populateVectorMaskLoweringPatternsForSideEffectingOps( 30472fd3644SDiego Caballero RewritePatternSet &patterns) { 3051ac874c9SDiego Caballero patterns.add<MaskedTransferReadOpPattern, MaskedTransferWriteOpPattern, 3061ac874c9SDiego Caballero MaskedGatherOpPattern>(patterns.getContext()); 30772fd3644SDiego Caballero } 30872fd3644SDiego Caballero 309c3e09036SDiego Caballero std::unique_ptr<Pass> mlir::vector::createLowerVectorMaskPass() { 310c3e09036SDiego Caballero return std::make_unique<LowerVectorMaskPass>(); 311c3e09036SDiego Caballero } 312