1cf6a7c19SMahesh Ravishankar //===- TilingInterfaceImpl.cpp - Implementation of TilingInterface -------===// 2cf6a7c19SMahesh Ravishankar // 3cf6a7c19SMahesh Ravishankar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4cf6a7c19SMahesh Ravishankar // See https://llvm.org/LICENSE.txt for license information. 5cf6a7c19SMahesh Ravishankar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6cf6a7c19SMahesh Ravishankar // 7cf6a7c19SMahesh Ravishankar //===----------------------------------------------------------------------===// 8cf6a7c19SMahesh Ravishankar 9cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" 10cf6a7c19SMahesh Ravishankar 113310fe55SThomas Raoux #include "mlir/Analysis/SliceAnalysis.h" 12cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Affine/IR/AffineOps.h" 13abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 14abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/Utils/Utils.h" 15cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Linalg/IR/Linalg.h" 16cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Linalg/Utils/Utils.h" 176f03a10eSMahesh Ravishankar #include "mlir/Dialect/MemRef/IR/MemRef.h" 18cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Tensor/IR/Tensor.h" 196f03a10eSMahesh Ravishankar #include "mlir/Dialect/Utils/StaticValueUtils.h" 20cf6a7c19SMahesh Ravishankar #include "mlir/Interfaces/TilingInterface.h" 21a1fe1f5fSKazu Hirata #include <optional> 22cf6a7c19SMahesh Ravishankar 23cf6a7c19SMahesh Ravishankar using namespace mlir; 24cf6a7c19SMahesh Ravishankar using namespace mlir::linalg; 25cf6a7c19SMahesh Ravishankar 266f03a10eSMahesh Ravishankar //===----------------------------------------------------------------------===// 276f03a10eSMahesh Ravishankar // Utility methods for implementation of Tiling Interface for Linalg ops 286f03a10eSMahesh Ravishankar //===----------------------------------------------------------------------===// 29cf6a7c19SMahesh Ravishankar 306f03a10eSMahesh Ravishankar /// Return the SSA values that represent the data point accessed using a given 316f03a10eSMahesh Ravishankar /// `indexingMap` for a given point in the iteration space represented by `ivs`. 326f03a10eSMahesh Ravishankar static SmallVector<Value> getIndicesForAccess(OpBuilder &b, Location loc, 336f03a10eSMahesh Ravishankar AffineMap indexingMap, 346f03a10eSMahesh Ravishankar ValueRange ivs) { 356f03a10eSMahesh Ravishankar SmallVector<Value> indices; 366f03a10eSMahesh Ravishankar indices.reserve(indexingMap.getNumResults()); 376f03a10eSMahesh Ravishankar for (auto result : indexingMap.getResults()) { 386f03a10eSMahesh Ravishankar AffineMap m = AffineMap::get(indexingMap.getNumDims(), 396f03a10eSMahesh Ravishankar indexingMap.getNumSymbols(), result); 404c48f016SMatthias Springer Value v = b.create<affine::AffineApplyOp>(loc, m, ivs); 416f03a10eSMahesh Ravishankar indices.push_back(v); 426f03a10eSMahesh Ravishankar } 436f03a10eSMahesh Ravishankar return indices; 446f03a10eSMahesh Ravishankar } 456f03a10eSMahesh Ravishankar 466f03a10eSMahesh Ravishankar /// Method to inline the payload of a `linalgOp` given the iteration space 476f03a10eSMahesh Ravishankar /// point and values for the arguments of the payload. 486f03a10eSMahesh Ravishankar static LogicalResult inlinePayload(OpBuilder &b, LinalgOp linalgOp, 496f03a10eSMahesh Ravishankar ValueRange ivs, ValueRange argValues) { 506f03a10eSMahesh Ravishankar Block *body = linalgOp.getBlock(); 514d67b278SJeff Niu IRMapping map; 526f03a10eSMahesh Ravishankar map.map(body->getArguments(), argValues); 536f03a10eSMahesh Ravishankar for (auto &op : body->without_terminator()) { 546f03a10eSMahesh Ravishankar if (auto indexOp = dyn_cast<IndexOp>(&op)) { 55d3b3f765SJacques Pienaar map.map(indexOp.getResult(), ivs[indexOp.getDim()]); 566f03a10eSMahesh Ravishankar continue; 576f03a10eSMahesh Ravishankar } 586f03a10eSMahesh Ravishankar b.clone(op, map); 596f03a10eSMahesh Ravishankar } 606f03a10eSMahesh Ravishankar 616f03a10eSMahesh Ravishankar Operation *terminator = body->getTerminator(); 626f03a10eSMahesh Ravishankar Location loc = terminator->getLoc(); 6328e5e3d6SMehdi Amini for (const auto &operand : llvm::enumerate(terminator->getOperands())) { 646f03a10eSMahesh Ravishankar Value toStore = map.lookupOrDefault(operand.value()); 65b4db15a9SAlexander Belyaev OpOperand *storeInto = linalgOp.getDpsInitOperand(operand.index()); 666f03a10eSMahesh Ravishankar auto indices = getIndicesForAccess( 671227b8abSOleg Shyshkov b, loc, linalgOp.getMatchingIndexingMap(storeInto), ivs); 68b4db15a9SAlexander Belyaev b.create<memref::StoreOp>( 69b4db15a9SAlexander Belyaev loc, toStore, linalgOp.getDpsInitOperand(operand.index())->get(), 706f03a10eSMahesh Ravishankar indices); 716f03a10eSMahesh Ravishankar } 726f03a10eSMahesh Ravishankar return success(); 736f03a10eSMahesh Ravishankar } 746f03a10eSMahesh Ravishankar 756f03a10eSMahesh Ravishankar //===----------------------------------------------------------------------===// 766f03a10eSMahesh Ravishankar // External Model for implementing `TilingInterface` for `LinalgOp`s. 776f03a10eSMahesh Ravishankar //===----------------------------------------------------------------------===// 786f03a10eSMahesh Ravishankar 796f03a10eSMahesh Ravishankar namespace { 80cf6a7c19SMahesh Ravishankar /// External model implementation of TilingInterface for LinalgOps. An external 81cf6a7c19SMahesh Ravishankar /// model implementation is used for now till the use of `TilingInterface` is 82cf6a7c19SMahesh Ravishankar /// on-par with the current Linalg tiling + fusion patterns. Once it is 83cf6a7c19SMahesh Ravishankar /// maybe possible to move this into the op-definition (though there are 84cf6a7c19SMahesh Ravishankar /// advantages to leaving it as an external model) 85cf6a7c19SMahesh Ravishankar template <typename LinalgOpTy> 86cf6a7c19SMahesh Ravishankar struct LinalgOpTilingInterface 87cf6a7c19SMahesh Ravishankar : public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>, 88cf6a7c19SMahesh Ravishankar LinalgOpTy> { 89cf6a7c19SMahesh Ravishankar /// Return the loop iterator type. 904f1c1242SOleg Shyshkov SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { 91cf6a7c19SMahesh Ravishankar LinalgOpTy concreteOp = cast<LinalgOpTy>(op); 92e6598b05SOleg Shyshkov return concreteOp.getIteratorTypesArray(); 93cf6a7c19SMahesh Ravishankar } 94cf6a7c19SMahesh Ravishankar 95cf6a7c19SMahesh Ravishankar /// Return the iteration domain range. 96cf6a7c19SMahesh Ravishankar SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { 972f637fe7SMahesh Ravishankar OpBuilder::InsertionGuard g(b); 982f637fe7SMahesh Ravishankar b.setInsertionPoint(op); 99cf6a7c19SMahesh Ravishankar Location loc = op->getLoc(); 100cf6a7c19SMahesh Ravishankar LinalgOp linalgOp = cast<LinalgOp>(op); 101e99fae89SAlex Zinenko SmallVector<OpFoldResult> allShapesSizes = 102e99fae89SAlex Zinenko linalgOp.createFlatListOfOperandDims(b, loc); 103cf6a7c19SMahesh Ravishankar AffineMap map = linalgOp.getShapesToLoopsMap(); 104e99fae89SAlex Zinenko 105e99fae89SAlex Zinenko return llvm::to_vector( 106e99fae89SAlex Zinenko llvm::map_range(map.getResults(), [&](AffineExpr loopExpr) { 1074c48f016SMatthias Springer OpFoldResult ofr = affine::makeComposedFoldedAffineApply( 1084c48f016SMatthias Springer b, loc, loopExpr, allShapesSizes); 109e99fae89SAlex Zinenko return Range{b.getIndexAttr(0), ofr, b.getIndexAttr(1)}; 110cf6a7c19SMahesh Ravishankar })); 111cf6a7c19SMahesh Ravishankar } 112cf6a7c19SMahesh Ravishankar 1132b2ce50fSAbhishek Varma /// Instantiate the tiled implementation of the operation. 114809e3d8cSMahesh Ravishankar FailureOr<TilingResult> 11554794284SMatthias Springer getTiledImplementation(Operation *op, OpBuilder &b, 116cf6a7c19SMahesh Ravishankar ArrayRef<OpFoldResult> offsets, 11754794284SMatthias Springer ArrayRef<OpFoldResult> sizes) const { 118cf6a7c19SMahesh Ravishankar // Leave the `sizeBounds` value empty. That is only needed when the `sizes` 119cf6a7c19SMahesh Ravishankar // specified could lead to out of bounds accesses. 120cf6a7c19SMahesh Ravishankar Location loc = op->getLoc(); 121cf6a7c19SMahesh Ravishankar LinalgOp linalgOp = cast<LinalgOp>(op); 122a7cccb9cSAlexander Belyaev SmallVector<Value> valuesToTile = linalgOp->getOperands(); 123d5f0969cSMaheshRavishankar SmallVector<Value> tiledOperands = makeTiledShapes( 124e99fae89SAlex Zinenko b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true); 125d5f0969cSMaheshRavishankar SmallVector<Operation *> generatedSlices = llvm::map_to_vector( 126d5f0969cSMaheshRavishankar llvm::make_filter_range( 127d5f0969cSMaheshRavishankar tiledOperands, 128d5f0969cSMaheshRavishankar [](Value v) -> bool { 129d5f0969cSMaheshRavishankar return isa_and_nonnull<tensor::ExtractSliceOp, memref::SubViewOp>( 130d5f0969cSMaheshRavishankar v.getDefiningOp()); 131d5f0969cSMaheshRavishankar }), 132d5f0969cSMaheshRavishankar [](Value v) -> Operation * { return v.getDefiningOp(); }); 133cf6a7c19SMahesh Ravishankar 134a7cccb9cSAlexander Belyaev SmallVector<Type> resultTensorTypes = 135a7cccb9cSAlexander Belyaev getTensorOutputTypes(linalgOp, tiledOperands); 136cf6a7c19SMahesh Ravishankar 137f286af29SAlexander Belyaev Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands); 138e99fae89SAlex Zinenko offsetIndices(b, cast<LinalgOp>(tiledOp), offsets); 139cf6a7c19SMahesh Ravishankar 140d5f0969cSMaheshRavishankar return TilingResult{ 141d5f0969cSMaheshRavishankar {tiledOp}, SmallVector<Value>(tiledOp->getResults()), generatedSlices}; 142cf6a7c19SMahesh Ravishankar } 143cf6a7c19SMahesh Ravishankar 1442b2ce50fSAbhishek Varma /// Utility to fetch the offsets and sizes when applied as per the indexing 1452b2ce50fSAbhishek Varma /// map of the linalg op. This helps in fusing the linalg op as a consumer of 1462b2ce50fSAbhishek Varma /// a given slice op. 1472b2ce50fSAbhishek Varma void 1482b2ce50fSAbhishek Varma getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap, 1492b2ce50fSAbhishek Varma ArrayRef<OpFoldResult> offsets, 1502b2ce50fSAbhishek Varma ArrayRef<OpFoldResult> sizes, 1512b2ce50fSAbhishek Varma SmallVectorImpl<OpFoldResult> &mappedOffsets, 1522b2ce50fSAbhishek Varma SmallVectorImpl<OpFoldResult> &mappedSizes) const { 1532b2ce50fSAbhishek Varma unsigned numLoops = linalgOp.getNumLoops(); 1542b2ce50fSAbhishek Varma auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation()); 1552b2ce50fSAbhishek Varma mappedOffsets.resize(numLoops); 1562b2ce50fSAbhishek Varma mappedSizes.resize(numLoops); 1572b2ce50fSAbhishek Varma if (!indexingMap.isPermutation()) { 1582b2ce50fSAbhishek Varma SmallVector<Range> iterationDomain = 1592b2ce50fSAbhishek Varma tilingInterfaceOp.getIterationDomain(b); 1602b2ce50fSAbhishek Varma for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) { 1612b2ce50fSAbhishek Varma mappedOffsets[index] = value.offset; 1622b2ce50fSAbhishek Varma mappedSizes[index] = value.size; 1632b2ce50fSAbhishek Varma } 1642b2ce50fSAbhishek Varma } 1652b2ce50fSAbhishek Varma for (const auto &&[index, value] : 1662b2ce50fSAbhishek Varma llvm::enumerate(indexingMap.getResults())) { 1672b2ce50fSAbhishek Varma unsigned dimPosition = cast<AffineDimExpr>(value).getPosition(); 1682b2ce50fSAbhishek Varma mappedOffsets[dimPosition] = offsets[index]; 1692b2ce50fSAbhishek Varma mappedSizes[dimPosition] = sizes[index]; 1702b2ce50fSAbhishek Varma } 1712b2ce50fSAbhishek Varma } 1722b2ce50fSAbhishek Varma 1732b2ce50fSAbhishek Varma /// Method to return the position of the result tile computed by the tiled 1742b2ce50fSAbhishek Varma /// operation. 1752b2ce50fSAbhishek Varma LogicalResult getIterationDomainTileFromOperandTile( 1762b2ce50fSAbhishek Varma Operation *op, OpBuilder &b, unsigned operandNumber, 1772b2ce50fSAbhishek Varma ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, 1782b2ce50fSAbhishek Varma SmallVectorImpl<OpFoldResult> &iterDomainOffsets, 1792b2ce50fSAbhishek Varma SmallVectorImpl<OpFoldResult> &iterDomainSizes) const { 1802b2ce50fSAbhishek Varma auto linalgOp = cast<LinalgOp>(op); 1812b2ce50fSAbhishek Varma 1822b2ce50fSAbhishek Varma // Check that the indexing map used for the operand is a projected 1832b2ce50fSAbhishek Varma // permutation. This could be relaxed with a more general approach that can 1842b2ce50fSAbhishek Varma // map the offsets and sizes from the operand to iteration space tiles 1852b2ce50fSAbhishek Varma // (filling in full extent for dimensions not used to access the result). 1862b2ce50fSAbhishek Varma AffineMap indexingMap = 1872b2ce50fSAbhishek Varma linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber)); 1882b2ce50fSAbhishek Varma if (!indexingMap.isProjectedPermutation()) { 1892b2ce50fSAbhishek Varma return op->emitError() 1902b2ce50fSAbhishek Varma << "unhandled get iter domain position when operand is not " 1912b2ce50fSAbhishek Varma "accessed using a permuted projection"; 1922b2ce50fSAbhishek Varma } 1932b2ce50fSAbhishek Varma 1942b2ce50fSAbhishek Varma getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes, 1952b2ce50fSAbhishek Varma iterDomainOffsets, iterDomainSizes); 1962b2ce50fSAbhishek Varma return success(); 1972b2ce50fSAbhishek Varma } 1982b2ce50fSAbhishek Varma 1992b2ce50fSAbhishek Varma /// Return the details of the output tile generated by the tiled 2002b2ce50fSAbhishek Varma /// implementation. 201cf6a7c19SMahesh Ravishankar LogicalResult 202cf6a7c19SMahesh Ravishankar getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, 203cf6a7c19SMahesh Ravishankar ArrayRef<OpFoldResult> offsets, 204cf6a7c19SMahesh Ravishankar ArrayRef<OpFoldResult> sizes, 205f220c359SOleksandr "Alex" Zinenko SmallVector<OpFoldResult> &resultOffsets, 206f220c359SOleksandr "Alex" Zinenko SmallVector<OpFoldResult> &resultSizes) const { 207cf6a7c19SMahesh Ravishankar Location loc = op->getLoc(); 208cf6a7c19SMahesh Ravishankar LinalgOp linalgOp = cast<LinalgOp>(op); 209cf6a7c19SMahesh Ravishankar 210cf6a7c19SMahesh Ravishankar AffineExpr d0; 211cf6a7c19SMahesh Ravishankar bindDims(b.getContext(), d0); 212e99fae89SAlex Zinenko SmallVector<OpFoldResult> subShapeSizes = 213e99fae89SAlex Zinenko llvm::to_vector(llvm::map_range(sizes, [&](OpFoldResult ofr) { 2144c48f016SMatthias Springer return affine::makeComposedFoldedAffineApply(b, loc, d0 - 1, ofr); 215cf6a7c19SMahesh Ravishankar })); 216e99fae89SAlex Zinenko 217b4db15a9SAlexander Belyaev OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber); 2181227b8abSOleg Shyshkov SliceParameters sliceParams = computeSliceParameters( 2191227b8abSOleg Shyshkov b, loc, outOperand->get(), sizes, 2201227b8abSOleg Shyshkov linalgOp.getMatchingIndexingMap(outOperand), offsets, 221cf6a7c19SMahesh Ravishankar /*ubs*/ {}, subShapeSizes, true); 22206c02d5dSThomas Raoux resultOffsets = sliceParams.offsets; 22306c02d5dSThomas Raoux resultSizes = sliceParams.sizes; 224cf6a7c19SMahesh Ravishankar return success(); 225cf6a7c19SMahesh Ravishankar } 2262f637fe7SMahesh Ravishankar 2277ef08eacSYun-Fly LogicalResult getIterationDomainTileFromResultTile( 2287ef08eacSYun-Fly Operation *op, OpBuilder &b, unsigned resultNumber, 2297ef08eacSYun-Fly ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, 2307ef08eacSYun-Fly SmallVectorImpl<OpFoldResult> &iterDomainOffsets, 2317ef08eacSYun-Fly SmallVectorImpl<OpFoldResult> &iterDomainSizes) const { 2322f637fe7SMahesh Ravishankar auto linalgOp = cast<LinalgOp>(op); 2332f637fe7SMahesh Ravishankar 2342f637fe7SMahesh Ravishankar // Check that the indexing map used for the output is a projected 2352f637fe7SMahesh Ravishankar // permutation. This could be relaxed with a more general approach that can 2362f637fe7SMahesh Ravishankar // map the offsets and sizes from the result to iteration space tiles 2372f637fe7SMahesh Ravishankar // (filling in full extent for dimensions not used to access the result). 2382f637fe7SMahesh Ravishankar AffineMap indexingMap = 2391227b8abSOleg Shyshkov linalgOp.getIndexingMapMatchingResult(op->getResult(resultNumber)); 2402f637fe7SMahesh Ravishankar if (!indexingMap.isProjectedPermutation()) { 2412f637fe7SMahesh Ravishankar return op->emitOpError( 2422f637fe7SMahesh Ravishankar "unhandled tiled implementation generation when result is not " 2432f637fe7SMahesh Ravishankar "accessed using a permuted projection"); 2442f637fe7SMahesh Ravishankar } 2457ef08eacSYun-Fly 2462b2ce50fSAbhishek Varma getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes, 2477ef08eacSYun-Fly iterDomainOffsets, iterDomainSizes); 2487ef08eacSYun-Fly return success(); 2497ef08eacSYun-Fly } 2507ef08eacSYun-Fly 2517ef08eacSYun-Fly FailureOr<TilingResult> 2527ef08eacSYun-Fly generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, 2537ef08eacSYun-Fly ArrayRef<OpFoldResult> offsets, 2547ef08eacSYun-Fly ArrayRef<OpFoldResult> sizes) const { 2557ef08eacSYun-Fly SmallVector<OpFoldResult> mappedOffsets, mappedSizes; 2567ef08eacSYun-Fly if (failed(getIterationDomainTileFromResultTile( 2577ef08eacSYun-Fly op, b, resultNumber, offsets, sizes, mappedOffsets, mappedSizes))) { 2587ef08eacSYun-Fly return failure(); 2597ef08eacSYun-Fly } 2602f637fe7SMahesh Ravishankar auto tilingInterfaceOp = cast<TilingInterface>(op); 261809e3d8cSMahesh Ravishankar FailureOr<TilingResult> tilingResult = 2622b2ce50fSAbhishek Varma tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes); 2632b2ce50fSAbhishek Varma 2642b2ce50fSAbhishek Varma if (failed(tilingResult)) 2652b2ce50fSAbhishek Varma return failure(); 2662b2ce50fSAbhishek Varma 267809e3d8cSMahesh Ravishankar if (tilingResult->tiledOps.size() != 1) 2682f637fe7SMahesh Ravishankar return op->emitOpError("failed to generate tiled implementation"); 2692f637fe7SMahesh Ravishankar 270809e3d8cSMahesh Ravishankar return TilingResult{ 271809e3d8cSMahesh Ravishankar tilingResult->tiledOps, 272d5f0969cSMaheshRavishankar SmallVector<Value>{tilingResult->tiledValues[resultNumber]}, 273d5f0969cSMaheshRavishankar tilingResult->generatedSlices}; 2742f637fe7SMahesh Ravishankar } 2756f03a10eSMahesh Ravishankar 2762b2ce50fSAbhishek Varma /// Method to generate the tiled implementation of an operation from the tile 2772b2ce50fSAbhishek Varma /// of the operand. 2782b2ce50fSAbhishek Varma FailureOr<TilingResult> getTiledImplementationFromOperandTile( 2792b2ce50fSAbhishek Varma Operation *op, OpBuilder &b, unsigned operandNumber, 2802b2ce50fSAbhishek Varma ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const { 2812b2ce50fSAbhishek Varma SmallVector<OpFoldResult> mappedOffsets, mappedSizes; 2822b2ce50fSAbhishek Varma if (failed(getIterationDomainTileFromOperandTile( 2832b2ce50fSAbhishek Varma op, b, operandNumber, offsets, sizes, mappedOffsets, 2842b2ce50fSAbhishek Varma mappedSizes))) { 2852b2ce50fSAbhishek Varma return failure(); 2862b2ce50fSAbhishek Varma } 2872b2ce50fSAbhishek Varma return getTiledImplementation(op, b, mappedOffsets, mappedSizes); 2882b2ce50fSAbhishek Varma } 2892b2ce50fSAbhishek Varma 2906f03a10eSMahesh Ravishankar LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder, 2916f03a10eSMahesh Ravishankar Location loc, 2926f03a10eSMahesh Ravishankar ValueRange ivs) const { 2936f03a10eSMahesh Ravishankar auto linalgOp = cast<LinalgOp>(op); 2940a8e3dd4SMatthias Springer if (!linalgOp.hasPureBufferSemantics()) 2956f03a10eSMahesh Ravishankar return op->emitOpError("expected operation to have buffer semantics"); 2966f03a10eSMahesh Ravishankar 2976f03a10eSMahesh Ravishankar SmallVector<Value> indexedValues; 298a7cccb9cSAlexander Belyaev indexedValues.reserve(linalgOp->getNumOperands()); 2996f03a10eSMahesh Ravishankar Location linalgOpLoc = op->getLoc(); 3006f03a10eSMahesh Ravishankar /// Load the data corresponding to the block arguments that 3016f03a10eSMahesh Ravishankar /// represent input operands. 302a7cccb9cSAlexander Belyaev for (OpOperand &operand : linalgOp->getOpOperands()) { 303a7cccb9cSAlexander Belyaev if (!linalgOp.payloadUsesValueFromOperand(&operand)) { 3046f03a10eSMahesh Ravishankar indexedValues.push_back(nullptr); 3056f03a10eSMahesh Ravishankar continue; 3066f03a10eSMahesh Ravishankar } 307a7cccb9cSAlexander Belyaev if (linalgOp.isScalar(&operand)) { 308a7cccb9cSAlexander Belyaev indexedValues.push_back(operand.get()); 3096f03a10eSMahesh Ravishankar continue; 3106f03a10eSMahesh Ravishankar } 3116f03a10eSMahesh Ravishankar SmallVector<Value> indices = getIndicesForAccess( 312a7cccb9cSAlexander Belyaev builder, linalgOpLoc, linalgOp.getMatchingIndexingMap(&operand), ivs); 3136f03a10eSMahesh Ravishankar Value load = 314a7cccb9cSAlexander Belyaev builder.create<memref::LoadOp>(linalgOpLoc, operand.get(), indices); 3156f03a10eSMahesh Ravishankar indexedValues.push_back(load); 3166f03a10eSMahesh Ravishankar } 3176f03a10eSMahesh Ravishankar 3186f03a10eSMahesh Ravishankar /// Inline the op payload and store the result. 3196f03a10eSMahesh Ravishankar return inlinePayload(builder, linalgOp, ivs, indexedValues); 3206f03a10eSMahesh Ravishankar } 321cf6a7c19SMahesh Ravishankar }; 322cf6a7c19SMahesh Ravishankar 3233310fe55SThomas Raoux //===----------------------------------------------------------------------===// 3243310fe55SThomas Raoux // External Model for implementing `PartialReductionInterface` for `LinalgOp`s. 3253310fe55SThomas Raoux //===----------------------------------------------------------------------===// 3263310fe55SThomas Raoux 327*91bbebc7SKunwar Grover /// Return an AffineMap for a partial result for the given result number, 328*91bbebc7SKunwar Grover /// assuming the partial tiling strategy is outer-reduction loop + 329*91bbebc7SKunwar Grover /// inner-parallel tile. The returned AffineMap can be used as the replacement 330*91bbebc7SKunwar Grover /// AffineMap for the inner-parallel tile linalg op for the given result number. 331*91bbebc7SKunwar Grover /// 332*91bbebc7SKunwar Grover /// The new AffineMap is the old AffineMap with reduction dimensions appended 333*91bbebc7SKunwar Grover /// at end. 334*91bbebc7SKunwar Grover static AffineMap getPartialResultAffineMap(LinalgOp linalgOp, 335*91bbebc7SKunwar Grover ArrayRef<int> reductionDims, 336*91bbebc7SKunwar Grover unsigned resultNumber) { 337*91bbebc7SKunwar Grover AffineMap map = 338*91bbebc7SKunwar Grover linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(resultNumber)); 339*91bbebc7SKunwar Grover for (int redPos : reductionDims) { 340*91bbebc7SKunwar Grover map = map.insertResult(getAffineDimExpr(redPos, linalgOp.getContext()), 341*91bbebc7SKunwar Grover map.getNumResults()); 342*91bbebc7SKunwar Grover } 343*91bbebc7SKunwar Grover return map; 344*91bbebc7SKunwar Grover } 345*91bbebc7SKunwar Grover 346*91bbebc7SKunwar Grover /// External model implementation of PartialReductionInterface for 347*91bbebc7SKunwar Grover /// LinalgOps. 3483310fe55SThomas Raoux template <typename LinalgOpTy> 3493310fe55SThomas Raoux struct LinalgOpPartialReductionInterface 3503310fe55SThomas Raoux : public PartialReductionOpInterface::ExternalModel< 3513310fe55SThomas Raoux LinalgOpPartialReductionInterface<LinalgOpTy>, LinalgOpTy> { 3529329b20dSKunwar Grover FailureOr<SmallVector<Value>> generateInitialTensorForPartialReduction( 3533310fe55SThomas Raoux Operation *op, OpBuilder &b, Location loc, ArrayRef<OpFoldResult> sizes, 3543310fe55SThomas Raoux ArrayRef<int> reductionDims) const { 3553310fe55SThomas Raoux auto linalgOp = cast<LinalgOp>(op); 3563310fe55SThomas Raoux OpBuilder::InsertionGuard guard(b); 3572cc5f5d4SGroverkss 3580a8e3dd4SMatthias Springer if (linalgOp.hasPureBufferSemantics()) 3593310fe55SThomas Raoux return op->emitOpError("expected operation to have tensor semantics"); 3609329b20dSKunwar Grover 361*91bbebc7SKunwar Grover // LinalgOp implements TilingInterface. 362*91bbebc7SKunwar Grover auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation()); 363*91bbebc7SKunwar Grover SmallVector<OpFoldResult> shape = 364*91bbebc7SKunwar Grover llvm::map_to_vector(tilingInterfaceOp.getIterationDomain(b), 365*91bbebc7SKunwar Grover [](Range x) { return x.size; }); 366*91bbebc7SKunwar Grover 367*91bbebc7SKunwar Grover SmallVector<OpFoldResult> tiledShape; 368*91bbebc7SKunwar Grover for (auto [tileSize, dimSize] : llvm::zip_equal(sizes, shape)) { 369*91bbebc7SKunwar Grover if (isZeroIndex(tileSize)) { 370*91bbebc7SKunwar Grover tiledShape.push_back(dimSize); 371*91bbebc7SKunwar Grover } else { 372*91bbebc7SKunwar Grover tiledShape.push_back(tileSize); 373*91bbebc7SKunwar Grover } 374*91bbebc7SKunwar Grover } 375*91bbebc7SKunwar Grover 3769329b20dSKunwar Grover SmallVector<Value> inits; 3779329b20dSKunwar Grover for (int initIdx = 0, e = linalgOp.getNumDpsInits(); initIdx < e; 3789329b20dSKunwar Grover ++initIdx) { 3793310fe55SThomas Raoux SmallVector<Operation *, 4> combinerOps; 3809329b20dSKunwar Grover if (!matchReduction(linalgOp.getRegionOutputArgs(), initIdx, 3819329b20dSKunwar Grover combinerOps) || 3823310fe55SThomas Raoux combinerOps.size() != 1) 3833310fe55SThomas Raoux return op->emitOpError("Failed to anaysis the reduction operation."); 3843310fe55SThomas Raoux 3853310fe55SThomas Raoux Operation *reductionOp = combinerOps[0]; 386f8e59b09SQuentin Colombet std::optional<TypedAttr> identity = arith::getNeutralElement(reductionOp); 3873310fe55SThomas Raoux if (!identity.has_value()) 3883310fe55SThomas Raoux return op->emitOpError( 3893310fe55SThomas Raoux "Failed to get an identity value for the reduction operation."); 3903310fe55SThomas Raoux 391*91bbebc7SKunwar Grover // Append the new partial result dimensions. 392*91bbebc7SKunwar Grover AffineMap partialMap = 393*91bbebc7SKunwar Grover getPartialResultAffineMap(linalgOp, reductionDims, initIdx); 394*91bbebc7SKunwar Grover SmallVector<OpFoldResult> partialResultShape; 395*91bbebc7SKunwar Grover for (AffineExpr dimExpr : partialMap.getResults()) { 396*91bbebc7SKunwar Grover auto dim = cast<AffineDimExpr>(dimExpr); 397*91bbebc7SKunwar Grover partialResultShape.push_back(tiledShape[dim.getPosition()]); 398*91bbebc7SKunwar Grover } 3992cc5f5d4SGroverkss 400*91bbebc7SKunwar Grover Type elType = 401*91bbebc7SKunwar Grover getElementTypeOrSelf(linalgOp->getResult(initIdx).getType()); 402*91bbebc7SKunwar Grover Value emptyTensor = 403*91bbebc7SKunwar Grover b.create<tensor::EmptyOp>(loc, partialResultShape, elType); 4043310fe55SThomas Raoux Value constantOp = b.create<arith::ConstantOp>(loc, *identity); 4053310fe55SThomas Raoux auto identityTensor = 4063310fe55SThomas Raoux b.create<linalg::FillOp>(loc, constantOp, emptyTensor); 4079329b20dSKunwar Grover inits.push_back(identityTensor.getResult(0)); 4089329b20dSKunwar Grover } 4099329b20dSKunwar Grover 4109329b20dSKunwar Grover return inits; 4113310fe55SThomas Raoux } 4123310fe55SThomas Raoux 413b99d0b34SMaheshRavishankar FailureOr<TilingResult> 414b99d0b34SMaheshRavishankar tileToPartialReduction(Operation *op, OpBuilder &b, Location loc, 415b99d0b34SMaheshRavishankar ValueRange init, ArrayRef<OpFoldResult> offsets, 4163310fe55SThomas Raoux ArrayRef<OpFoldResult> sizes, 4173310fe55SThomas Raoux ArrayRef<int> reductionDims) const { 4183310fe55SThomas Raoux OpBuilder::InsertionGuard guard(b); 4193310fe55SThomas Raoux auto linalgOp = cast<LinalgOp>(op); 4203310fe55SThomas Raoux 4219329b20dSKunwar Grover // Step 1. Extend init maps to have reduction dimension dims, since we 4229329b20dSKunwar Grover // are converting them to parallel dimensions. 4239329b20dSKunwar Grover SmallVector<AffineMap> newInitMaps; 4249329b20dSKunwar Grover newInitMaps.reserve(linalgOp.getNumDpsInits()); 4259329b20dSKunwar Grover for (int idx : llvm::seq<int>(0, linalgOp.getNumDpsInits())) { 4269329b20dSKunwar Grover // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace 4279329b20dSKunwar Grover // this with a for range loop when we have it. 4289329b20dSKunwar Grover AffineMap newMap = 429*91bbebc7SKunwar Grover getPartialResultAffineMap(linalgOp, reductionDims, idx); 4309329b20dSKunwar Grover newInitMaps.push_back(newMap); 4313310fe55SThomas Raoux } 4323310fe55SThomas Raoux 4339329b20dSKunwar Grover // Step 2a: Extract a slice of the input operands. 434d5f0969cSMaheshRavishankar SmallVector<Value> tiledInputs = makeTiledShapes( 4359329b20dSKunwar Grover b, loc, linalgOp, linalgOp.getDpsInputs(), offsets, sizes, {}, true); 436d5f0969cSMaheshRavishankar SmallVector<Operation *> generatedSlices = llvm::map_to_vector( 437d5f0969cSMaheshRavishankar llvm::make_filter_range( 438d5f0969cSMaheshRavishankar tiledInputs, [](Value v) -> bool { return v.getDefiningOp(); }), 439d5f0969cSMaheshRavishankar [](Value v) -> Operation * { return v.getDefiningOp(); }); 4403310fe55SThomas Raoux 4419329b20dSKunwar Grover // Step 2b: Extract a slice of the init operands. 4429329b20dSKunwar Grover SmallVector<Value, 1> tiledInits; 4439329b20dSKunwar Grover for (auto [valueMap, valueToTile] : llvm::zip_equal(newInitMaps, init)) { 4449329b20dSKunwar Grover int64_t initRank = valueMap.getNumResults(); 4459329b20dSKunwar Grover SmallVector<OpFoldResult> initOffset(initRank, b.getIndexAttr(0)); 4469329b20dSKunwar Grover SmallVector<OpFoldResult> initStride(initRank, b.getIndexAttr(1)); 4479329b20dSKunwar Grover SmallVector<OpFoldResult> initSizes; 4489329b20dSKunwar Grover for (AffineExpr dimExpr : valueMap.getResults()) { 4499329b20dSKunwar Grover auto dim = cast<AffineDimExpr>(dimExpr); 4509329b20dSKunwar Grover initSizes.push_back(sizes[dim.getPosition()]); 4519329b20dSKunwar Grover } 4529329b20dSKunwar Grover // TODO: Use SubsetExtractOpInterface here once available. 4539329b20dSKunwar Grover auto extractSlice = b.create<tensor::ExtractSliceOp>( 4549329b20dSKunwar Grover loc, valueToTile, initOffset, initSizes, initStride); 4559329b20dSKunwar Grover tiledInits.push_back(extractSlice); 456d5f0969cSMaheshRavishankar generatedSlices.push_back(extractSlice); 4579329b20dSKunwar Grover } 4583310fe55SThomas Raoux 4599329b20dSKunwar Grover // Update the indexing maps. 4609329b20dSKunwar Grover SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray(); 4619329b20dSKunwar Grover // Change the init maps. 4629329b20dSKunwar Grover for (int idx : llvm::seq<int>(0, linalgOp.getNumDpsInits())) { 4639329b20dSKunwar Grover // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace 4649329b20dSKunwar Grover // this with a for range loop when we have it. 4659329b20dSKunwar Grover OpOperand *initOperand = linalgOp.getDpsInitOperand(idx); 4669329b20dSKunwar Grover int64_t mapIdx = linalgOp.getIndexingMapIndex(initOperand); 4679329b20dSKunwar Grover newMaps[mapIdx] = newInitMaps[idx]; 4689329b20dSKunwar Grover } 4699329b20dSKunwar Grover 4709329b20dSKunwar Grover // Step 3. Change the reduction dim iterator types. 471e6598b05SOleg Shyshkov SmallVector<utils::IteratorType> newIteratorTypes = 472e6598b05SOleg Shyshkov linalgOp.getIteratorTypesArray(); 4732cc5f5d4SGroverkss for (int dim : reductionDims) 4742cc5f5d4SGroverkss newIteratorTypes[dim] = utils::IteratorType::parallel; 4759329b20dSKunwar Grover 4769329b20dSKunwar Grover // Step 4. Create the new generic op. 4773310fe55SThomas Raoux auto genericOp = 4789329b20dSKunwar Grover b.create<GenericOp>(loc, ValueRange(tiledInits).getTypes(), tiledInputs, 4799329b20dSKunwar Grover tiledInits, newMaps, newIteratorTypes); 4804d67b278SJeff Niu IRMapping mapping; 4813310fe55SThomas Raoux op->getRegion(0).cloneInto(&genericOp.getRegion(), 4823310fe55SThomas Raoux genericOp.getRegion().begin(), mapping); 483b99d0b34SMaheshRavishankar return TilingResult{ 484b99d0b34SMaheshRavishankar {genericOp.getOperation()}, 485b99d0b34SMaheshRavishankar llvm::map_to_vector(genericOp->getResults(), 486d5f0969cSMaheshRavishankar [](OpResult r) -> Value { return r; }), 487d5f0969cSMaheshRavishankar generatedSlices}; 4883310fe55SThomas Raoux } 4893310fe55SThomas Raoux 490b99d0b34SMaheshRavishankar FailureOr<MergeResult> mergeReductions(Operation *op, OpBuilder &b, 491b99d0b34SMaheshRavishankar Location loc, ValueRange partialReduce, 4923310fe55SThomas Raoux ArrayRef<int> reductionDims) const { 4933310fe55SThomas Raoux auto linalgOp = cast<LinalgOp>(op); 494*91bbebc7SKunwar Grover 495*91bbebc7SKunwar Grover // Permute the reduction dims as permuted by the partial result map. 496*91bbebc7SKunwar Grover 4979329b20dSKunwar Grover int64_t numInits = linalgOp.getNumDpsInits(); 498*91bbebc7SKunwar Grover SmallVector<Operation *> mergeOperations; 499*91bbebc7SKunwar Grover SmallVector<Value> replacements; 500*91bbebc7SKunwar Grover for (int idx : llvm::seq(numInits)) { 501*91bbebc7SKunwar Grover // linalg.reduce's iteration space is the tiled result's iteration space 502*91bbebc7SKunwar Grover // (and not the tiled operation's iteration space). To account for this, 503*91bbebc7SKunwar Grover // permute the reduction dimensions based on the partial result map of the 504*91bbebc7SKunwar Grover // tiled result. 505*91bbebc7SKunwar Grover AffineMap partialMap = 506*91bbebc7SKunwar Grover getPartialResultAffineMap(linalgOp, reductionDims, idx); 507*91bbebc7SKunwar Grover SmallVector<int64_t> partialReductionDims; 508*91bbebc7SKunwar Grover for (auto [resultNum, dimExpr] : 509*91bbebc7SKunwar Grover llvm::enumerate(partialMap.getResults())) { 510*91bbebc7SKunwar Grover unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition(); 511*91bbebc7SKunwar Grover if (llvm::find(reductionDims, dim) != reductionDims.end()) { 512*91bbebc7SKunwar Grover partialReductionDims.push_back(resultNum); 513*91bbebc7SKunwar Grover } 514*91bbebc7SKunwar Grover } 515*91bbebc7SKunwar Grover 516*91bbebc7SKunwar Grover Value partialResult = partialReduce[idx]; 517*91bbebc7SKunwar Grover Value init = linalgOp.getDpsInits()[idx]; 518*91bbebc7SKunwar Grover 519*91bbebc7SKunwar Grover auto reduction = b.create<linalg::ReduceOp>( 520*91bbebc7SKunwar Grover loc, partialResult, init, partialReductionDims, 521*91bbebc7SKunwar Grover [&linalgOp, &idx](OpBuilder &b, Location loc, ValueRange inputs) { 5229329b20dSKunwar Grover // Get the combiner op. 5239329b20dSKunwar Grover SmallVector<Operation *, 4> combinerOps; 5249329b20dSKunwar Grover matchReduction(linalgOp.getRegionOutputArgs(), idx, combinerOps); 5259329b20dSKunwar Grover Operation *clonedReductionOp = b.clone(*combinerOps[0]); 5269329b20dSKunwar Grover // Combine the input at idx and output at numInits + idx. 527*91bbebc7SKunwar Grover clonedReductionOp->setOperand(0, inputs[0]); 528*91bbebc7SKunwar Grover clonedReductionOp->setOperand(1, inputs[1]); 529*91bbebc7SKunwar Grover b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0)); 5303310fe55SThomas Raoux }); 531*91bbebc7SKunwar Grover 532*91bbebc7SKunwar Grover mergeOperations.push_back(reduction); 533*91bbebc7SKunwar Grover replacements.push_back(reduction->getResult(0)); 534*91bbebc7SKunwar Grover } 535*91bbebc7SKunwar Grover 536*91bbebc7SKunwar Grover return MergeResult{mergeOperations, replacements}; 537*91bbebc7SKunwar Grover } 538*91bbebc7SKunwar Grover 539*91bbebc7SKunwar Grover LogicalResult getPartialResultTilePosition( 540*91bbebc7SKunwar Grover Operation *op, OpBuilder &b, unsigned resultNumber, 541*91bbebc7SKunwar Grover ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, 542*91bbebc7SKunwar Grover SmallVector<OpFoldResult> &resultOffsets, 543*91bbebc7SKunwar Grover SmallVector<OpFoldResult> &resultSizes, 544*91bbebc7SKunwar Grover ArrayRef<int> reductionDims) const { 545*91bbebc7SKunwar Grover auto linalgOp = cast<LinalgOp>(op); 546*91bbebc7SKunwar Grover 547*91bbebc7SKunwar Grover AffineMap partialMap = 548*91bbebc7SKunwar Grover getPartialResultAffineMap(linalgOp, reductionDims, resultNumber); 549*91bbebc7SKunwar Grover for (AffineExpr dimExpr : partialMap.getResults()) { 550*91bbebc7SKunwar Grover unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition(); 551*91bbebc7SKunwar Grover resultSizes.push_back(sizes[dim]); 552*91bbebc7SKunwar Grover 553*91bbebc7SKunwar Grover if (llvm::find(reductionDims, dim) != reductionDims.end()) { 554*91bbebc7SKunwar Grover // Reduction dims are reduced, and are always outputed in the same 555*91bbebc7SKunwar Grover // place. So use offset 0 for them. 556*91bbebc7SKunwar Grover resultOffsets.push_back(b.getIndexAttr(0)); 557*91bbebc7SKunwar Grover } else { 558*91bbebc7SKunwar Grover resultOffsets.push_back(offsets[dim]); 559*91bbebc7SKunwar Grover } 560*91bbebc7SKunwar Grover } 561*91bbebc7SKunwar Grover 562*91bbebc7SKunwar Grover return success(); 5633310fe55SThomas Raoux } 5643310fe55SThomas Raoux }; 5653310fe55SThomas Raoux 566cf6a7c19SMahesh Ravishankar } // namespace 567cf6a7c19SMahesh Ravishankar 5682f637fe7SMahesh Ravishankar template <typename OpType> 5692f637fe7SMahesh Ravishankar static void registerOne(MLIRContext *ctx) { 570cf6a7c19SMahesh Ravishankar OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx); 5713310fe55SThomas Raoux OpType::template attachInterface<LinalgOpPartialReductionInterface<OpType>>( 5723310fe55SThomas Raoux *ctx); 573cf6a7c19SMahesh Ravishankar } 574cf6a7c19SMahesh Ravishankar 575cf6a7c19SMahesh Ravishankar /// Variadic helper function. 5762f637fe7SMahesh Ravishankar template <typename... OpTypes> 5772f637fe7SMahesh Ravishankar static void registerAll(MLIRContext *ctx) { 57826d811b3SMarkus Böck (registerOne<OpTypes>(ctx), ...); 579cf6a7c19SMahesh Ravishankar } 580cf6a7c19SMahesh Ravishankar 581cf6a7c19SMahesh Ravishankar #define GET_OP_LIST 582cf6a7c19SMahesh Ravishankar 583cf6a7c19SMahesh Ravishankar void mlir::linalg::registerTilingInterfaceExternalModels( 584cf6a7c19SMahesh Ravishankar DialectRegistry ®istry) { 585cf6a7c19SMahesh Ravishankar registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) { 586cf6a7c19SMahesh Ravishankar registerOne<linalg::GenericOp>(ctx); 587cf6a7c19SMahesh Ravishankar registerAll< 588cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 589cf6a7c19SMahesh Ravishankar >(ctx); 590cf6a7c19SMahesh Ravishankar }); 591cf6a7c19SMahesh Ravishankar } 592