1//===-- Passes.td - ArmSVE pass definition file ------------*- tablegen -*-===// 2// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6// 7//===----------------------------------------------------------------------===// 8 9#ifndef MLIR_DIALECT_ARMSVE_TRANSFORMS_PASSES_TD 10#define MLIR_DIALECT_ARMSVE_TRANSFORMS_PASSES_TD 11 12include "mlir/Pass/PassBase.td" 13 14def LegalizeVectorStorage 15 : Pass<"arm-sve-legalize-vector-storage", "mlir::func::FuncOp"> { 16 let summary = "Ensures stores of SVE vector types will be legal"; 17 let description = [{ 18 This pass ensures that loads, stores, and allocations of SVE vector types 19 will be legal in the LLVM backend. It does this at the memref level, so this 20 pass must be applied before lowering all the way to LLVM. 21 22 This pass currently addresses two issues. 23 24 #### Loading and storing predicate types 25 26 It is only legal to load/store predicate types equal to (or greater than) a 27 full predicate register, which in MLIR is `vector<[16]xi1>`. Smaller 28 predicate types (`vector<[1|2|4|8]xi1>`) must be converted to/from a full 29 predicate type (referred to as a `svbool`) before and after storing and 30 loading respectively. This pass does this by widening allocations and 31 inserting conversion intrinsics. Note: Non-powers-of-two masks (e.g. 32 `vector<[7]xi1>`), which are not SVE predicates, are ignored. 33 34 For example: 35 36 ```mlir 37 %alloca = memref.alloca() : memref<vector<[4]xi1>> 38 %mask = vector.constant_mask [4] : vector<[4]xi1> 39 memref.store %mask, %alloca[] : memref<vector<[4]xi1>> 40 %reload = memref.load %alloca[] : memref<vector<[4]xi1>> 41 ``` 42 Becomes: 43 ```mlir 44 %alloca = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>> 45 %mask = vector.constant_mask [4] : vector<[4]xi1> 46 %svbool = arm_sve.convert_to_svbool %mask : vector<[4]xi1> 47 memref.store %svbool, %alloca[] : memref<vector<[16]xi1>> 48 %reload_svbool = memref.load %alloca[] : memref<vector<[16]xi1>> 49 %reload = arm_sve.convert_from_svbool %reload_svbool : vector<[4]xi1> 50 ``` 51 52 #### Relax alignments for SVE vector allocas 53 54 The storage for SVE vector types only needs to have an alignment that 55 matches the element type (for example 4 byte alignment for `f32`s). However, 56 the LLVM backend currently defaults to aligning to `base size` x 57 `element size` bytes. For non-legal vector types like `vector<[8]xf32>` this 58 results in 8 x 4 = 32-byte alignment, but the backend only supports up to 59 16-byte alignment for SVE vectors on the stack. Explicitly setting a smaller 60 alignment prevents this issue. 61 }]; 62 let constructor = "mlir::arm_sve::createLegalizeVectorStoragePass()"; 63 let dependentDialects = ["func::FuncDialect", 64 "memref::MemRefDialect", "vector::VectorDialect", 65 "arm_sve::ArmSVEDialect"]; 66} 67 68#endif // MLIR_DIALECT_ARMSVE_TRANSFORMS_PASSES_TD 69