19e39a5d9SLei Zhang //===- Generalization.cpp - linalg named ops to generic ops --------------===// 29e39a5d9SLei Zhang // 39e39a5d9SLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 49e39a5d9SLei Zhang // See https://llvm.org/LICENSE.txt for license information. 59e39a5d9SLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 69e39a5d9SLei Zhang // 79e39a5d9SLei Zhang //===----------------------------------------------------------------------===// 89e39a5d9SLei Zhang // 99e39a5d9SLei Zhang // This file implements the Linalg generalization pass. It converts named 109e39a5d9SLei Zhang // Linalg ops to linalg.generic ops. 119e39a5d9SLei Zhang // 129e39a5d9SLei Zhang //===----------------------------------------------------------------------===// 139e39a5d9SLei Zhang 14039b969bSMichele Scuttari #include "mlir/Dialect/Linalg/Passes.h" 1567d0d7acSMichele Scuttari 1667d0d7acSMichele Scuttari #include "mlir/Dialect/Func/IR/FuncOps.h" 1767d0d7acSMichele Scuttari #include "mlir/Dialect/Linalg/IR/Linalg.h" 189e39a5d9SLei Zhang #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 199e39a5d9SLei Zhang #include "mlir/IR/AffineMap.h" 209e39a5d9SLei Zhang #include "mlir/IR/Attributes.h" 219e39a5d9SLei Zhang #include "mlir/IR/Builders.h" 224519ca3dSNicolas Vasilache #include "mlir/IR/ImplicitLocOpBuilder.h" 239e39a5d9SLei Zhang #include "mlir/IR/PatternMatch.h" 249e39a5d9SLei Zhang #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 259e39a5d9SLei Zhang #include "llvm/ADT/SmallVector.h" 269e39a5d9SLei Zhang #include "llvm/Support/Debug.h" 279e39a5d9SLei Zhang 2867d0d7acSMichele Scuttari namespace mlir { 291e98d488SQuinn Dawkins #define GEN_PASS_DEF_LINALGGENERALIZENAMEDOPSPASS 3067d0d7acSMichele Scuttari #include "mlir/Dialect/Linalg/Passes.h.inc" 3167d0d7acSMichele Scuttari } // namespace mlir 3267d0d7acSMichele Scuttari 339e39a5d9SLei Zhang #define DEBUG_TYPE "linalg-generalization" 349e39a5d9SLei Zhang 359e39a5d9SLei Zhang using namespace mlir; 365a451e48STobias Gysi using namespace mlir::linalg; 379e39a5d9SLei Zhang 38c05db638SNicolas Vasilache static LogicalResult generalizeNamedOpPrecondition(LinalgOp linalgOp) { 3909d09fc3SLorenzo Chelini // Bailout if `linalgOp` is already a generic or a linalg.map. We cannot 4009d09fc3SLorenzo Chelini // trivially generalize a `linalg.map`, as it does not use the output as 4109d09fc3SLorenzo Chelini // region arguments in the block. 4209d09fc3SLorenzo Chelini if (isa<GenericOp>(linalgOp) || isa<MapOp>(linalgOp)) 43e826db62STobias Gysi return failure(); 4446cfdfb5SNicolas Vasilache // Check if the operation has exactly one region. 4546cfdfb5SNicolas Vasilache if (linalgOp->getNumRegions() != 1) { 4646cfdfb5SNicolas Vasilache assert(linalgOp->getNumRegions() == 0 && "op with multiple regions"); 4746cfdfb5SNicolas Vasilache // TOD: Otherwise it needs to be built explicitly from the region builder. 48e826db62STobias Gysi return failure(); 4946cfdfb5SNicolas Vasilache } 50e826db62STobias Gysi return success(); 51e826db62STobias Gysi } 52e826db62STobias Gysi 539a7d111fSNicolas Vasilache FailureOr<GenericOp> mlir::linalg::generalizeNamedOp(RewriterBase &rewriter, 54c05db638SNicolas Vasilache LinalgOp linalgOp) { 55c05db638SNicolas Vasilache if (failed(generalizeNamedOpPrecondition(linalgOp))) 56c05db638SNicolas Vasilache return rewriter.notifyMatchFailure(linalgOp, "preconditions not met"); 579a7d111fSNicolas Vasilache 580b2197b0SMatthias Springer SmallVector<Value> inputs = linalgOp.getDpsInputs(); 590b2197b0SMatthias Springer ValueRange outputs = linalgOp.getDpsInits(); 60d2c0572bSJacques Pienaar SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray(); 61e6598b05SOleg Shyshkov SmallVector<utils::IteratorType> iterators = linalgOp.getIteratorTypesArray(); 620a8e3dd4SMatthias Springer SmallVector<Type> resultTypes = linalgOp.hasPureTensorSemantics() 63a7cccb9cSAlexander Belyaev ? TypeRange(ValueRange(outputs)) 64a7cccb9cSAlexander Belyaev : TypeRange{}; 655a451e48STobias Gysi 66eaa52750STobias Gysi // All named ops have a region attached that can be inlined. 67c05db638SNicolas Vasilache assert(linalgOp->getNumRegions() == 1 && 68eaa52750STobias Gysi "expect named op to have one region attached"); 69a7cccb9cSAlexander Belyaev GenericOp genericOp = rewriter.create<GenericOp>( 70a7cccb9cSAlexander Belyaev linalgOp.getLoc(), resultTypes, inputs, outputs, indexingMaps, iterators); 71d3b3f765SJacques Pienaar rewriter.inlineRegionBefore(linalgOp->getRegion(0), genericOp.getRegion(), 72d3b3f765SJacques Pienaar genericOp.getRegion().begin()); 73c05db638SNicolas Vasilache rewriter.replaceOp(linalgOp, genericOp->getResults()); 745a451e48STobias Gysi return genericOp; 755a451e48STobias Gysi } 765a451e48STobias Gysi 779e39a5d9SLei Zhang namespace { 789e39a5d9SLei Zhang 791e98d488SQuinn Dawkins struct LinalgGeneralizeNamedOpsPass 801e98d488SQuinn Dawkins : public impl::LinalgGeneralizeNamedOpsPassBase< 811e98d488SQuinn Dawkins LinalgGeneralizeNamedOpsPass> { 821e98d488SQuinn Dawkins using impl::LinalgGeneralizeNamedOpsPassBase< 831e98d488SQuinn Dawkins LinalgGeneralizeNamedOpsPass>::LinalgGeneralizeNamedOpsPassBase; 8441574554SRiver Riddle void runOnOperation() override; 859e39a5d9SLei Zhang }; 869e39a5d9SLei Zhang 879e39a5d9SLei Zhang } // namespace 889e39a5d9SLei Zhang 891e98d488SQuinn Dawkins void LinalgGeneralizeNamedOpsPass::runOnOperation() { 90dc4e913bSChris Lattner RewritePatternSet patterns(&getContext()); 915a451e48STobias Gysi populateLinalgNamedOpsGeneralizationPatterns(patterns); 92*09dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 939e39a5d9SLei Zhang } 949e39a5d9SLei Zhang 959e39a5d9SLei Zhang void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns( 96bcfbf8ccSNicolas Vasilache RewritePatternSet &patterns) { 97bcfbf8ccSNicolas Vasilache patterns.add<LinalgGeneralizationPattern>(patterns.getContext()); 989e39a5d9SLei Zhang } 99