1 //===- VectorDropLeadUnitDim.cpp - Conversion within the Vector dialect ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <numeric> 10 11 #include "mlir/Dialect/Arith/IR/Arith.h" 12 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 13 #include "mlir/Dialect/Vector/IR/VectorOps.h" 14 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 15 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 16 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/IR/TypeUtilities.h" 19 20 #define DEBUG_TYPE "vector-drop-unit-dim" 21 22 using namespace mlir; 23 using namespace mlir::vector; 24 25 // Trims leading one dimensions from `oldType` and returns the result type. 26 // Returns `vector<1xT>` if `oldType` only has one element. 27 static VectorType trimLeadingOneDims(VectorType oldType) { 28 ArrayRef<int64_t> oldShape = oldType.getShape(); 29 ArrayRef<int64_t> newShape = oldShape; 30 31 ArrayRef<bool> oldScalableDims = oldType.getScalableDims(); 32 ArrayRef<bool> newScalableDims = oldScalableDims; 33 34 while (!newShape.empty() && newShape.front() == 1 && 35 !newScalableDims.front()) { 36 newShape = newShape.drop_front(1); 37 newScalableDims = newScalableDims.drop_front(1); 38 } 39 40 // Make sure we have at least 1 dimension per vector type requirements. 41 if (newShape.empty()) { 42 newShape = oldShape.take_back(); 43 newScalableDims = oldType.getScalableDims().take_back(); 44 } 45 return VectorType::get(newShape, oldType.getElementType(), newScalableDims); 46 } 47 48 /// Return a smallVector of size `rank` containing all zeros. 49 static SmallVector<int64_t> splatZero(int64_t rank) { 50 return SmallVector<int64_t>(rank, 0); 51 } 52 namespace { 53 54 // Casts away leading one dimensions in vector.extract_strided_slice's vector 55 // input by inserting vector.broadcast. 56 struct CastAwayExtractStridedSliceLeadingOneDim 57 : public OpRewritePattern<vector::ExtractStridedSliceOp> { 58 using OpRewritePattern::OpRewritePattern; 59 60 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, 61 PatternRewriter &rewriter) const override { 62 // vector.extract_strided_slice requires the input and output vector to have 63 // the same rank. Here we drop leading one dimensions from the input vector 64 // type to make sure we don't cause mismatch. 65 VectorType oldSrcType = extractOp.getSourceVectorType(); 66 VectorType newSrcType = trimLeadingOneDims(oldSrcType); 67 68 if (newSrcType.getRank() == oldSrcType.getRank()) 69 return failure(); 70 71 int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank(); 72 73 VectorType oldDstType = extractOp.getType(); 74 VectorType newDstType = 75 VectorType::get(oldDstType.getShape().drop_front(dropCount), 76 oldDstType.getElementType(), 77 oldDstType.getScalableDims().drop_front(dropCount)); 78 79 Location loc = extractOp.getLoc(); 80 81 Value newSrcVector = rewriter.create<vector::ExtractOp>( 82 loc, extractOp.getVector(), splatZero(dropCount)); 83 84 // The offsets/sizes/strides attribute can have a less number of elements 85 // than the input vector's rank: it is meant for the leading dimensions. 86 auto newOffsets = rewriter.getArrayAttr( 87 extractOp.getOffsets().getValue().drop_front(dropCount)); 88 auto newSizes = rewriter.getArrayAttr( 89 extractOp.getSizes().getValue().drop_front(dropCount)); 90 auto newStrides = rewriter.getArrayAttr( 91 extractOp.getStrides().getValue().drop_front(dropCount)); 92 93 auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>( 94 loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides); 95 96 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType, 97 newExtractOp); 98 99 return success(); 100 } 101 }; 102 103 // Casts away leading one dimensions in vector.insert_strided_slice's vector 104 // inputs by inserting vector.broadcast. 105 struct CastAwayInsertStridedSliceLeadingOneDim 106 : public OpRewritePattern<vector::InsertStridedSliceOp> { 107 using OpRewritePattern::OpRewritePattern; 108 109 LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp, 110 PatternRewriter &rewriter) const override { 111 VectorType oldSrcType = insertOp.getSourceVectorType(); 112 VectorType newSrcType = trimLeadingOneDims(oldSrcType); 113 VectorType oldDstType = insertOp.getDestVectorType(); 114 VectorType newDstType = trimLeadingOneDims(oldDstType); 115 116 int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank(); 117 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank(); 118 if (srcDropCount == 0 && dstDropCount == 0) 119 return failure(); 120 121 // Trim leading one dimensions from both operands. 122 Location loc = insertOp.getLoc(); 123 124 Value newSrcVector = rewriter.create<vector::ExtractOp>( 125 loc, insertOp.getSource(), splatZero(srcDropCount)); 126 Value newDstVector = rewriter.create<vector::ExtractOp>( 127 loc, insertOp.getDest(), splatZero(dstDropCount)); 128 129 auto newOffsets = rewriter.getArrayAttr( 130 insertOp.getOffsets().getValue().take_back(newDstType.getRank())); 131 auto newStrides = rewriter.getArrayAttr( 132 insertOp.getStrides().getValue().take_back(newSrcType.getRank())); 133 134 auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>( 135 loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides); 136 137 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType, 138 newInsertOp); 139 140 return success(); 141 } 142 }; 143 144 // Casts away leading one dimensions in vector.insert's vector inputs by 145 // inserting vector.broadcast. 146 struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> { 147 using OpRewritePattern::OpRewritePattern; 148 149 LogicalResult matchAndRewrite(vector::InsertOp insertOp, 150 PatternRewriter &rewriter) const override { 151 Type oldSrcType = insertOp.getSourceType(); 152 Type newSrcType = oldSrcType; 153 int64_t oldSrcRank = 0, newSrcRank = 0; 154 if (auto type = dyn_cast<VectorType>(oldSrcType)) { 155 newSrcType = trimLeadingOneDims(type); 156 oldSrcRank = type.getRank(); 157 newSrcRank = cast<VectorType>(newSrcType).getRank(); 158 } 159 160 VectorType oldDstType = insertOp.getDestVectorType(); 161 VectorType newDstType = trimLeadingOneDims(oldDstType); 162 163 int64_t srcDropCount = oldSrcRank - newSrcRank; 164 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank(); 165 if (srcDropCount == 0 && dstDropCount == 0) 166 return failure(); 167 168 // Trim leading one dimensions from both operands. 169 Location loc = insertOp.getLoc(); 170 171 Value newSrcVector = insertOp.getSource(); 172 if (oldSrcRank != 0) { 173 newSrcVector = rewriter.create<vector::ExtractOp>( 174 loc, insertOp.getSource(), splatZero(srcDropCount)); 175 } 176 Value newDstVector = rewriter.create<vector::ExtractOp>( 177 loc, insertOp.getDest(), splatZero(dstDropCount)); 178 179 // New position rank needs to be computed in two steps: (1) if destination 180 // type has leading unit dims, we also trim the position array accordingly, 181 // then (2) if source type also has leading unit dims, we need to append 182 // zeroes to the position array accordingly. 183 unsigned oldPosRank = insertOp.getNumIndices(); 184 unsigned newPosRank = std::max<int64_t>(0, oldPosRank - dstDropCount); 185 SmallVector<OpFoldResult> oldPosition = insertOp.getMixedPosition(); 186 SmallVector<OpFoldResult> newPosition = 187 llvm::to_vector(ArrayRef(oldPosition).take_back(newPosRank)); 188 newPosition.resize(newDstType.getRank() - newSrcRank, 189 rewriter.getI64IntegerAttr(0)); 190 191 auto newInsertOp = rewriter.create<vector::InsertOp>( 192 loc, newSrcVector, newDstVector, newPosition); 193 194 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType, 195 newInsertOp); 196 197 return success(); 198 } 199 }; 200 201 static Value dropUnitDimsFromMask(OpBuilder &b, Location loc, Value mask, 202 VectorType newType, AffineMap newMap, 203 VectorType oldMaskType) { 204 // Infer the type of the new mask from the new map. 205 VectorType newMaskType = inferTransferOpMaskType(newType, newMap); 206 207 // If the new mask is broadcastable to the old result type, we can safely 208 // use a `vector.extract` to get the new mask. Otherwise the best we can 209 // do is shape cast. 210 if (vector::isBroadcastableTo(newMaskType, oldMaskType) == 211 BroadcastableToResult::Success) { 212 int64_t dropDim = oldMaskType.getRank() - newMaskType.getRank(); 213 return b.create<vector::ExtractOp>(loc, mask, splatZero(dropDim)); 214 } 215 return b.create<vector::ShapeCastOp>(loc, newMaskType, mask); 216 } 217 218 // Turns vector.transfer_read on vector with leading 1 dimensions into 219 // vector.shape_cast followed by vector.transfer_read on vector without leading 220 // 1 dimensions. 221 struct CastAwayTransferReadLeadingOneDim 222 : public OpRewritePattern<vector::TransferReadOp> { 223 using OpRewritePattern::OpRewritePattern; 224 225 LogicalResult matchAndRewrite(vector::TransferReadOp read, 226 PatternRewriter &rewriter) const override { 227 // TODO(#78787): Not supported masked op yet. 228 if (cast<MaskableOpInterface>(read.getOperation()).isMasked()) 229 return failure(); 230 // TODO: support 0-d corner case. 231 if (read.getTransferRank() == 0) 232 return failure(); 233 234 auto shapedType = cast<ShapedType>(read.getSource().getType()); 235 if (shapedType.getElementType() != read.getVectorType().getElementType()) 236 return failure(); 237 238 VectorType oldType = read.getVectorType(); 239 VectorType newType = trimLeadingOneDims(oldType); 240 241 if (newType == oldType) 242 return failure(); 243 244 AffineMap oldMap = read.getPermutationMap(); 245 ArrayRef<AffineExpr> newResults = 246 oldMap.getResults().take_back(newType.getRank()); 247 AffineMap newMap = 248 AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults, 249 rewriter.getContext()); 250 251 ArrayAttr inBoundsAttr; 252 if (read.getInBounds()) 253 inBoundsAttr = rewriter.getArrayAttr( 254 read.getInBoundsAttr().getValue().take_back(newType.getRank())); 255 256 Value mask = Value(); 257 if (read.getMask()) { 258 VectorType maskType = read.getMaskType(); 259 mask = dropUnitDimsFromMask(rewriter, read.getLoc(), read.getMask(), 260 newType, newMap, maskType); 261 } 262 263 auto newRead = rewriter.create<vector::TransferReadOp>( 264 read.getLoc(), newType, read.getSource(), read.getIndices(), 265 AffineMapAttr::get(newMap), read.getPadding(), mask, inBoundsAttr); 266 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead); 267 268 return success(); 269 } 270 }; 271 272 // Turns vector.transfer_write on vector with leading 1 dimensions into 273 // vector.shape_cast followed by vector.transfer_write on vector without leading 274 // 1 dimensions. 275 struct CastAwayTransferWriteLeadingOneDim 276 : public OpRewritePattern<vector::TransferWriteOp> { 277 using OpRewritePattern::OpRewritePattern; 278 279 LogicalResult matchAndRewrite(vector::TransferWriteOp write, 280 PatternRewriter &rewriter) const override { 281 // TODO(#78787): Not supported masked op yet. 282 if (cast<MaskableOpInterface>(write.getOperation()).isMasked()) 283 return failure(); 284 // TODO: support 0-d corner case. 285 if (write.getTransferRank() == 0) 286 return failure(); 287 288 auto shapedType = dyn_cast<ShapedType>(write.getSource().getType()); 289 if (shapedType.getElementType() != write.getVectorType().getElementType()) 290 return failure(); 291 292 VectorType oldType = write.getVectorType(); 293 VectorType newType = trimLeadingOneDims(oldType); 294 if (newType == oldType) 295 return failure(); 296 int64_t dropDim = oldType.getRank() - newType.getRank(); 297 298 AffineMap oldMap = write.getPermutationMap(); 299 ArrayRef<AffineExpr> newResults = 300 oldMap.getResults().take_back(newType.getRank()); 301 AffineMap newMap = 302 AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults, 303 rewriter.getContext()); 304 305 ArrayAttr inBoundsAttr; 306 if (write.getInBounds()) 307 inBoundsAttr = rewriter.getArrayAttr( 308 write.getInBoundsAttr().getValue().take_back(newType.getRank())); 309 310 auto newVector = rewriter.create<vector::ExtractOp>( 311 write.getLoc(), write.getVector(), splatZero(dropDim)); 312 313 if (write.getMask()) { 314 VectorType maskType = write.getMaskType(); 315 Value newMask = dropUnitDimsFromMask( 316 rewriter, write.getLoc(), write.getMask(), newType, newMap, maskType); 317 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 318 write, newVector, write.getSource(), write.getIndices(), 319 AffineMapAttr::get(newMap), newMask, inBoundsAttr); 320 return success(); 321 } 322 323 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 324 write, newVector, write.getSource(), write.getIndices(), 325 AffineMapAttr::get(newMap), inBoundsAttr); 326 return success(); 327 } 328 }; 329 330 } // namespace 331 332 FailureOr<Value> 333 mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, 334 MaskingOpInterface maskingOp, 335 RewriterBase &rewriter) { 336 VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType()); 337 if (oldAccType == nullptr) 338 return failure(); 339 if (oldAccType.getRank() < 2) 340 return failure(); 341 if (oldAccType.getShape()[0] != 1) 342 return failure(); 343 // currently we support only dropping one dim but the pattern can be applied 344 // greedily to drop more. 345 int64_t dropDim = 1; 346 347 auto oldIndexingMaps = contractOp.getIndexingMapsArray(); 348 SmallVector<AffineMap> newIndexingMaps; 349 350 auto oldIteratorTypes = contractOp.getIteratorTypes(); 351 SmallVector<Attribute> newIteratorTypes; 352 353 int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0); 354 355 if (!isParallelIterator(oldIteratorTypes[dimToDrop])) 356 // only parallel type iterators can be dropped. 357 return failure(); 358 359 for (const auto &it : llvm::enumerate(oldIteratorTypes)) { 360 int64_t currDim = it.index(); 361 if (currDim == dimToDrop) 362 continue; 363 newIteratorTypes.push_back(it.value()); 364 } 365 366 SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(), 367 contractOp.getAcc()}; 368 SmallVector<Value> newOperands; 369 auto loc = contractOp.getLoc(); 370 371 for (const auto &it : llvm::enumerate(oldIndexingMaps)) { 372 // Check if the dim to be dropped exists as a leading dim in the operand 373 // if it does then we use vector.extract to drop it. 374 bool validExtract = false; 375 SmallVector<AffineExpr> results; 376 auto map = it.value(); 377 int64_t orginalZeroDim = it.value().getDimPosition(0); 378 if (orginalZeroDim != dimToDrop) { 379 // There are two reasons to be in this path, 1. We need to 380 // transpose the operand to make the dim to be dropped 381 // leading. 2. The dim to be dropped does not exist and in 382 // that case we dont want to add a unit transpose but we must 383 // check all the indices to make sure this is the case. 384 bool transposeNeeded = false; 385 SmallVector<int64_t> perm; 386 SmallVector<AffineExpr> transposeResults; 387 388 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 389 int64_t currDim = map.getDimPosition(i); 390 if (currDim == dimToDrop) { 391 transposeNeeded = true; 392 perm.insert(perm.begin(), i); 393 auto targetExpr = rewriter.getAffineDimExpr(currDim); 394 transposeResults.insert(transposeResults.begin(), targetExpr); 395 } else { 396 perm.push_back(i); 397 auto targetExpr = rewriter.getAffineDimExpr(currDim); 398 transposeResults.push_back(targetExpr); 399 } 400 } 401 402 // Checks if only the outer, unit dimensions (of size 1) are permuted. 403 // Such transposes do not materially effect the underlying vector and can 404 // be omitted. EG: perm [1, 0, 2] applied to vector<1x1x8xi32> 405 bool transposeNonOuterUnitDims = false; 406 auto operandShape = cast<ShapedType>(operands[it.index()].getType()); 407 for (auto [index, dim] : 408 llvm::enumerate(ArrayRef<int64_t>(perm).drop_back(1))) { 409 if (dim != static_cast<int64_t>(index) && 410 operandShape.getDimSize(index) != 1) { 411 transposeNonOuterUnitDims = true; 412 break; 413 } 414 } 415 416 // Do the transpose now if needed so that we can drop the 417 // correct dim using extract later. 418 if (transposeNeeded) { 419 map = AffineMap::get(map.getNumDims(), 0, transposeResults, 420 contractOp.getContext()); 421 if (transposeNonOuterUnitDims) { 422 operands[it.index()] = rewriter.createOrFold<vector::TransposeOp>( 423 loc, operands[it.index()], perm); 424 } 425 } 426 } 427 // We have taken care to have the dim to be dropped be 428 // the leading dim. If its still not leading that means it 429 // does not exist in this operand and hence we do not need 430 // an extract. 431 if (map.getDimPosition(0) == dimToDrop) 432 validExtract = true; 433 434 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 435 int64_t currDim = map.getDimPosition(i); 436 if (currDim == dimToDrop) 437 // This is the dim we are dropping. 438 continue; 439 auto targetExpr = rewriter.getAffineDimExpr( 440 currDim < dimToDrop ? currDim : currDim - 1); 441 results.push_back(targetExpr); 442 } 443 newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results, 444 contractOp.getContext())); 445 // Extract if its a valid extraction, otherwise use the operand 446 // without extraction. 447 newOperands.push_back( 448 validExtract ? rewriter.create<vector::ExtractOp>( 449 loc, operands[it.index()], splatZero(dropDim)) 450 : operands[it.index()]); 451 } 452 453 // Depending on whether this vector.contract is masked, the replacing Op 454 // should either be a new vector.contract Op or vector.mask Op. 455 Operation *newOp = rewriter.create<vector::ContractionOp>( 456 loc, newOperands[0], newOperands[1], newOperands[2], 457 rewriter.getAffineMapArrayAttr(newIndexingMaps), 458 rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind()); 459 460 if (maskingOp) { 461 auto newMask = rewriter.create<vector::ExtractOp>(loc, maskingOp.getMask(), 462 splatZero(dropDim)); 463 464 newOp = mlir::vector::maskOperation(rewriter, newOp, newMask); 465 } 466 467 return rewriter 468 .create<vector::BroadcastOp>(loc, contractOp->getResultTypes()[0], 469 newOp->getResults()[0]) 470 .getResult(); 471 } 472 473 namespace { 474 475 /// Turns vector.contract on vector with leading 1 dimensions into 476 /// vector.extract followed by vector.contract on vector without leading 477 /// 1 dimensions. Also performs transpose of lhs and rhs operands if required 478 /// prior to extract. 479 struct CastAwayContractionLeadingOneDim 480 : public MaskableOpRewritePattern<vector::ContractionOp> { 481 using MaskableOpRewritePattern::MaskableOpRewritePattern; 482 483 FailureOr<Value> 484 matchAndRewriteMaskableOp(vector::ContractionOp contractOp, 485 MaskingOpInterface maskingOp, 486 PatternRewriter &rewriter) const override { 487 return castAwayContractionLeadingOneDim(contractOp, maskingOp, rewriter); 488 } 489 }; 490 491 /// Looks at elementwise operations on vectors with at least one leading 492 /// dimension equal 1, e.g. vector<1x[4]x1xf32> (but not vector<2x[4]x1xf32>), 493 /// and cast aways the leading one dimensions (_plural_) and then broadcasts 494 /// the results. 495 /// 496 /// Example before: 497 /// %1 = arith.mulf %arg0, %arg1 : vector<1x4x1xf32> 498 /// Example after: 499 /// %2 = arith.mulf %0, %1 : vector<4x1xf32> 500 /// %3 = vector.broadcast %2 : vector<4x1xf32> to vector<1x4x1xf32> 501 /// 502 /// Does support scalable vectors. 503 class CastAwayElementwiseLeadingOneDim : public RewritePattern { 504 public: 505 CastAwayElementwiseLeadingOneDim(MLIRContext *context, 506 PatternBenefit benefit = 1) 507 : RewritePattern(MatchAnyOpTypeTag(), benefit, context) {} 508 509 LogicalResult matchAndRewrite(Operation *op, 510 PatternRewriter &rewriter) const override { 511 if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) 512 return failure(); 513 auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0]); 514 if (!vecType) 515 return failure(); 516 VectorType newVecType = trimLeadingOneDims(vecType); 517 if (newVecType == vecType) 518 return failure(); 519 int64_t dropDim = vecType.getRank() - newVecType.getRank(); 520 SmallVector<Value, 4> newOperands; 521 for (Value operand : op->getOperands()) { 522 if (auto opVecType = dyn_cast<VectorType>(operand.getType())) { 523 newOperands.push_back(rewriter.create<vector::ExtractOp>( 524 op->getLoc(), operand, splatZero(dropDim))); 525 } else { 526 newOperands.push_back(operand); 527 } 528 } 529 Operation *newOp = 530 rewriter.create(op->getLoc(), op->getName().getIdentifier(), 531 newOperands, newVecType, op->getAttrs()); 532 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType, 533 newOp->getResult(0)); 534 return success(); 535 } 536 }; 537 538 // Drops leading 1 dimensions from vector.constant_mask and inserts a 539 // vector.broadcast back to the original shape. 540 struct CastAwayConstantMaskLeadingOneDim 541 : public OpRewritePattern<vector::ConstantMaskOp> { 542 using OpRewritePattern::OpRewritePattern; 543 544 LogicalResult matchAndRewrite(vector::ConstantMaskOp mask, 545 PatternRewriter &rewriter) const override { 546 VectorType oldType = mask.getType(); 547 VectorType newType = trimLeadingOneDims(oldType); 548 549 if (newType == oldType) 550 return failure(); 551 552 int64_t dropDim = oldType.getRank() - newType.getRank(); 553 ArrayRef<int64_t> dimSizes = mask.getMaskDimSizes(); 554 555 // If any of the dropped unit dims has a size of `0`, the entire mask is a 556 // zero mask, else the unit dim has no effect on the mask. 557 int64_t flatLeadingSize = 558 std::accumulate(dimSizes.begin(), dimSizes.begin() + dropDim + 1, 559 static_cast<int64_t>(1), std::multiplies<int64_t>()); 560 SmallVector<int64_t> newDimSizes = {flatLeadingSize}; 561 newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end()); 562 563 auto newMask = rewriter.create<vector::ConstantMaskOp>( 564 mask.getLoc(), newType, newDimSizes); 565 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(mask, oldType, newMask); 566 return success(); 567 } 568 }; 569 570 } // namespace 571 572 void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns( 573 RewritePatternSet &patterns, PatternBenefit benefit) { 574 patterns 575 .add<CastAwayExtractStridedSliceLeadingOneDim, 576 CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim, 577 CastAwayConstantMaskLeadingOneDim, CastAwayTransferReadLeadingOneDim, 578 CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim, 579 CastAwayContractionLeadingOneDim>(patterns.getContext(), benefit); 580 populateShapeCastFoldingPatterns(patterns, benefit); 581 } 582