1 //===- VectorMaskElimination.cpp - Eliminate Vector Masks -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arith/IR/Arith.h" 10 #include "mlir/Dialect/Utils/StaticValueUtils.h" 11 #include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h" 12 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 13 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 14 #include "mlir/Interfaces/FunctionInterfaces.h" 15 16 using namespace mlir; 17 using namespace mlir::vector; 18 namespace { 19 20 /// Attempts to resolve a (scalable) CreateMaskOp to an all-true constant mask. 21 /// All-true masks can then be eliminated by simple folds. 22 LogicalResult resolveAllTrueCreateMaskOp(IRRewriter &rewriter, 23 vector::CreateMaskOp createMaskOp, 24 VscaleRange vscaleRange) { 25 auto maskType = createMaskOp.getVectorType(); 26 auto maskTypeDimScalableFlags = maskType.getScalableDims(); 27 auto maskTypeDimSizes = maskType.getShape(); 28 29 struct UnknownMaskDim { 30 size_t position; 31 Value dimSize; 32 }; 33 34 // Loop over the CreateMaskOp operands and collect unknown dims (i.e. dims 35 // that are not obviously constant). If any constant dimension is not all-true 36 // bail out early (as this transform only trying to resolve all-true masks). 37 // This avoids doing value-bounds anaylis in cases like: 38 // `%mask = vector.create_mask %dynamicValue, %c2 : vector<8x4xi1>` 39 // ...where it is known the mask is not all-true by looking at `%c2`. 40 SmallVector<UnknownMaskDim> unknownDims; 41 for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) { 42 if (auto intSize = getConstantIntValue(dimSize)) { 43 // Mask not all-true for this dim. 44 if (maskTypeDimScalableFlags[i] || intSize < maskTypeDimSizes[i]) 45 return failure(); 46 } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) { 47 // Mask not all-true for this dim. 48 if (vscaleMultiplier < maskTypeDimSizes[i]) 49 return failure(); 50 } else { 51 // Unknown (without further analysis). 52 unknownDims.push_back(UnknownMaskDim{i, dimSize}); 53 } 54 } 55 56 for (auto [i, dimSize] : unknownDims) { 57 // Compute the lower bound for the unknown dimension (i.e. the smallest 58 // value it could be). 59 FailureOr<ConstantOrScalableBound> dimLowerBound = 60 vector::ScalableValueBoundsConstraintSet::computeScalableBound( 61 dimSize, {}, vscaleRange.vscaleMin, vscaleRange.vscaleMax, 62 presburger::BoundType::LB); 63 if (failed(dimLowerBound)) 64 return failure(); 65 auto dimLowerBoundSize = dimLowerBound->getSize(); 66 if (failed(dimLowerBoundSize)) 67 return failure(); 68 if (dimLowerBoundSize->scalable) { 69 // 1. The lower bound, LB, is scalable. If LB is < the mask dim size then 70 // this dim is not all-true. 71 if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i]) 72 return failure(); 73 } else { 74 // 2. The lower bound, LB, is a constant. 75 // - If the mask dim size is scalable then this dim is not all-true. 76 if (maskTypeDimScalableFlags[i]) 77 return failure(); 78 // - If LB < the _fixed-size_ mask dim size then this dim is not all-true. 79 if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i]) 80 return failure(); 81 } 82 } 83 84 // Replace createMaskOp with an all-true constant. This should result in the 85 // mask being removed in most cases (as xfer ops + vector.mask have folds to 86 // remove all-true masks). 87 auto allTrue = rewriter.create<vector::ConstantMaskOp>( 88 createMaskOp.getLoc(), maskType, ConstantMaskKind::AllTrue); 89 rewriter.replaceAllUsesWith(createMaskOp, allTrue); 90 return success(); 91 } 92 93 } // namespace 94 95 namespace mlir::vector { 96 97 void eliminateVectorMasks(IRRewriter &rewriter, FunctionOpInterface function, 98 std::optional<VscaleRange> vscaleRange) { 99 // TODO: Support fixed-size case. This is less likely to be useful as for 100 // fixed-size code dimensions are all static so masks tend to fold away. 101 if (!vscaleRange) 102 return; 103 104 OpBuilder::InsertionGuard g(rewriter); 105 106 // Build worklist so we can safely insert new ops in 107 // `resolveAllTrueCreateMaskOp()`. 108 SmallVector<vector::CreateMaskOp> worklist; 109 function.walk([&](vector::CreateMaskOp createMaskOp) { 110 worklist.push_back(createMaskOp); 111 }); 112 113 rewriter.setInsertionPointToStart(&function.front()); 114 for (auto mask : worklist) 115 (void)resolveAllTrueCreateMaskOp(rewriter, mask, *vscaleRange); 116 } 117 118 } // namespace mlir::vector 119