1 //===- VectorDropLeadUnitDim.cpp - Conversion within the Vector dialect ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 10 #include "mlir/Dialect/Vector/IR/VectorOps.h" 11 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 12 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 13 #include "mlir/IR/Builders.h" 14 #include "mlir/IR/ImplicitLocOpBuilder.h" 15 #include "mlir/IR/TypeUtilities.h" 16 17 #define DEBUG_TYPE "vector-drop-unit-dim" 18 19 using namespace mlir; 20 using namespace mlir::vector; 21 22 // Trims leading one dimensions from `oldType` and returns the result type. 23 // Returns `vector<1xT>` if `oldType` only has one element. 24 static VectorType trimLeadingOneDims(VectorType oldType) { 25 ArrayRef<int64_t> oldShape = oldType.getShape(); 26 ArrayRef<int64_t> newShape = 27 oldShape.drop_while([](int64_t dim) { return dim == 1; }); 28 // Make sure we have at least 1 dimension per vector type requirements. 29 if (newShape.empty()) 30 newShape = oldShape.take_back(); 31 return VectorType::get(newShape, oldType.getElementType()); 32 } 33 34 /// Return a smallVector of size `rank` containing all zeros. 35 static SmallVector<int64_t> splatZero(int64_t rank) { 36 return SmallVector<int64_t>(rank, 0); 37 } 38 namespace { 39 40 // Casts away leading one dimensions in vector.extract_strided_slice's vector 41 // input by inserting vector.broadcast. 42 struct CastAwayExtractStridedSliceLeadingOneDim 43 : public OpRewritePattern<vector::ExtractStridedSliceOp> { 44 using OpRewritePattern::OpRewritePattern; 45 46 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, 47 PatternRewriter &rewriter) const override { 48 // vector.extract_strided_slice requires the input and output vector to have 49 // the same rank. Here we drop leading one dimensions from the input vector 50 // type to make sure we don't cause mismatch. 51 VectorType oldSrcType = extractOp.getVectorType(); 52 VectorType newSrcType = trimLeadingOneDims(oldSrcType); 53 54 if (newSrcType.getRank() == oldSrcType.getRank()) 55 return failure(); 56 57 int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank(); 58 59 VectorType oldDstType = extractOp.getType(); 60 VectorType newDstType = 61 VectorType::get(oldDstType.getShape().drop_front(dropCount), 62 oldDstType.getElementType()); 63 64 Location loc = extractOp.getLoc(); 65 66 Value newSrcVector = rewriter.create<vector::ExtractOp>( 67 loc, extractOp.getVector(), splatZero(dropCount)); 68 69 // The offsets/sizes/strides attribute can have a less number of elements 70 // than the input vector's rank: it is meant for the leading dimensions. 71 auto newOffsets = rewriter.getArrayAttr( 72 extractOp.getOffsets().getValue().drop_front(dropCount)); 73 auto newSizes = rewriter.getArrayAttr( 74 extractOp.getSizes().getValue().drop_front(dropCount)); 75 auto newStrides = rewriter.getArrayAttr( 76 extractOp.getStrides().getValue().drop_front(dropCount)); 77 78 auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>( 79 loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides); 80 81 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType, 82 newExtractOp); 83 84 return success(); 85 } 86 }; 87 88 // Casts away leading one dimensions in vector.insert_strided_slice's vector 89 // inputs by inserting vector.broadcast. 90 struct CastAwayInsertStridedSliceLeadingOneDim 91 : public OpRewritePattern<vector::InsertStridedSliceOp> { 92 using OpRewritePattern::OpRewritePattern; 93 94 LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp, 95 PatternRewriter &rewriter) const override { 96 VectorType oldSrcType = insertOp.getSourceVectorType(); 97 VectorType newSrcType = trimLeadingOneDims(oldSrcType); 98 VectorType oldDstType = insertOp.getDestVectorType(); 99 VectorType newDstType = trimLeadingOneDims(oldDstType); 100 101 int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank(); 102 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank(); 103 if (srcDropCount == 0 && dstDropCount == 0) 104 return failure(); 105 106 // Trim leading one dimensions from both operands. 107 Location loc = insertOp.getLoc(); 108 109 Value newSrcVector = rewriter.create<vector::ExtractOp>( 110 loc, insertOp.getSource(), splatZero(srcDropCount)); 111 Value newDstVector = rewriter.create<vector::ExtractOp>( 112 loc, insertOp.getDest(), splatZero(dstDropCount)); 113 114 auto newOffsets = rewriter.getArrayAttr( 115 insertOp.getOffsets().getValue().take_back(newDstType.getRank())); 116 auto newStrides = rewriter.getArrayAttr( 117 insertOp.getStrides().getValue().take_back(newSrcType.getRank())); 118 119 auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>( 120 loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides); 121 122 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType, 123 newInsertOp); 124 125 return success(); 126 } 127 }; 128 129 // Casts away leading one dimensions in vector.insert's vector inputs by 130 // inserting vector.broadcast. 131 struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> { 132 using OpRewritePattern::OpRewritePattern; 133 134 LogicalResult matchAndRewrite(vector::InsertOp insertOp, 135 PatternRewriter &rewriter) const override { 136 Type oldSrcType = insertOp.getSourceType(); 137 Type newSrcType = oldSrcType; 138 int64_t oldSrcRank = 0, newSrcRank = 0; 139 if (auto type = oldSrcType.dyn_cast<VectorType>()) { 140 newSrcType = trimLeadingOneDims(type); 141 oldSrcRank = type.getRank(); 142 newSrcRank = newSrcType.cast<VectorType>().getRank(); 143 } 144 145 VectorType oldDstType = insertOp.getDestVectorType(); 146 VectorType newDstType = trimLeadingOneDims(oldDstType); 147 148 int64_t srcDropCount = oldSrcRank - newSrcRank; 149 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank(); 150 if (srcDropCount == 0 && dstDropCount == 0) 151 return failure(); 152 153 // Trim leading one dimensions from both operands. 154 Location loc = insertOp.getLoc(); 155 156 Value newSrcVector = insertOp.getSource(); 157 if (oldSrcRank != 0) { 158 newSrcVector = rewriter.create<vector::ExtractOp>( 159 loc, insertOp.getSource(), splatZero(srcDropCount)); 160 } 161 Value newDstVector = rewriter.create<vector::ExtractOp>( 162 loc, insertOp.getDest(), splatZero(dstDropCount)); 163 164 unsigned oldPosRank = insertOp.getPosition().getValue().size(); 165 unsigned newPosRank = newDstType.getRank() - newSrcRank; 166 SmallVector<Attribute> newPositions = llvm::to_vector( 167 insertOp.getPosition().getValue().take_back(newPosRank)); 168 if (newPosRank > oldPosRank) { 169 auto zeroAttr = rewriter.getZeroAttr(rewriter.getI64Type()); 170 newPositions.resize(newPosRank, zeroAttr); 171 } 172 173 auto newInsertOp = rewriter.create<vector::InsertOp>( 174 loc, newDstType, newSrcVector, newDstVector, 175 rewriter.getArrayAttr(newPositions)); 176 177 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType, 178 newInsertOp); 179 180 return success(); 181 } 182 }; 183 184 // Turns vector.transfer_read on vector with leading 1 dimensions into 185 // vector.shape_cast followed by vector.transfer_read on vector without leading 186 // 1 dimensions. 187 struct CastAwayTransferReadLeadingOneDim 188 : public OpRewritePattern<vector::TransferReadOp> { 189 using OpRewritePattern::OpRewritePattern; 190 191 LogicalResult matchAndRewrite(vector::TransferReadOp read, 192 PatternRewriter &rewriter) const override { 193 // TODO: support 0-d corner case. 194 if (read.getTransferRank() == 0) 195 return failure(); 196 197 if (read.getMask()) 198 return failure(); 199 200 auto shapedType = read.getSource().getType().cast<ShapedType>(); 201 if (shapedType.getElementType() != read.getVectorType().getElementType()) 202 return failure(); 203 204 VectorType oldType = read.getVectorType(); 205 VectorType newType = trimLeadingOneDims(oldType); 206 207 if (newType == oldType) 208 return failure(); 209 210 AffineMap oldMap = read.getPermutationMap(); 211 ArrayRef<AffineExpr> newResults = 212 oldMap.getResults().take_back(newType.getRank()); 213 AffineMap newMap = 214 AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults, 215 rewriter.getContext()); 216 217 ArrayAttr inBoundsAttr; 218 if (read.getInBounds()) 219 inBoundsAttr = rewriter.getArrayAttr( 220 read.getInBoundsAttr().getValue().take_back(newType.getRank())); 221 222 auto newRead = rewriter.create<vector::TransferReadOp>( 223 read.getLoc(), newType, read.getSource(), read.getIndices(), 224 AffineMapAttr::get(newMap), read.getPadding(), /*mask=*/Value(), 225 inBoundsAttr); 226 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead); 227 228 return success(); 229 } 230 }; 231 232 // Turns vector.transfer_write on vector with leading 1 dimensions into 233 // vector.shape_cast followed by vector.transfer_write on vector without leading 234 // 1 dimensions. 235 struct CastAwayTransferWriteLeadingOneDim 236 : public OpRewritePattern<vector::TransferWriteOp> { 237 using OpRewritePattern::OpRewritePattern; 238 239 LogicalResult matchAndRewrite(vector::TransferWriteOp write, 240 PatternRewriter &rewriter) const override { 241 // TODO: support 0-d corner case. 242 if (write.getTransferRank() == 0) 243 return failure(); 244 245 if (write.getMask()) 246 return failure(); 247 248 auto shapedType = write.getSource().getType().dyn_cast<ShapedType>(); 249 if (shapedType.getElementType() != write.getVectorType().getElementType()) 250 return failure(); 251 252 VectorType oldType = write.getVectorType(); 253 VectorType newType = trimLeadingOneDims(oldType); 254 if (newType == oldType) 255 return failure(); 256 int64_t dropDim = oldType.getRank() - newType.getRank(); 257 258 AffineMap oldMap = write.getPermutationMap(); 259 ArrayRef<AffineExpr> newResults = 260 oldMap.getResults().take_back(newType.getRank()); 261 AffineMap newMap = 262 AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults, 263 rewriter.getContext()); 264 265 ArrayAttr inBoundsAttr; 266 if (write.getInBounds()) 267 inBoundsAttr = rewriter.getArrayAttr( 268 write.getInBoundsAttr().getValue().take_back(newType.getRank())); 269 270 auto newVector = rewriter.create<vector::ExtractOp>( 271 write.getLoc(), write.getVector(), splatZero(dropDim)); 272 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 273 write, newVector, write.getSource(), write.getIndices(), 274 AffineMapAttr::get(newMap), inBoundsAttr); 275 276 return success(); 277 } 278 }; 279 280 /// Turns vector.contract on vector with leading 1 dimensions into 281 /// vector.extract followed by vector.contract on vector without leading 282 /// 1 dimensions. Also performs tranpose of lhs and rhs operands if required 283 /// prior to extract. 284 struct CastAwayContractionLeadingOneDim 285 : public OpRewritePattern<vector::ContractionOp> { 286 using OpRewritePattern::OpRewritePattern; 287 288 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 289 PatternRewriter &rewriter) const override { 290 VectorType oldAccType = contractOp.getAccType().dyn_cast<VectorType>(); 291 if (oldAccType == nullptr) 292 return failure(); 293 if (oldAccType.getRank() < 2) 294 return failure(); 295 // TODO: implement masks. 296 if (llvm::size(contractOp.getMasks()) != 0) 297 return failure(); 298 if (oldAccType.getShape()[0] != 1) 299 return failure(); 300 // currently we support only dropping one dim but the pattern can be applied 301 // greedily to drop more. 302 int64_t dropDim = 1; 303 304 auto oldIndexingMaps = contractOp.getIndexingMaps(); 305 SmallVector<AffineMap> newIndexingMaps; 306 307 auto oldIteratorTypes = contractOp.getIteratorTypes(); 308 SmallVector<Attribute> newIteratorTypes; 309 310 int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0); 311 312 if (!isParallelIterator(oldIteratorTypes[dimToDrop])) 313 // only parallel type iterators can be dropped. 314 return failure(); 315 316 for (const auto &it : llvm::enumerate(oldIteratorTypes)) { 317 int64_t currDim = it.index(); 318 if (currDim == dimToDrop) 319 continue; 320 newIteratorTypes.push_back(it.value()); 321 } 322 323 SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(), 324 contractOp.getAcc()}; 325 SmallVector<Value> newOperands; 326 327 for (const auto &it : llvm::enumerate(oldIndexingMaps)) { 328 // Check if the dim to be dropped exists as a leading dim in the operand 329 // if it does then we use vector.extract to drop it. 330 bool validExtract = false; 331 SmallVector<AffineExpr> results; 332 auto map = it.value(); 333 int64_t orginalZeroDim = it.value().getDimPosition(0); 334 if (orginalZeroDim != dimToDrop) { 335 // There are two reasons to be in this path, 1. We need to 336 // tranpose the operand to make the dim to be dropped 337 // leading. 2. The dim to be dropped does not exist and in 338 // that case we dont want to add a unit tranpose but we must 339 // check all the indices to make sure this is the case. 340 bool tranposeNeeded = false; 341 SmallVector<int64_t> perm; 342 SmallVector<AffineExpr> transposeResults; 343 344 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 345 int64_t currDim = map.getDimPosition(i); 346 if (currDim == dimToDrop) { 347 tranposeNeeded = true; 348 perm.insert(perm.begin(), i); 349 auto targetExpr = rewriter.getAffineDimExpr(currDim); 350 transposeResults.insert(transposeResults.begin(), targetExpr); 351 } else { 352 perm.push_back(i); 353 auto targetExpr = rewriter.getAffineDimExpr(currDim); 354 transposeResults.push_back(targetExpr); 355 } 356 } 357 // Do the tranpose now if needed so that we can drop the 358 // correct dim using extract later. 359 if (tranposeNeeded) { 360 map = AffineMap::get(map.getNumDims(), 0, transposeResults, 361 contractOp.getContext()); 362 operands[it.index()] = rewriter.create<vector::TransposeOp>( 363 contractOp.getLoc(), operands[it.index()], perm); 364 } 365 } 366 // We have taken care to have the dim to be dropped be 367 // the leading dim. If its still not leading that means it 368 // does not exist in this operand and hence we do not need 369 // an extract. 370 if (map.getDimPosition(0) == dimToDrop) 371 validExtract = true; 372 373 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 374 int64_t currDim = map.getDimPosition(i); 375 if (currDim == dimToDrop) 376 // This is the dim we are dropping. 377 continue; 378 auto targetExpr = rewriter.getAffineDimExpr( 379 currDim < dimToDrop ? currDim : currDim - 1); 380 results.push_back(targetExpr); 381 } 382 newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results, 383 contractOp.getContext())); 384 // Extract if its a valid extraction, otherwise use the operand 385 // without extraction. 386 newOperands.push_back(validExtract 387 ? rewriter.create<vector::ExtractOp>( 388 contractOp.getLoc(), operands[it.index()], 389 splatZero(dropDim)) 390 : operands[it.index()]); 391 } 392 auto newContractOp = rewriter.create<vector::ContractionOp>( 393 contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2], 394 rewriter.getAffineMapArrayAttr(newIndexingMaps), 395 rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind()); 396 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 397 contractOp, contractOp->getResultTypes()[0], newContractOp); 398 return success(); 399 } 400 }; 401 402 class CastAwayElementwiseLeadingOneDim : public RewritePattern { 403 public: 404 CastAwayElementwiseLeadingOneDim(MLIRContext *context) 405 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 406 407 LogicalResult matchAndRewrite(Operation *op, 408 PatternRewriter &rewriter) const override { 409 if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) 410 return failure(); 411 auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>(); 412 if (!vecType) 413 return failure(); 414 VectorType newVecType = trimLeadingOneDims(vecType); 415 if (newVecType == vecType) 416 return failure(); 417 int64_t dropDim = vecType.getRank() - newVecType.getRank(); 418 SmallVector<Value, 4> newOperands; 419 for (Value operand : op->getOperands()) { 420 if (auto opVecType = operand.getType().dyn_cast<VectorType>()) { 421 newOperands.push_back(rewriter.create<vector::ExtractOp>( 422 op->getLoc(), operand, splatZero(dropDim))); 423 } else { 424 newOperands.push_back(operand); 425 } 426 } 427 Operation *newOp = 428 rewriter.create(op->getLoc(), op->getName().getIdentifier(), 429 newOperands, newVecType, op->getAttrs()); 430 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType, 431 newOp->getResult(0)); 432 return success(); 433 } 434 }; 435 436 } // namespace 437 438 void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns( 439 RewritePatternSet &patterns) { 440 patterns 441 .add<CastAwayExtractStridedSliceLeadingOneDim, 442 CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim, 443 CastAwayTransferReadLeadingOneDim, 444 CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim, 445 CastAwayContractionLeadingOneDim>(patterns.getContext()); 446 populateShapeCastFoldingPatterns(patterns); 447 } 448