xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp (revision f643eec892954653b1c9bde42407560caf660b8b)
1 //=- VectorEmulateMaskedLoadStore.cpp - Emulate 'vector.maskedload/store' op =//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to emulate the
10 // 'vector.maskedload' and 'vector.maskedstore' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/SCF/IR/SCF.h"
16 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
17 
18 using namespace mlir;
19 
20 namespace {
21 
22 /// Convert vector.maskedload
23 ///
24 /// Before:
25 ///
26 ///   vector.maskedload %base[%idx_0, %idx_1], %mask, %pass_thru
27 ///
28 /// After:
29 ///
30 ///   %ivalue = %pass_thru
31 ///   %m = vector.extract %mask[0]
32 ///   %result0 = scf.if %m {
33 ///     %v = memref.load %base[%idx_0, %idx_1]
34 ///     %combined = vector.insert %v, %ivalue[0]
35 ///     scf.yield %combined
36 ///   } else {
37 ///     scf.yield %ivalue
38 ///   }
39 ///   %m = vector.extract %mask[1]
40 ///   %result1 = scf.if %m {
41 ///     %v = memref.load %base[%idx_0, %idx_1 + 1]
42 ///     %combined = vector.insert %v, %result0[1]
43 ///     scf.yield %combined
44 ///   } else {
45 ///     scf.yield %result0
46 ///   }
47 ///   ...
48 ///
49 struct VectorMaskedLoadOpConverter final
50     : OpRewritePattern<vector::MaskedLoadOp> {
51   using OpRewritePattern::OpRewritePattern;
52 
matchAndRewrite__anonef10eae30111::VectorMaskedLoadOpConverter53   LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedLoadOp,
54                                 PatternRewriter &rewriter) const override {
55     VectorType maskVType = maskedLoadOp.getMaskVectorType();
56     if (maskVType.getShape().size() != 1)
57       return rewriter.notifyMatchFailure(
58           maskedLoadOp, "expected vector.maskedstore with 1-D mask");
59 
60     Location loc = maskedLoadOp.getLoc();
61     int64_t maskLength = maskVType.getShape()[0];
62 
63     Type indexType = rewriter.getIndexType();
64     Value mask = maskedLoadOp.getMask();
65     Value base = maskedLoadOp.getBase();
66     Value iValue = maskedLoadOp.getPassThru();
67     auto indices = llvm::to_vector_of<Value>(maskedLoadOp.getIndices());
68     Value one = rewriter.create<arith::ConstantOp>(
69         loc, indexType, IntegerAttr::get(indexType, 1));
70     for (int64_t i = 0; i < maskLength; ++i) {
71       auto maskBit = rewriter.create<vector::ExtractOp>(loc, mask, i);
72 
73       auto ifOp = rewriter.create<scf::IfOp>(
74           loc, maskBit,
75           [&](OpBuilder &builder, Location loc) {
76             auto loadedValue =
77                 builder.create<memref::LoadOp>(loc, base, indices);
78             auto combinedValue =
79                 builder.create<vector::InsertOp>(loc, loadedValue, iValue, i);
80             builder.create<scf::YieldOp>(loc, combinedValue.getResult());
81           },
82           [&](OpBuilder &builder, Location loc) {
83             builder.create<scf::YieldOp>(loc, iValue);
84           });
85       iValue = ifOp.getResult(0);
86 
87       indices.back() = rewriter.create<arith::AddIOp>(loc, indices.back(), one);
88     }
89 
90     rewriter.replaceOp(maskedLoadOp, iValue);
91 
92     return success();
93   }
94 };
95 
96 /// Convert vector.maskedstore
97 ///
98 /// Before:
99 ///
100 ///   vector.maskedstore %base[%idx_0, %idx_1], %mask, %value
101 ///
102 /// After:
103 ///
104 ///   %m = vector.extract %mask[0]
105 ///   scf.if %m {
106 ///     %extracted = vector.extract %value[0]
107 ///     memref.store %extracted, %base[%idx_0, %idx_1]
108 ///   }
109 ///   %m = vector.extract %mask[1]
110 ///   scf.if %m {
111 ///     %extracted = vector.extract %value[1]
112 ///     memref.store %extracted, %base[%idx_0, %idx_1 + 1]
113 ///   }
114 ///   ...
115 ///
116 struct VectorMaskedStoreOpConverter final
117     : OpRewritePattern<vector::MaskedStoreOp> {
118   using OpRewritePattern::OpRewritePattern;
119 
matchAndRewrite__anonef10eae30111::VectorMaskedStoreOpConverter120   LogicalResult matchAndRewrite(vector::MaskedStoreOp maskedStoreOp,
121                                 PatternRewriter &rewriter) const override {
122     VectorType maskVType = maskedStoreOp.getMaskVectorType();
123     if (maskVType.getShape().size() != 1)
124       return rewriter.notifyMatchFailure(
125           maskedStoreOp, "expected vector.maskedstore with 1-D mask");
126 
127     Location loc = maskedStoreOp.getLoc();
128     int64_t maskLength = maskVType.getShape()[0];
129 
130     Type indexType = rewriter.getIndexType();
131     Value mask = maskedStoreOp.getMask();
132     Value base = maskedStoreOp.getBase();
133     Value value = maskedStoreOp.getValueToStore();
134     auto indices = llvm::to_vector_of<Value>(maskedStoreOp.getIndices());
135     Value one = rewriter.create<arith::ConstantOp>(
136         loc, indexType, IntegerAttr::get(indexType, 1));
137     for (int64_t i = 0; i < maskLength; ++i) {
138       auto maskBit = rewriter.create<vector::ExtractOp>(loc, mask, i);
139 
140       auto ifOp = rewriter.create<scf::IfOp>(loc, maskBit, /*else=*/false);
141       rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
142       auto extractedValue = rewriter.create<vector::ExtractOp>(loc, value, i);
143       rewriter.create<memref::StoreOp>(loc, extractedValue, base, indices);
144 
145       rewriter.setInsertionPointAfter(ifOp);
146       indices.back() = rewriter.create<arith::AddIOp>(loc, indices.back(), one);
147     }
148 
149     rewriter.eraseOp(maskedStoreOp);
150 
151     return success();
152   }
153 };
154 
155 } // namespace
156 
populateVectorMaskedLoadStoreEmulationPatterns(RewritePatternSet & patterns,PatternBenefit benefit)157 void mlir::vector::populateVectorMaskedLoadStoreEmulationPatterns(
158     RewritePatternSet &patterns, PatternBenefit benefit) {
159   patterns.add<VectorMaskedLoadOpConverter, VectorMaskedStoreOpConverter>(
160       patterns.getContext(), benefit);
161 }
162