19e39a5d9SLei Zhang //===- Generalization.cpp - linalg named ops to generic ops --------------===// 29e39a5d9SLei Zhang // 39e39a5d9SLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 49e39a5d9SLei Zhang // See https://llvm.org/LICENSE.txt for license information. 59e39a5d9SLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 69e39a5d9SLei Zhang // 79e39a5d9SLei Zhang //===----------------------------------------------------------------------===// 89e39a5d9SLei Zhang // 99e39a5d9SLei Zhang // This file implements the Linalg generalization pass. It converts named 109e39a5d9SLei Zhang // Linalg ops to linalg.generic ops. 119e39a5d9SLei Zhang // 129e39a5d9SLei Zhang //===----------------------------------------------------------------------===// 139e39a5d9SLei Zhang 149e39a5d9SLei Zhang #include "PassDetail.h" 15b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h" 169e39a5d9SLei Zhang #include "mlir/Dialect/Linalg/Passes.h" 179e39a5d9SLei Zhang #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 189e39a5d9SLei Zhang #include "mlir/IR/AffineMap.h" 199e39a5d9SLei Zhang #include "mlir/IR/Attributes.h" 209e39a5d9SLei Zhang #include "mlir/IR/Builders.h" 214519ca3dSNicolas Vasilache #include "mlir/IR/ImplicitLocOpBuilder.h" 229e39a5d9SLei Zhang #include "mlir/IR/PatternMatch.h" 239e39a5d9SLei Zhang #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 249e39a5d9SLei Zhang #include "llvm/ADT/SmallVector.h" 259e39a5d9SLei Zhang #include "llvm/Support/Debug.h" 269e39a5d9SLei Zhang 279e39a5d9SLei Zhang #define DEBUG_TYPE "linalg-generalization" 289e39a5d9SLei Zhang 299e39a5d9SLei Zhang using namespace mlir; 305a451e48STobias Gysi using namespace mlir::linalg; 319e39a5d9SLei Zhang 32*9a7d111fSNicolas Vasilache static LogicalResult generalizeNamedOpPrecondition(Operation *op) { 33e826db62STobias Gysi LinalgOp namedOp = dyn_cast<LinalgOp>(op); 34e826db62STobias Gysi // Check if the operation is a LinalgOp but not a GenericOp. 35e826db62STobias Gysi if (!namedOp || isa<GenericOp>(op)) 36e826db62STobias Gysi return failure(); 37e826db62STobias Gysi // Check if the operation has a region builder. 38e826db62STobias Gysi if (!namedOp.getRegionBuilder()) 39e826db62STobias Gysi return failure(); 40e826db62STobias Gysi return success(); 41e826db62STobias Gysi } 42e826db62STobias Gysi 43*9a7d111fSNicolas Vasilache FailureOr<GenericOp> mlir::linalg::generalizeNamedOp(RewriterBase &rewriter, 44e826db62STobias Gysi LinalgOp namedOp) { 45*9a7d111fSNicolas Vasilache if (failed(generalizeNamedOpPrecondition(namedOp))) 46*9a7d111fSNicolas Vasilache return rewriter.notifyMatchFailure(namedOp, "preconditions not met"); 47*9a7d111fSNicolas Vasilache 48ad10d965STobias Gysi SmallVector<Value> inputOperands = namedOp.getInputOperands(); 49ad10d965STobias Gysi SmallVector<Value> outputOperands = namedOp.getOutputOperands(); 505a451e48STobias Gysi SmallVector<AffineMap> indexingMaps = namedOp.getIndexingMaps(); 515a451e48STobias Gysi SmallVector<StringRef> iterators = llvm::to_vector<4>( 525a451e48STobias Gysi namedOp.iterator_types().getAsValueRange<StringAttr>()); 535a451e48STobias Gysi SmallVector<RankedTensorType> resultTypes = namedOp.getOutputTensorTypes(); 545a451e48STobias Gysi SmallVector<Type> types(resultTypes.begin(), resultTypes.end()); 555a451e48STobias Gysi 56eaa52750STobias Gysi // All named ops have a region attached that can be inlined. 57eaa52750STobias Gysi assert(namedOp->getNumRegions() == 1 && 58eaa52750STobias Gysi "expect named op to have one region attached"); 59ad10d965STobias Gysi GenericOp genericOp = 60ad10d965STobias Gysi rewriter.create<GenericOp>(namedOp.getLoc(), types, inputOperands, 61ad10d965STobias Gysi outputOperands, indexingMaps, iterators); 625a451e48STobias Gysi rewriter.inlineRegionBefore(namedOp->getRegion(0), genericOp.region(), 635a451e48STobias Gysi genericOp.region().begin()); 64*9a7d111fSNicolas Vasilache rewriter.replaceOp(namedOp, genericOp->getResults()); 655a451e48STobias Gysi return genericOp; 665a451e48STobias Gysi } 675a451e48STobias Gysi 689e39a5d9SLei Zhang namespace { 699e39a5d9SLei Zhang 709e39a5d9SLei Zhang struct LinalgGeneralizationPass 719e39a5d9SLei Zhang : public LinalgGeneralizationBase<LinalgGeneralizationPass> { 729e39a5d9SLei Zhang void runOnFunction() override; 739e39a5d9SLei Zhang }; 749e39a5d9SLei Zhang 759e39a5d9SLei Zhang } // namespace 769e39a5d9SLei Zhang 779e39a5d9SLei Zhang void LinalgGeneralizationPass::runOnFunction() { 789e39a5d9SLei Zhang FuncOp func = getFunction(); 79dc4e913bSChris Lattner RewritePatternSet patterns(&getContext()); 805a451e48STobias Gysi populateLinalgNamedOpsGeneralizationPatterns(patterns); 81e21adfa3SRiver Riddle (void)applyPatternsAndFoldGreedily(func.getBody(), std::move(patterns)); 829e39a5d9SLei Zhang } 839e39a5d9SLei Zhang 849e39a5d9SLei Zhang void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns( 851fc096afSMehdi Amini RewritePatternSet &patterns, const LinalgTransformationFilter &marker) { 86e826db62STobias Gysi patterns.add<LinalgGeneralizationPattern>(patterns.getContext(), marker); 879e39a5d9SLei Zhang } 889e39a5d9SLei Zhang 899e39a5d9SLei Zhang std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgGeneralizationPass() { 909e39a5d9SLei Zhang return std::make_unique<LinalgGeneralizationPass>(); 919e39a5d9SLei Zhang } 92