1 //===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns/pass to remove usage of unit-extent dimensions 10 // to specify broadcasting in favor of more canonical representation of the 11 // computation 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/Dialect/Linalg/Passes.h" 16 17 #include "mlir/Dialect/Affine/IR/AffineOps.h" 18 #include "mlir/Dialect/Arith/IR/Arith.h" 19 #include "mlir/Dialect/Linalg/IR/Linalg.h" 20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 21 #include "mlir/Dialect/Linalg/Utils/Utils.h" 22 #include "mlir/Dialect/MemRef/Transforms/Transforms.h" 23 #include "mlir/Dialect/Tensor/IR/Tensor.h" 24 #include "mlir/Dialect/Tensor/Transforms/Transforms.h" 25 #include "mlir/Dialect/Tensor/Utils/Utils.h" 26 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" 27 #include "mlir/IR/AffineExpr.h" 28 #include "mlir/IR/AffineMap.h" 29 #include "mlir/IR/BuiltinTypes.h" 30 #include "mlir/Transforms/FoldUtils.h" 31 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 32 #include "llvm/ADT/SetVector.h" 33 #include "llvm/Support/CommandLine.h" 34 #include "llvm/Support/Debug.h" 35 36 namespace mlir { 37 #define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMSPASS 38 #include "mlir/Dialect/Linalg/Passes.h.inc" 39 } // namespace mlir 40 41 #define DEBUG_TYPE "linalg-drop-unit-dims" 42 43 using namespace mlir; 44 using namespace mlir::linalg; 45 46 namespace { 47 /// Pattern to move init operands to ins when all the loops are parallel and 48 /// blockArgument corresponding to init is used in the region. This is a fix-up 49 /// when unit reduction dimensions are all folded away. In this context, it 50 /// becomes a elementwise generic op. E.g., it converts 51 /// 52 /// %0 = tensor.empty() : tensor<1x1xf32> 53 /// %1 = linalg.fill 54 /// ins(%cst : f32) 55 /// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32> 56 /// %2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>, 57 /// affine_map<(d0) -> (0, d0)>], 58 /// iterator_types = ["parallel"]} 59 /// ins(%arg0 : tensor<1x?x1x1xf32>) 60 /// outs(%1 : tensor<1x1xf32>) { 61 /// ^bb0(%in: f32, %out: f32): 62 /// %3 = arith.addf %in, %out : f32 63 /// linalg.yield %3 : f32 64 /// } -> tensor<1x1xf32> 65 /// 66 /// into 67 /// 68 /// %0 = tensor.empty() : tensor<1x1xf32> 69 /// %1 = linalg.fill 70 /// ins(%cst : f32) 71 /// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32> 72 /// %2 = tensor.empty() : tensor<1x1xf32> 73 /// %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>, 74 /// affine_map<(d0) -> (0, d0)>, 75 /// affine_map<(d0) -> (0, d0)>], 76 /// iterator_types = ["parallel"]} 77 /// ins(%arg0, %1 : tensor<1x?x1x1xf32>, tensor<1x1xf32>) 78 /// outs(%2 : tensor<1x1xf32>) { 79 /// ^bb0(%in: f32, %in_0: f32, %out: f32): 80 /// %4 = arith.addf %in, %in_0 : f32 81 /// linalg.yield %4 : f32 82 /// } -> tensor<1x1xf32> 83 struct MoveInitOperandsToInput : public OpRewritePattern<GenericOp> { 84 using OpRewritePattern<GenericOp>::OpRewritePattern; 85 LogicalResult matchAndRewrite(GenericOp genericOp, 86 PatternRewriter &rewriter) const override { 87 if (!genericOp.hasPureTensorSemantics()) 88 return failure(); 89 if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) 90 return failure(); 91 92 auto outputOperands = genericOp.getDpsInitsMutable(); 93 SetVector<OpOperand *> candidates; 94 for (OpOperand &op : outputOperands) { 95 if (genericOp.getMatchingBlockArgument(&op).use_empty()) 96 continue; 97 candidates.insert(&op); 98 } 99 100 if (candidates.empty()) 101 return failure(); 102 103 // Compute the modified indexing maps. 104 int64_t origNumInput = genericOp.getNumDpsInputs(); 105 SmallVector<Value> newInputOperands = genericOp.getDpsInputs(); 106 SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray(); 107 SmallVector<AffineMap> newIndexingMaps; 108 newIndexingMaps.append(indexingMaps.begin(), 109 std::next(indexingMaps.begin(), origNumInput)); 110 for (OpOperand *op : candidates) { 111 newInputOperands.push_back(op->get()); 112 newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(op)); 113 } 114 newIndexingMaps.append(std::next(indexingMaps.begin(), origNumInput), 115 indexingMaps.end()); 116 117 Location loc = genericOp.getLoc(); 118 SmallVector<Value> newOutputOperands = 119 llvm::to_vector(genericOp.getDpsInits()); 120 for (OpOperand *op : candidates) { 121 OpBuilder::InsertionGuard guard(rewriter); 122 rewriter.setInsertionPointAfterValue(op->get()); 123 auto elemType = cast<ShapedType>(op->get().getType()).getElementType(); 124 auto empty = rewriter.create<tensor::EmptyOp>( 125 loc, tensor::getMixedSizes(rewriter, loc, op->get()), elemType); 126 127 unsigned start = genericOp.getDpsInits().getBeginOperandIndex(); 128 newOutputOperands[op->getOperandNumber() - start] = empty.getResult(); 129 } 130 131 auto newOp = rewriter.create<GenericOp>( 132 loc, genericOp.getResultTypes(), newInputOperands, newOutputOperands, 133 newIndexingMaps, genericOp.getIteratorTypesArray(), 134 /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp)); 135 136 OpBuilder::InsertionGuard guard(rewriter); 137 Region ®ion = newOp.getRegion(); 138 Block *block = rewriter.createBlock(®ion); 139 IRMapping mapper; 140 for (auto bbarg : genericOp.getRegionInputArgs()) 141 mapper.map(bbarg, block->addArgument(bbarg.getType(), loc)); 142 143 for (OpOperand *op : candidates) { 144 BlockArgument bbarg = genericOp.getMatchingBlockArgument(op); 145 mapper.map(bbarg, block->addArgument(bbarg.getType(), loc)); 146 } 147 148 for (OpOperand &op : outputOperands) { 149 BlockArgument bbarg = genericOp.getMatchingBlockArgument(&op); 150 if (candidates.count(&op)) 151 block->addArgument(bbarg.getType(), loc); 152 else 153 mapper.map(bbarg, block->addArgument(bbarg.getType(), loc)); 154 } 155 156 for (auto &op : genericOp.getBody()->getOperations()) { 157 rewriter.clone(op, mapper); 158 } 159 rewriter.replaceOp(genericOp, newOp.getResults()); 160 161 return success(); 162 } 163 }; 164 } // namespace 165 166 //===---------------------------------------------------------------------===// 167 // Drop loops that are unit-extents within Linalg operations. 168 //===---------------------------------------------------------------------===// 169 170 /// Implements a pass that canonicalizes the uses of unit-extent dimensions for 171 /// broadcasting. For example, 172 /// 173 /// ```mlir 174 /// #accesses = [ 175 /// affine_map<(d0, d1) -> (0, d1)>, 176 /// affine_map<(d0, d1) -> (d0, 0)>, 177 /// affine_map<(d0, d1) -> (d0, d1)> 178 /// ] 179 /// 180 /// #trait = { 181 /// indexing_maps = #accesses, 182 /// iterator_types = ["parallel", "parallel"], 183 /// library_call = "some_external_fn" 184 /// } 185 /// 186 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> 187 /// tensor<5x5xf32> 188 /// { 189 /// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] : 190 /// tensor<5xf32> into tensor<1x5xf32> 191 /// %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] : 192 /// tensor<5xf32> into tensor<5x1xf32> 193 /// %2 = linalg.generic #trait %0, %1 { 194 /// ^bb0(%arg2: f32, %arg3: f32): 195 /// %3 = arith.addf %arg2, %arg3 : f32 196 /// linalg.yield %3 : f32 197 /// } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32> 198 /// return %2 : tensor<5x5xf32> 199 /// } 200 /// 201 /// would canonicalize to 202 /// 203 /// ```mlir 204 /// #accesses = [ 205 /// affine_map<(d0, d1) -> (d1)>, 206 /// affine_map<(d0, d1) -> (d0)>, 207 /// affine_map<(d0, d1) -> (d0, d1)> 208 /// ] 209 /// 210 /// #trait = { 211 /// indexing_maps = #accesses, 212 /// iterator_types = ["parallel", "parallel"], 213 /// library_call = "some_external_fn" 214 /// } 215 /// 216 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> 217 /// tensor<5x5xf32> 218 /// { 219 /// %0 = linalg.generic #trait %arg0, %arg1 { 220 /// ^bb0(%arg2: f32, %arg3: f32): 221 /// %3 = arith.addf %arg2, %arg3 : f32 222 /// linalg.yield %3 : f32 223 /// } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32> 224 /// return %0 : tensor<5x5xf32> 225 /// } 226 227 /// Update the index accesses of linalg operations having index semantics. 228 static void 229 replaceUnitDimIndexOps(GenericOp genericOp, 230 const llvm::SmallDenseSet<unsigned> &unitDims, 231 RewriterBase &rewriter) { 232 for (IndexOp indexOp : 233 llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) { 234 OpBuilder::InsertionGuard guard(rewriter); 235 rewriter.setInsertionPoint(indexOp); 236 if (unitDims.count(indexOp.getDim()) != 0) { 237 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(indexOp, 0); 238 } else { 239 // Update the dimension of the index operation if needed. 240 unsigned droppedDims = llvm::count_if( 241 unitDims, [&](unsigned dim) { return dim < indexOp.getDim(); }); 242 if (droppedDims != 0) 243 rewriter.replaceOpWithNewOp<IndexOp>(indexOp, 244 indexOp.getDim() - droppedDims); 245 } 246 } 247 } 248 249 /// Expand the given `value` so that the type matches the type of `origDest`. 250 /// The `reassociation` is used when `rankReductionStrategy` is set to 251 /// `RankReductionStrategy::ReassociativeReshape`. 252 static Value 253 expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest, 254 ArrayRef<ReassociationIndices> reassociation, 255 ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) { 256 // There are no results for memref outputs. 257 auto origResultType = cast<RankedTensorType>(origDest.getType()); 258 if (rankReductionStrategy == 259 ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) { 260 unsigned rank = origResultType.getRank(); 261 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0)); 262 SmallVector<OpFoldResult> sizes = 263 tensor::getMixedSizes(rewriter, loc, origDest); 264 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); 265 return rewriter.createOrFold<tensor::InsertSliceOp>( 266 loc, result, origDest, offsets, sizes, strides); 267 } 268 269 assert(rankReductionStrategy == 270 ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape && 271 "unknown rank reduction strategy"); 272 return rewriter 273 .create<tensor::ExpandShapeOp>(loc, origResultType, result, reassociation) 274 .getResult(); 275 } 276 277 /// Collapse the given `value` so that the type matches the type of 278 /// `origOutput`. The `reassociation` is used when `rankReductionStrategy` is 279 /// set to `RankReductionStrategy::ReassociativeReshape`. 280 static Value collapseValue( 281 RewriterBase &rewriter, Location loc, Value operand, 282 ArrayRef<int64_t> targetShape, ArrayRef<ReassociationIndices> reassociation, 283 ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) { 284 if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) { 285 if (rankReductionStrategy == 286 ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) { 287 FailureOr<Value> rankReducingExtract = 288 memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand, 289 targetShape); 290 assert(succeeded(rankReducingExtract) && "not a unit-extent collapse"); 291 return *rankReducingExtract; 292 } 293 294 assert( 295 rankReductionStrategy == 296 ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape && 297 "unknown rank reduction strategy"); 298 MemRefLayoutAttrInterface layout; 299 auto targetType = MemRefType::get(targetShape, memrefType.getElementType(), 300 layout, memrefType.getMemorySpace()); 301 return rewriter.create<memref::CollapseShapeOp>(loc, targetType, operand, 302 reassociation); 303 } 304 if (auto tensorType = dyn_cast<RankedTensorType>(operand.getType())) { 305 if (rankReductionStrategy == 306 ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) { 307 FailureOr<Value> rankReducingExtract = 308 tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand, 309 targetShape); 310 assert(succeeded(rankReducingExtract) && "not a unit-extent collapse"); 311 return *rankReducingExtract; 312 } 313 314 assert( 315 rankReductionStrategy == 316 ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape && 317 "unknown rank reduction strategy"); 318 auto targetType = 319 RankedTensorType::get(targetShape, tensorType.getElementType()); 320 return rewriter.create<tensor::CollapseShapeOp>(loc, targetType, operand, 321 reassociation); 322 } 323 llvm_unreachable("unsupported operand type"); 324 } 325 326 /// Compute the modified metadata for an operands of operation 327 /// whose unit dims are being dropped. Return the new indexing map 328 /// to use, the shape of the operand in the replacement op 329 /// and the `reassocation` to use to go from original operand shape 330 /// to modified operand shape. 331 struct UnitExtentReplacementInfo { 332 AffineMap indexMap; 333 SmallVector<ReassociationIndices> reassociation; 334 SmallVector<int64_t> targetShape; 335 }; 336 static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata( 337 MLIRContext *context, GenericOp genericOp, OpOperand *opOperand, 338 llvm::SmallDenseMap<unsigned, unsigned> &oldDimsToNewDimsMap, 339 ArrayRef<AffineExpr> dimReplacements) { 340 UnitExtentReplacementInfo info; 341 ReassociationIndices reassociationGroup; 342 SmallVector<AffineExpr> newIndexExprs; 343 AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); 344 ArrayRef<int64_t> operandShape = genericOp.getShape(opOperand); 345 ArrayRef<AffineExpr> exprs = indexingMap.getResults(); 346 347 auto isUnitDim = [&](unsigned dim) { 348 if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[dim])) { 349 unsigned oldPosition = dimExpr.getPosition(); 350 return !oldDimsToNewDimsMap.count(oldPosition) && 351 (operandShape[dim] == 1); 352 } 353 // Handle the other case where the shape is 1, and is accessed using a 354 // constant 0. 355 if (operandShape[dim] == 1) { 356 auto constAffineExpr = dyn_cast<AffineConstantExpr>(exprs[dim]); 357 return constAffineExpr && constAffineExpr.getValue() == 0; 358 } 359 return false; 360 }; 361 362 unsigned dim = 0; 363 while (dim < operandShape.size() && isUnitDim(dim)) 364 reassociationGroup.push_back(dim++); 365 while (dim < operandShape.size()) { 366 assert(!isUnitDim(dim) && "expected non unit-extent"); 367 reassociationGroup.push_back(dim); 368 AffineExpr newExpr = exprs[dim].replaceDims(dimReplacements); 369 newIndexExprs.push_back(newExpr); 370 info.targetShape.push_back(operandShape[dim]); 371 ++dim; 372 // Fold all following dimensions that are unit-extent. 373 while (dim < operandShape.size() && isUnitDim(dim)) { 374 reassociationGroup.push_back(dim++); 375 } 376 info.reassociation.push_back(reassociationGroup); 377 reassociationGroup.clear(); 378 } 379 info.indexMap = 380 AffineMap::get(oldDimsToNewDimsMap.size(), indexingMap.getNumSymbols(), 381 newIndexExprs, context); 382 return info; 383 } 384 385 FailureOr<DropUnitDimsResult> 386 linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, 387 const ControlDropUnitDims &options) { 388 SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray(); 389 if (indexingMaps.empty()) 390 return failure(); 391 392 // 1. Check if any of the iteration dimensions are unit-trip count. They will 393 // end up being unit-trip count if they are used to index into a unit-dim 394 // tensor/memref. 395 AffineMap invertedMap = 396 inversePermutation(concatAffineMaps(indexingMaps, rewriter.getContext())); 397 if (!invertedMap) { 398 return rewriter.notifyMatchFailure(genericOp, 399 "invalid indexing maps for operation"); 400 } 401 SmallVector<int64_t> dims = genericOp.getStaticShape(); 402 403 // 1a. Get the allowed list of dimensions to drop from the `options`. 404 SmallVector<unsigned> allowedUnitDims = options.controlFn(genericOp); 405 if (allowedUnitDims.empty()) { 406 return rewriter.notifyMatchFailure( 407 genericOp, "control function returns no allowed unit dims to prune"); 408 } 409 llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(), 410 allowedUnitDims.end()); 411 llvm::SmallDenseSet<unsigned> unitDims; 412 for (const auto &expr : enumerate(invertedMap.getResults())) { 413 if (AffineDimExpr dimExpr = dyn_cast<AffineDimExpr>(expr.value())) { 414 if (dims[dimExpr.getPosition()] == 1 && 415 unitDimsFilter.count(expr.index())) 416 unitDims.insert(expr.index()); 417 } 418 } 419 420 // 2. Compute the iterator types of the modified op by dropping the one-trip 421 // count loops. 422 SmallVector<utils::IteratorType> newIteratorTypes; 423 llvm::SmallDenseMap<unsigned, unsigned> oldDimToNewDimMap; 424 SmallVector<AffineExpr> dimReplacements; 425 unsigned newDims = 0; 426 for (auto [index, attr] : 427 llvm::enumerate(genericOp.getIteratorTypesArray())) { 428 if (unitDims.count(index)) { 429 dimReplacements.push_back( 430 getAffineConstantExpr(0, rewriter.getContext())); 431 } else { 432 newIteratorTypes.push_back(attr); 433 oldDimToNewDimMap[index] = newDims; 434 dimReplacements.push_back( 435 getAffineDimExpr(newDims, rewriter.getContext())); 436 newDims++; 437 } 438 } 439 440 // 3. For each of the operands, find the 441 // - modified affine map to use. 442 // - shape of the operands after the unit-dims are dropped. 443 // - the reassociation indices used to convert from the original 444 // operand type to modified operand (needed only when using reshapes 445 // for rank reduction strategy) 446 // Note that the indexing maps might need changing even if there are no 447 // unit dimensions that are dropped to handle cases where `0` is used to 448 // access a unit-extent tensor. Consider moving this out of this specific 449 // transformation as a stand-alone transformation. Kept here right now due 450 // to legacy. 451 SmallVector<AffineMap> newIndexingMaps; 452 SmallVector<SmallVector<ReassociationIndices>> reassociations; 453 SmallVector<SmallVector<int64_t>> targetShapes; 454 SmallVector<bool> collapsed; 455 auto hasCollapsibleType = [](OpOperand &operand) { 456 Type operandType = operand.get().getType(); 457 if (auto memrefOperandType = dyn_cast_or_null<MemRefType>(operandType)) { 458 return memrefOperandType.getLayout().isIdentity(); 459 } 460 if (auto tensorOperandType = dyn_cast<RankedTensorType>(operandType)) { 461 return tensorOperandType.getEncoding() == nullptr; 462 } 463 return false; 464 }; 465 for (OpOperand &opOperand : genericOp->getOpOperands()) { 466 auto indexingMap = genericOp.getMatchingIndexingMap(&opOperand); 467 ArrayRef<int64_t> shape = genericOp.getShape(&opOperand); 468 if (!hasCollapsibleType(opOperand)) { 469 AffineMap newIndexingMap = indexingMap.replaceDimsAndSymbols( 470 dimReplacements, ArrayRef<AffineExpr>{}, oldDimToNewDimMap.size(), 0); 471 newIndexingMaps.push_back(newIndexingMap); 472 targetShapes.push_back(llvm::to_vector(shape)); 473 collapsed.push_back(false); 474 reassociations.push_back({}); 475 continue; 476 } 477 auto replacementInfo = dropUnitExtentFromOperandMetadata( 478 rewriter.getContext(), genericOp, &opOperand, oldDimToNewDimMap, 479 dimReplacements); 480 reassociations.push_back(replacementInfo.reassociation); 481 newIndexingMaps.push_back(replacementInfo.indexMap); 482 targetShapes.push_back(replacementInfo.targetShape); 483 collapsed.push_back(!(replacementInfo.indexMap.getNumResults() == 484 indexingMap.getNumResults())); 485 } 486 487 // Abort if the indexing maps of the result operation are not invertible 488 // (i.e. not legal) or if no dimension was reduced. 489 if (newIndexingMaps == indexingMaps || 490 !inversePermutation( 491 concatAffineMaps(newIndexingMaps, rewriter.getContext()))) 492 return failure(); 493 494 Location loc = genericOp.getLoc(); 495 // 4. For each of the operands, collapse the operand to convert 496 // from original shape to shape in the modified operation if needed, 497 // either through use of reshapes or rank-reducing slices as 498 // specified in `options`. 499 SmallVector<Value> newOperands; 500 for (OpOperand &opOperand : genericOp->getOpOperands()) { 501 int64_t idx = opOperand.getOperandNumber(); 502 if (!collapsed[idx]) { 503 newOperands.push_back(opOperand.get()); 504 continue; 505 } 506 newOperands.push_back(collapseValue(rewriter, loc, opOperand.get(), 507 targetShapes[idx], reassociations[idx], 508 options.rankReductionStrategy)); 509 } 510 511 // 5. Create the `linalg.generic` operation with the new operands, 512 // indexing maps, iterator types and result types. 513 ArrayRef<Value> newInputs = 514 ArrayRef<Value>(newOperands).take_front(genericOp.getNumDpsInputs()); 515 ArrayRef<Value> newOutputs = 516 ArrayRef<Value>(newOperands).take_back(genericOp.getNumDpsInits()); 517 SmallVector<Type> resultTypes; 518 resultTypes.reserve(genericOp.getNumResults()); 519 for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults())) 520 resultTypes.push_back(newOutputs[i].getType()); 521 GenericOp replacementOp = 522 rewriter.create<GenericOp>(loc, resultTypes, newInputs, newOutputs, 523 newIndexingMaps, newIteratorTypes); 524 rewriter.inlineRegionBefore(genericOp.getRegion(), replacementOp.getRegion(), 525 replacementOp.getRegion().begin()); 526 // 5a. Replace `linalg.index` operations that refer to the dropped unit 527 // dimensions. 528 replaceUnitDimIndexOps(replacementOp, unitDims, rewriter); 529 530 // 6. If any result type changes, insert a reshape/slice to convert from the 531 // original type to the new type. 532 SmallVector<Value> resultReplacements; 533 for (auto [index, result] : llvm::enumerate(replacementOp.getResults())) { 534 unsigned opOperandIndex = index + replacementOp.getNumDpsInputs(); 535 Value origDest = genericOp.getDpsInitOperand(index)->get(); 536 if (!collapsed[opOperandIndex]) { 537 resultReplacements.push_back(result); 538 continue; 539 } 540 Value expandedValue = expandValue(rewriter, loc, result, origDest, 541 reassociations[opOperandIndex], 542 options.rankReductionStrategy); 543 resultReplacements.push_back(expandedValue); 544 } 545 546 return DropUnitDimsResult{replacementOp, resultReplacements}; 547 } 548 549 namespace { 550 struct DropUnitDims : public OpRewritePattern<GenericOp> { 551 DropUnitDims(MLIRContext *context, ControlDropUnitDims options = {}, 552 PatternBenefit benefit = 1) 553 : OpRewritePattern(context, benefit), options(std::move(options)) {} 554 555 LogicalResult matchAndRewrite(GenericOp genericOp, 556 PatternRewriter &rewriter) const override { 557 FailureOr<DropUnitDimsResult> result = 558 dropUnitDims(rewriter, genericOp, options); 559 if (failed(result)) { 560 return failure(); 561 } 562 rewriter.replaceOp(genericOp, result->replacements); 563 return success(); 564 } 565 566 private: 567 ControlDropUnitDims options; 568 }; 569 } // namespace 570 571 //===---------------------------------------------------------------------===// 572 // Drop dimensions that are unit-extents within tensor operations. 573 //===---------------------------------------------------------------------===// 574 575 namespace { 576 struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> { 577 DropPadUnitDims(MLIRContext *context, ControlDropUnitDims options = {}, 578 PatternBenefit benefit = 1) 579 : OpRewritePattern(context, benefit), options(std::move(options)) {} 580 581 LogicalResult matchAndRewrite(tensor::PadOp padOp, 582 PatternRewriter &rewriter) const override { 583 // 1a. Get the allowed list of dimensions to drop from the `options`. 584 SmallVector<unsigned> allowedUnitDims = options.controlFn(padOp); 585 if (allowedUnitDims.empty()) { 586 return rewriter.notifyMatchFailure( 587 padOp, "control function returns no allowed unit dims to prune"); 588 } 589 590 if (padOp.getSourceType().getEncoding()) { 591 return rewriter.notifyMatchFailure( 592 padOp, "cannot collapse dims of tensor with encoding"); 593 } 594 595 // Fail for non-constant padding values. The body of the pad could 596 // depend on the padding indices and/or properties of the padded 597 // tensor so for now we fail. 598 // TODO: Support non-constant padding values. 599 Value paddingVal = padOp.getConstantPaddingValue(); 600 if (!paddingVal) { 601 return rewriter.notifyMatchFailure( 602 padOp, "unimplemented: non-constant padding value"); 603 } 604 605 ArrayRef<int64_t> sourceShape = padOp.getSourceType().getShape(); 606 int64_t padRank = sourceShape.size(); 607 608 auto isStaticZero = [](OpFoldResult f) { 609 std::optional<int64_t> maybeInt = getConstantIntValue(f); 610 return maybeInt && *maybeInt == 0; 611 }; 612 613 llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(), 614 allowedUnitDims.end()); 615 llvm::SmallDenseSet<unsigned> unitDims; 616 SmallVector<int64_t> newShape; 617 SmallVector<OpFoldResult> newLowPad; 618 SmallVector<OpFoldResult> newHighPad; 619 for (const auto [dim, size, low, high] : 620 zip_equal(llvm::seq(static_cast<int64_t>(0), padRank), sourceShape, 621 padOp.getMixedLowPad(), padOp.getMixedHighPad())) { 622 if (unitDimsFilter.contains(dim) && size == 1 && isStaticZero(low) && 623 isStaticZero(high)) { 624 unitDims.insert(dim); 625 } else { 626 newShape.push_back(size); 627 newLowPad.push_back(low); 628 newHighPad.push_back(high); 629 } 630 } 631 632 if (unitDims.empty()) { 633 return rewriter.notifyMatchFailure(padOp, "no unit dims to collapse"); 634 } 635 636 ReassociationIndices reassociationGroup; 637 SmallVector<ReassociationIndices> reassociationMap; 638 int64_t dim = 0; 639 while (dim < padRank && unitDims.contains(dim)) 640 reassociationGroup.push_back(dim++); 641 while (dim < padRank) { 642 assert(!unitDims.contains(dim) && "expected non unit-extent"); 643 reassociationGroup.push_back(dim); 644 dim++; 645 // Fold all following dimensions that are unit-extent. 646 while (dim < padRank && unitDims.contains(dim)) 647 reassociationGroup.push_back(dim++); 648 reassociationMap.push_back(reassociationGroup); 649 reassociationGroup.clear(); 650 } 651 652 Value collapsedSource = 653 collapseValue(rewriter, padOp.getLoc(), padOp.getSource(), newShape, 654 reassociationMap, options.rankReductionStrategy); 655 656 auto newPadOp = rewriter.create<tensor::PadOp>( 657 padOp.getLoc(), /*result=*/Type(), collapsedSource, newLowPad, 658 newHighPad, paddingVal, padOp.getNofold()); 659 660 Value dest = padOp.getResult(); 661 if (options.rankReductionStrategy == 662 ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) { 663 SmallVector<OpFoldResult> expandedSizes; 664 int64_t numUnitDims = 0; 665 for (auto dim : llvm::seq(static_cast<int64_t>(0), padRank)) { 666 if (unitDims.contains(dim)) { 667 expandedSizes.push_back(rewriter.getIndexAttr(1)); 668 numUnitDims++; 669 continue; 670 } 671 expandedSizes.push_back(tensor::getMixedSize( 672 rewriter, padOp.getLoc(), newPadOp, dim - numUnitDims)); 673 } 674 dest = rewriter.create<tensor::EmptyOp>( 675 padOp.getLoc(), expandedSizes, 676 padOp.getResultType().getElementType()); 677 } 678 679 Value expandedValue = 680 expandValue(rewriter, padOp.getLoc(), newPadOp.getResult(), dest, 681 reassociationMap, options.rankReductionStrategy); 682 rewriter.replaceOp(padOp, expandedValue); 683 return success(); 684 } 685 686 private: 687 ControlDropUnitDims options; 688 }; 689 } // namespace 690 691 namespace { 692 /// Convert `extract_slice` operations to rank-reduced versions. 693 struct RankReducedExtractSliceOp 694 : public OpRewritePattern<tensor::ExtractSliceOp> { 695 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern; 696 697 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, 698 PatternRewriter &rewriter) const override { 699 RankedTensorType resultType = sliceOp.getType(); 700 SmallVector<OpFoldResult> targetShape; 701 for (auto size : resultType.getShape()) 702 targetShape.push_back(rewriter.getIndexAttr(size)); 703 auto reassociation = getReassociationMapForFoldingUnitDims(targetShape); 704 if (!reassociation || 705 reassociation->size() == static_cast<size_t>(resultType.getRank())) 706 return failure(); 707 708 SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets(); 709 SmallVector<OpFoldResult> strides = sliceOp.getMixedStrides(); 710 SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes(); 711 auto rankReducedType = cast<RankedTensorType>( 712 tensor::ExtractSliceOp::inferCanonicalRankReducedResultType( 713 reassociation->size(), sliceOp.getSourceType(), offsets, sizes, 714 strides)); 715 716 Location loc = sliceOp.getLoc(); 717 Value newSlice = rewriter.create<tensor::ExtractSliceOp>( 718 loc, rankReducedType, sliceOp.getSource(), offsets, sizes, strides); 719 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>( 720 sliceOp, resultType, newSlice, *reassociation); 721 return success(); 722 } 723 }; 724 725 /// Convert `insert_slice` operations to rank-reduced versions. 726 /// This patterns works with both InsertSliceOp and ParallelInsertSliceOp. 727 template <typename InsertOpTy> 728 struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> { 729 using OpRewritePattern<InsertOpTy>::OpRewritePattern; 730 731 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp, 732 PatternRewriter &rewriter) const override { 733 RankedTensorType sourceType = insertSliceOp.getSourceType(); 734 SmallVector<OpFoldResult> targetShape; 735 for (auto size : sourceType.getShape()) 736 targetShape.push_back(rewriter.getIndexAttr(size)); 737 auto reassociation = getReassociationMapForFoldingUnitDims(targetShape); 738 if (!reassociation || 739 reassociation->size() == static_cast<size_t>(sourceType.getRank())) 740 return failure(); 741 742 Location loc = insertSliceOp.getLoc(); 743 tensor::CollapseShapeOp reshapedSource; 744 { 745 OpBuilder::InsertionGuard g(rewriter); 746 // The only difference between InsertSliceOp and ParallelInsertSliceOp 747 // is the insertion point is just before the ParallelCombiningOp in the 748 // parallel case. 749 if (std::is_same<InsertOpTy, tensor::ParallelInsertSliceOp>::value) 750 rewriter.setInsertionPoint(insertSliceOp->getParentOp()); 751 reshapedSource = rewriter.create<tensor::CollapseShapeOp>( 752 loc, insertSliceOp.getSource(), *reassociation); 753 } 754 rewriter.replaceOpWithNewOp<InsertOpTy>( 755 insertSliceOp, reshapedSource, insertSliceOp.getDest(), 756 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), 757 insertSliceOp.getMixedStrides()); 758 return success(); 759 } 760 }; 761 } // namespace 762 763 /// Patterns that are used to canonicalize the use of unit-extent dims for 764 /// broadcasting. 765 static void 766 populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns, 767 ControlDropUnitDims &options) { 768 auto *context = patterns.getContext(); 769 patterns.add<DropUnitDims>(context, options); 770 patterns.add<DropPadUnitDims>(context, options); 771 // TODO: Patterns unrelated to unit dim folding should be factored out. 772 patterns.add<RankReducedExtractSliceOp, 773 RankReducedInsertSliceOp<tensor::InsertSliceOp>, 774 RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>( 775 context); 776 linalg::FillOp::getCanonicalizationPatterns(patterns, context); 777 tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context); 778 tensor::EmptyOp::getCanonicalizationPatterns(patterns, context); 779 tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context); 780 tensor::populateFoldTensorEmptyPatterns(patterns); 781 memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); 782 memref::populateResolveShapedTypeResultDimsPatterns(patterns); 783 } 784 785 static void 786 populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns, 787 ControlDropUnitDims &options) { 788 auto *context = patterns.getContext(); 789 patterns.add<DropUnitDims>(context, options); 790 patterns.add<DropPadUnitDims>(context, options); 791 // TODO: Patterns unrelated to unit dim folding should be factored out. 792 linalg::FillOp::getCanonicalizationPatterns(patterns, context); 793 tensor::EmptyOp::getCanonicalizationPatterns(patterns, context); 794 tensor::populateFoldTensorEmptyPatterns(patterns); 795 memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); 796 memref::populateResolveShapedTypeResultDimsPatterns(patterns); 797 } 798 799 void mlir::linalg::populateFoldUnitExtentDimsPatterns( 800 RewritePatternSet &patterns, linalg::ControlDropUnitDims &options) { 801 if (options.rankReductionStrategy == 802 linalg::ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) { 803 populateFoldUnitExtentDimsViaSlicesPatterns(patterns, options); 804 } else if (options.rankReductionStrategy == 805 linalg::ControlDropUnitDims::RankReductionStrategy:: 806 ReassociativeReshape) { 807 populateFoldUnitExtentDimsViaReshapesPatterns(patterns, options); 808 } 809 } 810 811 void mlir::linalg::populateMoveInitOperandsToInputPattern( 812 RewritePatternSet &patterns) { 813 patterns.add<MoveInitOperandsToInput>(patterns.getContext()); 814 } 815 816 namespace { 817 /// Pass that removes unit-extent dims within generic ops. 818 struct LinalgFoldUnitExtentDimsPass 819 : public impl::LinalgFoldUnitExtentDimsPassBase< 820 LinalgFoldUnitExtentDimsPass> { 821 using impl::LinalgFoldUnitExtentDimsPassBase< 822 LinalgFoldUnitExtentDimsPass>::LinalgFoldUnitExtentDimsPassBase; 823 void runOnOperation() override { 824 Operation *op = getOperation(); 825 MLIRContext *context = op->getContext(); 826 RewritePatternSet patterns(context); 827 ControlDropUnitDims options; 828 if (useRankReducingSlices) { 829 options.rankReductionStrategy = linalg::ControlDropUnitDims:: 830 RankReductionStrategy::ExtractInsertSlice; 831 } 832 linalg::populateFoldUnitExtentDimsPatterns(patterns, options); 833 populateMoveInitOperandsToInputPattern(patterns); 834 (void)applyPatternsGreedily(op, std::move(patterns)); 835 } 836 }; 837 838 } // namespace 839 840 namespace { 841 842 /// Returns reassociation indices for collapsing/expanding a 843 /// tensor of rank `rank` at position `pos`. 844 static SmallVector<ReassociationIndices> 845 getReassociationForReshapeAtDim(int64_t rank, int64_t pos) { 846 SmallVector<ReassociationIndices> reassociation(rank - 1, {0, 1}); 847 bool lastDim = pos == rank - 1; 848 if (rank > 2) { 849 for (int64_t i = 0; i < rank - 1; i++) { 850 if (i == pos || (lastDim && i == pos - 1)) 851 reassociation[i] = ReassociationIndices{i, i + 1}; 852 else if (i < pos) 853 reassociation[i] = ReassociationIndices{i}; 854 else 855 reassociation[i] = ReassociationIndices{i + 1}; 856 } 857 } 858 return reassociation; 859 } 860 861 /// Returns a collapsed `val` where the collapsing occurs at dim `pos`. 862 /// If `pos < 0`, then don't collapse. 863 static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val, 864 int64_t pos) { 865 if (pos < 0) 866 return val; 867 auto valType = cast<ShapedType>(val.getType()); 868 SmallVector<int64_t> collapsedShape(valType.getShape()); 869 collapsedShape.erase(collapsedShape.begin() + pos); 870 return collapseValue( 871 rewriter, val.getLoc(), val, collapsedShape, 872 getReassociationForReshapeAtDim(valType.getRank(), pos), 873 ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape); 874 } 875 876 /// Base class for all rank reduction patterns for contraction ops 877 /// with unit dimensions. All patterns should convert one named op 878 /// to another named op. Intended to reduce only one iteration space dim 879 /// at a time. 880 /// Reducing multiple dims will happen with recusive application of 881 /// pattern rewrites. 882 template <typename FromOpTy, typename ToOpTy> 883 struct RankReduceContractionOps : OpRewritePattern<FromOpTy> { 884 using OpRewritePattern<FromOpTy>::OpRewritePattern; 885 886 /// Collapse all collapsable operands. 887 SmallVector<Value> 888 collapseOperands(PatternRewriter &rewriter, ArrayRef<Value> operands, 889 ArrayRef<int64_t> operandCollapseDims) const { 890 assert(operandCollapseDims.size() == 3 && operands.size() == 3 && 891 "expected 3 operands and dims"); 892 return llvm::map_to_vector( 893 llvm::zip(operands, operandCollapseDims), [&](auto pair) { 894 return collapseSingletonDimAt(rewriter, std::get<0>(pair), 895 std::get<1>(pair)); 896 }); 897 } 898 899 /// Expand result tensor. 900 Value expandResult(PatternRewriter &rewriter, Value result, 901 RankedTensorType expandedType, int64_t dim) const { 902 return rewriter.create<tensor::ExpandShapeOp>( 903 result.getLoc(), expandedType, result, 904 getReassociationForReshapeAtDim(expandedType.getRank(), dim)); 905 } 906 907 LogicalResult matchAndRewrite(FromOpTy contractionOp, 908 PatternRewriter &rewriter) const override { 909 910 auto loc = contractionOp.getLoc(); 911 auto inputs = contractionOp.getDpsInputs(); 912 auto inits = contractionOp.getDpsInits(); 913 if (inputs.size() != 2 || inits.size() != 1) 914 return rewriter.notifyMatchFailure(contractionOp, 915 "expected 2 inputs and 1 init"); 916 auto lhs = inputs[0]; 917 auto rhs = inputs[1]; 918 auto init = inits[0]; 919 SmallVector<Value> operands{lhs, rhs, init}; 920 921 SmallVector<int64_t> operandUnitDims; 922 if (failed(getOperandUnitDims(contractionOp, operandUnitDims))) 923 return rewriter.notifyMatchFailure(contractionOp, 924 "no reducable dims found"); 925 926 SmallVector<Value> collapsedOperands = 927 collapseOperands(rewriter, operands, operandUnitDims); 928 Value collapsedLhs = collapsedOperands[0]; 929 Value collapsedRhs = collapsedOperands[1]; 930 Value collapsedInit = collapsedOperands[2]; 931 SmallVector<Type, 1> collapsedResultTy; 932 if (isa<RankedTensorType>(collapsedInit.getType())) 933 collapsedResultTy.push_back(collapsedInit.getType()); 934 auto collapsedOp = rewriter.create<ToOpTy>( 935 loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs}, 936 ValueRange{collapsedInit}); 937 for (auto attr : contractionOp->getAttrs()) { 938 if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName) 939 continue; 940 collapsedOp->setAttr(attr.getName(), attr.getValue()); 941 } 942 943 auto results = contractionOp.getResults(); 944 assert(results.size() < 2 && "expected at most one result"); 945 if (results.empty()) { 946 rewriter.replaceOp(contractionOp, collapsedOp); 947 } else { 948 rewriter.replaceOp( 949 contractionOp, 950 expandResult(rewriter, collapsedOp.getResultTensors()[0], 951 cast<RankedTensorType>(results[0].getType()), 952 operandUnitDims[2])); 953 } 954 955 return success(); 956 } 957 958 /// Populate `operandUnitDims` with 3 indices indicating the unit dim 959 /// for each operand that should be collapsed in this pattern. If an 960 /// operand shouldn't be collapsed, the index should be negative. 961 virtual LogicalResult 962 getOperandUnitDims(LinalgOp op, 963 SmallVectorImpl<int64_t> &operandUnitDims) const = 0; 964 }; 965 966 /// Patterns for unbatching batched contraction ops 967 template <typename FromOpTy, typename ToOpTy> 968 struct RankReduceToUnBatched : RankReduceContractionOps<FromOpTy, ToOpTy> { 969 using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps; 970 971 /// Look for unit batch dims to collapse. 972 LogicalResult 973 getOperandUnitDims(LinalgOp op, 974 SmallVectorImpl<int64_t> &operandUnitDims) const override { 975 FailureOr<ContractionDimensions> maybeContractionDims = 976 inferContractionDims(op); 977 if (failed(maybeContractionDims)) { 978 LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims"); 979 return failure(); 980 } 981 ContractionDimensions contractionDims = maybeContractionDims.value(); 982 983 if (contractionDims.batch.size() != 1) 984 return failure(); 985 auto batchDim = contractionDims.batch[0]; 986 SmallVector<std::pair<Value, unsigned>, 3> bOperands; 987 op.mapIterationSpaceDimToAllOperandDims(batchDim, bOperands); 988 if (bOperands.size() != 3 || llvm::any_of(bOperands, [](auto pair) { 989 return cast<ShapedType>(std::get<0>(pair).getType()) 990 .getShape()[std::get<1>(pair)] != 1; 991 })) { 992 LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found"); 993 return failure(); 994 } 995 996 operandUnitDims = SmallVector<int64_t>{std::get<1>(bOperands[0]), 997 std::get<1>(bOperands[1]), 998 std::get<1>(bOperands[2])}; 999 return success(); 1000 } 1001 }; 1002 1003 /// Patterns for reducing non-batch dimensions 1004 template <typename FromOpTy, typename ToOpTy> 1005 struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> { 1006 using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps; 1007 1008 /// Helper for determining whether the lhs/init or rhs/init are reduced. 1009 static bool constexpr reduceLeft = 1010 (std::is_same_v<FromOpTy, BatchMatmulOp> && 1011 std::is_same_v<ToOpTy, BatchVecmatOp>) || 1012 (std::is_same_v<FromOpTy, BatchMatmulTransposeAOp> && 1013 std::is_same_v<ToOpTy, BatchVecmatOp>) || 1014 (std::is_same_v<FromOpTy, MatmulOp> && 1015 std::is_same_v<ToOpTy, VecmatOp>) || 1016 (std::is_same_v<FromOpTy, MatmulTransposeAOp> && 1017 std::is_same_v<ToOpTy, VecmatOp>) || 1018 (std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>); 1019 1020 /// Look for non-batch spatial dims to collapse. 1021 LogicalResult 1022 getOperandUnitDims(LinalgOp op, 1023 SmallVectorImpl<int64_t> &operandUnitDims) const override { 1024 FailureOr<ContractionDimensions> maybeContractionDims = 1025 inferContractionDims(op); 1026 if (failed(maybeContractionDims)) { 1027 LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims"); 1028 return failure(); 1029 } 1030 ContractionDimensions contractionDims = maybeContractionDims.value(); 1031 1032 if constexpr (reduceLeft) { 1033 auto m = contractionDims.m[0]; 1034 SmallVector<std::pair<Value, unsigned>, 2> mOperands; 1035 op.mapIterationSpaceDimToAllOperandDims(m, mOperands); 1036 if (mOperands.size() != 2) 1037 return failure(); 1038 if (llvm::all_of(mOperands, [](auto pair) { 1039 return cast<ShapedType>(std::get<0>(pair).getType()) 1040 .getShape()[std::get<1>(pair)] == 1; 1041 })) { 1042 operandUnitDims = SmallVector<int64_t>{std::get<1>(mOperands[0]), -1, 1043 std::get<1>(mOperands[1])}; 1044 return success(); 1045 } 1046 } else { 1047 auto n = contractionDims.n[0]; 1048 SmallVector<std::pair<Value, unsigned>, 2> nOperands; 1049 op.mapIterationSpaceDimToAllOperandDims(n, nOperands); 1050 if (nOperands.size() != 2) 1051 return failure(); 1052 if (llvm::all_of(nOperands, [](auto pair) { 1053 return cast<ShapedType>(std::get<0>(pair).getType()) 1054 .getShape()[std::get<1>(pair)] == 1; 1055 })) { 1056 operandUnitDims = SmallVector<int64_t>{-1, std::get<1>(nOperands[0]), 1057 std::get<1>(nOperands[1])}; 1058 return success(); 1059 } 1060 } 1061 LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found"); 1062 return failure(); 1063 } 1064 }; 1065 1066 } // namespace 1067 1068 void mlir::linalg::populateContractionOpRankReducingPatterns( 1069 RewritePatternSet &patterns) { 1070 MLIRContext *context = patterns.getContext(); 1071 // Unbatching patterns for unit batch size 1072 patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context); 1073 patterns 1074 .add<RankReduceToUnBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>( 1075 context); 1076 patterns 1077 .add<RankReduceToUnBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>( 1078 context); 1079 patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context); 1080 patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context); 1081 1082 // Non-batch rank 1 reducing patterns 1083 patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context); 1084 patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context); 1085 patterns.add<RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context); 1086 patterns.add<RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context); 1087 // Batch rank 1 reducing patterns 1088 patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context); 1089 patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context); 1090 patterns.add<RankReduceMatmul<BatchMatmulTransposeAOp, BatchVecmatOp>>( 1091 context); 1092 patterns.add<RankReduceMatmul<BatchMatmulTransposeBOp, BatchMatvecOp>>( 1093 context); 1094 1095 // Non-batch rank 0 reducing patterns 1096 patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context); 1097 patterns.add<RankReduceMatmul<VecmatOp, DotOp>>(context); 1098 } 1099