1 //===- LowerVectorGather.cpp - Lower 'vector.gather' operation ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements target-independent rewrites and utilities to lower the 10 // 'vector.gather' operation. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/Dialect/Arith/Utils/Utils.h" 17 #include "mlir/Dialect/Linalg/IR/Linalg.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/SCF/IR/SCF.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/Dialect/Utils/IndexingUtils.h" 22 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 23 #include "mlir/Dialect/Vector/IR/VectorOps.h" 24 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 25 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 26 #include "mlir/IR/BuiltinAttributeInterfaces.h" 27 #include "mlir/IR/BuiltinTypes.h" 28 #include "mlir/IR/ImplicitLocOpBuilder.h" 29 #include "mlir/IR/Location.h" 30 #include "mlir/IR/Matchers.h" 31 #include "mlir/IR/PatternMatch.h" 32 #include "mlir/IR/TypeUtilities.h" 33 #include "mlir/Interfaces/VectorInterfaces.h" 34 35 #define DEBUG_TYPE "vector-broadcast-lowering" 36 37 using namespace mlir; 38 using namespace mlir::vector; 39 40 namespace { 41 /// Flattens 2 or more dimensional `vector.gather` ops by unrolling the 42 /// outermost dimension. For example: 43 /// ``` 44 /// %g = vector.gather %base[%c0][%v], %mask, %pass_thru : 45 /// ... into vector<2x3xf32> 46 /// 47 /// ==> 48 /// 49 /// %0 = arith.constant dense<0.0> : vector<2x3xf32> 50 /// %g0 = vector.gather %base[%c0][%v0], %mask0, %pass_thru0 : ... 51 /// %1 = vector.insert %g0, %0 [0] : vector<3xf32> into vector<2x3xf32> 52 /// %g1 = vector.gather %base[%c0][%v1], %mask1, %pass_thru1 : ... 53 /// %g = vector.insert %g1, %1 [1] : vector<3xf32> into vector<2x3xf32> 54 /// ``` 55 /// 56 /// When applied exhaustively, this will produce a sequence of 1-d gather ops. 57 /// 58 /// Supports vector types with a fixed leading dimension. 59 struct FlattenGather : OpRewritePattern<vector::GatherOp> { 60 using OpRewritePattern::OpRewritePattern; 61 62 LogicalResult matchAndRewrite(vector::GatherOp op, 63 PatternRewriter &rewriter) const override { 64 VectorType resultTy = op.getType(); 65 if (resultTy.getRank() < 2) 66 return rewriter.notifyMatchFailure(op, "already flat"); 67 68 // Unrolling doesn't take vscale into account. Pattern is disabled for 69 // vectors with leading scalable dim(s). 70 if (resultTy.getScalableDims().front()) 71 return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim"); 72 73 Location loc = op.getLoc(); 74 Value indexVec = op.getIndexVec(); 75 Value maskVec = op.getMask(); 76 Value passThruVec = op.getPassThru(); 77 78 Value result = rewriter.create<arith::ConstantOp>( 79 loc, resultTy, rewriter.getZeroAttr(resultTy)); 80 81 VectorType subTy = VectorType::Builder(resultTy).dropDim(0); 82 83 for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) { 84 int64_t thisIdx[1] = {i}; 85 86 Value indexSubVec = 87 rewriter.create<vector::ExtractOp>(loc, indexVec, thisIdx); 88 Value maskSubVec = 89 rewriter.create<vector::ExtractOp>(loc, maskVec, thisIdx); 90 Value passThruSubVec = 91 rewriter.create<vector::ExtractOp>(loc, passThruVec, thisIdx); 92 Value subGather = rewriter.create<vector::GatherOp>( 93 loc, subTy, op.getBase(), op.getIndices(), indexSubVec, maskSubVec, 94 passThruSubVec); 95 result = 96 rewriter.create<vector::InsertOp>(loc, subGather, result, thisIdx); 97 } 98 99 rewriter.replaceOp(op, result); 100 return success(); 101 } 102 }; 103 104 /// Rewrites a vector.gather of a strided MemRef as a gather of a non-strided 105 /// MemRef with updated indices that model the strided access. 106 /// 107 /// ```mlir 108 /// %subview = memref.subview %M (...) 109 /// : memref<100x3xf32> to memref<100xf32, strided<[3]>> 110 /// %gather = vector.gather %subview[%idxs] (...) : memref<100xf32, strided<[3]>> 111 /// ``` 112 /// ==> 113 /// ```mlir 114 /// %collapse_shape = memref.collapse_shape %M (...) 115 /// : memref<100x3xf32> into memref<300xf32> 116 /// %new_idxs = arith.muli %idxs, %c3 : vector<4xindex> 117 /// %gather = vector.gather %collapse_shape[%new_idxs] (...) 118 /// : memref<300xf32> (...) 119 /// ``` 120 /// 121 /// ATM this is effectively limited to reading a 1D Vector from a 2D MemRef, 122 /// but should be fairly straightforward to extend beyond that. 123 struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> { 124 using OpRewritePattern::OpRewritePattern; 125 126 LogicalResult matchAndRewrite(vector::GatherOp op, 127 PatternRewriter &rewriter) const override { 128 Value base = op.getBase(); 129 130 // TODO: Strided accesses might be coming from other ops as well 131 auto subview = base.getDefiningOp<memref::SubViewOp>(); 132 if (!subview) 133 return failure(); 134 135 auto sourceType = subview.getSource().getType(); 136 137 // TODO: Allow ranks > 2. 138 if (sourceType.getRank() != 2) 139 return failure(); 140 141 // Get strides 142 auto layout = subview.getResult().getType().getLayout(); 143 auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(layout); 144 if (!stridedLayoutAttr) 145 return failure(); 146 147 // TODO: Allow the access to be strided in multiple dimensions. 148 if (stridedLayoutAttr.getStrides().size() != 1) 149 return failure(); 150 151 int64_t srcTrailingDim = sourceType.getShape().back(); 152 153 // Assume that the stride matches the trailing dimension of the source 154 // memref. 155 // TODO: Relax this assumption. 156 if (stridedLayoutAttr.getStrides()[0] != srcTrailingDim) 157 return failure(); 158 159 // 1. Collapse the input memref so that it's "flat". 160 SmallVector<ReassociationIndices> reassoc = {{0, 1}}; 161 Value collapsed = rewriter.create<memref::CollapseShapeOp>( 162 op.getLoc(), subview.getSource(), reassoc); 163 164 // 2. Generate new gather indices that will model the 165 // strided access. 166 IntegerAttr stride = rewriter.getIndexAttr(srcTrailingDim); 167 VectorType vType = op.getIndexVec().getType(); 168 Value mulCst = rewriter.create<arith::ConstantOp>( 169 op.getLoc(), vType, DenseElementsAttr::get(vType, stride)); 170 171 Value newIdxs = 172 rewriter.create<arith::MulIOp>(op.getLoc(), op.getIndexVec(), mulCst); 173 174 // 3. Create an updated gather op with the collapsed input memref and the 175 // updated indices. 176 Value newGather = rewriter.create<vector::GatherOp>( 177 op.getLoc(), op.getResult().getType(), collapsed, op.getIndices(), 178 newIdxs, op.getMask(), op.getPassThru()); 179 rewriter.replaceOp(op, newGather); 180 181 return success(); 182 } 183 }; 184 185 /// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or 186 /// `tensor.extract`s. To avoid out-of-bounds memory accesses, these 187 /// loads/extracts are made conditional using `scf.if` ops. 188 struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> { 189 using OpRewritePattern::OpRewritePattern; 190 191 LogicalResult matchAndRewrite(vector::GatherOp op, 192 PatternRewriter &rewriter) const override { 193 VectorType resultTy = op.getType(); 194 if (resultTy.getRank() != 1) 195 return rewriter.notifyMatchFailure(op, "unsupported rank"); 196 197 if (resultTy.isScalable()) 198 return rewriter.notifyMatchFailure(op, "not a fixed-width vector"); 199 200 Location loc = op.getLoc(); 201 Type elemTy = resultTy.getElementType(); 202 // Vector type with a single element. Used to generate `vector.loads`. 203 VectorType elemVecTy = VectorType::get({1}, elemTy); 204 205 Value condMask = op.getMask(); 206 Value base = op.getBase(); 207 208 // vector.load requires the most minor memref dim to have unit stride 209 // (unless reading exactly 1 element) 210 if (auto memType = dyn_cast<MemRefType>(base.getType())) { 211 if (auto stridesAttr = 212 dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) { 213 if (stridesAttr.getStrides().back() != 1 && 214 resultTy.getNumElements() != 1) 215 return failure(); 216 } 217 } 218 219 Value indexVec = rewriter.createOrFold<arith::IndexCastOp>( 220 loc, op.getIndexVectorType().clone(rewriter.getIndexType()), 221 op.getIndexVec()); 222 auto baseOffsets = llvm::to_vector(op.getIndices()); 223 Value lastBaseOffset = baseOffsets.back(); 224 225 Value result = op.getPassThru(); 226 227 // Emit a conditional access for each vector element. 228 for (int64_t i = 0, e = resultTy.getNumElements(); i < e; ++i) { 229 int64_t thisIdx[1] = {i}; 230 Value condition = 231 rewriter.create<vector::ExtractOp>(loc, condMask, thisIdx); 232 Value index = rewriter.create<vector::ExtractOp>(loc, indexVec, thisIdx); 233 baseOffsets.back() = 234 rewriter.createOrFold<arith::AddIOp>(loc, lastBaseOffset, index); 235 236 auto loadBuilder = [&](OpBuilder &b, Location loc) { 237 Value extracted; 238 if (isa<MemRefType>(base.getType())) { 239 // `vector.load` does not support scalar result; emit a vector load 240 // and extract the single result instead. 241 Value load = 242 b.create<vector::LoadOp>(loc, elemVecTy, base, baseOffsets); 243 int64_t zeroIdx[1] = {0}; 244 extracted = b.create<vector::ExtractOp>(loc, load, zeroIdx); 245 } else { 246 extracted = b.create<tensor::ExtractOp>(loc, base, baseOffsets); 247 } 248 249 Value newResult = 250 b.create<vector::InsertOp>(loc, extracted, result, thisIdx); 251 b.create<scf::YieldOp>(loc, newResult); 252 }; 253 auto passThruBuilder = [result](OpBuilder &b, Location loc) { 254 b.create<scf::YieldOp>(loc, result); 255 }; 256 257 result = 258 rewriter 259 .create<scf::IfOp>(loc, condition, /*thenBuilder=*/loadBuilder, 260 /*elseBuilder=*/passThruBuilder) 261 .getResult(0); 262 } 263 264 rewriter.replaceOp(op, result); 265 return success(); 266 } 267 }; 268 } // namespace 269 270 void mlir::vector::populateVectorGatherLoweringPatterns( 271 RewritePatternSet &patterns, PatternBenefit benefit) { 272 patterns.add<FlattenGather, RemoveStrideFromGatherSource, 273 Gather1DToConditionalLoads>(patterns.getContext(), benefit); 274 } 275