1 //===- VectorDropLeadUnitDim.cpp - Conversion within the Vector dialect ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <numeric> 10 11 #include "mlir/Dialect/Arith/IR/Arith.h" 12 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 13 #include "mlir/Dialect/Vector/IR/VectorOps.h" 14 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 15 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 16 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/IR/TypeUtilities.h" 19 20 #define DEBUG_TYPE "vector-drop-unit-dim" 21 22 using namespace mlir; 23 using namespace mlir::vector; 24 25 // Trims leading one dimensions from `oldType` and returns the result type. 26 // Returns `vector<1xT>` if `oldType` only has one element. 27 static VectorType trimLeadingOneDims(VectorType oldType) { 28 ArrayRef<int64_t> oldShape = oldType.getShape(); 29 ArrayRef<int64_t> newShape = oldShape; 30 31 ArrayRef<bool> oldScalableDims = oldType.getScalableDims(); 32 ArrayRef<bool> newScalableDims = oldScalableDims; 33 34 while (!newShape.empty() && newShape.front() == 1 && 35 !newScalableDims.front()) { 36 newShape = newShape.drop_front(1); 37 newScalableDims = newScalableDims.drop_front(1); 38 } 39 40 // Make sure we have at least 1 dimension per vector type requirements. 41 if (newShape.empty()) { 42 newShape = oldShape.take_back(); 43 newScalableDims = oldType.getScalableDims().take_back(); 44 } 45 return VectorType::get(newShape, oldType.getElementType(), newScalableDims); 46 } 47 48 /// Return a smallVector of size `rank` containing all zeros. 49 static SmallVector<int64_t> splatZero(int64_t rank) { 50 return SmallVector<int64_t>(rank, 0); 51 } 52 namespace { 53 54 // Casts away leading one dimensions in vector.extract_strided_slice's vector 55 // input by inserting vector.broadcast. 56 struct CastAwayExtractStridedSliceLeadingOneDim 57 : public OpRewritePattern<vector::ExtractStridedSliceOp> { 58 using OpRewritePattern::OpRewritePattern; 59 60 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, 61 PatternRewriter &rewriter) const override { 62 // vector.extract_strided_slice requires the input and output vector to have 63 // the same rank. Here we drop leading one dimensions from the input vector 64 // type to make sure we don't cause mismatch. 65 VectorType oldSrcType = extractOp.getSourceVectorType(); 66 VectorType newSrcType = trimLeadingOneDims(oldSrcType); 67 68 if (newSrcType.getRank() == oldSrcType.getRank()) 69 return failure(); 70 71 int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank(); 72 73 VectorType oldDstType = extractOp.getType(); 74 VectorType newDstType = 75 VectorType::get(oldDstType.getShape().drop_front(dropCount), 76 oldDstType.getElementType()); 77 78 Location loc = extractOp.getLoc(); 79 80 Value newSrcVector = rewriter.create<vector::ExtractOp>( 81 loc, extractOp.getVector(), splatZero(dropCount)); 82 83 // The offsets/sizes/strides attribute can have a less number of elements 84 // than the input vector's rank: it is meant for the leading dimensions. 85 auto newOffsets = rewriter.getArrayAttr( 86 extractOp.getOffsets().getValue().drop_front(dropCount)); 87 auto newSizes = rewriter.getArrayAttr( 88 extractOp.getSizes().getValue().drop_front(dropCount)); 89 auto newStrides = rewriter.getArrayAttr( 90 extractOp.getStrides().getValue().drop_front(dropCount)); 91 92 auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>( 93 loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides); 94 95 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType, 96 newExtractOp); 97 98 return success(); 99 } 100 }; 101 102 // Casts away leading one dimensions in vector.insert_strided_slice's vector 103 // inputs by inserting vector.broadcast. 104 struct CastAwayInsertStridedSliceLeadingOneDim 105 : public OpRewritePattern<vector::InsertStridedSliceOp> { 106 using OpRewritePattern::OpRewritePattern; 107 108 LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp, 109 PatternRewriter &rewriter) const override { 110 VectorType oldSrcType = insertOp.getSourceVectorType(); 111 VectorType newSrcType = trimLeadingOneDims(oldSrcType); 112 VectorType oldDstType = insertOp.getDestVectorType(); 113 VectorType newDstType = trimLeadingOneDims(oldDstType); 114 115 int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank(); 116 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank(); 117 if (srcDropCount == 0 && dstDropCount == 0) 118 return failure(); 119 120 // Trim leading one dimensions from both operands. 121 Location loc = insertOp.getLoc(); 122 123 Value newSrcVector = rewriter.create<vector::ExtractOp>( 124 loc, insertOp.getSource(), splatZero(srcDropCount)); 125 Value newDstVector = rewriter.create<vector::ExtractOp>( 126 loc, insertOp.getDest(), splatZero(dstDropCount)); 127 128 auto newOffsets = rewriter.getArrayAttr( 129 insertOp.getOffsets().getValue().take_back(newDstType.getRank())); 130 auto newStrides = rewriter.getArrayAttr( 131 insertOp.getStrides().getValue().take_back(newSrcType.getRank())); 132 133 auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>( 134 loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides); 135 136 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType, 137 newInsertOp); 138 139 return success(); 140 } 141 }; 142 143 // Casts away leading one dimensions in vector.insert's vector inputs by 144 // inserting vector.broadcast. 145 struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> { 146 using OpRewritePattern::OpRewritePattern; 147 148 LogicalResult matchAndRewrite(vector::InsertOp insertOp, 149 PatternRewriter &rewriter) const override { 150 Type oldSrcType = insertOp.getSourceType(); 151 Type newSrcType = oldSrcType; 152 int64_t oldSrcRank = 0, newSrcRank = 0; 153 if (auto type = dyn_cast<VectorType>(oldSrcType)) { 154 newSrcType = trimLeadingOneDims(type); 155 oldSrcRank = type.getRank(); 156 newSrcRank = cast<VectorType>(newSrcType).getRank(); 157 } 158 159 VectorType oldDstType = insertOp.getDestVectorType(); 160 VectorType newDstType = trimLeadingOneDims(oldDstType); 161 162 int64_t srcDropCount = oldSrcRank - newSrcRank; 163 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank(); 164 if (srcDropCount == 0 && dstDropCount == 0) 165 return failure(); 166 167 // Trim leading one dimensions from both operands. 168 Location loc = insertOp.getLoc(); 169 170 Value newSrcVector = insertOp.getSource(); 171 if (oldSrcRank != 0) { 172 newSrcVector = rewriter.create<vector::ExtractOp>( 173 loc, insertOp.getSource(), splatZero(srcDropCount)); 174 } 175 Value newDstVector = rewriter.create<vector::ExtractOp>( 176 loc, insertOp.getDest(), splatZero(dstDropCount)); 177 178 // New position rank needs to be computed in two steps: (1) if destination 179 // type has leading unit dims, we also trim the position array accordingly, 180 // then (2) if source type also has leading unit dims, we need to append 181 // zeroes to the position array accordingly. 182 unsigned oldPosRank = insertOp.getNumIndices(); 183 unsigned newPosRank = std::max<int64_t>(0, oldPosRank - dstDropCount); 184 SmallVector<OpFoldResult> oldPosition = insertOp.getMixedPosition(); 185 SmallVector<OpFoldResult> newPosition = 186 llvm::to_vector(ArrayRef(oldPosition).take_back(newPosRank)); 187 newPosition.resize(newDstType.getRank() - newSrcRank, 188 rewriter.getI64IntegerAttr(0)); 189 190 auto newInsertOp = rewriter.create<vector::InsertOp>( 191 loc, newSrcVector, newDstVector, newPosition); 192 193 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType, 194 newInsertOp); 195 196 return success(); 197 } 198 }; 199 200 static Value dropUnitDimsFromMask(OpBuilder &b, Location loc, Value mask, 201 VectorType newType, AffineMap newMap, 202 VectorType oldMaskType) { 203 // Infer the type of the new mask from the new map. 204 VectorType newMaskType = inferTransferOpMaskType(newType, newMap); 205 206 // If the new mask is broadcastable to the old result type, we can safely 207 // use a `vector.extract` to get the new mask. Otherwise the best we can 208 // do is shape cast. 209 if (vector::isBroadcastableTo(newMaskType, oldMaskType) == 210 BroadcastableToResult::Success) { 211 int64_t dropDim = oldMaskType.getRank() - newMaskType.getRank(); 212 return b.create<vector::ExtractOp>(loc, mask, splatZero(dropDim)); 213 } 214 return b.create<vector::ShapeCastOp>(loc, newMaskType, mask); 215 } 216 217 // Turns vector.transfer_read on vector with leading 1 dimensions into 218 // vector.shape_cast followed by vector.transfer_read on vector without leading 219 // 1 dimensions. 220 struct CastAwayTransferReadLeadingOneDim 221 : public OpRewritePattern<vector::TransferReadOp> { 222 using OpRewritePattern::OpRewritePattern; 223 224 LogicalResult matchAndRewrite(vector::TransferReadOp read, 225 PatternRewriter &rewriter) const override { 226 // TODO: support 0-d corner case. 227 if (read.getTransferRank() == 0) 228 return failure(); 229 230 auto shapedType = cast<ShapedType>(read.getSource().getType()); 231 if (shapedType.getElementType() != read.getVectorType().getElementType()) 232 return failure(); 233 234 VectorType oldType = read.getVectorType(); 235 VectorType newType = trimLeadingOneDims(oldType); 236 237 if (newType == oldType) 238 return failure(); 239 240 AffineMap oldMap = read.getPermutationMap(); 241 ArrayRef<AffineExpr> newResults = 242 oldMap.getResults().take_back(newType.getRank()); 243 AffineMap newMap = 244 AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults, 245 rewriter.getContext()); 246 247 ArrayAttr inBoundsAttr; 248 if (read.getInBounds()) 249 inBoundsAttr = rewriter.getArrayAttr( 250 read.getInBoundsAttr().getValue().take_back(newType.getRank())); 251 252 Value mask = Value(); 253 if (read.getMask()) { 254 VectorType maskType = read.getMaskType(); 255 mask = dropUnitDimsFromMask(rewriter, read.getLoc(), read.getMask(), 256 newType, newMap, maskType); 257 } 258 259 auto newRead = rewriter.create<vector::TransferReadOp>( 260 read.getLoc(), newType, read.getSource(), read.getIndices(), 261 AffineMapAttr::get(newMap), read.getPadding(), mask, inBoundsAttr); 262 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead); 263 264 return success(); 265 } 266 }; 267 268 // Turns vector.transfer_write on vector with leading 1 dimensions into 269 // vector.shape_cast followed by vector.transfer_write on vector without leading 270 // 1 dimensions. 271 struct CastAwayTransferWriteLeadingOneDim 272 : public OpRewritePattern<vector::TransferWriteOp> { 273 using OpRewritePattern::OpRewritePattern; 274 275 LogicalResult matchAndRewrite(vector::TransferWriteOp write, 276 PatternRewriter &rewriter) const override { 277 // TODO: support 0-d corner case. 278 if (write.getTransferRank() == 0) 279 return failure(); 280 281 auto shapedType = dyn_cast<ShapedType>(write.getSource().getType()); 282 if (shapedType.getElementType() != write.getVectorType().getElementType()) 283 return failure(); 284 285 VectorType oldType = write.getVectorType(); 286 VectorType newType = trimLeadingOneDims(oldType); 287 if (newType == oldType) 288 return failure(); 289 int64_t dropDim = oldType.getRank() - newType.getRank(); 290 291 AffineMap oldMap = write.getPermutationMap(); 292 ArrayRef<AffineExpr> newResults = 293 oldMap.getResults().take_back(newType.getRank()); 294 AffineMap newMap = 295 AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults, 296 rewriter.getContext()); 297 298 ArrayAttr inBoundsAttr; 299 if (write.getInBounds()) 300 inBoundsAttr = rewriter.getArrayAttr( 301 write.getInBoundsAttr().getValue().take_back(newType.getRank())); 302 303 auto newVector = rewriter.create<vector::ExtractOp>( 304 write.getLoc(), write.getVector(), splatZero(dropDim)); 305 306 if (write.getMask()) { 307 VectorType maskType = write.getMaskType(); 308 Value newMask = dropUnitDimsFromMask( 309 rewriter, write.getLoc(), write.getMask(), newType, newMap, maskType); 310 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 311 write, newVector, write.getSource(), write.getIndices(), 312 AffineMapAttr::get(newMap), newMask, inBoundsAttr); 313 return success(); 314 } 315 316 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 317 write, newVector, write.getSource(), write.getIndices(), 318 AffineMapAttr::get(newMap), inBoundsAttr); 319 return success(); 320 } 321 }; 322 323 } // namespace 324 325 LogicalResult 326 mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, 327 RewriterBase &rewriter) { 328 VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType()); 329 if (oldAccType == nullptr) 330 return failure(); 331 if (oldAccType.getRank() < 2) 332 return failure(); 333 if (oldAccType.getShape()[0] != 1) 334 return failure(); 335 // currently we support only dropping one dim but the pattern can be applied 336 // greedily to drop more. 337 int64_t dropDim = 1; 338 339 auto oldIndexingMaps = contractOp.getIndexingMapsArray(); 340 SmallVector<AffineMap> newIndexingMaps; 341 342 auto oldIteratorTypes = contractOp.getIteratorTypes(); 343 SmallVector<Attribute> newIteratorTypes; 344 345 int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0); 346 347 if (!isParallelIterator(oldIteratorTypes[dimToDrop])) 348 // only parallel type iterators can be dropped. 349 return failure(); 350 351 for (const auto &it : llvm::enumerate(oldIteratorTypes)) { 352 int64_t currDim = it.index(); 353 if (currDim == dimToDrop) 354 continue; 355 newIteratorTypes.push_back(it.value()); 356 } 357 358 SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(), 359 contractOp.getAcc()}; 360 SmallVector<Value> newOperands; 361 362 for (const auto &it : llvm::enumerate(oldIndexingMaps)) { 363 // Check if the dim to be dropped exists as a leading dim in the operand 364 // if it does then we use vector.extract to drop it. 365 bool validExtract = false; 366 SmallVector<AffineExpr> results; 367 auto map = it.value(); 368 int64_t orginalZeroDim = it.value().getDimPosition(0); 369 if (orginalZeroDim != dimToDrop) { 370 // There are two reasons to be in this path, 1. We need to 371 // tranpose the operand to make the dim to be dropped 372 // leading. 2. The dim to be dropped does not exist and in 373 // that case we dont want to add a unit tranpose but we must 374 // check all the indices to make sure this is the case. 375 bool tranposeNeeded = false; 376 SmallVector<int64_t> perm; 377 SmallVector<AffineExpr> transposeResults; 378 379 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 380 int64_t currDim = map.getDimPosition(i); 381 if (currDim == dimToDrop) { 382 tranposeNeeded = true; 383 perm.insert(perm.begin(), i); 384 auto targetExpr = rewriter.getAffineDimExpr(currDim); 385 transposeResults.insert(transposeResults.begin(), targetExpr); 386 } else { 387 perm.push_back(i); 388 auto targetExpr = rewriter.getAffineDimExpr(currDim); 389 transposeResults.push_back(targetExpr); 390 } 391 } 392 // Do the tranpose now if needed so that we can drop the 393 // correct dim using extract later. 394 if (tranposeNeeded) { 395 map = AffineMap::get(map.getNumDims(), 0, transposeResults, 396 contractOp.getContext()); 397 operands[it.index()] = rewriter.create<vector::TransposeOp>( 398 contractOp.getLoc(), operands[it.index()], perm); 399 } 400 } 401 // We have taken care to have the dim to be dropped be 402 // the leading dim. If its still not leading that means it 403 // does not exist in this operand and hence we do not need 404 // an extract. 405 if (map.getDimPosition(0) == dimToDrop) 406 validExtract = true; 407 408 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 409 int64_t currDim = map.getDimPosition(i); 410 if (currDim == dimToDrop) 411 // This is the dim we are dropping. 412 continue; 413 auto targetExpr = rewriter.getAffineDimExpr( 414 currDim < dimToDrop ? currDim : currDim - 1); 415 results.push_back(targetExpr); 416 } 417 newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results, 418 contractOp.getContext())); 419 // Extract if its a valid extraction, otherwise use the operand 420 // without extraction. 421 newOperands.push_back( 422 validExtract ? rewriter.create<vector::ExtractOp>(contractOp.getLoc(), 423 operands[it.index()], 424 splatZero(dropDim)) 425 : operands[it.index()]); 426 } 427 auto newContractOp = rewriter.create<vector::ContractionOp>( 428 contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2], 429 rewriter.getAffineMapArrayAttr(newIndexingMaps), 430 rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind()); 431 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 432 contractOp, contractOp->getResultTypes()[0], newContractOp); 433 return success(); 434 } 435 436 namespace { 437 438 /// Turns vector.contract on vector with leading 1 dimensions into 439 /// vector.extract followed by vector.contract on vector without leading 440 /// 1 dimensions. Also performs tranpose of lhs and rhs operands if required 441 /// prior to extract. 442 struct CastAwayContractionLeadingOneDim 443 : public OpRewritePattern<vector::ContractionOp> { 444 using OpRewritePattern::OpRewritePattern; 445 446 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 447 PatternRewriter &rewriter) const override { 448 return castAwayContractionLeadingOneDim(contractOp, rewriter); 449 } 450 }; 451 452 /// Looks at elementwise operations on vectors with at least one leading 453 /// dimension equal 1, e.g. vector<1x[4]x1xf32> (but not vector<2x[4]x1xf32>), 454 /// and cast aways the leading one dimensions (_plural_) and then broadcasts 455 /// the results. 456 /// 457 /// Example before: 458 /// %1 = arith.mulf %arg0, %arg1 : vector<1x4x1xf32> 459 /// Example after: 460 /// %2 = arith.mulf %0, %1 : vector<4x1xf32> 461 /// %3 = vector.broadcast %2 : vector<4x1xf32> to vector<1x4x1xf32> 462 /// 463 /// Does support scalable vectors. 464 class CastAwayElementwiseLeadingOneDim : public RewritePattern { 465 public: 466 CastAwayElementwiseLeadingOneDim(MLIRContext *context, 467 PatternBenefit benefit = 1) 468 : RewritePattern(MatchAnyOpTypeTag(), benefit, context) {} 469 470 LogicalResult matchAndRewrite(Operation *op, 471 PatternRewriter &rewriter) const override { 472 if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) 473 return failure(); 474 auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0]); 475 if (!vecType) 476 return failure(); 477 VectorType newVecType = trimLeadingOneDims(vecType); 478 if (newVecType == vecType) 479 return failure(); 480 int64_t dropDim = vecType.getRank() - newVecType.getRank(); 481 SmallVector<Value, 4> newOperands; 482 for (Value operand : op->getOperands()) { 483 if (auto opVecType = dyn_cast<VectorType>(operand.getType())) { 484 newOperands.push_back(rewriter.create<vector::ExtractOp>( 485 op->getLoc(), operand, splatZero(dropDim))); 486 } else { 487 newOperands.push_back(operand); 488 } 489 } 490 Operation *newOp = 491 rewriter.create(op->getLoc(), op->getName().getIdentifier(), 492 newOperands, newVecType, op->getAttrs()); 493 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType, 494 newOp->getResult(0)); 495 return success(); 496 } 497 }; 498 499 // Drops leading 1 dimensions from vector.constant_mask and inserts a 500 // vector.broadcast back to the original shape. 501 struct CastAwayConstantMaskLeadingOneDim 502 : public OpRewritePattern<vector::ConstantMaskOp> { 503 using OpRewritePattern::OpRewritePattern; 504 505 LogicalResult matchAndRewrite(vector::ConstantMaskOp mask, 506 PatternRewriter &rewriter) const override { 507 VectorType oldType = mask.getType(); 508 VectorType newType = trimLeadingOneDims(oldType); 509 510 if (newType == oldType) 511 return failure(); 512 513 int64_t dropDim = oldType.getRank() - newType.getRank(); 514 SmallVector<int64_t> dimSizes; 515 for (auto attr : mask.getMaskDimSizes()) 516 dimSizes.push_back(llvm::cast<IntegerAttr>(attr).getInt()); 517 518 // If any of the dropped unit dims has a size of `0`, the entire mask is a 519 // zero mask, else the unit dim has no effect on the mask. 520 int64_t flatLeadingSize = 521 std::accumulate(dimSizes.begin(), dimSizes.begin() + dropDim + 1, 522 static_cast<int64_t>(1), std::multiplies<int64_t>()); 523 SmallVector<int64_t> newDimSizes({flatLeadingSize}); 524 newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end()); 525 526 auto newMask = rewriter.create<vector::ConstantMaskOp>( 527 mask.getLoc(), newType, rewriter.getI64ArrayAttr(newDimSizes)); 528 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(mask, oldType, newMask); 529 return success(); 530 } 531 }; 532 533 } // namespace 534 535 void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns( 536 RewritePatternSet &patterns, PatternBenefit benefit) { 537 patterns 538 .add<CastAwayExtractStridedSliceLeadingOneDim, 539 CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim, 540 CastAwayConstantMaskLeadingOneDim, CastAwayTransferReadLeadingOneDim, 541 CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim, 542 CastAwayContractionLeadingOneDim>(patterns.getContext(), benefit); 543 populateShapeCastFoldingPatterns(patterns, benefit); 544 } 545