1 //=- VectorEmulateMaskedLoadStore.cpp - Emulate 'vector.maskedload/store' op =//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to emulate the
10 // 'vector.maskedload' and 'vector.maskedstore' operation.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/SCF/IR/SCF.h"
16 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
17
18 using namespace mlir;
19
20 namespace {
21
22 /// Convert vector.maskedload
23 ///
24 /// Before:
25 ///
26 /// vector.maskedload %base[%idx_0, %idx_1], %mask, %pass_thru
27 ///
28 /// After:
29 ///
30 /// %ivalue = %pass_thru
31 /// %m = vector.extract %mask[0]
32 /// %result0 = scf.if %m {
33 /// %v = memref.load %base[%idx_0, %idx_1]
34 /// %combined = vector.insert %v, %ivalue[0]
35 /// scf.yield %combined
36 /// } else {
37 /// scf.yield %ivalue
38 /// }
39 /// %m = vector.extract %mask[1]
40 /// %result1 = scf.if %m {
41 /// %v = memref.load %base[%idx_0, %idx_1 + 1]
42 /// %combined = vector.insert %v, %result0[1]
43 /// scf.yield %combined
44 /// } else {
45 /// scf.yield %result0
46 /// }
47 /// ...
48 ///
49 struct VectorMaskedLoadOpConverter final
50 : OpRewritePattern<vector::MaskedLoadOp> {
51 using OpRewritePattern::OpRewritePattern;
52
matchAndRewrite__anonef10eae30111::VectorMaskedLoadOpConverter53 LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedLoadOp,
54 PatternRewriter &rewriter) const override {
55 VectorType maskVType = maskedLoadOp.getMaskVectorType();
56 if (maskVType.getShape().size() != 1)
57 return rewriter.notifyMatchFailure(
58 maskedLoadOp, "expected vector.maskedstore with 1-D mask");
59
60 Location loc = maskedLoadOp.getLoc();
61 int64_t maskLength = maskVType.getShape()[0];
62
63 Type indexType = rewriter.getIndexType();
64 Value mask = maskedLoadOp.getMask();
65 Value base = maskedLoadOp.getBase();
66 Value iValue = maskedLoadOp.getPassThru();
67 auto indices = llvm::to_vector_of<Value>(maskedLoadOp.getIndices());
68 Value one = rewriter.create<arith::ConstantOp>(
69 loc, indexType, IntegerAttr::get(indexType, 1));
70 for (int64_t i = 0; i < maskLength; ++i) {
71 auto maskBit = rewriter.create<vector::ExtractOp>(loc, mask, i);
72
73 auto ifOp = rewriter.create<scf::IfOp>(
74 loc, maskBit,
75 [&](OpBuilder &builder, Location loc) {
76 auto loadedValue =
77 builder.create<memref::LoadOp>(loc, base, indices);
78 auto combinedValue =
79 builder.create<vector::InsertOp>(loc, loadedValue, iValue, i);
80 builder.create<scf::YieldOp>(loc, combinedValue.getResult());
81 },
82 [&](OpBuilder &builder, Location loc) {
83 builder.create<scf::YieldOp>(loc, iValue);
84 });
85 iValue = ifOp.getResult(0);
86
87 indices.back() = rewriter.create<arith::AddIOp>(loc, indices.back(), one);
88 }
89
90 rewriter.replaceOp(maskedLoadOp, iValue);
91
92 return success();
93 }
94 };
95
96 /// Convert vector.maskedstore
97 ///
98 /// Before:
99 ///
100 /// vector.maskedstore %base[%idx_0, %idx_1], %mask, %value
101 ///
102 /// After:
103 ///
104 /// %m = vector.extract %mask[0]
105 /// scf.if %m {
106 /// %extracted = vector.extract %value[0]
107 /// memref.store %extracted, %base[%idx_0, %idx_1]
108 /// }
109 /// %m = vector.extract %mask[1]
110 /// scf.if %m {
111 /// %extracted = vector.extract %value[1]
112 /// memref.store %extracted, %base[%idx_0, %idx_1 + 1]
113 /// }
114 /// ...
115 ///
116 struct VectorMaskedStoreOpConverter final
117 : OpRewritePattern<vector::MaskedStoreOp> {
118 using OpRewritePattern::OpRewritePattern;
119
matchAndRewrite__anonef10eae30111::VectorMaskedStoreOpConverter120 LogicalResult matchAndRewrite(vector::MaskedStoreOp maskedStoreOp,
121 PatternRewriter &rewriter) const override {
122 VectorType maskVType = maskedStoreOp.getMaskVectorType();
123 if (maskVType.getShape().size() != 1)
124 return rewriter.notifyMatchFailure(
125 maskedStoreOp, "expected vector.maskedstore with 1-D mask");
126
127 Location loc = maskedStoreOp.getLoc();
128 int64_t maskLength = maskVType.getShape()[0];
129
130 Type indexType = rewriter.getIndexType();
131 Value mask = maskedStoreOp.getMask();
132 Value base = maskedStoreOp.getBase();
133 Value value = maskedStoreOp.getValueToStore();
134 auto indices = llvm::to_vector_of<Value>(maskedStoreOp.getIndices());
135 Value one = rewriter.create<arith::ConstantOp>(
136 loc, indexType, IntegerAttr::get(indexType, 1));
137 for (int64_t i = 0; i < maskLength; ++i) {
138 auto maskBit = rewriter.create<vector::ExtractOp>(loc, mask, i);
139
140 auto ifOp = rewriter.create<scf::IfOp>(loc, maskBit, /*else=*/false);
141 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
142 auto extractedValue = rewriter.create<vector::ExtractOp>(loc, value, i);
143 rewriter.create<memref::StoreOp>(loc, extractedValue, base, indices);
144
145 rewriter.setInsertionPointAfter(ifOp);
146 indices.back() = rewriter.create<arith::AddIOp>(loc, indices.back(), one);
147 }
148
149 rewriter.eraseOp(maskedStoreOp);
150
151 return success();
152 }
153 };
154
155 } // namespace
156
populateVectorMaskedLoadStoreEmulationPatterns(RewritePatternSet & patterns,PatternBenefit benefit)157 void mlir::vector::populateVectorMaskedLoadStoreEmulationPatterns(
158 RewritePatternSet &patterns, PatternBenefit benefit) {
159 patterns.add<VectorMaskedLoadOpConverter, VectorMaskedStoreOpConverter>(
160 patterns.getContext(), benefit);
161 }
162