xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp (revision 9b06e25e73470612d14f0e1e18fde82f62266216)
1 //===- VectorMaskElimination.cpp - Eliminate Vector Masks -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arith/IR/Arith.h"
10 #include "mlir/Dialect/Utils/StaticValueUtils.h"
11 #include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
12 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
13 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
14 #include "mlir/Interfaces/FunctionInterfaces.h"
15 
16 using namespace mlir;
17 using namespace mlir::vector;
18 namespace {
19 
20 /// Attempts to resolve a (scalable) CreateMaskOp to an all-true constant mask.
21 /// All-true masks can then be eliminated by simple folds.
22 LogicalResult resolveAllTrueCreateMaskOp(IRRewriter &rewriter,
23                                          vector::CreateMaskOp createMaskOp,
24                                          VscaleRange vscaleRange) {
25   auto maskType = createMaskOp.getVectorType();
26   auto maskTypeDimScalableFlags = maskType.getScalableDims();
27   auto maskTypeDimSizes = maskType.getShape();
28 
29   struct UnknownMaskDim {
30     size_t position;
31     Value dimSize;
32   };
33 
34   // Loop over the CreateMaskOp operands and collect unknown dims (i.e. dims
35   // that are not obviously constant). If any constant dimension is not all-true
36   // bail out early (as this transform only trying to resolve all-true masks).
37   // This avoids doing value-bounds anaylis in cases like:
38   // `%mask = vector.create_mask %dynamicValue, %c2 : vector<8x4xi1>`
39   // ...where it is known the mask is not all-true by looking at `%c2`.
40   SmallVector<UnknownMaskDim> unknownDims;
41   for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
42     if (auto intSize = getConstantIntValue(dimSize)) {
43       // Mask not all-true for this dim.
44       if (maskTypeDimScalableFlags[i] || intSize < maskTypeDimSizes[i])
45         return failure();
46     } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
47       // Mask not all-true for this dim.
48       if (vscaleMultiplier < maskTypeDimSizes[i])
49         return failure();
50     } else {
51       // Unknown (without further analysis).
52       unknownDims.push_back(UnknownMaskDim{i, dimSize});
53     }
54   }
55 
56   for (auto [i, dimSize] : unknownDims) {
57     // Compute the lower bound for the unknown dimension (i.e. the smallest
58     // value it could be).
59     FailureOr<ConstantOrScalableBound> dimLowerBound =
60         vector::ScalableValueBoundsConstraintSet::computeScalableBound(
61             dimSize, {}, vscaleRange.vscaleMin, vscaleRange.vscaleMax,
62             presburger::BoundType::LB);
63     if (failed(dimLowerBound))
64       return failure();
65     auto dimLowerBoundSize = dimLowerBound->getSize();
66     if (failed(dimLowerBoundSize))
67       return failure();
68     if (dimLowerBoundSize->scalable) {
69       // 1. The lower bound, LB, is scalable. If LB is < the mask dim size then
70       // this dim is not all-true.
71       if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i])
72         return failure();
73     } else {
74       // 2. The lower bound, LB, is a constant.
75       // - If the mask dim size is scalable then this dim is not all-true.
76       if (maskTypeDimScalableFlags[i])
77         return failure();
78       // - If LB < the _fixed-size_ mask dim size then this dim is not all-true.
79       if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i])
80         return failure();
81     }
82   }
83 
84   // Replace createMaskOp with an all-true constant. This should result in the
85   // mask being removed in most cases (as xfer ops + vector.mask have folds to
86   // remove all-true masks).
87   auto allTrue = rewriter.create<vector::ConstantMaskOp>(
88       createMaskOp.getLoc(), maskType, ConstantMaskKind::AllTrue);
89   rewriter.replaceAllUsesWith(createMaskOp, allTrue);
90   return success();
91 }
92 
93 } // namespace
94 
95 namespace mlir::vector {
96 
97 void eliminateVectorMasks(IRRewriter &rewriter, FunctionOpInterface function,
98                           std::optional<VscaleRange> vscaleRange) {
99   // TODO: Support fixed-size case. This is less likely to be useful as for
100   // fixed-size code dimensions are all static so masks tend to fold away.
101   if (!vscaleRange)
102     return;
103 
104   OpBuilder::InsertionGuard g(rewriter);
105 
106   // Build worklist so we can safely insert new ops in
107   // `resolveAllTrueCreateMaskOp()`.
108   SmallVector<vector::CreateMaskOp> worklist;
109   function.walk([&](vector::CreateMaskOp createMaskOp) {
110     worklist.push_back(createMaskOp);
111   });
112 
113   rewriter.setInsertionPointToStart(&function.front());
114   for (auto mask : worklist)
115     (void)resolveAllTrueCreateMaskOp(rewriter, mask, *vscaleRange);
116 }
117 
118 } // namespace mlir::vector
119