xref: /llvm-project/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
196e040acSBenjamin Maxwell //===- LegalizeVectorStorage.cpp - Ensures SVE loads/stores are legal -----===//
296e040acSBenjamin Maxwell //
396e040acSBenjamin Maxwell // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
496e040acSBenjamin Maxwell // See https://llvm.org/LICENSE.txt for license information.
596e040acSBenjamin Maxwell // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
696e040acSBenjamin Maxwell //
796e040acSBenjamin Maxwell //===----------------------------------------------------------------------===//
896e040acSBenjamin Maxwell 
996e040acSBenjamin Maxwell #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
1096e040acSBenjamin Maxwell #include "mlir/Dialect/ArmSVE/Transforms/Passes.h"
1196e040acSBenjamin Maxwell #include "mlir/Dialect/Func/IR/FuncOps.h"
1296e040acSBenjamin Maxwell #include "mlir/Dialect/MemRef/IR/MemRef.h"
1396e040acSBenjamin Maxwell #include "mlir/Dialect/Vector/IR/VectorOps.h"
1496e040acSBenjamin Maxwell #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1596e040acSBenjamin Maxwell 
1696e040acSBenjamin Maxwell namespace mlir::arm_sve {
1796e040acSBenjamin Maxwell #define GEN_PASS_DEF_LEGALIZEVECTORSTORAGE
1896e040acSBenjamin Maxwell #include "mlir/Dialect/ArmSVE/Transforms/Passes.h.inc"
1996e040acSBenjamin Maxwell } // namespace mlir::arm_sve
2096e040acSBenjamin Maxwell 
2196e040acSBenjamin Maxwell using namespace mlir;
2296e040acSBenjamin Maxwell using namespace mlir::arm_sve;
2396e040acSBenjamin Maxwell 
2496e040acSBenjamin Maxwell // A tag to mark unrealized_conversions produced by this pass. This is used to
2596e040acSBenjamin Maxwell // detect IR this pass failed to completely legalize, and report an error.
2696e040acSBenjamin Maxwell // If everything was successfully legalized, no tagged ops will remain after
2796e040acSBenjamin Maxwell // this pass.
2896e040acSBenjamin Maxwell constexpr StringLiteral kSVELegalizerTag("__arm_sve_legalize_vector_storage__");
2996e040acSBenjamin Maxwell 
3096e040acSBenjamin Maxwell /// Definitions:
3196e040acSBenjamin Maxwell ///
3296e040acSBenjamin Maxwell /// [1] svbool = vector<...x[16]xi1>, which maps to some multiple of full SVE
3396e040acSBenjamin Maxwell /// predicate registers. A full predicate is the smallest quantity that can be
3496e040acSBenjamin Maxwell /// loaded/stored.
3596e040acSBenjamin Maxwell ///
3696e040acSBenjamin Maxwell /// [2] SVE mask = hardware-sized SVE predicate mask, i.e. its trailing
3796e040acSBenjamin Maxwell /// dimension matches the size of a legal SVE vector size (such as
3896e040acSBenjamin Maxwell /// vector<[4]xi1>), but is too small to be stored to memory (i.e smaller than
3996e040acSBenjamin Maxwell /// a svbool).
4096e040acSBenjamin Maxwell 
4196e040acSBenjamin Maxwell namespace {
4296e040acSBenjamin Maxwell 
4396e040acSBenjamin Maxwell /// Checks if a vector type is a SVE mask [2].
4496e040acSBenjamin Maxwell bool isSVEMaskType(VectorType type) {
4596e040acSBenjamin Maxwell   return type.getRank() > 0 && type.getElementType().isInteger(1) &&
4696e040acSBenjamin Maxwell          type.getScalableDims().back() && type.getShape().back() < 16 &&
4796e040acSBenjamin Maxwell          llvm::isPowerOf2_32(type.getShape().back()) &&
4896e040acSBenjamin Maxwell          !llvm::is_contained(type.getScalableDims().drop_back(), true);
4996e040acSBenjamin Maxwell }
5096e040acSBenjamin Maxwell 
5196e040acSBenjamin Maxwell VectorType widenScalableMaskTypeToSvbool(VectorType type) {
5296e040acSBenjamin Maxwell   assert(isSVEMaskType(type));
5396e040acSBenjamin Maxwell   return VectorType::Builder(type).setDim(type.getRank() - 1, 16);
5496e040acSBenjamin Maxwell }
5596e040acSBenjamin Maxwell 
5696e040acSBenjamin Maxwell /// A helper for cloning an op and replacing it will a new version, updated by a
5796e040acSBenjamin Maxwell /// callback.
5896e040acSBenjamin Maxwell template <typename TOp, typename TLegalizerCallback>
5996e040acSBenjamin Maxwell void replaceOpWithLegalizedOp(PatternRewriter &rewriter, TOp op,
6096e040acSBenjamin Maxwell                               TLegalizerCallback callback) {
6196e040acSBenjamin Maxwell   // Clone the previous op to preserve any properties/attributes.
6296e040acSBenjamin Maxwell   auto newOp = op.clone();
6396e040acSBenjamin Maxwell   rewriter.insert(newOp);
6496e040acSBenjamin Maxwell   rewriter.replaceOp(op, callback(newOp));
6596e040acSBenjamin Maxwell }
6696e040acSBenjamin Maxwell 
6796e040acSBenjamin Maxwell /// A helper for cloning an op and replacing it with a new version, updated by a
6896e040acSBenjamin Maxwell /// callback, and an unrealized conversion back to the type of the replaced op.
6996e040acSBenjamin Maxwell template <typename TOp, typename TLegalizerCallback>
7096e040acSBenjamin Maxwell void replaceOpWithUnrealizedConversion(PatternRewriter &rewriter, TOp op,
7196e040acSBenjamin Maxwell                                        TLegalizerCallback callback) {
7296e040acSBenjamin Maxwell   replaceOpWithLegalizedOp(rewriter, op, [&](TOp newOp) {
7396e040acSBenjamin Maxwell     // Mark our `unrealized_conversion_casts` with a pass label.
7496e040acSBenjamin Maxwell     return rewriter.create<UnrealizedConversionCastOp>(
7596e040acSBenjamin Maxwell         op.getLoc(), TypeRange{op.getResult().getType()},
7696e040acSBenjamin Maxwell         ValueRange{callback(newOp)},
7796e040acSBenjamin Maxwell         NamedAttribute(rewriter.getStringAttr(kSVELegalizerTag),
7896e040acSBenjamin Maxwell                        rewriter.getUnitAttr()));
7996e040acSBenjamin Maxwell   });
8096e040acSBenjamin Maxwell }
8196e040acSBenjamin Maxwell 
8296e040acSBenjamin Maxwell /// Extracts the widened SVE memref value (that's legal to store/load) from the
8396e040acSBenjamin Maxwell /// `unrealized_conversion_cast`s added by this pass.
8496e040acSBenjamin Maxwell static FailureOr<Value> getSVELegalizedMemref(Value illegalMemref) {
8596e040acSBenjamin Maxwell   Operation *definingOp = illegalMemref.getDefiningOp();
8696e040acSBenjamin Maxwell   if (!definingOp || !definingOp->hasAttr(kSVELegalizerTag))
8796e040acSBenjamin Maxwell     return failure();
8896e040acSBenjamin Maxwell   auto unrealizedConversion =
8996e040acSBenjamin Maxwell       llvm::cast<UnrealizedConversionCastOp>(definingOp);
9096e040acSBenjamin Maxwell   return unrealizedConversion.getOperand(0);
9196e040acSBenjamin Maxwell }
9296e040acSBenjamin Maxwell 
9396e040acSBenjamin Maxwell /// The default alignment of an alloca in LLVM may request overaligned sizes for
9496e040acSBenjamin Maxwell /// SVE types, which will fail during stack frame allocation. This rewrite
9596e040acSBenjamin Maxwell /// explicitly adds a reasonable alignment to allocas of scalable types.
9696e040acSBenjamin Maxwell struct RelaxScalableVectorAllocaAlignment
9796e040acSBenjamin Maxwell     : public OpRewritePattern<memref::AllocaOp> {
9896e040acSBenjamin Maxwell   using OpRewritePattern::OpRewritePattern;
9996e040acSBenjamin Maxwell 
10096e040acSBenjamin Maxwell   LogicalResult matchAndRewrite(memref::AllocaOp allocaOp,
10196e040acSBenjamin Maxwell                                 PatternRewriter &rewriter) const override {
10296e040acSBenjamin Maxwell     auto memrefElementType = allocaOp.getType().getElementType();
10396e040acSBenjamin Maxwell     auto vectorType = llvm::dyn_cast<VectorType>(memrefElementType);
10496e040acSBenjamin Maxwell     if (!vectorType || !vectorType.isScalable() || allocaOp.getAlignment())
10596e040acSBenjamin Maxwell       return failure();
10696e040acSBenjamin Maxwell 
10796e040acSBenjamin Maxwell     // Set alignment based on the defaults for SVE vectors and predicates.
10896e040acSBenjamin Maxwell     unsigned aligment = vectorType.getElementType().isInteger(1) ? 2 : 16;
1095fcf907bSMatthias Springer     rewriter.modifyOpInPlace(allocaOp,
110e2bb47caSMatthias Springer                              [&] { allocaOp.setAlignment(aligment); });
11196e040acSBenjamin Maxwell 
11296e040acSBenjamin Maxwell     return success();
11396e040acSBenjamin Maxwell   }
11496e040acSBenjamin Maxwell };
11596e040acSBenjamin Maxwell 
11696e040acSBenjamin Maxwell /// Replaces allocations of SVE predicates smaller than an svbool [1] (_illegal_
11796e040acSBenjamin Maxwell /// to load/store) with a wider allocation of svbool (_legal_ to load/store)
11896e040acSBenjamin Maxwell /// followed by a tagged unrealized conversion to the original type.
11996e040acSBenjamin Maxwell ///
12096e040acSBenjamin Maxwell /// Example
12196e040acSBenjamin Maxwell /// ```
12296e040acSBenjamin Maxwell /// %alloca = memref.alloca() : memref<vector<[4]xi1>>
12396e040acSBenjamin Maxwell /// ```
12496e040acSBenjamin Maxwell /// is rewritten into:
12596e040acSBenjamin Maxwell /// ```
12696e040acSBenjamin Maxwell /// %widened = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>>
12796e040acSBenjamin Maxwell /// %alloca = builtin.unrealized_conversion_cast %widened
12896e040acSBenjamin Maxwell ///   : memref<vector<[16]xi1>> to memref<vector<[4]xi1>>
12996e040acSBenjamin Maxwell ///     {__arm_sve_legalize_vector_storage__}
13096e040acSBenjamin Maxwell /// ```
13196e040acSBenjamin Maxwell template <typename AllocLikeOp>
13296e040acSBenjamin Maxwell struct LegalizeSVEMaskAllocation : public OpRewritePattern<AllocLikeOp> {
13396e040acSBenjamin Maxwell   using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
13496e040acSBenjamin Maxwell 
13596e040acSBenjamin Maxwell   LogicalResult matchAndRewrite(AllocLikeOp allocLikeOp,
13696e040acSBenjamin Maxwell                                 PatternRewriter &rewriter) const override {
13796e040acSBenjamin Maxwell     auto vectorType =
13896e040acSBenjamin Maxwell         llvm::dyn_cast<VectorType>(allocLikeOp.getType().getElementType());
13996e040acSBenjamin Maxwell 
14096e040acSBenjamin Maxwell     if (!vectorType || !isSVEMaskType(vectorType))
14196e040acSBenjamin Maxwell       return failure();
14296e040acSBenjamin Maxwell 
14396e040acSBenjamin Maxwell     // Replace this alloc-like op of an SVE mask [2] with one of a (storable)
14496e040acSBenjamin Maxwell     // svbool mask [1]. A temporary unrealized_conversion_cast is added to the
14596e040acSBenjamin Maxwell     // old type to allow local rewrites.
14696e040acSBenjamin Maxwell     replaceOpWithUnrealizedConversion(
14796e040acSBenjamin Maxwell         rewriter, allocLikeOp, [&](AllocLikeOp newAllocLikeOp) {
14896e040acSBenjamin Maxwell           newAllocLikeOp.getResult().setType(
14996e040acSBenjamin Maxwell               llvm::cast<MemRefType>(newAllocLikeOp.getType().cloneWith(
15096e040acSBenjamin Maxwell                   {}, widenScalableMaskTypeToSvbool(vectorType))));
15196e040acSBenjamin Maxwell           return newAllocLikeOp;
15296e040acSBenjamin Maxwell         });
15396e040acSBenjamin Maxwell 
15496e040acSBenjamin Maxwell     return success();
15596e040acSBenjamin Maxwell   }
15696e040acSBenjamin Maxwell };
15796e040acSBenjamin Maxwell 
15896e040acSBenjamin Maxwell /// Replaces vector.type_casts of unrealized conversions to SVE predicate memref
15996e040acSBenjamin Maxwell /// types that are _illegal_ to load/store from (!= svbool [1]), with type casts
16096e040acSBenjamin Maxwell /// of memref types that are _legal_ to load/store, followed by unrealized
16196e040acSBenjamin Maxwell /// conversions.
16296e040acSBenjamin Maxwell ///
16396e040acSBenjamin Maxwell /// Example:
16496e040acSBenjamin Maxwell /// ```
16596e040acSBenjamin Maxwell /// %alloca = builtin.unrealized_conversion_cast %widened
16696e040acSBenjamin Maxwell ///   : memref<vector<[16]xi1>> to memref<vector<[8]xi1>>
16796e040acSBenjamin Maxwell ///     {__arm_sve_legalize_vector_storage__}
16896e040acSBenjamin Maxwell /// %cast = vector.type_cast %alloca
16996e040acSBenjamin Maxwell ///   : memref<vector<3x[8]xi1>> to memref<3xvector<[8]xi1>>
17096e040acSBenjamin Maxwell /// ```
17196e040acSBenjamin Maxwell /// is rewritten into:
17296e040acSBenjamin Maxwell /// ```
17396e040acSBenjamin Maxwell /// %widened_cast = vector.type_cast %widened
17496e040acSBenjamin Maxwell ///   : memref<vector<3x[16]xi1>> to memref<3xvector<[16]xi1>>
17596e040acSBenjamin Maxwell /// %cast = builtin.unrealized_conversion_cast %widened_cast
17696e040acSBenjamin Maxwell ///   : memref<3xvector<[16]xi1>> to memref<3xvector<[8]xi1>>
17796e040acSBenjamin Maxwell ///     {__arm_sve_legalize_vector_storage__}
17896e040acSBenjamin Maxwell /// ```
17996e040acSBenjamin Maxwell struct LegalizeSVEMaskTypeCastConversion
18096e040acSBenjamin Maxwell     : public OpRewritePattern<vector::TypeCastOp> {
18196e040acSBenjamin Maxwell   using OpRewritePattern::OpRewritePattern;
18296e040acSBenjamin Maxwell 
18396e040acSBenjamin Maxwell   LogicalResult matchAndRewrite(vector::TypeCastOp typeCastOp,
18496e040acSBenjamin Maxwell                                 PatternRewriter &rewriter) const override {
18596e040acSBenjamin Maxwell     auto resultType = typeCastOp.getResultMemRefType();
18696e040acSBenjamin Maxwell     auto vectorType = llvm::dyn_cast<VectorType>(resultType.getElementType());
18796e040acSBenjamin Maxwell 
18896e040acSBenjamin Maxwell     if (!vectorType || !isSVEMaskType(vectorType))
18996e040acSBenjamin Maxwell       return failure();
19096e040acSBenjamin Maxwell 
19196e040acSBenjamin Maxwell     auto legalMemref = getSVELegalizedMemref(typeCastOp.getMemref());
19296e040acSBenjamin Maxwell     if (failed(legalMemref))
19396e040acSBenjamin Maxwell       return failure();
19496e040acSBenjamin Maxwell 
19596e040acSBenjamin Maxwell     // Replace this vector.type_cast with one of a (storable) svbool mask [1].
19696e040acSBenjamin Maxwell     replaceOpWithUnrealizedConversion(
19796e040acSBenjamin Maxwell         rewriter, typeCastOp, [&](vector::TypeCastOp newTypeCast) {
19896e040acSBenjamin Maxwell           newTypeCast.setOperand(*legalMemref);
19996e040acSBenjamin Maxwell           newTypeCast.getResult().setType(
20096e040acSBenjamin Maxwell               llvm::cast<MemRefType>(newTypeCast.getType().cloneWith(
20196e040acSBenjamin Maxwell                   {}, widenScalableMaskTypeToSvbool(vectorType))));
20296e040acSBenjamin Maxwell           return newTypeCast;
20396e040acSBenjamin Maxwell         });
20496e040acSBenjamin Maxwell 
20596e040acSBenjamin Maxwell     return success();
20696e040acSBenjamin Maxwell   }
20796e040acSBenjamin Maxwell };
20896e040acSBenjamin Maxwell 
20996e040acSBenjamin Maxwell /// Replaces stores to unrealized conversions to SVE predicate memref types that
21096e040acSBenjamin Maxwell /// are _illegal_ to load/store from (!= svbool [1]), with
21196e040acSBenjamin Maxwell /// `arm_sve.convert_to_svbool`s followed by (legal) wider stores.
21296e040acSBenjamin Maxwell ///
21396e040acSBenjamin Maxwell /// Example:
21496e040acSBenjamin Maxwell /// ```
21596e040acSBenjamin Maxwell /// memref.store %mask, %alloca[] : memref<vector<[8]xi1>>
21696e040acSBenjamin Maxwell /// ```
21796e040acSBenjamin Maxwell /// is rewritten into:
21896e040acSBenjamin Maxwell /// ```
21996e040acSBenjamin Maxwell /// %svbool = arm_sve.convert_to_svbool %mask : vector<[8]xi1>
22096e040acSBenjamin Maxwell /// memref.store %svbool, %widened[] : memref<vector<[16]xi1>>
22196e040acSBenjamin Maxwell /// ```
22296e040acSBenjamin Maxwell struct LegalizeSVEMaskStoreConversion
22396e040acSBenjamin Maxwell     : public OpRewritePattern<memref::StoreOp> {
22496e040acSBenjamin Maxwell   using OpRewritePattern::OpRewritePattern;
22596e040acSBenjamin Maxwell 
22696e040acSBenjamin Maxwell   LogicalResult matchAndRewrite(memref::StoreOp storeOp,
22796e040acSBenjamin Maxwell                                 PatternRewriter &rewriter) const override {
22896e040acSBenjamin Maxwell     auto loc = storeOp.getLoc();
22996e040acSBenjamin Maxwell 
23096e040acSBenjamin Maxwell     Value valueToStore = storeOp.getValueToStore();
23196e040acSBenjamin Maxwell     auto vectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
23296e040acSBenjamin Maxwell 
23396e040acSBenjamin Maxwell     if (!vectorType || !isSVEMaskType(vectorType))
23496e040acSBenjamin Maxwell       return failure();
23596e040acSBenjamin Maxwell 
23696e040acSBenjamin Maxwell     auto legalMemref = getSVELegalizedMemref(storeOp.getMemref());
23796e040acSBenjamin Maxwell     if (failed(legalMemref))
23896e040acSBenjamin Maxwell       return failure();
23996e040acSBenjamin Maxwell 
24096e040acSBenjamin Maxwell     auto legalMaskType = widenScalableMaskTypeToSvbool(
24196e040acSBenjamin Maxwell         llvm::cast<VectorType>(valueToStore.getType()));
24296e040acSBenjamin Maxwell     auto convertToSvbool = rewriter.create<arm_sve::ConvertToSvboolOp>(
24396e040acSBenjamin Maxwell         loc, legalMaskType, valueToStore);
24496e040acSBenjamin Maxwell     // Replace this store with a conversion to a storable svbool mask [1],
24596e040acSBenjamin Maxwell     // followed by a wider store.
24696e040acSBenjamin Maxwell     replaceOpWithLegalizedOp(rewriter, storeOp,
24796e040acSBenjamin Maxwell                              [&](memref::StoreOp newStoreOp) {
24896e040acSBenjamin Maxwell                                newStoreOp.setOperand(0, convertToSvbool);
24996e040acSBenjamin Maxwell                                newStoreOp.setOperand(1, *legalMemref);
25096e040acSBenjamin Maxwell                                return newStoreOp;
25196e040acSBenjamin Maxwell                              });
25296e040acSBenjamin Maxwell 
25396e040acSBenjamin Maxwell     return success();
25496e040acSBenjamin Maxwell   }
25596e040acSBenjamin Maxwell };
25696e040acSBenjamin Maxwell 
25796e040acSBenjamin Maxwell /// Replaces loads from unrealized conversions to SVE predicate memref types
25896e040acSBenjamin Maxwell /// that are _illegal_ to load/store from (!= svbool [1]), types with (legal)
25996e040acSBenjamin Maxwell /// wider loads, followed by `arm_sve.convert_from_svbool`s.
26096e040acSBenjamin Maxwell ///
26196e040acSBenjamin Maxwell /// Example:
26296e040acSBenjamin Maxwell /// ```
26396e040acSBenjamin Maxwell /// %reload = memref.load %alloca[] : memref<vector<[4]xi1>>
26496e040acSBenjamin Maxwell /// ```
26596e040acSBenjamin Maxwell /// is rewritten into:
26696e040acSBenjamin Maxwell /// ```
26796e040acSBenjamin Maxwell /// %svbool = memref.load %widened[] : memref<vector<[16]xi1>>
26896e040acSBenjamin Maxwell /// %reload = arm_sve.convert_from_svbool %reload : vector<[4]xi1>
26996e040acSBenjamin Maxwell /// ```
27096e040acSBenjamin Maxwell struct LegalizeSVEMaskLoadConversion : public OpRewritePattern<memref::LoadOp> {
27196e040acSBenjamin Maxwell   using OpRewritePattern::OpRewritePattern;
27296e040acSBenjamin Maxwell 
27396e040acSBenjamin Maxwell   LogicalResult matchAndRewrite(memref::LoadOp loadOp,
27496e040acSBenjamin Maxwell                                 PatternRewriter &rewriter) const override {
27596e040acSBenjamin Maxwell     auto loc = loadOp.getLoc();
27696e040acSBenjamin Maxwell 
27796e040acSBenjamin Maxwell     Value loadedMask = loadOp.getResult();
27896e040acSBenjamin Maxwell     auto vectorType = llvm::dyn_cast<VectorType>(loadedMask.getType());
27996e040acSBenjamin Maxwell 
28096e040acSBenjamin Maxwell     if (!vectorType || !isSVEMaskType(vectorType))
28196e040acSBenjamin Maxwell       return failure();
28296e040acSBenjamin Maxwell 
28396e040acSBenjamin Maxwell     auto legalMemref = getSVELegalizedMemref(loadOp.getMemref());
28496e040acSBenjamin Maxwell     if (failed(legalMemref))
28596e040acSBenjamin Maxwell       return failure();
28696e040acSBenjamin Maxwell 
28796e040acSBenjamin Maxwell     auto legalMaskType = widenScalableMaskTypeToSvbool(vectorType);
28896e040acSBenjamin Maxwell     // Replace this load with a legal load of an svbool type, followed by a
28996e040acSBenjamin Maxwell     // conversion back to the original type.
29096e040acSBenjamin Maxwell     replaceOpWithLegalizedOp(rewriter, loadOp, [&](memref::LoadOp newLoadOp) {
29196e040acSBenjamin Maxwell       newLoadOp.setMemRef(*legalMemref);
29296e040acSBenjamin Maxwell       newLoadOp.getResult().setType(legalMaskType);
29396e040acSBenjamin Maxwell       return rewriter.create<arm_sve::ConvertFromSvboolOp>(
29496e040acSBenjamin Maxwell           loc, loadedMask.getType(), newLoadOp);
29596e040acSBenjamin Maxwell     });
29696e040acSBenjamin Maxwell 
29796e040acSBenjamin Maxwell     return success();
29896e040acSBenjamin Maxwell   }
29996e040acSBenjamin Maxwell };
30096e040acSBenjamin Maxwell 
30196e040acSBenjamin Maxwell } // namespace
30296e040acSBenjamin Maxwell 
30396e040acSBenjamin Maxwell void mlir::arm_sve::populateLegalizeVectorStoragePatterns(
30496e040acSBenjamin Maxwell     RewritePatternSet &patterns) {
30596e040acSBenjamin Maxwell   patterns.add<RelaxScalableVectorAllocaAlignment,
30696e040acSBenjamin Maxwell                LegalizeSVEMaskAllocation<memref::AllocaOp>,
30796e040acSBenjamin Maxwell                LegalizeSVEMaskAllocation<memref::AllocOp>,
30896e040acSBenjamin Maxwell                LegalizeSVEMaskTypeCastConversion,
30996e040acSBenjamin Maxwell                LegalizeSVEMaskStoreConversion, LegalizeSVEMaskLoadConversion>(
31096e040acSBenjamin Maxwell       patterns.getContext());
31196e040acSBenjamin Maxwell }
31296e040acSBenjamin Maxwell 
31396e040acSBenjamin Maxwell namespace {
31496e040acSBenjamin Maxwell struct LegalizeVectorStorage
31596e040acSBenjamin Maxwell     : public arm_sve::impl::LegalizeVectorStorageBase<LegalizeVectorStorage> {
31696e040acSBenjamin Maxwell 
31796e040acSBenjamin Maxwell   void runOnOperation() override {
31896e040acSBenjamin Maxwell     RewritePatternSet patterns(&getContext());
31996e040acSBenjamin Maxwell     populateLegalizeVectorStoragePatterns(patterns);
320*09dfc571SJacques Pienaar     if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
32196e040acSBenjamin Maxwell       signalPassFailure();
32296e040acSBenjamin Maxwell     }
32396e040acSBenjamin Maxwell     ConversionTarget target(getContext());
32496e040acSBenjamin Maxwell     target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
32596e040acSBenjamin Maxwell         [](UnrealizedConversionCastOp unrealizedConversion) {
32696e040acSBenjamin Maxwell           return !unrealizedConversion->hasAttr(kSVELegalizerTag);
32796e040acSBenjamin Maxwell         });
32896e040acSBenjamin Maxwell     // This detects if we failed to completely legalize the IR.
32996e040acSBenjamin Maxwell     if (failed(applyPartialConversion(getOperation(), target, {})))
33096e040acSBenjamin Maxwell       signalPassFailure();
33196e040acSBenjamin Maxwell   }
33296e040acSBenjamin Maxwell };
33396e040acSBenjamin Maxwell 
33496e040acSBenjamin Maxwell } // namespace
33596e040acSBenjamin Maxwell 
33696e040acSBenjamin Maxwell std::unique_ptr<Pass> mlir::arm_sve::createLegalizeVectorStoragePass() {
33796e040acSBenjamin Maxwell   return std::make_unique<LegalizeVectorStorage>();
33896e040acSBenjamin Maxwell }
339