xref: /llvm-project/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- LegalizeVectorStorage.cpp - Ensures SVE loads/stores are legal -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
10 #include "mlir/Dialect/ArmSVE/Transforms/Passes.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/Dialect/Vector/IR/VectorOps.h"
14 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
15 
16 namespace mlir::arm_sve {
17 #define GEN_PASS_DEF_LEGALIZEVECTORSTORAGE
18 #include "mlir/Dialect/ArmSVE/Transforms/Passes.h.inc"
19 } // namespace mlir::arm_sve
20 
21 using namespace mlir;
22 using namespace mlir::arm_sve;
23 
24 // A tag to mark unrealized_conversions produced by this pass. This is used to
25 // detect IR this pass failed to completely legalize, and report an error.
26 // If everything was successfully legalized, no tagged ops will remain after
27 // this pass.
28 constexpr StringLiteral kSVELegalizerTag("__arm_sve_legalize_vector_storage__");
29 
30 /// Definitions:
31 ///
32 /// [1] svbool = vector<...x[16]xi1>, which maps to some multiple of full SVE
33 /// predicate registers. A full predicate is the smallest quantity that can be
34 /// loaded/stored.
35 ///
36 /// [2] SVE mask = hardware-sized SVE predicate mask, i.e. its trailing
37 /// dimension matches the size of a legal SVE vector size (such as
38 /// vector<[4]xi1>), but is too small to be stored to memory (i.e smaller than
39 /// a svbool).
40 
41 namespace {
42 
43 /// Checks if a vector type is a SVE mask [2].
44 bool isSVEMaskType(VectorType type) {
45   return type.getRank() > 0 && type.getElementType().isInteger(1) &&
46          type.getScalableDims().back() && type.getShape().back() < 16 &&
47          llvm::isPowerOf2_32(type.getShape().back()) &&
48          !llvm::is_contained(type.getScalableDims().drop_back(), true);
49 }
50 
51 VectorType widenScalableMaskTypeToSvbool(VectorType type) {
52   assert(isSVEMaskType(type));
53   return VectorType::Builder(type).setDim(type.getRank() - 1, 16);
54 }
55 
56 /// A helper for cloning an op and replacing it will a new version, updated by a
57 /// callback.
58 template <typename TOp, typename TLegalizerCallback>
59 void replaceOpWithLegalizedOp(PatternRewriter &rewriter, TOp op,
60                               TLegalizerCallback callback) {
61   // Clone the previous op to preserve any properties/attributes.
62   auto newOp = op.clone();
63   rewriter.insert(newOp);
64   rewriter.replaceOp(op, callback(newOp));
65 }
66 
67 /// A helper for cloning an op and replacing it with a new version, updated by a
68 /// callback, and an unrealized conversion back to the type of the replaced op.
69 template <typename TOp, typename TLegalizerCallback>
70 void replaceOpWithUnrealizedConversion(PatternRewriter &rewriter, TOp op,
71                                        TLegalizerCallback callback) {
72   replaceOpWithLegalizedOp(rewriter, op, [&](TOp newOp) {
73     // Mark our `unrealized_conversion_casts` with a pass label.
74     return rewriter.create<UnrealizedConversionCastOp>(
75         op.getLoc(), TypeRange{op.getResult().getType()},
76         ValueRange{callback(newOp)},
77         NamedAttribute(rewriter.getStringAttr(kSVELegalizerTag),
78                        rewriter.getUnitAttr()));
79   });
80 }
81 
82 /// Extracts the widened SVE memref value (that's legal to store/load) from the
83 /// `unrealized_conversion_cast`s added by this pass.
84 static FailureOr<Value> getSVELegalizedMemref(Value illegalMemref) {
85   Operation *definingOp = illegalMemref.getDefiningOp();
86   if (!definingOp || !definingOp->hasAttr(kSVELegalizerTag))
87     return failure();
88   auto unrealizedConversion =
89       llvm::cast<UnrealizedConversionCastOp>(definingOp);
90   return unrealizedConversion.getOperand(0);
91 }
92 
93 /// The default alignment of an alloca in LLVM may request overaligned sizes for
94 /// SVE types, which will fail during stack frame allocation. This rewrite
95 /// explicitly adds a reasonable alignment to allocas of scalable types.
96 struct RelaxScalableVectorAllocaAlignment
97     : public OpRewritePattern<memref::AllocaOp> {
98   using OpRewritePattern::OpRewritePattern;
99 
100   LogicalResult matchAndRewrite(memref::AllocaOp allocaOp,
101                                 PatternRewriter &rewriter) const override {
102     auto memrefElementType = allocaOp.getType().getElementType();
103     auto vectorType = llvm::dyn_cast<VectorType>(memrefElementType);
104     if (!vectorType || !vectorType.isScalable() || allocaOp.getAlignment())
105       return failure();
106 
107     // Set alignment based on the defaults for SVE vectors and predicates.
108     unsigned aligment = vectorType.getElementType().isInteger(1) ? 2 : 16;
109     rewriter.modifyOpInPlace(allocaOp,
110                              [&] { allocaOp.setAlignment(aligment); });
111 
112     return success();
113   }
114 };
115 
116 /// Replaces allocations of SVE predicates smaller than an svbool [1] (_illegal_
117 /// to load/store) with a wider allocation of svbool (_legal_ to load/store)
118 /// followed by a tagged unrealized conversion to the original type.
119 ///
120 /// Example
121 /// ```
122 /// %alloca = memref.alloca() : memref<vector<[4]xi1>>
123 /// ```
124 /// is rewritten into:
125 /// ```
126 /// %widened = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>>
127 /// %alloca = builtin.unrealized_conversion_cast %widened
128 ///   : memref<vector<[16]xi1>> to memref<vector<[4]xi1>>
129 ///     {__arm_sve_legalize_vector_storage__}
130 /// ```
131 template <typename AllocLikeOp>
132 struct LegalizeSVEMaskAllocation : public OpRewritePattern<AllocLikeOp> {
133   using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
134 
135   LogicalResult matchAndRewrite(AllocLikeOp allocLikeOp,
136                                 PatternRewriter &rewriter) const override {
137     auto vectorType =
138         llvm::dyn_cast<VectorType>(allocLikeOp.getType().getElementType());
139 
140     if (!vectorType || !isSVEMaskType(vectorType))
141       return failure();
142 
143     // Replace this alloc-like op of an SVE mask [2] with one of a (storable)
144     // svbool mask [1]. A temporary unrealized_conversion_cast is added to the
145     // old type to allow local rewrites.
146     replaceOpWithUnrealizedConversion(
147         rewriter, allocLikeOp, [&](AllocLikeOp newAllocLikeOp) {
148           newAllocLikeOp.getResult().setType(
149               llvm::cast<MemRefType>(newAllocLikeOp.getType().cloneWith(
150                   {}, widenScalableMaskTypeToSvbool(vectorType))));
151           return newAllocLikeOp;
152         });
153 
154     return success();
155   }
156 };
157 
158 /// Replaces vector.type_casts of unrealized conversions to SVE predicate memref
159 /// types that are _illegal_ to load/store from (!= svbool [1]), with type casts
160 /// of memref types that are _legal_ to load/store, followed by unrealized
161 /// conversions.
162 ///
163 /// Example:
164 /// ```
165 /// %alloca = builtin.unrealized_conversion_cast %widened
166 ///   : memref<vector<[16]xi1>> to memref<vector<[8]xi1>>
167 ///     {__arm_sve_legalize_vector_storage__}
168 /// %cast = vector.type_cast %alloca
169 ///   : memref<vector<3x[8]xi1>> to memref<3xvector<[8]xi1>>
170 /// ```
171 /// is rewritten into:
172 /// ```
173 /// %widened_cast = vector.type_cast %widened
174 ///   : memref<vector<3x[16]xi1>> to memref<3xvector<[16]xi1>>
175 /// %cast = builtin.unrealized_conversion_cast %widened_cast
176 ///   : memref<3xvector<[16]xi1>> to memref<3xvector<[8]xi1>>
177 ///     {__arm_sve_legalize_vector_storage__}
178 /// ```
179 struct LegalizeSVEMaskTypeCastConversion
180     : public OpRewritePattern<vector::TypeCastOp> {
181   using OpRewritePattern::OpRewritePattern;
182 
183   LogicalResult matchAndRewrite(vector::TypeCastOp typeCastOp,
184                                 PatternRewriter &rewriter) const override {
185     auto resultType = typeCastOp.getResultMemRefType();
186     auto vectorType = llvm::dyn_cast<VectorType>(resultType.getElementType());
187 
188     if (!vectorType || !isSVEMaskType(vectorType))
189       return failure();
190 
191     auto legalMemref = getSVELegalizedMemref(typeCastOp.getMemref());
192     if (failed(legalMemref))
193       return failure();
194 
195     // Replace this vector.type_cast with one of a (storable) svbool mask [1].
196     replaceOpWithUnrealizedConversion(
197         rewriter, typeCastOp, [&](vector::TypeCastOp newTypeCast) {
198           newTypeCast.setOperand(*legalMemref);
199           newTypeCast.getResult().setType(
200               llvm::cast<MemRefType>(newTypeCast.getType().cloneWith(
201                   {}, widenScalableMaskTypeToSvbool(vectorType))));
202           return newTypeCast;
203         });
204 
205     return success();
206   }
207 };
208 
209 /// Replaces stores to unrealized conversions to SVE predicate memref types that
210 /// are _illegal_ to load/store from (!= svbool [1]), with
211 /// `arm_sve.convert_to_svbool`s followed by (legal) wider stores.
212 ///
213 /// Example:
214 /// ```
215 /// memref.store %mask, %alloca[] : memref<vector<[8]xi1>>
216 /// ```
217 /// is rewritten into:
218 /// ```
219 /// %svbool = arm_sve.convert_to_svbool %mask : vector<[8]xi1>
220 /// memref.store %svbool, %widened[] : memref<vector<[16]xi1>>
221 /// ```
222 struct LegalizeSVEMaskStoreConversion
223     : public OpRewritePattern<memref::StoreOp> {
224   using OpRewritePattern::OpRewritePattern;
225 
226   LogicalResult matchAndRewrite(memref::StoreOp storeOp,
227                                 PatternRewriter &rewriter) const override {
228     auto loc = storeOp.getLoc();
229 
230     Value valueToStore = storeOp.getValueToStore();
231     auto vectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
232 
233     if (!vectorType || !isSVEMaskType(vectorType))
234       return failure();
235 
236     auto legalMemref = getSVELegalizedMemref(storeOp.getMemref());
237     if (failed(legalMemref))
238       return failure();
239 
240     auto legalMaskType = widenScalableMaskTypeToSvbool(
241         llvm::cast<VectorType>(valueToStore.getType()));
242     auto convertToSvbool = rewriter.create<arm_sve::ConvertToSvboolOp>(
243         loc, legalMaskType, valueToStore);
244     // Replace this store with a conversion to a storable svbool mask [1],
245     // followed by a wider store.
246     replaceOpWithLegalizedOp(rewriter, storeOp,
247                              [&](memref::StoreOp newStoreOp) {
248                                newStoreOp.setOperand(0, convertToSvbool);
249                                newStoreOp.setOperand(1, *legalMemref);
250                                return newStoreOp;
251                              });
252 
253     return success();
254   }
255 };
256 
257 /// Replaces loads from unrealized conversions to SVE predicate memref types
258 /// that are _illegal_ to load/store from (!= svbool [1]), types with (legal)
259 /// wider loads, followed by `arm_sve.convert_from_svbool`s.
260 ///
261 /// Example:
262 /// ```
263 /// %reload = memref.load %alloca[] : memref<vector<[4]xi1>>
264 /// ```
265 /// is rewritten into:
266 /// ```
267 /// %svbool = memref.load %widened[] : memref<vector<[16]xi1>>
268 /// %reload = arm_sve.convert_from_svbool %reload : vector<[4]xi1>
269 /// ```
270 struct LegalizeSVEMaskLoadConversion : public OpRewritePattern<memref::LoadOp> {
271   using OpRewritePattern::OpRewritePattern;
272 
273   LogicalResult matchAndRewrite(memref::LoadOp loadOp,
274                                 PatternRewriter &rewriter) const override {
275     auto loc = loadOp.getLoc();
276 
277     Value loadedMask = loadOp.getResult();
278     auto vectorType = llvm::dyn_cast<VectorType>(loadedMask.getType());
279 
280     if (!vectorType || !isSVEMaskType(vectorType))
281       return failure();
282 
283     auto legalMemref = getSVELegalizedMemref(loadOp.getMemref());
284     if (failed(legalMemref))
285       return failure();
286 
287     auto legalMaskType = widenScalableMaskTypeToSvbool(vectorType);
288     // Replace this load with a legal load of an svbool type, followed by a
289     // conversion back to the original type.
290     replaceOpWithLegalizedOp(rewriter, loadOp, [&](memref::LoadOp newLoadOp) {
291       newLoadOp.setMemRef(*legalMemref);
292       newLoadOp.getResult().setType(legalMaskType);
293       return rewriter.create<arm_sve::ConvertFromSvboolOp>(
294           loc, loadedMask.getType(), newLoadOp);
295     });
296 
297     return success();
298   }
299 };
300 
301 } // namespace
302 
303 void mlir::arm_sve::populateLegalizeVectorStoragePatterns(
304     RewritePatternSet &patterns) {
305   patterns.add<RelaxScalableVectorAllocaAlignment,
306                LegalizeSVEMaskAllocation<memref::AllocaOp>,
307                LegalizeSVEMaskAllocation<memref::AllocOp>,
308                LegalizeSVEMaskTypeCastConversion,
309                LegalizeSVEMaskStoreConversion, LegalizeSVEMaskLoadConversion>(
310       patterns.getContext());
311 }
312 
313 namespace {
314 struct LegalizeVectorStorage
315     : public arm_sve::impl::LegalizeVectorStorageBase<LegalizeVectorStorage> {
316 
317   void runOnOperation() override {
318     RewritePatternSet patterns(&getContext());
319     populateLegalizeVectorStoragePatterns(patterns);
320     if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
321       signalPassFailure();
322     }
323     ConversionTarget target(getContext());
324     target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
325         [](UnrealizedConversionCastOp unrealizedConversion) {
326           return !unrealizedConversion->hasAttr(kSVELegalizerTag);
327         });
328     // This detects if we failed to completely legalize the IR.
329     if (failed(applyPartialConversion(getOperation(), target, {})))
330       signalPassFailure();
331   }
332 };
333 
334 } // namespace
335 
336 std::unique_ptr<Pass> mlir::arm_sve::createLegalizeVectorStoragePass() {
337   return std::make_unique<LegalizeVectorStorage>();
338 }
339