1 //===- VectorDistribution.h - Vector distribution patterns --*- C++------*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORDISTRIBUTION_H_ 10 #define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORDISTRIBUTION_H_ 11 12 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 13 #include "mlir/Dialect/Vector/IR/VectorOps.h" 14 15 namespace mlir { 16 class RewritePatternSet; 17 namespace vector { 18 19 struct WarpExecuteOnLane0LoweringOptions { 20 /// Lamdba function to let users allocate memory needed for the lowering of 21 /// WarpExecuteOnLane0Op. 22 /// The function needs to return an allocation that the lowering can use as 23 /// temporary memory. The allocation needs to match the shape of the type (the 24 /// type may be VectorType or a scalar) and be availble for the current warp. 25 /// If there are several warps running in parallel the allocation needs to be 26 /// split so that each warp has its own allocation. 27 using WarpAllocationFn = std::function<Value( 28 Location, OpBuilder &, gpu::WarpExecuteOnLane0Op, Type)>; 29 WarpAllocationFn warpAllocationFn = nullptr; 30 31 /// Lamdba function to let user emit operation to syncronize all the thread 32 /// within a warp. After this operation all the threads can see any memory 33 /// written before the operation. 34 using WarpSyncronizationFn = 35 std::function<void(Location, OpBuilder &, gpu::WarpExecuteOnLane0Op)>; 36 WarpSyncronizationFn warpSyncronizationFn = nullptr; 37 }; 38 39 void populateWarpExecuteOnLane0OpToScfForPattern( 40 RewritePatternSet &patterns, 41 const WarpExecuteOnLane0LoweringOptions &options, 42 PatternBenefit benefit = 1); 43 44 using DistributionMapFn = std::function<AffineMap(Value)>; 45 46 /// Distribute transfer_write ops based on the affine map returned by 47 /// `distributionMapFn`. Writes of size more than `maxNumElementToExtract` 48 /// will not be distributed (it should be less than the warp size). 49 /// 50 /// Example: 51 /// ``` 52 /// %0 = gpu.warp_execute_on_lane_0(%id){ 53 /// ... 54 /// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32> 55 /// gpu.yield 56 /// } 57 /// ``` 58 /// To 59 /// ``` 60 /// %r:3 = gpu.warp_execute_on_lane_0(%id) -> (vector<1xf32>) { 61 /// ... 62 /// gpu.yield %v : vector<32xf32> 63 /// } 64 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32> 65 /// 66 /// When applied at the same time as the vector propagation patterns, 67 /// distribution of `vector.transfer_write` is expected to have the highest 68 /// priority (pattern benefit). By making propagation of `vector.transfer_read` 69 /// be the lowest priority pattern, it will be the last vector operation to 70 /// distribute, meaning writes should propagate first. 71 void populateDistributeTransferWriteOpPatterns( 72 RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn, 73 unsigned maxNumElementsToExtract, PatternBenefit benefit = 2); 74 75 /// Move scalar operations with no dependency on the warp op outside of the 76 /// region. 77 void moveScalarUniformCode(gpu::WarpExecuteOnLane0Op op); 78 79 /// Lambda signature to compute a warp shuffle of a given value of a given lane 80 /// within a given warp size. 81 using WarpShuffleFromIdxFn = 82 std::function<Value(Location, OpBuilder &b, Value, Value, int64_t)>; 83 84 /// Collect patterns to propagate warp distribution. `distributionMapFn` is used 85 /// to decide how a value should be distributed when this cannot be inferred 86 /// from its uses. 87 /// 88 /// The separate control over the `vector.transfer_read` op pattern benefit 89 /// is given to ensure the order of reads/writes before and after distribution 90 /// is consistent. As noted above, writes are expected to have the highest 91 /// priority for distribution, but are only ever distributed if adjacent to the 92 /// yield. By making reads the lowest priority pattern, it will be the last 93 /// vector operation to distribute, meaning writes should propagate first. This 94 /// is relatively brittle when ops fail to distribute, but that is a limitation 95 /// of these propagation patterns when there is a dependency not modeled by SSA. 96 void populatePropagateWarpVectorDistributionPatterns( 97 RewritePatternSet &pattern, const DistributionMapFn &distributionMapFn, 98 const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, 99 PatternBenefit benefit = 1, PatternBenefit readBenefit = 0); 100 101 /// Lambda signature to compute a reduction of a distributed value for the given 102 /// reduction kind and size. 103 using DistributedReductionFn = 104 std::function<Value(Location, OpBuilder &, Value, CombiningKind, uint32_t)>; 105 106 /// Collect patterns to distribute vector reduction ops using given lamdba to 107 /// distribute reduction op. 108 void populateDistributeReduction( 109 RewritePatternSet &pattern, 110 const DistributedReductionFn &distributedReductionFn, 111 PatternBenefit benefit = 1); 112 113 } // namespace vector 114 } // namespace mlir 115 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORDISTRIBUTION_H_ 116