//===- TilingInterfaceImpl.cpp - Implementation of TilingInterface -------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Interfaces/TilingInterface.h" #include using namespace mlir; using namespace mlir::linalg; //===----------------------------------------------------------------------===// // Utility methods for implementation of Tiling Interface for Linalg ops //===----------------------------------------------------------------------===// /// Return the SSA values that represent the data point accessed using a given /// `indexingMap` for a given point in the iteration space represented by `ivs`. static SmallVector getIndicesForAccess(OpBuilder &b, Location loc, AffineMap indexingMap, ValueRange ivs) { SmallVector indices; indices.reserve(indexingMap.getNumResults()); for (auto result : indexingMap.getResults()) { AffineMap m = AffineMap::get(indexingMap.getNumDims(), indexingMap.getNumSymbols(), result); Value v = b.create(loc, m, ivs); indices.push_back(v); } return indices; } /// Method to inline the payload of a `linalgOp` given the iteration space /// point and values for the arguments of the payload. static LogicalResult inlinePayload(OpBuilder &b, LinalgOp linalgOp, ValueRange ivs, ValueRange argValues) { Block *body = linalgOp.getBlock(); IRMapping map; map.map(body->getArguments(), argValues); for (auto &op : body->without_terminator()) { if (auto indexOp = dyn_cast(&op)) { map.map(indexOp.getResult(), ivs[indexOp.getDim()]); continue; } b.clone(op, map); } Operation *terminator = body->getTerminator(); Location loc = terminator->getLoc(); for (const auto &operand : llvm::enumerate(terminator->getOperands())) { Value toStore = map.lookupOrDefault(operand.value()); OpOperand *storeInto = linalgOp.getDpsInitOperand(operand.index()); auto indices = getIndicesForAccess( b, loc, linalgOp.getMatchingIndexingMap(storeInto), ivs); b.create( loc, toStore, linalgOp.getDpsInitOperand(operand.index())->get(), indices); } return success(); } //===----------------------------------------------------------------------===// // External Model for implementing `TilingInterface` for `LinalgOp`s. //===----------------------------------------------------------------------===// namespace { /// External model implementation of TilingInterface for LinalgOps. An external /// model implementation is used for now till the use of `TilingInterface` is /// on-par with the current Linalg tiling + fusion patterns. Once it is /// maybe possible to move this into the op-definition (though there are /// advantages to leaving it as an external model) template struct LinalgOpTilingInterface : public TilingInterface::ExternalModel, LinalgOpTy> { /// Return the loop iterator type. SmallVector getLoopIteratorTypes(Operation *op) const { LinalgOpTy concreteOp = cast(op); return concreteOp.getIteratorTypesArray(); } /// Return the iteration domain range. SmallVector getIterationDomain(Operation *op, OpBuilder &b) const { OpBuilder::InsertionGuard g(b); b.setInsertionPoint(op); Location loc = op->getLoc(); LinalgOp linalgOp = cast(op); SmallVector allShapesSizes = linalgOp.createFlatListOfOperandDims(b, loc); AffineMap map = linalgOp.getShapesToLoopsMap(); return llvm::to_vector( llvm::map_range(map.getResults(), [&](AffineExpr loopExpr) { OpFoldResult ofr = affine::makeComposedFoldedAffineApply( b, loc, loopExpr, allShapesSizes); return Range{b.getIndexAttr(0), ofr, b.getIndexAttr(1)}; })); } /// Instantiate the tiled implementation of the operation. FailureOr getTiledImplementation(Operation *op, OpBuilder &b, ArrayRef offsets, ArrayRef sizes) const { // Leave the `sizeBounds` value empty. That is only needed when the `sizes` // specified could lead to out of bounds accesses. Location loc = op->getLoc(); LinalgOp linalgOp = cast(op); SmallVector valuesToTile = linalgOp->getOperands(); SmallVector tiledOperands = makeTiledShapes( b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true); SmallVector generatedSlices = llvm::map_to_vector( llvm::make_filter_range( tiledOperands, [](Value v) -> bool { return isa_and_nonnull( v.getDefiningOp()); }), [](Value v) -> Operation * { return v.getDefiningOp(); }); SmallVector resultTensorTypes = getTensorOutputTypes(linalgOp, tiledOperands); Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands); offsetIndices(b, cast(tiledOp), offsets); return TilingResult{ {tiledOp}, SmallVector(tiledOp->getResults()), generatedSlices}; } /// Utility to fetch the offsets and sizes when applied as per the indexing /// map of the linalg op. This helps in fusing the linalg op as a consumer of /// a given slice op. void getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap, ArrayRef offsets, ArrayRef sizes, SmallVectorImpl &mappedOffsets, SmallVectorImpl &mappedSizes) const { unsigned numLoops = linalgOp.getNumLoops(); auto tilingInterfaceOp = cast(linalgOp.getOperation()); mappedOffsets.resize(numLoops); mappedSizes.resize(numLoops); if (!indexingMap.isPermutation()) { SmallVector iterationDomain = tilingInterfaceOp.getIterationDomain(b); for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) { mappedOffsets[index] = value.offset; mappedSizes[index] = value.size; } } for (const auto &&[index, value] : llvm::enumerate(indexingMap.getResults())) { unsigned dimPosition = cast(value).getPosition(); mappedOffsets[dimPosition] = offsets[index]; mappedSizes[dimPosition] = sizes[index]; } } /// Method to return the position of the result tile computed by the tiled /// operation. LogicalResult getIterationDomainTileFromOperandTile( Operation *op, OpBuilder &b, unsigned operandNumber, ArrayRef offsets, ArrayRef sizes, SmallVectorImpl &iterDomainOffsets, SmallVectorImpl &iterDomainSizes) const { auto linalgOp = cast(op); // Check that the indexing map used for the operand is a projected // permutation. This could be relaxed with a more general approach that can // map the offsets and sizes from the operand to iteration space tiles // (filling in full extent for dimensions not used to access the result). AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber)); if (!indexingMap.isProjectedPermutation()) { return op->emitError() << "unhandled get iter domain position when operand is not " "accessed using a permuted projection"; } getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes, iterDomainOffsets, iterDomainSizes); return success(); } /// Return the details of the output tile generated by the tiled /// implementation. LogicalResult getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, ArrayRef offsets, ArrayRef sizes, SmallVector &resultOffsets, SmallVector &resultSizes) const { Location loc = op->getLoc(); LinalgOp linalgOp = cast(op); AffineExpr d0; bindDims(b.getContext(), d0); SmallVector subShapeSizes = llvm::to_vector(llvm::map_range(sizes, [&](OpFoldResult ofr) { return affine::makeComposedFoldedAffineApply(b, loc, d0 - 1, ofr); })); OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber); SliceParameters sliceParams = computeSliceParameters( b, loc, outOperand->get(), sizes, linalgOp.getMatchingIndexingMap(outOperand), offsets, /*ubs*/ {}, subShapeSizes, true); resultOffsets = sliceParams.offsets; resultSizes = sliceParams.sizes; return success(); } LogicalResult getIterationDomainTileFromResultTile( Operation *op, OpBuilder &b, unsigned resultNumber, ArrayRef offsets, ArrayRef sizes, SmallVectorImpl &iterDomainOffsets, SmallVectorImpl &iterDomainSizes) const { auto linalgOp = cast(op); // Check that the indexing map used for the output is a projected // permutation. This could be relaxed with a more general approach that can // map the offsets and sizes from the result to iteration space tiles // (filling in full extent for dimensions not used to access the result). AffineMap indexingMap = linalgOp.getIndexingMapMatchingResult(op->getResult(resultNumber)); if (!indexingMap.isProjectedPermutation()) { return op->emitOpError( "unhandled tiled implementation generation when result is not " "accessed using a permuted projection"); } getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes, iterDomainOffsets, iterDomainSizes); return success(); } FailureOr generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, ArrayRef offsets, ArrayRef sizes) const { SmallVector mappedOffsets, mappedSizes; if (failed(getIterationDomainTileFromResultTile( op, b, resultNumber, offsets, sizes, mappedOffsets, mappedSizes))) { return failure(); } auto tilingInterfaceOp = cast(op); FailureOr tilingResult = tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes); if (failed(tilingResult)) return failure(); if (tilingResult->tiledOps.size() != 1) return op->emitOpError("failed to generate tiled implementation"); return TilingResult{ tilingResult->tiledOps, SmallVector{tilingResult->tiledValues[resultNumber]}, tilingResult->generatedSlices}; } /// Method to generate the tiled implementation of an operation from the tile /// of the operand. FailureOr getTiledImplementationFromOperandTile( Operation *op, OpBuilder &b, unsigned operandNumber, ArrayRef offsets, ArrayRef sizes) const { SmallVector mappedOffsets, mappedSizes; if (failed(getIterationDomainTileFromOperandTile( op, b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) { return failure(); } return getTiledImplementation(op, b, mappedOffsets, mappedSizes); } LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder, Location loc, ValueRange ivs) const { auto linalgOp = cast(op); if (!linalgOp.hasPureBufferSemantics()) return op->emitOpError("expected operation to have buffer semantics"); SmallVector indexedValues; indexedValues.reserve(linalgOp->getNumOperands()); Location linalgOpLoc = op->getLoc(); /// Load the data corresponding to the block arguments that /// represent input operands. for (OpOperand &operand : linalgOp->getOpOperands()) { if (!linalgOp.payloadUsesValueFromOperand(&operand)) { indexedValues.push_back(nullptr); continue; } if (linalgOp.isScalar(&operand)) { indexedValues.push_back(operand.get()); continue; } SmallVector indices = getIndicesForAccess( builder, linalgOpLoc, linalgOp.getMatchingIndexingMap(&operand), ivs); Value load = builder.create(linalgOpLoc, operand.get(), indices); indexedValues.push_back(load); } /// Inline the op payload and store the result. return inlinePayload(builder, linalgOp, ivs, indexedValues); } }; //===----------------------------------------------------------------------===// // External Model for implementing `PartialReductionInterface` for `LinalgOp`s. //===----------------------------------------------------------------------===// /// Return an AffineMap for a partial result for the given result number, /// assuming the partial tiling strategy is outer-reduction loop + /// inner-parallel tile. The returned AffineMap can be used as the replacement /// AffineMap for the inner-parallel tile linalg op for the given result number. /// /// The new AffineMap is the old AffineMap with reduction dimensions appended /// at end. static AffineMap getPartialResultAffineMap(LinalgOp linalgOp, ArrayRef reductionDims, unsigned resultNumber) { AffineMap map = linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(resultNumber)); for (int redPos : reductionDims) { map = map.insertResult(getAffineDimExpr(redPos, linalgOp.getContext()), map.getNumResults()); } return map; } /// External model implementation of PartialReductionInterface for /// LinalgOps. template struct LinalgOpPartialReductionInterface : public PartialReductionOpInterface::ExternalModel< LinalgOpPartialReductionInterface, LinalgOpTy> { FailureOr> generateInitialTensorForPartialReduction( Operation *op, OpBuilder &b, Location loc, ArrayRef sizes, ArrayRef reductionDims) const { auto linalgOp = cast(op); OpBuilder::InsertionGuard guard(b); if (linalgOp.hasPureBufferSemantics()) return op->emitOpError("expected operation to have tensor semantics"); // LinalgOp implements TilingInterface. auto tilingInterfaceOp = cast(linalgOp.getOperation()); SmallVector shape = llvm::map_to_vector(tilingInterfaceOp.getIterationDomain(b), [](Range x) { return x.size; }); SmallVector tiledShape; for (auto [tileSize, dimSize] : llvm::zip_equal(sizes, shape)) { if (isZeroIndex(tileSize)) { tiledShape.push_back(dimSize); } else { tiledShape.push_back(tileSize); } } SmallVector inits; for (int initIdx = 0, e = linalgOp.getNumDpsInits(); initIdx < e; ++initIdx) { SmallVector combinerOps; if (!matchReduction(linalgOp.getRegionOutputArgs(), initIdx, combinerOps) || combinerOps.size() != 1) return op->emitOpError("Failed to anaysis the reduction operation."); Operation *reductionOp = combinerOps[0]; std::optional identity = arith::getNeutralElement(reductionOp); if (!identity.has_value()) return op->emitOpError( "Failed to get an identity value for the reduction operation."); // Append the new partial result dimensions. AffineMap partialMap = getPartialResultAffineMap(linalgOp, reductionDims, initIdx); SmallVector partialResultShape; for (AffineExpr dimExpr : partialMap.getResults()) { auto dim = cast(dimExpr); partialResultShape.push_back(tiledShape[dim.getPosition()]); } Type elType = getElementTypeOrSelf(linalgOp->getResult(initIdx).getType()); Value emptyTensor = b.create(loc, partialResultShape, elType); Value constantOp = b.create(loc, *identity); auto identityTensor = b.create(loc, constantOp, emptyTensor); inits.push_back(identityTensor.getResult(0)); } return inits; } FailureOr tileToPartialReduction(Operation *op, OpBuilder &b, Location loc, ValueRange init, ArrayRef offsets, ArrayRef sizes, ArrayRef reductionDims) const { OpBuilder::InsertionGuard guard(b); auto linalgOp = cast(op); // Step 1. Extend init maps to have reduction dimension dims, since we // are converting them to parallel dimensions. SmallVector newInitMaps; newInitMaps.reserve(linalgOp.getNumDpsInits()); for (int idx : llvm::seq(0, linalgOp.getNumDpsInits())) { // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace // this with a for range loop when we have it. AffineMap newMap = getPartialResultAffineMap(linalgOp, reductionDims, idx); newInitMaps.push_back(newMap); } // Step 2a: Extract a slice of the input operands. SmallVector tiledInputs = makeTiledShapes( b, loc, linalgOp, linalgOp.getDpsInputs(), offsets, sizes, {}, true); SmallVector generatedSlices = llvm::map_to_vector( llvm::make_filter_range( tiledInputs, [](Value v) -> bool { return v.getDefiningOp(); }), [](Value v) -> Operation * { return v.getDefiningOp(); }); // Step 2b: Extract a slice of the init operands. SmallVector tiledInits; for (auto [valueMap, valueToTile] : llvm::zip_equal(newInitMaps, init)) { int64_t initRank = valueMap.getNumResults(); SmallVector initOffset(initRank, b.getIndexAttr(0)); SmallVector initStride(initRank, b.getIndexAttr(1)); SmallVector initSizes; for (AffineExpr dimExpr : valueMap.getResults()) { auto dim = cast(dimExpr); initSizes.push_back(sizes[dim.getPosition()]); } // TODO: Use SubsetExtractOpInterface here once available. auto extractSlice = b.create( loc, valueToTile, initOffset, initSizes, initStride); tiledInits.push_back(extractSlice); generatedSlices.push_back(extractSlice); } // Update the indexing maps. SmallVector newMaps = linalgOp.getIndexingMapsArray(); // Change the init maps. for (int idx : llvm::seq(0, linalgOp.getNumDpsInits())) { // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace // this with a for range loop when we have it. OpOperand *initOperand = linalgOp.getDpsInitOperand(idx); int64_t mapIdx = linalgOp.getIndexingMapIndex(initOperand); newMaps[mapIdx] = newInitMaps[idx]; } // Step 3. Change the reduction dim iterator types. SmallVector newIteratorTypes = linalgOp.getIteratorTypesArray(); for (int dim : reductionDims) newIteratorTypes[dim] = utils::IteratorType::parallel; // Step 4. Create the new generic op. auto genericOp = b.create(loc, ValueRange(tiledInits).getTypes(), tiledInputs, tiledInits, newMaps, newIteratorTypes); IRMapping mapping; op->getRegion(0).cloneInto(&genericOp.getRegion(), genericOp.getRegion().begin(), mapping); return TilingResult{ {genericOp.getOperation()}, llvm::map_to_vector(genericOp->getResults(), [](OpResult r) -> Value { return r; }), generatedSlices}; } FailureOr mergeReductions(Operation *op, OpBuilder &b, Location loc, ValueRange partialReduce, ArrayRef reductionDims) const { auto linalgOp = cast(op); // Permute the reduction dims as permuted by the partial result map. int64_t numInits = linalgOp.getNumDpsInits(); SmallVector mergeOperations; SmallVector replacements; for (int idx : llvm::seq(numInits)) { // linalg.reduce's iteration space is the tiled result's iteration space // (and not the tiled operation's iteration space). To account for this, // permute the reduction dimensions based on the partial result map of the // tiled result. AffineMap partialMap = getPartialResultAffineMap(linalgOp, reductionDims, idx); SmallVector partialReductionDims; for (auto [resultNum, dimExpr] : llvm::enumerate(partialMap.getResults())) { unsigned dim = cast(dimExpr).getPosition(); if (llvm::find(reductionDims, dim) != reductionDims.end()) { partialReductionDims.push_back(resultNum); } } Value partialResult = partialReduce[idx]; Value init = linalgOp.getDpsInits()[idx]; auto reduction = b.create( loc, partialResult, init, partialReductionDims, [&linalgOp, &idx](OpBuilder &b, Location loc, ValueRange inputs) { // Get the combiner op. SmallVector combinerOps; matchReduction(linalgOp.getRegionOutputArgs(), idx, combinerOps); Operation *clonedReductionOp = b.clone(*combinerOps[0]); // Combine the input at idx and output at numInits + idx. clonedReductionOp->setOperand(0, inputs[0]); clonedReductionOp->setOperand(1, inputs[1]); b.create(loc, clonedReductionOp->getResult(0)); }); mergeOperations.push_back(reduction); replacements.push_back(reduction->getResult(0)); } return MergeResult{mergeOperations, replacements}; } LogicalResult getPartialResultTilePosition( Operation *op, OpBuilder &b, unsigned resultNumber, ArrayRef offsets, ArrayRef sizes, SmallVector &resultOffsets, SmallVector &resultSizes, ArrayRef reductionDims) const { auto linalgOp = cast(op); AffineMap partialMap = getPartialResultAffineMap(linalgOp, reductionDims, resultNumber); for (AffineExpr dimExpr : partialMap.getResults()) { unsigned dim = cast(dimExpr).getPosition(); resultSizes.push_back(sizes[dim]); if (llvm::find(reductionDims, dim) != reductionDims.end()) { // Reduction dims are reduced, and are always outputed in the same // place. So use offset 0 for them. resultOffsets.push_back(b.getIndexAttr(0)); } else { resultOffsets.push_back(offsets[dim]); } } return success(); } }; } // namespace template static void registerOne(MLIRContext *ctx) { OpType::template attachInterface>(*ctx); OpType::template attachInterface>( *ctx); } /// Variadic helper function. template static void registerAll(MLIRContext *ctx) { (registerOne(ctx), ...); } #define GET_OP_LIST void mlir::linalg::registerTilingInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) { registerOne(ctx); registerAll< #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" >(ctx); }); }