1 //===- TilingInterfaceImpl.cpp - Implementation of TilingInterface -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" 10 11 #include "mlir/Analysis/SliceAnalysis.h" 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/Arith/IR/Arith.h" 14 #include "mlir/Dialect/Arith/Utils/Utils.h" 15 #include "mlir/Dialect/Linalg/IR/Linalg.h" 16 #include "mlir/Dialect/Linalg/Utils/Utils.h" 17 #include "mlir/Dialect/MemRef/IR/MemRef.h" 18 #include "mlir/Dialect/Tensor/IR/Tensor.h" 19 #include "mlir/Dialect/Utils/StaticValueUtils.h" 20 #include "mlir/Interfaces/TilingInterface.h" 21 #include <optional> 22 23 using namespace mlir; 24 using namespace mlir::linalg; 25 26 //===----------------------------------------------------------------------===// 27 // Utility methods for implementation of Tiling Interface for Linalg ops 28 //===----------------------------------------------------------------------===// 29 30 /// Return the SSA values that represent the data point accessed using a given 31 /// `indexingMap` for a given point in the iteration space represented by `ivs`. 32 static SmallVector<Value> getIndicesForAccess(OpBuilder &b, Location loc, 33 AffineMap indexingMap, 34 ValueRange ivs) { 35 SmallVector<Value> indices; 36 indices.reserve(indexingMap.getNumResults()); 37 for (auto result : indexingMap.getResults()) { 38 AffineMap m = AffineMap::get(indexingMap.getNumDims(), 39 indexingMap.getNumSymbols(), result); 40 Value v = b.create<affine::AffineApplyOp>(loc, m, ivs); 41 indices.push_back(v); 42 } 43 return indices; 44 } 45 46 /// Method to inline the payload of a `linalgOp` given the iteration space 47 /// point and values for the arguments of the payload. 48 static LogicalResult inlinePayload(OpBuilder &b, LinalgOp linalgOp, 49 ValueRange ivs, ValueRange argValues) { 50 Block *body = linalgOp.getBlock(); 51 IRMapping map; 52 map.map(body->getArguments(), argValues); 53 for (auto &op : body->without_terminator()) { 54 if (auto indexOp = dyn_cast<IndexOp>(&op)) { 55 map.map(indexOp.getResult(), ivs[indexOp.getDim()]); 56 continue; 57 } 58 b.clone(op, map); 59 } 60 61 Operation *terminator = body->getTerminator(); 62 Location loc = terminator->getLoc(); 63 for (const auto &operand : llvm::enumerate(terminator->getOperands())) { 64 Value toStore = map.lookupOrDefault(operand.value()); 65 OpOperand *storeInto = linalgOp.getDpsInitOperand(operand.index()); 66 auto indices = getIndicesForAccess( 67 b, loc, linalgOp.getMatchingIndexingMap(storeInto), ivs); 68 b.create<memref::StoreOp>( 69 loc, toStore, linalgOp.getDpsInitOperand(operand.index())->get(), 70 indices); 71 } 72 return success(); 73 } 74 75 //===----------------------------------------------------------------------===// 76 // External Model for implementing `TilingInterface` for `LinalgOp`s. 77 //===----------------------------------------------------------------------===// 78 79 namespace { 80 /// External model implementation of TilingInterface for LinalgOps. An external 81 /// model implementation is used for now till the use of `TilingInterface` is 82 /// on-par with the current Linalg tiling + fusion patterns. Once it is 83 /// maybe possible to move this into the op-definition (though there are 84 /// advantages to leaving it as an external model) 85 template <typename LinalgOpTy> 86 struct LinalgOpTilingInterface 87 : public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>, 88 LinalgOpTy> { 89 /// Return the loop iterator type. 90 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { 91 LinalgOpTy concreteOp = cast<LinalgOpTy>(op); 92 return concreteOp.getIteratorTypesArray(); 93 } 94 95 /// Return the iteration domain range. 96 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { 97 OpBuilder::InsertionGuard g(b); 98 b.setInsertionPoint(op); 99 Location loc = op->getLoc(); 100 LinalgOp linalgOp = cast<LinalgOp>(op); 101 SmallVector<OpFoldResult> allShapesSizes = 102 linalgOp.createFlatListOfOperandDims(b, loc); 103 AffineMap map = linalgOp.getShapesToLoopsMap(); 104 105 return llvm::to_vector( 106 llvm::map_range(map.getResults(), [&](AffineExpr loopExpr) { 107 OpFoldResult ofr = affine::makeComposedFoldedAffineApply( 108 b, loc, loopExpr, allShapesSizes); 109 return Range{b.getIndexAttr(0), ofr, b.getIndexAttr(1)}; 110 })); 111 } 112 113 /// Instantiate the tiled implementation of the operation. 114 FailureOr<TilingResult> 115 getTiledImplementation(Operation *op, OpBuilder &b, 116 ArrayRef<OpFoldResult> offsets, 117 ArrayRef<OpFoldResult> sizes) const { 118 // Leave the `sizeBounds` value empty. That is only needed when the `sizes` 119 // specified could lead to out of bounds accesses. 120 Location loc = op->getLoc(); 121 LinalgOp linalgOp = cast<LinalgOp>(op); 122 SmallVector<Value> valuesToTile = linalgOp->getOperands(); 123 SmallVector<Value> tiledOperands = makeTiledShapes( 124 b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true); 125 SmallVector<Operation *> generatedSlices = llvm::map_to_vector( 126 llvm::make_filter_range( 127 tiledOperands, 128 [](Value v) -> bool { 129 return isa_and_nonnull<tensor::ExtractSliceOp, memref::SubViewOp>( 130 v.getDefiningOp()); 131 }), 132 [](Value v) -> Operation * { return v.getDefiningOp(); }); 133 134 SmallVector<Type> resultTensorTypes = 135 getTensorOutputTypes(linalgOp, tiledOperands); 136 137 Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands); 138 offsetIndices(b, cast<LinalgOp>(tiledOp), offsets); 139 140 return TilingResult{ 141 {tiledOp}, SmallVector<Value>(tiledOp->getResults()), generatedSlices}; 142 } 143 144 /// Utility to fetch the offsets and sizes when applied as per the indexing 145 /// map of the linalg op. This helps in fusing the linalg op as a consumer of 146 /// a given slice op. 147 void 148 getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap, 149 ArrayRef<OpFoldResult> offsets, 150 ArrayRef<OpFoldResult> sizes, 151 SmallVectorImpl<OpFoldResult> &mappedOffsets, 152 SmallVectorImpl<OpFoldResult> &mappedSizes) const { 153 unsigned numLoops = linalgOp.getNumLoops(); 154 auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation()); 155 mappedOffsets.resize(numLoops); 156 mappedSizes.resize(numLoops); 157 if (!indexingMap.isPermutation()) { 158 SmallVector<Range> iterationDomain = 159 tilingInterfaceOp.getIterationDomain(b); 160 for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) { 161 mappedOffsets[index] = value.offset; 162 mappedSizes[index] = value.size; 163 } 164 } 165 for (const auto &&[index, value] : 166 llvm::enumerate(indexingMap.getResults())) { 167 unsigned dimPosition = cast<AffineDimExpr>(value).getPosition(); 168 mappedOffsets[dimPosition] = offsets[index]; 169 mappedSizes[dimPosition] = sizes[index]; 170 } 171 } 172 173 /// Method to return the position of the result tile computed by the tiled 174 /// operation. 175 LogicalResult getIterationDomainTileFromOperandTile( 176 Operation *op, OpBuilder &b, unsigned operandNumber, 177 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, 178 SmallVectorImpl<OpFoldResult> &iterDomainOffsets, 179 SmallVectorImpl<OpFoldResult> &iterDomainSizes) const { 180 auto linalgOp = cast<LinalgOp>(op); 181 182 // Check that the indexing map used for the operand is a projected 183 // permutation. This could be relaxed with a more general approach that can 184 // map the offsets and sizes from the operand to iteration space tiles 185 // (filling in full extent for dimensions not used to access the result). 186 AffineMap indexingMap = 187 linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber)); 188 if (!indexingMap.isProjectedPermutation()) { 189 return op->emitError() 190 << "unhandled get iter domain position when operand is not " 191 "accessed using a permuted projection"; 192 } 193 194 getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes, 195 iterDomainOffsets, iterDomainSizes); 196 return success(); 197 } 198 199 /// Return the details of the output tile generated by the tiled 200 /// implementation. 201 LogicalResult 202 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, 203 ArrayRef<OpFoldResult> offsets, 204 ArrayRef<OpFoldResult> sizes, 205 SmallVector<OpFoldResult> &resultOffsets, 206 SmallVector<OpFoldResult> &resultSizes) const { 207 Location loc = op->getLoc(); 208 LinalgOp linalgOp = cast<LinalgOp>(op); 209 210 AffineExpr d0; 211 bindDims(b.getContext(), d0); 212 SmallVector<OpFoldResult> subShapeSizes = 213 llvm::to_vector(llvm::map_range(sizes, [&](OpFoldResult ofr) { 214 return affine::makeComposedFoldedAffineApply(b, loc, d0 - 1, ofr); 215 })); 216 217 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber); 218 SliceParameters sliceParams = computeSliceParameters( 219 b, loc, outOperand->get(), sizes, 220 linalgOp.getMatchingIndexingMap(outOperand), offsets, 221 /*ubs*/ {}, subShapeSizes, true); 222 resultOffsets = sliceParams.offsets; 223 resultSizes = sliceParams.sizes; 224 return success(); 225 } 226 227 LogicalResult getIterationDomainTileFromResultTile( 228 Operation *op, OpBuilder &b, unsigned resultNumber, 229 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, 230 SmallVectorImpl<OpFoldResult> &iterDomainOffsets, 231 SmallVectorImpl<OpFoldResult> &iterDomainSizes) const { 232 auto linalgOp = cast<LinalgOp>(op); 233 234 // Check that the indexing map used for the output is a projected 235 // permutation. This could be relaxed with a more general approach that can 236 // map the offsets and sizes from the result to iteration space tiles 237 // (filling in full extent for dimensions not used to access the result). 238 AffineMap indexingMap = 239 linalgOp.getIndexingMapMatchingResult(op->getResult(resultNumber)); 240 if (!indexingMap.isProjectedPermutation()) { 241 return op->emitOpError( 242 "unhandled tiled implementation generation when result is not " 243 "accessed using a permuted projection"); 244 } 245 246 getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes, 247 iterDomainOffsets, iterDomainSizes); 248 return success(); 249 } 250 251 FailureOr<TilingResult> 252 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, 253 ArrayRef<OpFoldResult> offsets, 254 ArrayRef<OpFoldResult> sizes) const { 255 SmallVector<OpFoldResult> mappedOffsets, mappedSizes; 256 if (failed(getIterationDomainTileFromResultTile( 257 op, b, resultNumber, offsets, sizes, mappedOffsets, mappedSizes))) { 258 return failure(); 259 } 260 auto tilingInterfaceOp = cast<TilingInterface>(op); 261 FailureOr<TilingResult> tilingResult = 262 tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes); 263 264 if (failed(tilingResult)) 265 return failure(); 266 267 if (tilingResult->tiledOps.size() != 1) 268 return op->emitOpError("failed to generate tiled implementation"); 269 270 return TilingResult{ 271 tilingResult->tiledOps, 272 SmallVector<Value>{tilingResult->tiledValues[resultNumber]}, 273 tilingResult->generatedSlices}; 274 } 275 276 /// Method to generate the tiled implementation of an operation from the tile 277 /// of the operand. 278 FailureOr<TilingResult> getTiledImplementationFromOperandTile( 279 Operation *op, OpBuilder &b, unsigned operandNumber, 280 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const { 281 SmallVector<OpFoldResult> mappedOffsets, mappedSizes; 282 if (failed(getIterationDomainTileFromOperandTile( 283 op, b, operandNumber, offsets, sizes, mappedOffsets, 284 mappedSizes))) { 285 return failure(); 286 } 287 return getTiledImplementation(op, b, mappedOffsets, mappedSizes); 288 } 289 290 LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder, 291 Location loc, 292 ValueRange ivs) const { 293 auto linalgOp = cast<LinalgOp>(op); 294 if (!linalgOp.hasPureBufferSemantics()) 295 return op->emitOpError("expected operation to have buffer semantics"); 296 297 SmallVector<Value> indexedValues; 298 indexedValues.reserve(linalgOp->getNumOperands()); 299 Location linalgOpLoc = op->getLoc(); 300 /// Load the data corresponding to the block arguments that 301 /// represent input operands. 302 for (OpOperand &operand : linalgOp->getOpOperands()) { 303 if (!linalgOp.payloadUsesValueFromOperand(&operand)) { 304 indexedValues.push_back(nullptr); 305 continue; 306 } 307 if (linalgOp.isScalar(&operand)) { 308 indexedValues.push_back(operand.get()); 309 continue; 310 } 311 SmallVector<Value> indices = getIndicesForAccess( 312 builder, linalgOpLoc, linalgOp.getMatchingIndexingMap(&operand), ivs); 313 Value load = 314 builder.create<memref::LoadOp>(linalgOpLoc, operand.get(), indices); 315 indexedValues.push_back(load); 316 } 317 318 /// Inline the op payload and store the result. 319 return inlinePayload(builder, linalgOp, ivs, indexedValues); 320 } 321 }; 322 323 //===----------------------------------------------------------------------===// 324 // External Model for implementing `PartialReductionInterface` for `LinalgOp`s. 325 //===----------------------------------------------------------------------===// 326 327 /// Return an AffineMap for a partial result for the given result number, 328 /// assuming the partial tiling strategy is outer-reduction loop + 329 /// inner-parallel tile. The returned AffineMap can be used as the replacement 330 /// AffineMap for the inner-parallel tile linalg op for the given result number. 331 /// 332 /// The new AffineMap is the old AffineMap with reduction dimensions appended 333 /// at end. 334 static AffineMap getPartialResultAffineMap(LinalgOp linalgOp, 335 ArrayRef<int> reductionDims, 336 unsigned resultNumber) { 337 AffineMap map = 338 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(resultNumber)); 339 for (int redPos : reductionDims) { 340 map = map.insertResult(getAffineDimExpr(redPos, linalgOp.getContext()), 341 map.getNumResults()); 342 } 343 return map; 344 } 345 346 /// External model implementation of PartialReductionInterface for 347 /// LinalgOps. 348 template <typename LinalgOpTy> 349 struct LinalgOpPartialReductionInterface 350 : public PartialReductionOpInterface::ExternalModel< 351 LinalgOpPartialReductionInterface<LinalgOpTy>, LinalgOpTy> { 352 FailureOr<SmallVector<Value>> generateInitialTensorForPartialReduction( 353 Operation *op, OpBuilder &b, Location loc, ArrayRef<OpFoldResult> sizes, 354 ArrayRef<int> reductionDims) const { 355 auto linalgOp = cast<LinalgOp>(op); 356 OpBuilder::InsertionGuard guard(b); 357 358 if (linalgOp.hasPureBufferSemantics()) 359 return op->emitOpError("expected operation to have tensor semantics"); 360 361 // LinalgOp implements TilingInterface. 362 auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation()); 363 SmallVector<OpFoldResult> shape = 364 llvm::map_to_vector(tilingInterfaceOp.getIterationDomain(b), 365 [](Range x) { return x.size; }); 366 367 SmallVector<OpFoldResult> tiledShape; 368 for (auto [tileSize, dimSize] : llvm::zip_equal(sizes, shape)) { 369 if (isZeroIndex(tileSize)) { 370 tiledShape.push_back(dimSize); 371 } else { 372 tiledShape.push_back(tileSize); 373 } 374 } 375 376 SmallVector<Value> inits; 377 for (int initIdx = 0, e = linalgOp.getNumDpsInits(); initIdx < e; 378 ++initIdx) { 379 SmallVector<Operation *, 4> combinerOps; 380 if (!matchReduction(linalgOp.getRegionOutputArgs(), initIdx, 381 combinerOps) || 382 combinerOps.size() != 1) 383 return op->emitOpError("Failed to anaysis the reduction operation."); 384 385 Operation *reductionOp = combinerOps[0]; 386 std::optional<TypedAttr> identity = arith::getNeutralElement(reductionOp); 387 if (!identity.has_value()) 388 return op->emitOpError( 389 "Failed to get an identity value for the reduction operation."); 390 391 // Append the new partial result dimensions. 392 AffineMap partialMap = 393 getPartialResultAffineMap(linalgOp, reductionDims, initIdx); 394 SmallVector<OpFoldResult> partialResultShape; 395 for (AffineExpr dimExpr : partialMap.getResults()) { 396 auto dim = cast<AffineDimExpr>(dimExpr); 397 partialResultShape.push_back(tiledShape[dim.getPosition()]); 398 } 399 400 Type elType = 401 getElementTypeOrSelf(linalgOp->getResult(initIdx).getType()); 402 Value emptyTensor = 403 b.create<tensor::EmptyOp>(loc, partialResultShape, elType); 404 Value constantOp = b.create<arith::ConstantOp>(loc, *identity); 405 auto identityTensor = 406 b.create<linalg::FillOp>(loc, constantOp, emptyTensor); 407 inits.push_back(identityTensor.getResult(0)); 408 } 409 410 return inits; 411 } 412 413 FailureOr<TilingResult> 414 tileToPartialReduction(Operation *op, OpBuilder &b, Location loc, 415 ValueRange init, ArrayRef<OpFoldResult> offsets, 416 ArrayRef<OpFoldResult> sizes, 417 ArrayRef<int> reductionDims) const { 418 OpBuilder::InsertionGuard guard(b); 419 auto linalgOp = cast<LinalgOp>(op); 420 421 // Step 1. Extend init maps to have reduction dimension dims, since we 422 // are converting them to parallel dimensions. 423 SmallVector<AffineMap> newInitMaps; 424 newInitMaps.reserve(linalgOp.getNumDpsInits()); 425 for (int idx : llvm::seq<int>(0, linalgOp.getNumDpsInits())) { 426 // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace 427 // this with a for range loop when we have it. 428 AffineMap newMap = 429 getPartialResultAffineMap(linalgOp, reductionDims, idx); 430 newInitMaps.push_back(newMap); 431 } 432 433 // Step 2a: Extract a slice of the input operands. 434 SmallVector<Value> tiledInputs = makeTiledShapes( 435 b, loc, linalgOp, linalgOp.getDpsInputs(), offsets, sizes, {}, true); 436 SmallVector<Operation *> generatedSlices = llvm::map_to_vector( 437 llvm::make_filter_range( 438 tiledInputs, [](Value v) -> bool { return v.getDefiningOp(); }), 439 [](Value v) -> Operation * { return v.getDefiningOp(); }); 440 441 // Step 2b: Extract a slice of the init operands. 442 SmallVector<Value, 1> tiledInits; 443 for (auto [valueMap, valueToTile] : llvm::zip_equal(newInitMaps, init)) { 444 int64_t initRank = valueMap.getNumResults(); 445 SmallVector<OpFoldResult> initOffset(initRank, b.getIndexAttr(0)); 446 SmallVector<OpFoldResult> initStride(initRank, b.getIndexAttr(1)); 447 SmallVector<OpFoldResult> initSizes; 448 for (AffineExpr dimExpr : valueMap.getResults()) { 449 auto dim = cast<AffineDimExpr>(dimExpr); 450 initSizes.push_back(sizes[dim.getPosition()]); 451 } 452 // TODO: Use SubsetExtractOpInterface here once available. 453 auto extractSlice = b.create<tensor::ExtractSliceOp>( 454 loc, valueToTile, initOffset, initSizes, initStride); 455 tiledInits.push_back(extractSlice); 456 generatedSlices.push_back(extractSlice); 457 } 458 459 // Update the indexing maps. 460 SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray(); 461 // Change the init maps. 462 for (int idx : llvm::seq<int>(0, linalgOp.getNumDpsInits())) { 463 // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace 464 // this with a for range loop when we have it. 465 OpOperand *initOperand = linalgOp.getDpsInitOperand(idx); 466 int64_t mapIdx = linalgOp.getIndexingMapIndex(initOperand); 467 newMaps[mapIdx] = newInitMaps[idx]; 468 } 469 470 // Step 3. Change the reduction dim iterator types. 471 SmallVector<utils::IteratorType> newIteratorTypes = 472 linalgOp.getIteratorTypesArray(); 473 for (int dim : reductionDims) 474 newIteratorTypes[dim] = utils::IteratorType::parallel; 475 476 // Step 4. Create the new generic op. 477 auto genericOp = 478 b.create<GenericOp>(loc, ValueRange(tiledInits).getTypes(), tiledInputs, 479 tiledInits, newMaps, newIteratorTypes); 480 IRMapping mapping; 481 op->getRegion(0).cloneInto(&genericOp.getRegion(), 482 genericOp.getRegion().begin(), mapping); 483 return TilingResult{ 484 {genericOp.getOperation()}, 485 llvm::map_to_vector(genericOp->getResults(), 486 [](OpResult r) -> Value { return r; }), 487 generatedSlices}; 488 } 489 490 FailureOr<MergeResult> mergeReductions(Operation *op, OpBuilder &b, 491 Location loc, ValueRange partialReduce, 492 ArrayRef<int> reductionDims) const { 493 auto linalgOp = cast<LinalgOp>(op); 494 495 // Permute the reduction dims as permuted by the partial result map. 496 497 int64_t numInits = linalgOp.getNumDpsInits(); 498 SmallVector<Operation *> mergeOperations; 499 SmallVector<Value> replacements; 500 for (int idx : llvm::seq(numInits)) { 501 // linalg.reduce's iteration space is the tiled result's iteration space 502 // (and not the tiled operation's iteration space). To account for this, 503 // permute the reduction dimensions based on the partial result map of the 504 // tiled result. 505 AffineMap partialMap = 506 getPartialResultAffineMap(linalgOp, reductionDims, idx); 507 SmallVector<int64_t> partialReductionDims; 508 for (auto [resultNum, dimExpr] : 509 llvm::enumerate(partialMap.getResults())) { 510 unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition(); 511 if (llvm::find(reductionDims, dim) != reductionDims.end()) { 512 partialReductionDims.push_back(resultNum); 513 } 514 } 515 516 Value partialResult = partialReduce[idx]; 517 Value init = linalgOp.getDpsInits()[idx]; 518 519 auto reduction = b.create<linalg::ReduceOp>( 520 loc, partialResult, init, partialReductionDims, 521 [&linalgOp, &idx](OpBuilder &b, Location loc, ValueRange inputs) { 522 // Get the combiner op. 523 SmallVector<Operation *, 4> combinerOps; 524 matchReduction(linalgOp.getRegionOutputArgs(), idx, combinerOps); 525 Operation *clonedReductionOp = b.clone(*combinerOps[0]); 526 // Combine the input at idx and output at numInits + idx. 527 clonedReductionOp->setOperand(0, inputs[0]); 528 clonedReductionOp->setOperand(1, inputs[1]); 529 b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0)); 530 }); 531 532 mergeOperations.push_back(reduction); 533 replacements.push_back(reduction->getResult(0)); 534 } 535 536 return MergeResult{mergeOperations, replacements}; 537 } 538 539 LogicalResult getPartialResultTilePosition( 540 Operation *op, OpBuilder &b, unsigned resultNumber, 541 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, 542 SmallVector<OpFoldResult> &resultOffsets, 543 SmallVector<OpFoldResult> &resultSizes, 544 ArrayRef<int> reductionDims) const { 545 auto linalgOp = cast<LinalgOp>(op); 546 547 AffineMap partialMap = 548 getPartialResultAffineMap(linalgOp, reductionDims, resultNumber); 549 for (AffineExpr dimExpr : partialMap.getResults()) { 550 unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition(); 551 resultSizes.push_back(sizes[dim]); 552 553 if (llvm::find(reductionDims, dim) != reductionDims.end()) { 554 // Reduction dims are reduced, and are always outputed in the same 555 // place. So use offset 0 for them. 556 resultOffsets.push_back(b.getIndexAttr(0)); 557 } else { 558 resultOffsets.push_back(offsets[dim]); 559 } 560 } 561 562 return success(); 563 } 564 }; 565 566 } // namespace 567 568 template <typename OpType> 569 static void registerOne(MLIRContext *ctx) { 570 OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx); 571 OpType::template attachInterface<LinalgOpPartialReductionInterface<OpType>>( 572 *ctx); 573 } 574 575 /// Variadic helper function. 576 template <typename... OpTypes> 577 static void registerAll(MLIRContext *ctx) { 578 (registerOne<OpTypes>(ctx), ...); 579 } 580 581 #define GET_OP_LIST 582 583 void mlir::linalg::registerTilingInterfaceExternalModels( 584 DialectRegistry ®istry) { 585 registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) { 586 registerOne<linalg::GenericOp>(ctx); 587 registerAll< 588 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 589 >(ctx); 590 }); 591 } 592