1 //===- VectorDropLeadUnitDim.cpp - Conversion within the Vector dialect ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <numeric> 10 11 #include "mlir/Dialect/Arith/IR/Arith.h" 12 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 13 #include "mlir/Dialect/Vector/IR/VectorOps.h" 14 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 15 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 16 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/IR/TypeUtilities.h" 19 20 #define DEBUG_TYPE "vector-drop-unit-dim" 21 22 using namespace mlir; 23 using namespace mlir::vector; 24 25 // Trims leading one dimensions from `oldType` and returns the result type. 26 // Returns `vector<1xT>` if `oldType` only has one element. 27 static VectorType trimLeadingOneDims(VectorType oldType) { 28 ArrayRef<int64_t> oldShape = oldType.getShape(); 29 ArrayRef<int64_t> newShape = oldShape; 30 31 ArrayRef<bool> oldScalableDims = oldType.getScalableDims(); 32 ArrayRef<bool> newScalableDims = oldScalableDims; 33 34 while (!newShape.empty() && newShape.front() == 1 && 35 !newScalableDims.front()) { 36 newShape = newShape.drop_front(1); 37 newScalableDims = newScalableDims.drop_front(1); 38 } 39 40 // Make sure we have at least 1 dimension per vector type requirements. 41 if (newShape.empty()) { 42 newShape = oldShape.take_back(); 43 newScalableDims = oldType.getScalableDims().take_back(); 44 } 45 return VectorType::get(newShape, oldType.getElementType(), newScalableDims); 46 } 47 48 /// Return a smallVector of size `rank` containing all zeros. 49 static SmallVector<int64_t> splatZero(int64_t rank) { 50 return SmallVector<int64_t>(rank, 0); 51 } 52 namespace { 53 54 // Casts away leading one dimensions in vector.extract_strided_slice's vector 55 // input by inserting vector.broadcast. 56 struct CastAwayExtractStridedSliceLeadingOneDim 57 : public OpRewritePattern<vector::ExtractStridedSliceOp> { 58 using OpRewritePattern::OpRewritePattern; 59 60 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, 61 PatternRewriter &rewriter) const override { 62 // vector.extract_strided_slice requires the input and output vector to have 63 // the same rank. Here we drop leading one dimensions from the input vector 64 // type to make sure we don't cause mismatch. 65 VectorType oldSrcType = extractOp.getSourceVectorType(); 66 VectorType newSrcType = trimLeadingOneDims(oldSrcType); 67 68 if (newSrcType.getRank() == oldSrcType.getRank()) 69 return failure(); 70 71 int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank(); 72 73 VectorType oldDstType = extractOp.getType(); 74 VectorType newDstType = 75 VectorType::get(oldDstType.getShape().drop_front(dropCount), 76 oldDstType.getElementType()); 77 78 Location loc = extractOp.getLoc(); 79 80 Value newSrcVector = rewriter.create<vector::ExtractOp>( 81 loc, extractOp.getVector(), splatZero(dropCount)); 82 83 // The offsets/sizes/strides attribute can have a less number of elements 84 // than the input vector's rank: it is meant for the leading dimensions. 85 auto newOffsets = rewriter.getArrayAttr( 86 extractOp.getOffsets().getValue().drop_front(dropCount)); 87 auto newSizes = rewriter.getArrayAttr( 88 extractOp.getSizes().getValue().drop_front(dropCount)); 89 auto newStrides = rewriter.getArrayAttr( 90 extractOp.getStrides().getValue().drop_front(dropCount)); 91 92 auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>( 93 loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides); 94 95 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType, 96 newExtractOp); 97 98 return success(); 99 } 100 }; 101 102 // Casts away leading one dimensions in vector.insert_strided_slice's vector 103 // inputs by inserting vector.broadcast. 104 struct CastAwayInsertStridedSliceLeadingOneDim 105 : public OpRewritePattern<vector::InsertStridedSliceOp> { 106 using OpRewritePattern::OpRewritePattern; 107 108 LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp, 109 PatternRewriter &rewriter) const override { 110 VectorType oldSrcType = insertOp.getSourceVectorType(); 111 VectorType newSrcType = trimLeadingOneDims(oldSrcType); 112 VectorType oldDstType = insertOp.getDestVectorType(); 113 VectorType newDstType = trimLeadingOneDims(oldDstType); 114 115 int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank(); 116 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank(); 117 if (srcDropCount == 0 && dstDropCount == 0) 118 return failure(); 119 120 // Trim leading one dimensions from both operands. 121 Location loc = insertOp.getLoc(); 122 123 Value newSrcVector = rewriter.create<vector::ExtractOp>( 124 loc, insertOp.getSource(), splatZero(srcDropCount)); 125 Value newDstVector = rewriter.create<vector::ExtractOp>( 126 loc, insertOp.getDest(), splatZero(dstDropCount)); 127 128 auto newOffsets = rewriter.getArrayAttr( 129 insertOp.getOffsets().getValue().take_back(newDstType.getRank())); 130 auto newStrides = rewriter.getArrayAttr( 131 insertOp.getStrides().getValue().take_back(newSrcType.getRank())); 132 133 auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>( 134 loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides); 135 136 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType, 137 newInsertOp); 138 139 return success(); 140 } 141 }; 142 143 // Casts away leading one dimensions in vector.insert's vector inputs by 144 // inserting vector.broadcast. 145 struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> { 146 using OpRewritePattern::OpRewritePattern; 147 148 LogicalResult matchAndRewrite(vector::InsertOp insertOp, 149 PatternRewriter &rewriter) const override { 150 Type oldSrcType = insertOp.getSourceType(); 151 Type newSrcType = oldSrcType; 152 int64_t oldSrcRank = 0, newSrcRank = 0; 153 if (auto type = dyn_cast<VectorType>(oldSrcType)) { 154 newSrcType = trimLeadingOneDims(type); 155 oldSrcRank = type.getRank(); 156 newSrcRank = cast<VectorType>(newSrcType).getRank(); 157 } 158 159 VectorType oldDstType = insertOp.getDestVectorType(); 160 VectorType newDstType = trimLeadingOneDims(oldDstType); 161 162 int64_t srcDropCount = oldSrcRank - newSrcRank; 163 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank(); 164 if (srcDropCount == 0 && dstDropCount == 0) 165 return failure(); 166 167 // Trim leading one dimensions from both operands. 168 Location loc = insertOp.getLoc(); 169 170 Value newSrcVector = insertOp.getSource(); 171 if (oldSrcRank != 0) { 172 newSrcVector = rewriter.create<vector::ExtractOp>( 173 loc, insertOp.getSource(), splatZero(srcDropCount)); 174 } 175 Value newDstVector = rewriter.create<vector::ExtractOp>( 176 loc, insertOp.getDest(), splatZero(dstDropCount)); 177 178 // New position rank needs to be computed in two steps: (1) if destination 179 // type has leading unit dims, we also trim the position array accordingly, 180 // then (2) if source type also has leading unit dims, we need to append 181 // zeroes to the position array accordingly. 182 unsigned oldPosRank = insertOp.getNumIndices(); 183 unsigned newPosRank = std::max<int64_t>(0, oldPosRank - dstDropCount); 184 SmallVector<OpFoldResult> oldPosition = insertOp.getMixedPosition(); 185 SmallVector<OpFoldResult> newPosition = 186 llvm::to_vector(ArrayRef(oldPosition).take_back(newPosRank)); 187 newPosition.resize(newDstType.getRank() - newSrcRank, 188 rewriter.getI64IntegerAttr(0)); 189 190 auto newInsertOp = rewriter.create<vector::InsertOp>( 191 loc, newSrcVector, newDstVector, newPosition); 192 193 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType, 194 newInsertOp); 195 196 return success(); 197 } 198 }; 199 200 static Value dropUnitDimsFromMask(OpBuilder &b, Location loc, Value mask, 201 VectorType newType, AffineMap newMap, 202 VectorType oldMaskType) { 203 // Infer the type of the new mask from the new map. 204 VectorType newMaskType = inferTransferOpMaskType(newType, newMap); 205 206 // If the new mask is broadcastable to the old result type, we can safely 207 // use a `vector.extract` to get the new mask. Otherwise the best we can 208 // do is shape cast. 209 if (vector::isBroadcastableTo(newMaskType, oldMaskType) == 210 BroadcastableToResult::Success) { 211 int64_t dropDim = oldMaskType.getRank() - newMaskType.getRank(); 212 return b.create<vector::ExtractOp>(loc, mask, splatZero(dropDim)); 213 } 214 return b.create<vector::ShapeCastOp>(loc, newMaskType, mask); 215 } 216 217 // Turns vector.transfer_read on vector with leading 1 dimensions into 218 // vector.shape_cast followed by vector.transfer_read on vector without leading 219 // 1 dimensions. 220 struct CastAwayTransferReadLeadingOneDim 221 : public OpRewritePattern<vector::TransferReadOp> { 222 using OpRewritePattern::OpRewritePattern; 223 224 LogicalResult matchAndRewrite(vector::TransferReadOp read, 225 PatternRewriter &rewriter) const override { 226 // TODO(#78787): Not supported masked op yet. 227 if (cast<MaskableOpInterface>(read.getOperation()).isMasked()) 228 return failure(); 229 // TODO: support 0-d corner case. 230 if (read.getTransferRank() == 0) 231 return failure(); 232 233 auto shapedType = cast<ShapedType>(read.getSource().getType()); 234 if (shapedType.getElementType() != read.getVectorType().getElementType()) 235 return failure(); 236 237 VectorType oldType = read.getVectorType(); 238 VectorType newType = trimLeadingOneDims(oldType); 239 240 if (newType == oldType) 241 return failure(); 242 243 AffineMap oldMap = read.getPermutationMap(); 244 ArrayRef<AffineExpr> newResults = 245 oldMap.getResults().take_back(newType.getRank()); 246 AffineMap newMap = 247 AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults, 248 rewriter.getContext()); 249 250 ArrayAttr inBoundsAttr; 251 if (read.getInBounds()) 252 inBoundsAttr = rewriter.getArrayAttr( 253 read.getInBoundsAttr().getValue().take_back(newType.getRank())); 254 255 Value mask = Value(); 256 if (read.getMask()) { 257 VectorType maskType = read.getMaskType(); 258 mask = dropUnitDimsFromMask(rewriter, read.getLoc(), read.getMask(), 259 newType, newMap, maskType); 260 } 261 262 auto newRead = rewriter.create<vector::TransferReadOp>( 263 read.getLoc(), newType, read.getSource(), read.getIndices(), 264 AffineMapAttr::get(newMap), read.getPadding(), mask, inBoundsAttr); 265 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead); 266 267 return success(); 268 } 269 }; 270 271 // Turns vector.transfer_write on vector with leading 1 dimensions into 272 // vector.shape_cast followed by vector.transfer_write on vector without leading 273 // 1 dimensions. 274 struct CastAwayTransferWriteLeadingOneDim 275 : public OpRewritePattern<vector::TransferWriteOp> { 276 using OpRewritePattern::OpRewritePattern; 277 278 LogicalResult matchAndRewrite(vector::TransferWriteOp write, 279 PatternRewriter &rewriter) const override { 280 // TODO(#78787): Not supported masked op yet. 281 if (cast<MaskableOpInterface>(write.getOperation()).isMasked()) 282 return failure(); 283 // TODO: support 0-d corner case. 284 if (write.getTransferRank() == 0) 285 return failure(); 286 287 auto shapedType = dyn_cast<ShapedType>(write.getSource().getType()); 288 if (shapedType.getElementType() != write.getVectorType().getElementType()) 289 return failure(); 290 291 VectorType oldType = write.getVectorType(); 292 VectorType newType = trimLeadingOneDims(oldType); 293 if (newType == oldType) 294 return failure(); 295 int64_t dropDim = oldType.getRank() - newType.getRank(); 296 297 AffineMap oldMap = write.getPermutationMap(); 298 ArrayRef<AffineExpr> newResults = 299 oldMap.getResults().take_back(newType.getRank()); 300 AffineMap newMap = 301 AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults, 302 rewriter.getContext()); 303 304 ArrayAttr inBoundsAttr; 305 if (write.getInBounds()) 306 inBoundsAttr = rewriter.getArrayAttr( 307 write.getInBoundsAttr().getValue().take_back(newType.getRank())); 308 309 auto newVector = rewriter.create<vector::ExtractOp>( 310 write.getLoc(), write.getVector(), splatZero(dropDim)); 311 312 if (write.getMask()) { 313 VectorType maskType = write.getMaskType(); 314 Value newMask = dropUnitDimsFromMask( 315 rewriter, write.getLoc(), write.getMask(), newType, newMap, maskType); 316 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 317 write, newVector, write.getSource(), write.getIndices(), 318 AffineMapAttr::get(newMap), newMask, inBoundsAttr); 319 return success(); 320 } 321 322 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 323 write, newVector, write.getSource(), write.getIndices(), 324 AffineMapAttr::get(newMap), inBoundsAttr); 325 return success(); 326 } 327 }; 328 329 } // namespace 330 331 LogicalResult 332 mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, 333 RewriterBase &rewriter) { 334 // TODO(#78787): Not supported masked op yet. 335 if (cast<MaskableOpInterface>(contractOp.getOperation()).isMasked()) 336 return failure(); 337 VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType()); 338 if (oldAccType == nullptr) 339 return failure(); 340 if (oldAccType.getRank() < 2) 341 return failure(); 342 if (oldAccType.getShape()[0] != 1) 343 return failure(); 344 // currently we support only dropping one dim but the pattern can be applied 345 // greedily to drop more. 346 int64_t dropDim = 1; 347 348 auto oldIndexingMaps = contractOp.getIndexingMapsArray(); 349 SmallVector<AffineMap> newIndexingMaps; 350 351 auto oldIteratorTypes = contractOp.getIteratorTypes(); 352 SmallVector<Attribute> newIteratorTypes; 353 354 int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0); 355 356 if (!isParallelIterator(oldIteratorTypes[dimToDrop])) 357 // only parallel type iterators can be dropped. 358 return failure(); 359 360 for (const auto &it : llvm::enumerate(oldIteratorTypes)) { 361 int64_t currDim = it.index(); 362 if (currDim == dimToDrop) 363 continue; 364 newIteratorTypes.push_back(it.value()); 365 } 366 367 SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(), 368 contractOp.getAcc()}; 369 SmallVector<Value> newOperands; 370 371 for (const auto &it : llvm::enumerate(oldIndexingMaps)) { 372 // Check if the dim to be dropped exists as a leading dim in the operand 373 // if it does then we use vector.extract to drop it. 374 bool validExtract = false; 375 SmallVector<AffineExpr> results; 376 auto map = it.value(); 377 int64_t orginalZeroDim = it.value().getDimPosition(0); 378 if (orginalZeroDim != dimToDrop) { 379 // There are two reasons to be in this path, 1. We need to 380 // tranpose the operand to make the dim to be dropped 381 // leading. 2. The dim to be dropped does not exist and in 382 // that case we dont want to add a unit tranpose but we must 383 // check all the indices to make sure this is the case. 384 bool tranposeNeeded = false; 385 SmallVector<int64_t> perm; 386 SmallVector<AffineExpr> transposeResults; 387 388 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 389 int64_t currDim = map.getDimPosition(i); 390 if (currDim == dimToDrop) { 391 tranposeNeeded = true; 392 perm.insert(perm.begin(), i); 393 auto targetExpr = rewriter.getAffineDimExpr(currDim); 394 transposeResults.insert(transposeResults.begin(), targetExpr); 395 } else { 396 perm.push_back(i); 397 auto targetExpr = rewriter.getAffineDimExpr(currDim); 398 transposeResults.push_back(targetExpr); 399 } 400 } 401 // Do the tranpose now if needed so that we can drop the 402 // correct dim using extract later. 403 if (tranposeNeeded) { 404 map = AffineMap::get(map.getNumDims(), 0, transposeResults, 405 contractOp.getContext()); 406 operands[it.index()] = rewriter.create<vector::TransposeOp>( 407 contractOp.getLoc(), operands[it.index()], perm); 408 } 409 } 410 // We have taken care to have the dim to be dropped be 411 // the leading dim. If its still not leading that means it 412 // does not exist in this operand and hence we do not need 413 // an extract. 414 if (map.getDimPosition(0) == dimToDrop) 415 validExtract = true; 416 417 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 418 int64_t currDim = map.getDimPosition(i); 419 if (currDim == dimToDrop) 420 // This is the dim we are dropping. 421 continue; 422 auto targetExpr = rewriter.getAffineDimExpr( 423 currDim < dimToDrop ? currDim : currDim - 1); 424 results.push_back(targetExpr); 425 } 426 newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results, 427 contractOp.getContext())); 428 // Extract if its a valid extraction, otherwise use the operand 429 // without extraction. 430 newOperands.push_back( 431 validExtract ? rewriter.create<vector::ExtractOp>(contractOp.getLoc(), 432 operands[it.index()], 433 splatZero(dropDim)) 434 : operands[it.index()]); 435 } 436 auto newContractOp = rewriter.create<vector::ContractionOp>( 437 contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2], 438 rewriter.getAffineMapArrayAttr(newIndexingMaps), 439 rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind()); 440 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 441 contractOp, contractOp->getResultTypes()[0], newContractOp); 442 return success(); 443 } 444 445 namespace { 446 447 /// Turns vector.contract on vector with leading 1 dimensions into 448 /// vector.extract followed by vector.contract on vector without leading 449 /// 1 dimensions. Also performs tranpose of lhs and rhs operands if required 450 /// prior to extract. 451 struct CastAwayContractionLeadingOneDim 452 : public OpRewritePattern<vector::ContractionOp> { 453 using OpRewritePattern::OpRewritePattern; 454 455 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 456 PatternRewriter &rewriter) const override { 457 return castAwayContractionLeadingOneDim(contractOp, rewriter); 458 } 459 }; 460 461 /// Looks at elementwise operations on vectors with at least one leading 462 /// dimension equal 1, e.g. vector<1x[4]x1xf32> (but not vector<2x[4]x1xf32>), 463 /// and cast aways the leading one dimensions (_plural_) and then broadcasts 464 /// the results. 465 /// 466 /// Example before: 467 /// %1 = arith.mulf %arg0, %arg1 : vector<1x4x1xf32> 468 /// Example after: 469 /// %2 = arith.mulf %0, %1 : vector<4x1xf32> 470 /// %3 = vector.broadcast %2 : vector<4x1xf32> to vector<1x4x1xf32> 471 /// 472 /// Does support scalable vectors. 473 class CastAwayElementwiseLeadingOneDim : public RewritePattern { 474 public: 475 CastAwayElementwiseLeadingOneDim(MLIRContext *context, 476 PatternBenefit benefit = 1) 477 : RewritePattern(MatchAnyOpTypeTag(), benefit, context) {} 478 479 LogicalResult matchAndRewrite(Operation *op, 480 PatternRewriter &rewriter) const override { 481 if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) 482 return failure(); 483 auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0]); 484 if (!vecType) 485 return failure(); 486 VectorType newVecType = trimLeadingOneDims(vecType); 487 if (newVecType == vecType) 488 return failure(); 489 int64_t dropDim = vecType.getRank() - newVecType.getRank(); 490 SmallVector<Value, 4> newOperands; 491 for (Value operand : op->getOperands()) { 492 if (auto opVecType = dyn_cast<VectorType>(operand.getType())) { 493 newOperands.push_back(rewriter.create<vector::ExtractOp>( 494 op->getLoc(), operand, splatZero(dropDim))); 495 } else { 496 newOperands.push_back(operand); 497 } 498 } 499 Operation *newOp = 500 rewriter.create(op->getLoc(), op->getName().getIdentifier(), 501 newOperands, newVecType, op->getAttrs()); 502 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType, 503 newOp->getResult(0)); 504 return success(); 505 } 506 }; 507 508 // Drops leading 1 dimensions from vector.constant_mask and inserts a 509 // vector.broadcast back to the original shape. 510 struct CastAwayConstantMaskLeadingOneDim 511 : public OpRewritePattern<vector::ConstantMaskOp> { 512 using OpRewritePattern::OpRewritePattern; 513 514 LogicalResult matchAndRewrite(vector::ConstantMaskOp mask, 515 PatternRewriter &rewriter) const override { 516 VectorType oldType = mask.getType(); 517 VectorType newType = trimLeadingOneDims(oldType); 518 519 if (newType == oldType) 520 return failure(); 521 522 int64_t dropDim = oldType.getRank() - newType.getRank(); 523 SmallVector<int64_t> dimSizes; 524 for (auto attr : mask.getMaskDimSizes()) 525 dimSizes.push_back(llvm::cast<IntegerAttr>(attr).getInt()); 526 527 // If any of the dropped unit dims has a size of `0`, the entire mask is a 528 // zero mask, else the unit dim has no effect on the mask. 529 int64_t flatLeadingSize = 530 std::accumulate(dimSizes.begin(), dimSizes.begin() + dropDim + 1, 531 static_cast<int64_t>(1), std::multiplies<int64_t>()); 532 SmallVector<int64_t> newDimSizes({flatLeadingSize}); 533 newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end()); 534 535 auto newMask = rewriter.create<vector::ConstantMaskOp>( 536 mask.getLoc(), newType, rewriter.getI64ArrayAttr(newDimSizes)); 537 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(mask, oldType, newMask); 538 return success(); 539 } 540 }; 541 542 } // namespace 543 544 void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns( 545 RewritePatternSet &patterns, PatternBenefit benefit) { 546 patterns 547 .add<CastAwayExtractStridedSliceLeadingOneDim, 548 CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim, 549 CastAwayConstantMaskLeadingOneDim, CastAwayTransferReadLeadingOneDim, 550 CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim, 551 CastAwayContractionLeadingOneDim>(patterns.getContext(), benefit); 552 populateShapeCastFoldingPatterns(patterns, benefit); 553 } 554