xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp (revision b91d5af1ac3ad2c18b1dfde2061a6ac1d638e6e4)
1 //===- LowerVectorGather.cpp - Lower 'vector.gather' operation ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.gather' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/Arith/Utils/Utils.h"
17 #include "mlir/Dialect/Linalg/IR/Linalg.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/SCF/IR/SCF.h"
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"
21 #include "mlir/Dialect/Utils/IndexingUtils.h"
22 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
23 #include "mlir/Dialect/Vector/IR/VectorOps.h"
24 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
25 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
26 #include "mlir/IR/BuiltinAttributeInterfaces.h"
27 #include "mlir/IR/BuiltinTypes.h"
28 #include "mlir/IR/ImplicitLocOpBuilder.h"
29 #include "mlir/IR/Location.h"
30 #include "mlir/IR/Matchers.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/TypeUtilities.h"
33 #include "mlir/Interfaces/VectorInterfaces.h"
34 
35 #define DEBUG_TYPE "vector-broadcast-lowering"
36 
37 using namespace mlir;
38 using namespace mlir::vector;
39 
40 namespace {
41 /// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
42 /// outermost dimension. For example:
43 /// ```
44 /// %g = vector.gather %base[%c0][%v], %mask, %pass_thru :
45 ///        ... into vector<2x3xf32>
46 ///
47 /// ==>
48 ///
49 /// %0   = arith.constant dense<0.0> : vector<2x3xf32>
50 /// %g0  = vector.gather %base[%c0][%v0], %mask0, %pass_thru0 : ...
51 /// %1   = vector.insert %g0, %0 [0] : vector<3xf32> into vector<2x3xf32>
52 /// %g1  = vector.gather %base[%c0][%v1], %mask1, %pass_thru1 : ...
53 /// %g   = vector.insert %g1, %1 [1] : vector<3xf32> into vector<2x3xf32>
54 /// ```
55 ///
56 /// When applied exhaustively, this will produce a sequence of 1-d gather ops.
57 ///
58 /// Supports vector types with a fixed leading dimension.
59 struct FlattenGather : OpRewritePattern<vector::GatherOp> {
60   using OpRewritePattern::OpRewritePattern;
61 
62   LogicalResult matchAndRewrite(vector::GatherOp op,
63                                 PatternRewriter &rewriter) const override {
64     VectorType resultTy = op.getType();
65     if (resultTy.getRank() < 2)
66       return rewriter.notifyMatchFailure(op, "already flat");
67 
68     // Unrolling doesn't take vscale into account. Pattern is disabled for
69     // vectors with leading scalable dim(s).
70     if (resultTy.getScalableDims().front())
71       return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim");
72 
73     Location loc = op.getLoc();
74     Value indexVec = op.getIndexVec();
75     Value maskVec = op.getMask();
76     Value passThruVec = op.getPassThru();
77 
78     Value result = rewriter.create<arith::ConstantOp>(
79         loc, resultTy, rewriter.getZeroAttr(resultTy));
80 
81     VectorType subTy = VectorType::Builder(resultTy).dropDim(0);
82 
83     for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
84       int64_t thisIdx[1] = {i};
85 
86       Value indexSubVec =
87           rewriter.create<vector::ExtractOp>(loc, indexVec, thisIdx);
88       Value maskSubVec =
89           rewriter.create<vector::ExtractOp>(loc, maskVec, thisIdx);
90       Value passThruSubVec =
91           rewriter.create<vector::ExtractOp>(loc, passThruVec, thisIdx);
92       Value subGather = rewriter.create<vector::GatherOp>(
93           loc, subTy, op.getBase(), op.getIndices(), indexSubVec, maskSubVec,
94           passThruSubVec);
95       result =
96           rewriter.create<vector::InsertOp>(loc, subGather, result, thisIdx);
97     }
98 
99     rewriter.replaceOp(op, result);
100     return success();
101   }
102 };
103 
104 /// Rewrites a vector.gather of a strided MemRef as a gather of a non-strided
105 /// MemRef with updated indices that model the strided access.
106 ///
107 /// ```mlir
108 ///   %subview = memref.subview %M (...)
109 ///     : memref<100x3xf32> to memref<100xf32, strided<[3]>>
110 ///   %gather = vector.gather %subview[%idxs] (...) : memref<100xf32, strided<[3]>>
111 /// ```
112 /// ==>
113 /// ```mlir
114 ///   %collapse_shape = memref.collapse_shape %M (...)
115 ///     : memref<100x3xf32> into memref<300xf32>
116 ///   %new_idxs = arith.muli %idxs, %c3 : vector<4xindex>
117 ///   %gather = vector.gather %collapse_shape[%new_idxs] (...)
118 ///     : memref<300xf32> (...)
119 /// ```
120 ///
121 /// ATM this is effectively limited to reading a 1D Vector from a 2D MemRef,
122 /// but should be fairly straightforward to extend beyond that.
123 struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
124   using OpRewritePattern::OpRewritePattern;
125 
126   LogicalResult matchAndRewrite(vector::GatherOp op,
127                                 PatternRewriter &rewriter) const override {
128     Value base = op.getBase();
129 
130     // TODO: Strided accesses might be coming from other ops as well
131     auto subview = base.getDefiningOp<memref::SubViewOp>();
132     if (!subview)
133       return failure();
134 
135     auto sourceType = subview.getSource().getType();
136 
137     // TODO: Allow ranks > 2.
138     if (sourceType.getRank() != 2)
139       return failure();
140 
141     // Get strides
142     auto layout = subview.getResult().getType().getLayout();
143     auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(layout);
144     if (!stridedLayoutAttr)
145       return failure();
146 
147     // TODO: Allow the access to be strided in multiple dimensions.
148     if (stridedLayoutAttr.getStrides().size() != 1)
149       return failure();
150 
151     int64_t srcTrailingDim = sourceType.getShape().back();
152 
153     // Assume that the stride matches the trailing dimension of the source
154     // memref.
155     // TODO: Relax this assumption.
156     if (stridedLayoutAttr.getStrides()[0] != srcTrailingDim)
157       return failure();
158 
159     // 1. Collapse the input memref so that it's "flat".
160     SmallVector<ReassociationIndices> reassoc = {{0, 1}};
161     Value collapsed = rewriter.create<memref::CollapseShapeOp>(
162         op.getLoc(), subview.getSource(), reassoc);
163 
164     // 2. Generate new gather indices that will model the
165     // strided access.
166     IntegerAttr stride = rewriter.getIndexAttr(srcTrailingDim);
167     VectorType vType = op.getIndexVec().getType();
168     Value mulCst = rewriter.create<arith::ConstantOp>(
169         op.getLoc(), vType, DenseElementsAttr::get(vType, stride));
170 
171     Value newIdxs =
172         rewriter.create<arith::MulIOp>(op.getLoc(), op.getIndexVec(), mulCst);
173 
174     // 3. Create an updated gather op with the collapsed input memref and the
175     // updated indices.
176     Value newGather = rewriter.create<vector::GatherOp>(
177         op.getLoc(), op.getResult().getType(), collapsed, op.getIndices(),
178         newIdxs, op.getMask(), op.getPassThru());
179     rewriter.replaceOp(op, newGather);
180 
181     return success();
182   }
183 };
184 
185 /// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
186 /// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
187 /// loads/extracts are made conditional using `scf.if` ops.
188 struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
189   using OpRewritePattern::OpRewritePattern;
190 
191   LogicalResult matchAndRewrite(vector::GatherOp op,
192                                 PatternRewriter &rewriter) const override {
193     VectorType resultTy = op.getType();
194     if (resultTy.getRank() != 1)
195       return rewriter.notifyMatchFailure(op, "unsupported rank");
196 
197     if (resultTy.isScalable())
198       return rewriter.notifyMatchFailure(op, "not a fixed-width vector");
199 
200     Location loc = op.getLoc();
201     Type elemTy = resultTy.getElementType();
202     // Vector type with a single element. Used to generate `vector.loads`.
203     VectorType elemVecTy = VectorType::get({1}, elemTy);
204 
205     Value condMask = op.getMask();
206     Value base = op.getBase();
207 
208     // vector.load requires the most minor memref dim to have unit stride
209     // (unless reading exactly 1 element)
210     if (auto memType = dyn_cast<MemRefType>(base.getType())) {
211       if (auto stridesAttr =
212               dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
213         if (stridesAttr.getStrides().back() != 1 &&
214             resultTy.getNumElements() != 1)
215           return failure();
216       }
217     }
218 
219     Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
220         loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
221         op.getIndexVec());
222     auto baseOffsets = llvm::to_vector(op.getIndices());
223     Value lastBaseOffset = baseOffsets.back();
224 
225     Value result = op.getPassThru();
226 
227     // Emit a conditional access for each vector element.
228     for (int64_t i = 0, e = resultTy.getNumElements(); i < e; ++i) {
229       int64_t thisIdx[1] = {i};
230       Value condition =
231           rewriter.create<vector::ExtractOp>(loc, condMask, thisIdx);
232       Value index = rewriter.create<vector::ExtractOp>(loc, indexVec, thisIdx);
233       baseOffsets.back() =
234           rewriter.createOrFold<arith::AddIOp>(loc, lastBaseOffset, index);
235 
236       auto loadBuilder = [&](OpBuilder &b, Location loc) {
237         Value extracted;
238         if (isa<MemRefType>(base.getType())) {
239           // `vector.load` does not support scalar result; emit a vector load
240           // and extract the single result instead.
241           Value load =
242               b.create<vector::LoadOp>(loc, elemVecTy, base, baseOffsets);
243           int64_t zeroIdx[1] = {0};
244           extracted = b.create<vector::ExtractOp>(loc, load, zeroIdx);
245         } else {
246           extracted = b.create<tensor::ExtractOp>(loc, base, baseOffsets);
247         }
248 
249         Value newResult =
250             b.create<vector::InsertOp>(loc, extracted, result, thisIdx);
251         b.create<scf::YieldOp>(loc, newResult);
252       };
253       auto passThruBuilder = [result](OpBuilder &b, Location loc) {
254         b.create<scf::YieldOp>(loc, result);
255       };
256 
257       result =
258           rewriter
259               .create<scf::IfOp>(loc, condition, /*thenBuilder=*/loadBuilder,
260                                  /*elseBuilder=*/passThruBuilder)
261               .getResult(0);
262     }
263 
264     rewriter.replaceOp(op, result);
265     return success();
266   }
267 };
268 } // namespace
269 
270 void mlir::vector::populateVectorGatherLoweringPatterns(
271     RewritePatternSet &patterns, PatternBenefit benefit) {
272   patterns.add<FlattenGather, RemoveStrideFromGatherSource,
273                Gather1DToConditionalLoads>(patterns.getContext(), benefit);
274 }
275