xref: /llvm-project/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h (revision ecaf2c335cd612646086ec53315cb1018a5b9d91)
1 //===- VectorDistribution.h - Vector distribution patterns --*- C++------*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORDISTRIBUTION_H_
10 #define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORDISTRIBUTION_H_
11 
12 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
13 #include "mlir/Dialect/Vector/IR/VectorOps.h"
14 
15 namespace mlir {
16 class RewritePatternSet;
17 namespace vector {
18 
19 struct WarpExecuteOnLane0LoweringOptions {
20   /// Lamdba function to let users allocate memory needed for the lowering of
21   /// WarpExecuteOnLane0Op.
22   /// The function needs to return an allocation that the lowering can use as
23   /// temporary memory. The allocation needs to match the shape of the type (the
24   /// type may be VectorType or a scalar) and be availble for the current warp.
25   /// If there are several warps running in parallel the allocation needs to be
26   /// split so that each warp has its own allocation.
27   using WarpAllocationFn = std::function<Value(
28       Location, OpBuilder &, gpu::WarpExecuteOnLane0Op, Type)>;
29   WarpAllocationFn warpAllocationFn = nullptr;
30 
31   /// Lamdba function to let user emit operation to syncronize all the thread
32   /// within a warp. After this operation all the threads can see any memory
33   /// written before the operation.
34   using WarpSyncronizationFn =
35       std::function<void(Location, OpBuilder &, gpu::WarpExecuteOnLane0Op)>;
36   WarpSyncronizationFn warpSyncronizationFn = nullptr;
37 };
38 
39 void populateWarpExecuteOnLane0OpToScfForPattern(
40     RewritePatternSet &patterns,
41     const WarpExecuteOnLane0LoweringOptions &options,
42     PatternBenefit benefit = 1);
43 
44 using DistributionMapFn = std::function<AffineMap(Value)>;
45 
46 /// Distribute transfer_write ops based on the affine map returned by
47 /// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
48 /// will not be distributed (it should be less than the warp size).
49 ///
50 /// Example:
51 /// ```
52 /// %0 = gpu.warp_execute_on_lane_0(%id){
53 ///   ...
54 ///   vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
55 ///   gpu.yield
56 /// }
57 /// ```
58 /// To
59 /// ```
60 /// %r:3 = gpu.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
61 ///   ...
62 ///   gpu.yield %v : vector<32xf32>
63 /// }
64 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
65 ///
66 /// When applied at the same time as the vector propagation patterns,
67 /// distribution of `vector.transfer_write` is expected to have the highest
68 /// priority (pattern benefit). By making propagation of `vector.transfer_read`
69 /// be the lowest priority pattern, it will be the last vector operation to
70 /// distribute, meaning writes should propagate first.
71 void populateDistributeTransferWriteOpPatterns(
72     RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
73     unsigned maxNumElementsToExtract, PatternBenefit benefit = 2);
74 
75 /// Move scalar operations with no dependency on the warp op outside of the
76 /// region.
77 void moveScalarUniformCode(gpu::WarpExecuteOnLane0Op op);
78 
79 /// Lambda signature to compute a warp shuffle of a given value of a given lane
80 /// within a given warp size.
81 using WarpShuffleFromIdxFn =
82     std::function<Value(Location, OpBuilder &b, Value, Value, int64_t)>;
83 
84 /// Collect patterns to propagate warp distribution. `distributionMapFn` is used
85 /// to decide how a value should be distributed when this cannot be inferred
86 /// from its uses.
87 ///
88 /// The separate control over the `vector.transfer_read` op pattern benefit
89 /// is given to ensure the order of reads/writes before and after distribution
90 /// is consistent. As noted above, writes are expected to have the highest
91 /// priority for distribution, but are only ever distributed if adjacent to the
92 /// yield. By making reads the lowest priority pattern, it will be the last
93 /// vector operation to distribute, meaning writes should propagate first. This
94 /// is relatively brittle when ops fail to distribute, but that is a limitation
95 /// of these propagation patterns when there is a dependency not modeled by SSA.
96 void populatePropagateWarpVectorDistributionPatterns(
97     RewritePatternSet &pattern, const DistributionMapFn &distributionMapFn,
98     const WarpShuffleFromIdxFn &warpShuffleFromIdxFn,
99     PatternBenefit benefit = 1, PatternBenefit readBenefit = 0);
100 
101 /// Lambda signature to compute a reduction of a distributed value for the given
102 /// reduction kind and size.
103 using DistributedReductionFn =
104     std::function<Value(Location, OpBuilder &, Value, CombiningKind, uint32_t)>;
105 
106 /// Collect patterns to distribute vector reduction ops using given lamdba to
107 /// distribute reduction op.
108 void populateDistributeReduction(
109     RewritePatternSet &pattern,
110     const DistributedReductionFn &distributedReductionFn,
111     PatternBenefit benefit = 1);
112 
113 } // namespace vector
114 } // namespace mlir
115 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORDISTRIBUTION_H_
116