1 //===- LegalizeVectorStorage.cpp - Ensures SVE loads/stores are legal -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" 10 #include "mlir/Dialect/ArmSVE/Transforms/Passes.h" 11 #include "mlir/Dialect/Func/IR/FuncOps.h" 12 #include "mlir/Dialect/MemRef/IR/MemRef.h" 13 #include "mlir/Dialect/Vector/IR/VectorOps.h" 14 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 15 16 namespace mlir::arm_sve { 17 #define GEN_PASS_DEF_LEGALIZEVECTORSTORAGE 18 #include "mlir/Dialect/ArmSVE/Transforms/Passes.h.inc" 19 } // namespace mlir::arm_sve 20 21 using namespace mlir; 22 using namespace mlir::arm_sve; 23 24 // A tag to mark unrealized_conversions produced by this pass. This is used to 25 // detect IR this pass failed to completely legalize, and report an error. 26 // If everything was successfully legalized, no tagged ops will remain after 27 // this pass. 28 constexpr StringLiteral kSVELegalizerTag("__arm_sve_legalize_vector_storage__"); 29 30 /// Definitions: 31 /// 32 /// [1] svbool = vector<...x[16]xi1>, which maps to some multiple of full SVE 33 /// predicate registers. A full predicate is the smallest quantity that can be 34 /// loaded/stored. 35 /// 36 /// [2] SVE mask = hardware-sized SVE predicate mask, i.e. its trailing 37 /// dimension matches the size of a legal SVE vector size (such as 38 /// vector<[4]xi1>), but is too small to be stored to memory (i.e smaller than 39 /// a svbool). 40 41 namespace { 42 43 /// Checks if a vector type is a SVE mask [2]. 44 bool isSVEMaskType(VectorType type) { 45 return type.getRank() > 0 && type.getElementType().isInteger(1) && 46 type.getScalableDims().back() && type.getShape().back() < 16 && 47 llvm::isPowerOf2_32(type.getShape().back()) && 48 !llvm::is_contained(type.getScalableDims().drop_back(), true); 49 } 50 51 VectorType widenScalableMaskTypeToSvbool(VectorType type) { 52 assert(isSVEMaskType(type)); 53 return VectorType::Builder(type).setDim(type.getRank() - 1, 16); 54 } 55 56 /// A helper for cloning an op and replacing it will a new version, updated by a 57 /// callback. 58 template <typename TOp, typename TLegalizerCallback> 59 void replaceOpWithLegalizedOp(PatternRewriter &rewriter, TOp op, 60 TLegalizerCallback callback) { 61 // Clone the previous op to preserve any properties/attributes. 62 auto newOp = op.clone(); 63 rewriter.insert(newOp); 64 rewriter.replaceOp(op, callback(newOp)); 65 } 66 67 /// A helper for cloning an op and replacing it with a new version, updated by a 68 /// callback, and an unrealized conversion back to the type of the replaced op. 69 template <typename TOp, typename TLegalizerCallback> 70 void replaceOpWithUnrealizedConversion(PatternRewriter &rewriter, TOp op, 71 TLegalizerCallback callback) { 72 replaceOpWithLegalizedOp(rewriter, op, [&](TOp newOp) { 73 // Mark our `unrealized_conversion_casts` with a pass label. 74 return rewriter.create<UnrealizedConversionCastOp>( 75 op.getLoc(), TypeRange{op.getResult().getType()}, 76 ValueRange{callback(newOp)}, 77 NamedAttribute(rewriter.getStringAttr(kSVELegalizerTag), 78 rewriter.getUnitAttr())); 79 }); 80 } 81 82 /// Extracts the widened SVE memref value (that's legal to store/load) from the 83 /// `unrealized_conversion_cast`s added by this pass. 84 static FailureOr<Value> getSVELegalizedMemref(Value illegalMemref) { 85 Operation *definingOp = illegalMemref.getDefiningOp(); 86 if (!definingOp || !definingOp->hasAttr(kSVELegalizerTag)) 87 return failure(); 88 auto unrealizedConversion = 89 llvm::cast<UnrealizedConversionCastOp>(definingOp); 90 return unrealizedConversion.getOperand(0); 91 } 92 93 /// The default alignment of an alloca in LLVM may request overaligned sizes for 94 /// SVE types, which will fail during stack frame allocation. This rewrite 95 /// explicitly adds a reasonable alignment to allocas of scalable types. 96 struct RelaxScalableVectorAllocaAlignment 97 : public OpRewritePattern<memref::AllocaOp> { 98 using OpRewritePattern::OpRewritePattern; 99 100 LogicalResult matchAndRewrite(memref::AllocaOp allocaOp, 101 PatternRewriter &rewriter) const override { 102 auto memrefElementType = allocaOp.getType().getElementType(); 103 auto vectorType = llvm::dyn_cast<VectorType>(memrefElementType); 104 if (!vectorType || !vectorType.isScalable() || allocaOp.getAlignment()) 105 return failure(); 106 107 // Set alignment based on the defaults for SVE vectors and predicates. 108 unsigned aligment = vectorType.getElementType().isInteger(1) ? 2 : 16; 109 rewriter.modifyOpInPlace(allocaOp, 110 [&] { allocaOp.setAlignment(aligment); }); 111 112 return success(); 113 } 114 }; 115 116 /// Replaces allocations of SVE predicates smaller than an svbool [1] (_illegal_ 117 /// to load/store) with a wider allocation of svbool (_legal_ to load/store) 118 /// followed by a tagged unrealized conversion to the original type. 119 /// 120 /// Example 121 /// ``` 122 /// %alloca = memref.alloca() : memref<vector<[4]xi1>> 123 /// ``` 124 /// is rewritten into: 125 /// ``` 126 /// %widened = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>> 127 /// %alloca = builtin.unrealized_conversion_cast %widened 128 /// : memref<vector<[16]xi1>> to memref<vector<[4]xi1>> 129 /// {__arm_sve_legalize_vector_storage__} 130 /// ``` 131 template <typename AllocLikeOp> 132 struct LegalizeSVEMaskAllocation : public OpRewritePattern<AllocLikeOp> { 133 using OpRewritePattern<AllocLikeOp>::OpRewritePattern; 134 135 LogicalResult matchAndRewrite(AllocLikeOp allocLikeOp, 136 PatternRewriter &rewriter) const override { 137 auto vectorType = 138 llvm::dyn_cast<VectorType>(allocLikeOp.getType().getElementType()); 139 140 if (!vectorType || !isSVEMaskType(vectorType)) 141 return failure(); 142 143 // Replace this alloc-like op of an SVE mask [2] with one of a (storable) 144 // svbool mask [1]. A temporary unrealized_conversion_cast is added to the 145 // old type to allow local rewrites. 146 replaceOpWithUnrealizedConversion( 147 rewriter, allocLikeOp, [&](AllocLikeOp newAllocLikeOp) { 148 newAllocLikeOp.getResult().setType( 149 llvm::cast<MemRefType>(newAllocLikeOp.getType().cloneWith( 150 {}, widenScalableMaskTypeToSvbool(vectorType)))); 151 return newAllocLikeOp; 152 }); 153 154 return success(); 155 } 156 }; 157 158 /// Replaces vector.type_casts of unrealized conversions to SVE predicate memref 159 /// types that are _illegal_ to load/store from (!= svbool [1]), with type casts 160 /// of memref types that are _legal_ to load/store, followed by unrealized 161 /// conversions. 162 /// 163 /// Example: 164 /// ``` 165 /// %alloca = builtin.unrealized_conversion_cast %widened 166 /// : memref<vector<[16]xi1>> to memref<vector<[8]xi1>> 167 /// {__arm_sve_legalize_vector_storage__} 168 /// %cast = vector.type_cast %alloca 169 /// : memref<vector<3x[8]xi1>> to memref<3xvector<[8]xi1>> 170 /// ``` 171 /// is rewritten into: 172 /// ``` 173 /// %widened_cast = vector.type_cast %widened 174 /// : memref<vector<3x[16]xi1>> to memref<3xvector<[16]xi1>> 175 /// %cast = builtin.unrealized_conversion_cast %widened_cast 176 /// : memref<3xvector<[16]xi1>> to memref<3xvector<[8]xi1>> 177 /// {__arm_sve_legalize_vector_storage__} 178 /// ``` 179 struct LegalizeSVEMaskTypeCastConversion 180 : public OpRewritePattern<vector::TypeCastOp> { 181 using OpRewritePattern::OpRewritePattern; 182 183 LogicalResult matchAndRewrite(vector::TypeCastOp typeCastOp, 184 PatternRewriter &rewriter) const override { 185 auto resultType = typeCastOp.getResultMemRefType(); 186 auto vectorType = llvm::dyn_cast<VectorType>(resultType.getElementType()); 187 188 if (!vectorType || !isSVEMaskType(vectorType)) 189 return failure(); 190 191 auto legalMemref = getSVELegalizedMemref(typeCastOp.getMemref()); 192 if (failed(legalMemref)) 193 return failure(); 194 195 // Replace this vector.type_cast with one of a (storable) svbool mask [1]. 196 replaceOpWithUnrealizedConversion( 197 rewriter, typeCastOp, [&](vector::TypeCastOp newTypeCast) { 198 newTypeCast.setOperand(*legalMemref); 199 newTypeCast.getResult().setType( 200 llvm::cast<MemRefType>(newTypeCast.getType().cloneWith( 201 {}, widenScalableMaskTypeToSvbool(vectorType)))); 202 return newTypeCast; 203 }); 204 205 return success(); 206 } 207 }; 208 209 /// Replaces stores to unrealized conversions to SVE predicate memref types that 210 /// are _illegal_ to load/store from (!= svbool [1]), with 211 /// `arm_sve.convert_to_svbool`s followed by (legal) wider stores. 212 /// 213 /// Example: 214 /// ``` 215 /// memref.store %mask, %alloca[] : memref<vector<[8]xi1>> 216 /// ``` 217 /// is rewritten into: 218 /// ``` 219 /// %svbool = arm_sve.convert_to_svbool %mask : vector<[8]xi1> 220 /// memref.store %svbool, %widened[] : memref<vector<[16]xi1>> 221 /// ``` 222 struct LegalizeSVEMaskStoreConversion 223 : public OpRewritePattern<memref::StoreOp> { 224 using OpRewritePattern::OpRewritePattern; 225 226 LogicalResult matchAndRewrite(memref::StoreOp storeOp, 227 PatternRewriter &rewriter) const override { 228 auto loc = storeOp.getLoc(); 229 230 Value valueToStore = storeOp.getValueToStore(); 231 auto vectorType = llvm::dyn_cast<VectorType>(valueToStore.getType()); 232 233 if (!vectorType || !isSVEMaskType(vectorType)) 234 return failure(); 235 236 auto legalMemref = getSVELegalizedMemref(storeOp.getMemref()); 237 if (failed(legalMemref)) 238 return failure(); 239 240 auto legalMaskType = widenScalableMaskTypeToSvbool( 241 llvm::cast<VectorType>(valueToStore.getType())); 242 auto convertToSvbool = rewriter.create<arm_sve::ConvertToSvboolOp>( 243 loc, legalMaskType, valueToStore); 244 // Replace this store with a conversion to a storable svbool mask [1], 245 // followed by a wider store. 246 replaceOpWithLegalizedOp(rewriter, storeOp, 247 [&](memref::StoreOp newStoreOp) { 248 newStoreOp.setOperand(0, convertToSvbool); 249 newStoreOp.setOperand(1, *legalMemref); 250 return newStoreOp; 251 }); 252 253 return success(); 254 } 255 }; 256 257 /// Replaces loads from unrealized conversions to SVE predicate memref types 258 /// that are _illegal_ to load/store from (!= svbool [1]), types with (legal) 259 /// wider loads, followed by `arm_sve.convert_from_svbool`s. 260 /// 261 /// Example: 262 /// ``` 263 /// %reload = memref.load %alloca[] : memref<vector<[4]xi1>> 264 /// ``` 265 /// is rewritten into: 266 /// ``` 267 /// %svbool = memref.load %widened[] : memref<vector<[16]xi1>> 268 /// %reload = arm_sve.convert_from_svbool %reload : vector<[4]xi1> 269 /// ``` 270 struct LegalizeSVEMaskLoadConversion : public OpRewritePattern<memref::LoadOp> { 271 using OpRewritePattern::OpRewritePattern; 272 273 LogicalResult matchAndRewrite(memref::LoadOp loadOp, 274 PatternRewriter &rewriter) const override { 275 auto loc = loadOp.getLoc(); 276 277 Value loadedMask = loadOp.getResult(); 278 auto vectorType = llvm::dyn_cast<VectorType>(loadedMask.getType()); 279 280 if (!vectorType || !isSVEMaskType(vectorType)) 281 return failure(); 282 283 auto legalMemref = getSVELegalizedMemref(loadOp.getMemref()); 284 if (failed(legalMemref)) 285 return failure(); 286 287 auto legalMaskType = widenScalableMaskTypeToSvbool(vectorType); 288 // Replace this load with a legal load of an svbool type, followed by a 289 // conversion back to the original type. 290 replaceOpWithLegalizedOp(rewriter, loadOp, [&](memref::LoadOp newLoadOp) { 291 newLoadOp.setMemRef(*legalMemref); 292 newLoadOp.getResult().setType(legalMaskType); 293 return rewriter.create<arm_sve::ConvertFromSvboolOp>( 294 loc, loadedMask.getType(), newLoadOp); 295 }); 296 297 return success(); 298 } 299 }; 300 301 } // namespace 302 303 void mlir::arm_sve::populateLegalizeVectorStoragePatterns( 304 RewritePatternSet &patterns) { 305 patterns.add<RelaxScalableVectorAllocaAlignment, 306 LegalizeSVEMaskAllocation<memref::AllocaOp>, 307 LegalizeSVEMaskAllocation<memref::AllocOp>, 308 LegalizeSVEMaskTypeCastConversion, 309 LegalizeSVEMaskStoreConversion, LegalizeSVEMaskLoadConversion>( 310 patterns.getContext()); 311 } 312 313 namespace { 314 struct LegalizeVectorStorage 315 : public arm_sve::impl::LegalizeVectorStorageBase<LegalizeVectorStorage> { 316 317 void runOnOperation() override { 318 RewritePatternSet patterns(&getContext()); 319 populateLegalizeVectorStoragePatterns(patterns); 320 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { 321 signalPassFailure(); 322 } 323 ConversionTarget target(getContext()); 324 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>( 325 [](UnrealizedConversionCastOp unrealizedConversion) { 326 return !unrealizedConversion->hasAttr(kSVELegalizerTag); 327 }); 328 // This detects if we failed to completely legalize the IR. 329 if (failed(applyPartialConversion(getOperation(), target, {}))) 330 signalPassFailure(); 331 } 332 }; 333 334 } // namespace 335 336 std::unique_ptr<Pass> mlir::arm_sve::createLegalizeVectorStoragePass() { 337 return std::make_unique<LegalizeVectorStorage>(); 338 } 339