196e040acSBenjamin Maxwell //===- LegalizeVectorStorage.cpp - Ensures SVE loads/stores are legal -----===// 296e040acSBenjamin Maxwell // 396e040acSBenjamin Maxwell // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 496e040acSBenjamin Maxwell // See https://llvm.org/LICENSE.txt for license information. 596e040acSBenjamin Maxwell // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 696e040acSBenjamin Maxwell // 796e040acSBenjamin Maxwell //===----------------------------------------------------------------------===// 896e040acSBenjamin Maxwell 996e040acSBenjamin Maxwell #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" 1096e040acSBenjamin Maxwell #include "mlir/Dialect/ArmSVE/Transforms/Passes.h" 1196e040acSBenjamin Maxwell #include "mlir/Dialect/Func/IR/FuncOps.h" 1296e040acSBenjamin Maxwell #include "mlir/Dialect/MemRef/IR/MemRef.h" 1396e040acSBenjamin Maxwell #include "mlir/Dialect/Vector/IR/VectorOps.h" 1496e040acSBenjamin Maxwell #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 1596e040acSBenjamin Maxwell 1696e040acSBenjamin Maxwell namespace mlir::arm_sve { 1796e040acSBenjamin Maxwell #define GEN_PASS_DEF_LEGALIZEVECTORSTORAGE 1896e040acSBenjamin Maxwell #include "mlir/Dialect/ArmSVE/Transforms/Passes.h.inc" 1996e040acSBenjamin Maxwell } // namespace mlir::arm_sve 2096e040acSBenjamin Maxwell 2196e040acSBenjamin Maxwell using namespace mlir; 2296e040acSBenjamin Maxwell using namespace mlir::arm_sve; 2396e040acSBenjamin Maxwell 2496e040acSBenjamin Maxwell // A tag to mark unrealized_conversions produced by this pass. This is used to 2596e040acSBenjamin Maxwell // detect IR this pass failed to completely legalize, and report an error. 2696e040acSBenjamin Maxwell // If everything was successfully legalized, no tagged ops will remain after 2796e040acSBenjamin Maxwell // this pass. 2896e040acSBenjamin Maxwell constexpr StringLiteral kSVELegalizerTag("__arm_sve_legalize_vector_storage__"); 2996e040acSBenjamin Maxwell 3096e040acSBenjamin Maxwell /// Definitions: 3196e040acSBenjamin Maxwell /// 3296e040acSBenjamin Maxwell /// [1] svbool = vector<...x[16]xi1>, which maps to some multiple of full SVE 3396e040acSBenjamin Maxwell /// predicate registers. A full predicate is the smallest quantity that can be 3496e040acSBenjamin Maxwell /// loaded/stored. 3596e040acSBenjamin Maxwell /// 3696e040acSBenjamin Maxwell /// [2] SVE mask = hardware-sized SVE predicate mask, i.e. its trailing 3796e040acSBenjamin Maxwell /// dimension matches the size of a legal SVE vector size (such as 3896e040acSBenjamin Maxwell /// vector<[4]xi1>), but is too small to be stored to memory (i.e smaller than 3996e040acSBenjamin Maxwell /// a svbool). 4096e040acSBenjamin Maxwell 4196e040acSBenjamin Maxwell namespace { 4296e040acSBenjamin Maxwell 4396e040acSBenjamin Maxwell /// Checks if a vector type is a SVE mask [2]. 4496e040acSBenjamin Maxwell bool isSVEMaskType(VectorType type) { 4596e040acSBenjamin Maxwell return type.getRank() > 0 && type.getElementType().isInteger(1) && 4696e040acSBenjamin Maxwell type.getScalableDims().back() && type.getShape().back() < 16 && 4796e040acSBenjamin Maxwell llvm::isPowerOf2_32(type.getShape().back()) && 4896e040acSBenjamin Maxwell !llvm::is_contained(type.getScalableDims().drop_back(), true); 4996e040acSBenjamin Maxwell } 5096e040acSBenjamin Maxwell 5196e040acSBenjamin Maxwell VectorType widenScalableMaskTypeToSvbool(VectorType type) { 5296e040acSBenjamin Maxwell assert(isSVEMaskType(type)); 5396e040acSBenjamin Maxwell return VectorType::Builder(type).setDim(type.getRank() - 1, 16); 5496e040acSBenjamin Maxwell } 5596e040acSBenjamin Maxwell 5696e040acSBenjamin Maxwell /// A helper for cloning an op and replacing it will a new version, updated by a 5796e040acSBenjamin Maxwell /// callback. 5896e040acSBenjamin Maxwell template <typename TOp, typename TLegalizerCallback> 5996e040acSBenjamin Maxwell void replaceOpWithLegalizedOp(PatternRewriter &rewriter, TOp op, 6096e040acSBenjamin Maxwell TLegalizerCallback callback) { 6196e040acSBenjamin Maxwell // Clone the previous op to preserve any properties/attributes. 6296e040acSBenjamin Maxwell auto newOp = op.clone(); 6396e040acSBenjamin Maxwell rewriter.insert(newOp); 6496e040acSBenjamin Maxwell rewriter.replaceOp(op, callback(newOp)); 6596e040acSBenjamin Maxwell } 6696e040acSBenjamin Maxwell 6796e040acSBenjamin Maxwell /// A helper for cloning an op and replacing it with a new version, updated by a 6896e040acSBenjamin Maxwell /// callback, and an unrealized conversion back to the type of the replaced op. 6996e040acSBenjamin Maxwell template <typename TOp, typename TLegalizerCallback> 7096e040acSBenjamin Maxwell void replaceOpWithUnrealizedConversion(PatternRewriter &rewriter, TOp op, 7196e040acSBenjamin Maxwell TLegalizerCallback callback) { 7296e040acSBenjamin Maxwell replaceOpWithLegalizedOp(rewriter, op, [&](TOp newOp) { 7396e040acSBenjamin Maxwell // Mark our `unrealized_conversion_casts` with a pass label. 7496e040acSBenjamin Maxwell return rewriter.create<UnrealizedConversionCastOp>( 7596e040acSBenjamin Maxwell op.getLoc(), TypeRange{op.getResult().getType()}, 7696e040acSBenjamin Maxwell ValueRange{callback(newOp)}, 7796e040acSBenjamin Maxwell NamedAttribute(rewriter.getStringAttr(kSVELegalizerTag), 7896e040acSBenjamin Maxwell rewriter.getUnitAttr())); 7996e040acSBenjamin Maxwell }); 8096e040acSBenjamin Maxwell } 8196e040acSBenjamin Maxwell 8296e040acSBenjamin Maxwell /// Extracts the widened SVE memref value (that's legal to store/load) from the 8396e040acSBenjamin Maxwell /// `unrealized_conversion_cast`s added by this pass. 8496e040acSBenjamin Maxwell static FailureOr<Value> getSVELegalizedMemref(Value illegalMemref) { 8596e040acSBenjamin Maxwell Operation *definingOp = illegalMemref.getDefiningOp(); 8696e040acSBenjamin Maxwell if (!definingOp || !definingOp->hasAttr(kSVELegalizerTag)) 8796e040acSBenjamin Maxwell return failure(); 8896e040acSBenjamin Maxwell auto unrealizedConversion = 8996e040acSBenjamin Maxwell llvm::cast<UnrealizedConversionCastOp>(definingOp); 9096e040acSBenjamin Maxwell return unrealizedConversion.getOperand(0); 9196e040acSBenjamin Maxwell } 9296e040acSBenjamin Maxwell 9396e040acSBenjamin Maxwell /// The default alignment of an alloca in LLVM may request overaligned sizes for 9496e040acSBenjamin Maxwell /// SVE types, which will fail during stack frame allocation. This rewrite 9596e040acSBenjamin Maxwell /// explicitly adds a reasonable alignment to allocas of scalable types. 9696e040acSBenjamin Maxwell struct RelaxScalableVectorAllocaAlignment 9796e040acSBenjamin Maxwell : public OpRewritePattern<memref::AllocaOp> { 9896e040acSBenjamin Maxwell using OpRewritePattern::OpRewritePattern; 9996e040acSBenjamin Maxwell 10096e040acSBenjamin Maxwell LogicalResult matchAndRewrite(memref::AllocaOp allocaOp, 10196e040acSBenjamin Maxwell PatternRewriter &rewriter) const override { 10296e040acSBenjamin Maxwell auto memrefElementType = allocaOp.getType().getElementType(); 10396e040acSBenjamin Maxwell auto vectorType = llvm::dyn_cast<VectorType>(memrefElementType); 10496e040acSBenjamin Maxwell if (!vectorType || !vectorType.isScalable() || allocaOp.getAlignment()) 10596e040acSBenjamin Maxwell return failure(); 10696e040acSBenjamin Maxwell 10796e040acSBenjamin Maxwell // Set alignment based on the defaults for SVE vectors and predicates. 10896e040acSBenjamin Maxwell unsigned aligment = vectorType.getElementType().isInteger(1) ? 2 : 16; 1095fcf907bSMatthias Springer rewriter.modifyOpInPlace(allocaOp, 110e2bb47caSMatthias Springer [&] { allocaOp.setAlignment(aligment); }); 11196e040acSBenjamin Maxwell 11296e040acSBenjamin Maxwell return success(); 11396e040acSBenjamin Maxwell } 11496e040acSBenjamin Maxwell }; 11596e040acSBenjamin Maxwell 11696e040acSBenjamin Maxwell /// Replaces allocations of SVE predicates smaller than an svbool [1] (_illegal_ 11796e040acSBenjamin Maxwell /// to load/store) with a wider allocation of svbool (_legal_ to load/store) 11896e040acSBenjamin Maxwell /// followed by a tagged unrealized conversion to the original type. 11996e040acSBenjamin Maxwell /// 12096e040acSBenjamin Maxwell /// Example 12196e040acSBenjamin Maxwell /// ``` 12296e040acSBenjamin Maxwell /// %alloca = memref.alloca() : memref<vector<[4]xi1>> 12396e040acSBenjamin Maxwell /// ``` 12496e040acSBenjamin Maxwell /// is rewritten into: 12596e040acSBenjamin Maxwell /// ``` 12696e040acSBenjamin Maxwell /// %widened = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>> 12796e040acSBenjamin Maxwell /// %alloca = builtin.unrealized_conversion_cast %widened 12896e040acSBenjamin Maxwell /// : memref<vector<[16]xi1>> to memref<vector<[4]xi1>> 12996e040acSBenjamin Maxwell /// {__arm_sve_legalize_vector_storage__} 13096e040acSBenjamin Maxwell /// ``` 13196e040acSBenjamin Maxwell template <typename AllocLikeOp> 13296e040acSBenjamin Maxwell struct LegalizeSVEMaskAllocation : public OpRewritePattern<AllocLikeOp> { 13396e040acSBenjamin Maxwell using OpRewritePattern<AllocLikeOp>::OpRewritePattern; 13496e040acSBenjamin Maxwell 13596e040acSBenjamin Maxwell LogicalResult matchAndRewrite(AllocLikeOp allocLikeOp, 13696e040acSBenjamin Maxwell PatternRewriter &rewriter) const override { 13796e040acSBenjamin Maxwell auto vectorType = 13896e040acSBenjamin Maxwell llvm::dyn_cast<VectorType>(allocLikeOp.getType().getElementType()); 13996e040acSBenjamin Maxwell 14096e040acSBenjamin Maxwell if (!vectorType || !isSVEMaskType(vectorType)) 14196e040acSBenjamin Maxwell return failure(); 14296e040acSBenjamin Maxwell 14396e040acSBenjamin Maxwell // Replace this alloc-like op of an SVE mask [2] with one of a (storable) 14496e040acSBenjamin Maxwell // svbool mask [1]. A temporary unrealized_conversion_cast is added to the 14596e040acSBenjamin Maxwell // old type to allow local rewrites. 14696e040acSBenjamin Maxwell replaceOpWithUnrealizedConversion( 14796e040acSBenjamin Maxwell rewriter, allocLikeOp, [&](AllocLikeOp newAllocLikeOp) { 14896e040acSBenjamin Maxwell newAllocLikeOp.getResult().setType( 14996e040acSBenjamin Maxwell llvm::cast<MemRefType>(newAllocLikeOp.getType().cloneWith( 15096e040acSBenjamin Maxwell {}, widenScalableMaskTypeToSvbool(vectorType)))); 15196e040acSBenjamin Maxwell return newAllocLikeOp; 15296e040acSBenjamin Maxwell }); 15396e040acSBenjamin Maxwell 15496e040acSBenjamin Maxwell return success(); 15596e040acSBenjamin Maxwell } 15696e040acSBenjamin Maxwell }; 15796e040acSBenjamin Maxwell 15896e040acSBenjamin Maxwell /// Replaces vector.type_casts of unrealized conversions to SVE predicate memref 15996e040acSBenjamin Maxwell /// types that are _illegal_ to load/store from (!= svbool [1]), with type casts 16096e040acSBenjamin Maxwell /// of memref types that are _legal_ to load/store, followed by unrealized 16196e040acSBenjamin Maxwell /// conversions. 16296e040acSBenjamin Maxwell /// 16396e040acSBenjamin Maxwell /// Example: 16496e040acSBenjamin Maxwell /// ``` 16596e040acSBenjamin Maxwell /// %alloca = builtin.unrealized_conversion_cast %widened 16696e040acSBenjamin Maxwell /// : memref<vector<[16]xi1>> to memref<vector<[8]xi1>> 16796e040acSBenjamin Maxwell /// {__arm_sve_legalize_vector_storage__} 16896e040acSBenjamin Maxwell /// %cast = vector.type_cast %alloca 16996e040acSBenjamin Maxwell /// : memref<vector<3x[8]xi1>> to memref<3xvector<[8]xi1>> 17096e040acSBenjamin Maxwell /// ``` 17196e040acSBenjamin Maxwell /// is rewritten into: 17296e040acSBenjamin Maxwell /// ``` 17396e040acSBenjamin Maxwell /// %widened_cast = vector.type_cast %widened 17496e040acSBenjamin Maxwell /// : memref<vector<3x[16]xi1>> to memref<3xvector<[16]xi1>> 17596e040acSBenjamin Maxwell /// %cast = builtin.unrealized_conversion_cast %widened_cast 17696e040acSBenjamin Maxwell /// : memref<3xvector<[16]xi1>> to memref<3xvector<[8]xi1>> 17796e040acSBenjamin Maxwell /// {__arm_sve_legalize_vector_storage__} 17896e040acSBenjamin Maxwell /// ``` 17996e040acSBenjamin Maxwell struct LegalizeSVEMaskTypeCastConversion 18096e040acSBenjamin Maxwell : public OpRewritePattern<vector::TypeCastOp> { 18196e040acSBenjamin Maxwell using OpRewritePattern::OpRewritePattern; 18296e040acSBenjamin Maxwell 18396e040acSBenjamin Maxwell LogicalResult matchAndRewrite(vector::TypeCastOp typeCastOp, 18496e040acSBenjamin Maxwell PatternRewriter &rewriter) const override { 18596e040acSBenjamin Maxwell auto resultType = typeCastOp.getResultMemRefType(); 18696e040acSBenjamin Maxwell auto vectorType = llvm::dyn_cast<VectorType>(resultType.getElementType()); 18796e040acSBenjamin Maxwell 18896e040acSBenjamin Maxwell if (!vectorType || !isSVEMaskType(vectorType)) 18996e040acSBenjamin Maxwell return failure(); 19096e040acSBenjamin Maxwell 19196e040acSBenjamin Maxwell auto legalMemref = getSVELegalizedMemref(typeCastOp.getMemref()); 19296e040acSBenjamin Maxwell if (failed(legalMemref)) 19396e040acSBenjamin Maxwell return failure(); 19496e040acSBenjamin Maxwell 19596e040acSBenjamin Maxwell // Replace this vector.type_cast with one of a (storable) svbool mask [1]. 19696e040acSBenjamin Maxwell replaceOpWithUnrealizedConversion( 19796e040acSBenjamin Maxwell rewriter, typeCastOp, [&](vector::TypeCastOp newTypeCast) { 19896e040acSBenjamin Maxwell newTypeCast.setOperand(*legalMemref); 19996e040acSBenjamin Maxwell newTypeCast.getResult().setType( 20096e040acSBenjamin Maxwell llvm::cast<MemRefType>(newTypeCast.getType().cloneWith( 20196e040acSBenjamin Maxwell {}, widenScalableMaskTypeToSvbool(vectorType)))); 20296e040acSBenjamin Maxwell return newTypeCast; 20396e040acSBenjamin Maxwell }); 20496e040acSBenjamin Maxwell 20596e040acSBenjamin Maxwell return success(); 20696e040acSBenjamin Maxwell } 20796e040acSBenjamin Maxwell }; 20896e040acSBenjamin Maxwell 20996e040acSBenjamin Maxwell /// Replaces stores to unrealized conversions to SVE predicate memref types that 21096e040acSBenjamin Maxwell /// are _illegal_ to load/store from (!= svbool [1]), with 21196e040acSBenjamin Maxwell /// `arm_sve.convert_to_svbool`s followed by (legal) wider stores. 21296e040acSBenjamin Maxwell /// 21396e040acSBenjamin Maxwell /// Example: 21496e040acSBenjamin Maxwell /// ``` 21596e040acSBenjamin Maxwell /// memref.store %mask, %alloca[] : memref<vector<[8]xi1>> 21696e040acSBenjamin Maxwell /// ``` 21796e040acSBenjamin Maxwell /// is rewritten into: 21896e040acSBenjamin Maxwell /// ``` 21996e040acSBenjamin Maxwell /// %svbool = arm_sve.convert_to_svbool %mask : vector<[8]xi1> 22096e040acSBenjamin Maxwell /// memref.store %svbool, %widened[] : memref<vector<[16]xi1>> 22196e040acSBenjamin Maxwell /// ``` 22296e040acSBenjamin Maxwell struct LegalizeSVEMaskStoreConversion 22396e040acSBenjamin Maxwell : public OpRewritePattern<memref::StoreOp> { 22496e040acSBenjamin Maxwell using OpRewritePattern::OpRewritePattern; 22596e040acSBenjamin Maxwell 22696e040acSBenjamin Maxwell LogicalResult matchAndRewrite(memref::StoreOp storeOp, 22796e040acSBenjamin Maxwell PatternRewriter &rewriter) const override { 22896e040acSBenjamin Maxwell auto loc = storeOp.getLoc(); 22996e040acSBenjamin Maxwell 23096e040acSBenjamin Maxwell Value valueToStore = storeOp.getValueToStore(); 23196e040acSBenjamin Maxwell auto vectorType = llvm::dyn_cast<VectorType>(valueToStore.getType()); 23296e040acSBenjamin Maxwell 23396e040acSBenjamin Maxwell if (!vectorType || !isSVEMaskType(vectorType)) 23496e040acSBenjamin Maxwell return failure(); 23596e040acSBenjamin Maxwell 23696e040acSBenjamin Maxwell auto legalMemref = getSVELegalizedMemref(storeOp.getMemref()); 23796e040acSBenjamin Maxwell if (failed(legalMemref)) 23896e040acSBenjamin Maxwell return failure(); 23996e040acSBenjamin Maxwell 24096e040acSBenjamin Maxwell auto legalMaskType = widenScalableMaskTypeToSvbool( 24196e040acSBenjamin Maxwell llvm::cast<VectorType>(valueToStore.getType())); 24296e040acSBenjamin Maxwell auto convertToSvbool = rewriter.create<arm_sve::ConvertToSvboolOp>( 24396e040acSBenjamin Maxwell loc, legalMaskType, valueToStore); 24496e040acSBenjamin Maxwell // Replace this store with a conversion to a storable svbool mask [1], 24596e040acSBenjamin Maxwell // followed by a wider store. 24696e040acSBenjamin Maxwell replaceOpWithLegalizedOp(rewriter, storeOp, 24796e040acSBenjamin Maxwell [&](memref::StoreOp newStoreOp) { 24896e040acSBenjamin Maxwell newStoreOp.setOperand(0, convertToSvbool); 24996e040acSBenjamin Maxwell newStoreOp.setOperand(1, *legalMemref); 25096e040acSBenjamin Maxwell return newStoreOp; 25196e040acSBenjamin Maxwell }); 25296e040acSBenjamin Maxwell 25396e040acSBenjamin Maxwell return success(); 25496e040acSBenjamin Maxwell } 25596e040acSBenjamin Maxwell }; 25696e040acSBenjamin Maxwell 25796e040acSBenjamin Maxwell /// Replaces loads from unrealized conversions to SVE predicate memref types 25896e040acSBenjamin Maxwell /// that are _illegal_ to load/store from (!= svbool [1]), types with (legal) 25996e040acSBenjamin Maxwell /// wider loads, followed by `arm_sve.convert_from_svbool`s. 26096e040acSBenjamin Maxwell /// 26196e040acSBenjamin Maxwell /// Example: 26296e040acSBenjamin Maxwell /// ``` 26396e040acSBenjamin Maxwell /// %reload = memref.load %alloca[] : memref<vector<[4]xi1>> 26496e040acSBenjamin Maxwell /// ``` 26596e040acSBenjamin Maxwell /// is rewritten into: 26696e040acSBenjamin Maxwell /// ``` 26796e040acSBenjamin Maxwell /// %svbool = memref.load %widened[] : memref<vector<[16]xi1>> 26896e040acSBenjamin Maxwell /// %reload = arm_sve.convert_from_svbool %reload : vector<[4]xi1> 26996e040acSBenjamin Maxwell /// ``` 27096e040acSBenjamin Maxwell struct LegalizeSVEMaskLoadConversion : public OpRewritePattern<memref::LoadOp> { 27196e040acSBenjamin Maxwell using OpRewritePattern::OpRewritePattern; 27296e040acSBenjamin Maxwell 27396e040acSBenjamin Maxwell LogicalResult matchAndRewrite(memref::LoadOp loadOp, 27496e040acSBenjamin Maxwell PatternRewriter &rewriter) const override { 27596e040acSBenjamin Maxwell auto loc = loadOp.getLoc(); 27696e040acSBenjamin Maxwell 27796e040acSBenjamin Maxwell Value loadedMask = loadOp.getResult(); 27896e040acSBenjamin Maxwell auto vectorType = llvm::dyn_cast<VectorType>(loadedMask.getType()); 27996e040acSBenjamin Maxwell 28096e040acSBenjamin Maxwell if (!vectorType || !isSVEMaskType(vectorType)) 28196e040acSBenjamin Maxwell return failure(); 28296e040acSBenjamin Maxwell 28396e040acSBenjamin Maxwell auto legalMemref = getSVELegalizedMemref(loadOp.getMemref()); 28496e040acSBenjamin Maxwell if (failed(legalMemref)) 28596e040acSBenjamin Maxwell return failure(); 28696e040acSBenjamin Maxwell 28796e040acSBenjamin Maxwell auto legalMaskType = widenScalableMaskTypeToSvbool(vectorType); 28896e040acSBenjamin Maxwell // Replace this load with a legal load of an svbool type, followed by a 28996e040acSBenjamin Maxwell // conversion back to the original type. 29096e040acSBenjamin Maxwell replaceOpWithLegalizedOp(rewriter, loadOp, [&](memref::LoadOp newLoadOp) { 29196e040acSBenjamin Maxwell newLoadOp.setMemRef(*legalMemref); 29296e040acSBenjamin Maxwell newLoadOp.getResult().setType(legalMaskType); 29396e040acSBenjamin Maxwell return rewriter.create<arm_sve::ConvertFromSvboolOp>( 29496e040acSBenjamin Maxwell loc, loadedMask.getType(), newLoadOp); 29596e040acSBenjamin Maxwell }); 29696e040acSBenjamin Maxwell 29796e040acSBenjamin Maxwell return success(); 29896e040acSBenjamin Maxwell } 29996e040acSBenjamin Maxwell }; 30096e040acSBenjamin Maxwell 30196e040acSBenjamin Maxwell } // namespace 30296e040acSBenjamin Maxwell 30396e040acSBenjamin Maxwell void mlir::arm_sve::populateLegalizeVectorStoragePatterns( 30496e040acSBenjamin Maxwell RewritePatternSet &patterns) { 30596e040acSBenjamin Maxwell patterns.add<RelaxScalableVectorAllocaAlignment, 30696e040acSBenjamin Maxwell LegalizeSVEMaskAllocation<memref::AllocaOp>, 30796e040acSBenjamin Maxwell LegalizeSVEMaskAllocation<memref::AllocOp>, 30896e040acSBenjamin Maxwell LegalizeSVEMaskTypeCastConversion, 30996e040acSBenjamin Maxwell LegalizeSVEMaskStoreConversion, LegalizeSVEMaskLoadConversion>( 31096e040acSBenjamin Maxwell patterns.getContext()); 31196e040acSBenjamin Maxwell } 31296e040acSBenjamin Maxwell 31396e040acSBenjamin Maxwell namespace { 31496e040acSBenjamin Maxwell struct LegalizeVectorStorage 31596e040acSBenjamin Maxwell : public arm_sve::impl::LegalizeVectorStorageBase<LegalizeVectorStorage> { 31696e040acSBenjamin Maxwell 31796e040acSBenjamin Maxwell void runOnOperation() override { 31896e040acSBenjamin Maxwell RewritePatternSet patterns(&getContext()); 31996e040acSBenjamin Maxwell populateLegalizeVectorStoragePatterns(patterns); 320*09dfc571SJacques Pienaar if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { 32196e040acSBenjamin Maxwell signalPassFailure(); 32296e040acSBenjamin Maxwell } 32396e040acSBenjamin Maxwell ConversionTarget target(getContext()); 32496e040acSBenjamin Maxwell target.addDynamicallyLegalOp<UnrealizedConversionCastOp>( 32596e040acSBenjamin Maxwell [](UnrealizedConversionCastOp unrealizedConversion) { 32696e040acSBenjamin Maxwell return !unrealizedConversion->hasAttr(kSVELegalizerTag); 32796e040acSBenjamin Maxwell }); 32896e040acSBenjamin Maxwell // This detects if we failed to completely legalize the IR. 32996e040acSBenjamin Maxwell if (failed(applyPartialConversion(getOperation(), target, {}))) 33096e040acSBenjamin Maxwell signalPassFailure(); 33196e040acSBenjamin Maxwell } 33296e040acSBenjamin Maxwell }; 33396e040acSBenjamin Maxwell 33496e040acSBenjamin Maxwell } // namespace 33596e040acSBenjamin Maxwell 33696e040acSBenjamin Maxwell std::unique_ptr<Pass> mlir::arm_sve::createLegalizeVectorStoragePass() { 33796e040acSBenjamin Maxwell return std::make_unique<LegalizeVectorStorage>(); 33896e040acSBenjamin Maxwell } 339