1 //===- VectorDropLeadUnitDim.cpp - Conversion within the Vector dialect ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 10 #include "mlir/Dialect/Vector/IR/VectorOps.h" 11 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 12 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 13 #include "mlir/IR/Builders.h" 14 #include "mlir/IR/ImplicitLocOpBuilder.h" 15 #include "mlir/IR/TypeUtilities.h" 16 17 #define DEBUG_TYPE "vector-drop-unit-dim" 18 19 using namespace mlir; 20 using namespace mlir::vector; 21 22 // Trims leading one dimensions from `oldType` and returns the result type. 23 // Returns `vector<1xT>` if `oldType` only has one element. 24 static VectorType trimLeadingOneDims(VectorType oldType) { 25 ArrayRef<int64_t> oldShape = oldType.getShape(); 26 ArrayRef<int64_t> newShape = 27 oldShape.drop_while([](int64_t dim) { return dim == 1; }); 28 // Make sure we have at least 1 dimension per vector type requirements. 29 if (newShape.empty()) 30 newShape = oldShape.take_back(); 31 return VectorType::get(newShape, oldType.getElementType()); 32 } 33 34 /// Return a smallVector of size `rank` containing all zeros. 35 static SmallVector<int64_t> splatZero(int64_t rank) { 36 return SmallVector<int64_t>(rank, 0); 37 } 38 namespace { 39 40 // Casts away leading one dimensions in vector.extract_strided_slice's vector 41 // input by inserting vector.broadcast. 42 struct CastAwayExtractStridedSliceLeadingOneDim 43 : public OpRewritePattern<vector::ExtractStridedSliceOp> { 44 using OpRewritePattern::OpRewritePattern; 45 46 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, 47 PatternRewriter &rewriter) const override { 48 // vector.extract_strided_slice requires the input and output vector to have 49 // the same rank. Here we drop leading one dimensions from the input vector 50 // type to make sure we don't cause mismatch. 51 VectorType oldSrcType = extractOp.getSourceVectorType(); 52 VectorType newSrcType = trimLeadingOneDims(oldSrcType); 53 54 if (newSrcType.getRank() == oldSrcType.getRank()) 55 return failure(); 56 57 int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank(); 58 59 VectorType oldDstType = extractOp.getType(); 60 VectorType newDstType = 61 VectorType::get(oldDstType.getShape().drop_front(dropCount), 62 oldDstType.getElementType()); 63 64 Location loc = extractOp.getLoc(); 65 66 Value newSrcVector = rewriter.create<vector::ExtractOp>( 67 loc, extractOp.getVector(), splatZero(dropCount)); 68 69 // The offsets/sizes/strides attribute can have a less number of elements 70 // than the input vector's rank: it is meant for the leading dimensions. 71 auto newOffsets = rewriter.getArrayAttr( 72 extractOp.getOffsets().getValue().drop_front(dropCount)); 73 auto newSizes = rewriter.getArrayAttr( 74 extractOp.getSizes().getValue().drop_front(dropCount)); 75 auto newStrides = rewriter.getArrayAttr( 76 extractOp.getStrides().getValue().drop_front(dropCount)); 77 78 auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>( 79 loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides); 80 81 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType, 82 newExtractOp); 83 84 return success(); 85 } 86 }; 87 88 // Casts away leading one dimensions in vector.insert_strided_slice's vector 89 // inputs by inserting vector.broadcast. 90 struct CastAwayInsertStridedSliceLeadingOneDim 91 : public OpRewritePattern<vector::InsertStridedSliceOp> { 92 using OpRewritePattern::OpRewritePattern; 93 94 LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp, 95 PatternRewriter &rewriter) const override { 96 VectorType oldSrcType = insertOp.getSourceVectorType(); 97 VectorType newSrcType = trimLeadingOneDims(oldSrcType); 98 VectorType oldDstType = insertOp.getDestVectorType(); 99 VectorType newDstType = trimLeadingOneDims(oldDstType); 100 101 int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank(); 102 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank(); 103 if (srcDropCount == 0 && dstDropCount == 0) 104 return failure(); 105 106 // Trim leading one dimensions from both operands. 107 Location loc = insertOp.getLoc(); 108 109 Value newSrcVector = rewriter.create<vector::ExtractOp>( 110 loc, insertOp.getSource(), splatZero(srcDropCount)); 111 Value newDstVector = rewriter.create<vector::ExtractOp>( 112 loc, insertOp.getDest(), splatZero(dstDropCount)); 113 114 auto newOffsets = rewriter.getArrayAttr( 115 insertOp.getOffsets().getValue().take_back(newDstType.getRank())); 116 auto newStrides = rewriter.getArrayAttr( 117 insertOp.getStrides().getValue().take_back(newSrcType.getRank())); 118 119 auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>( 120 loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides); 121 122 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType, 123 newInsertOp); 124 125 return success(); 126 } 127 }; 128 129 // Casts away leading one dimensions in vector.insert's vector inputs by 130 // inserting vector.broadcast. 131 struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> { 132 using OpRewritePattern::OpRewritePattern; 133 134 LogicalResult matchAndRewrite(vector::InsertOp insertOp, 135 PatternRewriter &rewriter) const override { 136 Type oldSrcType = insertOp.getSourceType(); 137 Type newSrcType = oldSrcType; 138 int64_t oldSrcRank = 0, newSrcRank = 0; 139 if (auto type = oldSrcType.dyn_cast<VectorType>()) { 140 newSrcType = trimLeadingOneDims(type); 141 oldSrcRank = type.getRank(); 142 newSrcRank = newSrcType.cast<VectorType>().getRank(); 143 } 144 145 VectorType oldDstType = insertOp.getDestVectorType(); 146 VectorType newDstType = trimLeadingOneDims(oldDstType); 147 148 int64_t srcDropCount = oldSrcRank - newSrcRank; 149 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank(); 150 if (srcDropCount == 0 && dstDropCount == 0) 151 return failure(); 152 153 // Trim leading one dimensions from both operands. 154 Location loc = insertOp.getLoc(); 155 156 Value newSrcVector = insertOp.getSource(); 157 if (oldSrcRank != 0) { 158 newSrcVector = rewriter.create<vector::ExtractOp>( 159 loc, insertOp.getSource(), splatZero(srcDropCount)); 160 } 161 Value newDstVector = rewriter.create<vector::ExtractOp>( 162 loc, insertOp.getDest(), splatZero(dstDropCount)); 163 164 unsigned oldPosRank = insertOp.getPosition().getValue().size(); 165 unsigned newPosRank = newDstType.getRank() - newSrcRank; 166 SmallVector<Attribute> newPositions = llvm::to_vector( 167 insertOp.getPosition().getValue().take_back(newPosRank)); 168 if (newPosRank > oldPosRank) { 169 auto zeroAttr = rewriter.getZeroAttr(rewriter.getI64Type()); 170 newPositions.resize(newPosRank, zeroAttr); 171 } 172 173 auto newInsertOp = rewriter.create<vector::InsertOp>( 174 loc, newDstType, newSrcVector, newDstVector, 175 rewriter.getArrayAttr(newPositions)); 176 177 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType, 178 newInsertOp); 179 180 return success(); 181 } 182 }; 183 184 // Turns vector.transfer_read on vector with leading 1 dimensions into 185 // vector.shape_cast followed by vector.transfer_read on vector without leading 186 // 1 dimensions. 187 struct CastAwayTransferReadLeadingOneDim 188 : public OpRewritePattern<vector::TransferReadOp> { 189 using OpRewritePattern::OpRewritePattern; 190 191 LogicalResult matchAndRewrite(vector::TransferReadOp read, 192 PatternRewriter &rewriter) const override { 193 // TODO: support 0-d corner case. 194 if (read.getTransferRank() == 0) 195 return failure(); 196 197 if (read.getMask()) 198 return failure(); 199 200 auto shapedType = read.getSource().getType().cast<ShapedType>(); 201 if (shapedType.getElementType() != read.getVectorType().getElementType()) 202 return failure(); 203 204 VectorType oldType = read.getVectorType(); 205 VectorType newType = trimLeadingOneDims(oldType); 206 207 if (newType == oldType) 208 return failure(); 209 210 AffineMap oldMap = read.getPermutationMap(); 211 ArrayRef<AffineExpr> newResults = 212 oldMap.getResults().take_back(newType.getRank()); 213 AffineMap newMap = 214 AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults, 215 rewriter.getContext()); 216 217 ArrayAttr inBoundsAttr; 218 if (read.getInBounds()) 219 inBoundsAttr = rewriter.getArrayAttr( 220 read.getInBoundsAttr().getValue().take_back(newType.getRank())); 221 222 auto newRead = rewriter.create<vector::TransferReadOp>( 223 read.getLoc(), newType, read.getSource(), read.getIndices(), 224 AffineMapAttr::get(newMap), read.getPadding(), /*mask=*/Value(), 225 inBoundsAttr); 226 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead); 227 228 return success(); 229 } 230 }; 231 232 // Turns vector.transfer_write on vector with leading 1 dimensions into 233 // vector.shape_cast followed by vector.transfer_write on vector without leading 234 // 1 dimensions. 235 struct CastAwayTransferWriteLeadingOneDim 236 : public OpRewritePattern<vector::TransferWriteOp> { 237 using OpRewritePattern::OpRewritePattern; 238 239 LogicalResult matchAndRewrite(vector::TransferWriteOp write, 240 PatternRewriter &rewriter) const override { 241 // TODO: support 0-d corner case. 242 if (write.getTransferRank() == 0) 243 return failure(); 244 245 if (write.getMask()) 246 return failure(); 247 248 auto shapedType = write.getSource().getType().dyn_cast<ShapedType>(); 249 if (shapedType.getElementType() != write.getVectorType().getElementType()) 250 return failure(); 251 252 VectorType oldType = write.getVectorType(); 253 VectorType newType = trimLeadingOneDims(oldType); 254 if (newType == oldType) 255 return failure(); 256 int64_t dropDim = oldType.getRank() - newType.getRank(); 257 258 AffineMap oldMap = write.getPermutationMap(); 259 ArrayRef<AffineExpr> newResults = 260 oldMap.getResults().take_back(newType.getRank()); 261 AffineMap newMap = 262 AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults, 263 rewriter.getContext()); 264 265 ArrayAttr inBoundsAttr; 266 if (write.getInBounds()) 267 inBoundsAttr = rewriter.getArrayAttr( 268 write.getInBoundsAttr().getValue().take_back(newType.getRank())); 269 270 auto newVector = rewriter.create<vector::ExtractOp>( 271 write.getLoc(), write.getVector(), splatZero(dropDim)); 272 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 273 write, newVector, write.getSource(), write.getIndices(), 274 AffineMapAttr::get(newMap), inBoundsAttr); 275 276 return success(); 277 } 278 }; 279 280 /// Turns vector.contract on vector with leading 1 dimensions into 281 /// vector.extract followed by vector.contract on vector without leading 282 /// 1 dimensions. Also performs tranpose of lhs and rhs operands if required 283 /// prior to extract. 284 struct CastAwayContractionLeadingOneDim 285 : public OpRewritePattern<vector::ContractionOp> { 286 using OpRewritePattern::OpRewritePattern; 287 288 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 289 PatternRewriter &rewriter) const override { 290 VectorType oldAccType = contractOp.getAccType().dyn_cast<VectorType>(); 291 if (oldAccType == nullptr) 292 return failure(); 293 if (oldAccType.getRank() < 2) 294 return failure(); 295 if (oldAccType.getShape()[0] != 1) 296 return failure(); 297 // currently we support only dropping one dim but the pattern can be applied 298 // greedily to drop more. 299 int64_t dropDim = 1; 300 301 auto oldIndexingMaps = contractOp.getIndexingMapsArray(); 302 SmallVector<AffineMap> newIndexingMaps; 303 304 auto oldIteratorTypes = contractOp.getIteratorTypes(); 305 SmallVector<Attribute> newIteratorTypes; 306 307 int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0); 308 309 if (!isParallelIterator(oldIteratorTypes[dimToDrop])) 310 // only parallel type iterators can be dropped. 311 return failure(); 312 313 for (const auto &it : llvm::enumerate(oldIteratorTypes)) { 314 int64_t currDim = it.index(); 315 if (currDim == dimToDrop) 316 continue; 317 newIteratorTypes.push_back(it.value()); 318 } 319 320 SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(), 321 contractOp.getAcc()}; 322 SmallVector<Value> newOperands; 323 324 for (const auto &it : llvm::enumerate(oldIndexingMaps)) { 325 // Check if the dim to be dropped exists as a leading dim in the operand 326 // if it does then we use vector.extract to drop it. 327 bool validExtract = false; 328 SmallVector<AffineExpr> results; 329 auto map = it.value(); 330 int64_t orginalZeroDim = it.value().getDimPosition(0); 331 if (orginalZeroDim != dimToDrop) { 332 // There are two reasons to be in this path, 1. We need to 333 // tranpose the operand to make the dim to be dropped 334 // leading. 2. The dim to be dropped does not exist and in 335 // that case we dont want to add a unit tranpose but we must 336 // check all the indices to make sure this is the case. 337 bool tranposeNeeded = false; 338 SmallVector<int64_t> perm; 339 SmallVector<AffineExpr> transposeResults; 340 341 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 342 int64_t currDim = map.getDimPosition(i); 343 if (currDim == dimToDrop) { 344 tranposeNeeded = true; 345 perm.insert(perm.begin(), i); 346 auto targetExpr = rewriter.getAffineDimExpr(currDim); 347 transposeResults.insert(transposeResults.begin(), targetExpr); 348 } else { 349 perm.push_back(i); 350 auto targetExpr = rewriter.getAffineDimExpr(currDim); 351 transposeResults.push_back(targetExpr); 352 } 353 } 354 // Do the tranpose now if needed so that we can drop the 355 // correct dim using extract later. 356 if (tranposeNeeded) { 357 map = AffineMap::get(map.getNumDims(), 0, transposeResults, 358 contractOp.getContext()); 359 operands[it.index()] = rewriter.create<vector::TransposeOp>( 360 contractOp.getLoc(), operands[it.index()], perm); 361 } 362 } 363 // We have taken care to have the dim to be dropped be 364 // the leading dim. If its still not leading that means it 365 // does not exist in this operand and hence we do not need 366 // an extract. 367 if (map.getDimPosition(0) == dimToDrop) 368 validExtract = true; 369 370 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 371 int64_t currDim = map.getDimPosition(i); 372 if (currDim == dimToDrop) 373 // This is the dim we are dropping. 374 continue; 375 auto targetExpr = rewriter.getAffineDimExpr( 376 currDim < dimToDrop ? currDim : currDim - 1); 377 results.push_back(targetExpr); 378 } 379 newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results, 380 contractOp.getContext())); 381 // Extract if its a valid extraction, otherwise use the operand 382 // without extraction. 383 newOperands.push_back(validExtract 384 ? rewriter.create<vector::ExtractOp>( 385 contractOp.getLoc(), operands[it.index()], 386 splatZero(dropDim)) 387 : operands[it.index()]); 388 } 389 auto newContractOp = rewriter.create<vector::ContractionOp>( 390 contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2], 391 rewriter.getAffineMapArrayAttr(newIndexingMaps), 392 rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind()); 393 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 394 contractOp, contractOp->getResultTypes()[0], newContractOp); 395 return success(); 396 } 397 }; 398 399 class CastAwayElementwiseLeadingOneDim : public RewritePattern { 400 public: 401 CastAwayElementwiseLeadingOneDim(MLIRContext *context, 402 PatternBenefit benefit = 1) 403 : RewritePattern(MatchAnyOpTypeTag(), benefit, context) {} 404 405 LogicalResult matchAndRewrite(Operation *op, 406 PatternRewriter &rewriter) const override { 407 if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) 408 return failure(); 409 auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>(); 410 if (!vecType) 411 return failure(); 412 VectorType newVecType = trimLeadingOneDims(vecType); 413 if (newVecType == vecType) 414 return failure(); 415 int64_t dropDim = vecType.getRank() - newVecType.getRank(); 416 SmallVector<Value, 4> newOperands; 417 for (Value operand : op->getOperands()) { 418 if (auto opVecType = operand.getType().dyn_cast<VectorType>()) { 419 newOperands.push_back(rewriter.create<vector::ExtractOp>( 420 op->getLoc(), operand, splatZero(dropDim))); 421 } else { 422 newOperands.push_back(operand); 423 } 424 } 425 Operation *newOp = 426 rewriter.create(op->getLoc(), op->getName().getIdentifier(), 427 newOperands, newVecType, op->getAttrs()); 428 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType, 429 newOp->getResult(0)); 430 return success(); 431 } 432 }; 433 434 } // namespace 435 436 void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns( 437 RewritePatternSet &patterns, PatternBenefit benefit) { 438 patterns 439 .add<CastAwayExtractStridedSliceLeadingOneDim, 440 CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim, 441 CastAwayTransferReadLeadingOneDim, 442 CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim, 443 CastAwayContractionLeadingOneDim>(patterns.getContext(), benefit); 444 populateShapeCastFoldingPatterns(patterns, benefit); 445 } 446