xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
19e39a5d9SLei Zhang //===- Generalization.cpp - linalg named ops to generic ops  --------------===//
29e39a5d9SLei Zhang //
39e39a5d9SLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
49e39a5d9SLei Zhang // See https://llvm.org/LICENSE.txt for license information.
59e39a5d9SLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
69e39a5d9SLei Zhang //
79e39a5d9SLei Zhang //===----------------------------------------------------------------------===//
89e39a5d9SLei Zhang //
99e39a5d9SLei Zhang // This file implements the Linalg generalization pass. It converts named
109e39a5d9SLei Zhang // Linalg ops to linalg.generic ops.
119e39a5d9SLei Zhang //
129e39a5d9SLei Zhang //===----------------------------------------------------------------------===//
139e39a5d9SLei Zhang 
14039b969bSMichele Scuttari #include "mlir/Dialect/Linalg/Passes.h"
1567d0d7acSMichele Scuttari 
1667d0d7acSMichele Scuttari #include "mlir/Dialect/Func/IR/FuncOps.h"
1767d0d7acSMichele Scuttari #include "mlir/Dialect/Linalg/IR/Linalg.h"
189e39a5d9SLei Zhang #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
199e39a5d9SLei Zhang #include "mlir/IR/AffineMap.h"
209e39a5d9SLei Zhang #include "mlir/IR/Attributes.h"
219e39a5d9SLei Zhang #include "mlir/IR/Builders.h"
224519ca3dSNicolas Vasilache #include "mlir/IR/ImplicitLocOpBuilder.h"
239e39a5d9SLei Zhang #include "mlir/IR/PatternMatch.h"
249e39a5d9SLei Zhang #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
259e39a5d9SLei Zhang #include "llvm/ADT/SmallVector.h"
269e39a5d9SLei Zhang #include "llvm/Support/Debug.h"
279e39a5d9SLei Zhang 
2867d0d7acSMichele Scuttari namespace mlir {
291e98d488SQuinn Dawkins #define GEN_PASS_DEF_LINALGGENERALIZENAMEDOPSPASS
3067d0d7acSMichele Scuttari #include "mlir/Dialect/Linalg/Passes.h.inc"
3167d0d7acSMichele Scuttari } // namespace mlir
3267d0d7acSMichele Scuttari 
339e39a5d9SLei Zhang #define DEBUG_TYPE "linalg-generalization"
349e39a5d9SLei Zhang 
359e39a5d9SLei Zhang using namespace mlir;
365a451e48STobias Gysi using namespace mlir::linalg;
379e39a5d9SLei Zhang 
38c05db638SNicolas Vasilache static LogicalResult generalizeNamedOpPrecondition(LinalgOp linalgOp) {
3909d09fc3SLorenzo Chelini   // Bailout if `linalgOp` is already a generic or a linalg.map. We cannot
4009d09fc3SLorenzo Chelini   // trivially generalize a `linalg.map`, as it does not use the output as
4109d09fc3SLorenzo Chelini   // region arguments in the block.
4209d09fc3SLorenzo Chelini   if (isa<GenericOp>(linalgOp) || isa<MapOp>(linalgOp))
43e826db62STobias Gysi     return failure();
4446cfdfb5SNicolas Vasilache   // Check if the operation has exactly one region.
4546cfdfb5SNicolas Vasilache   if (linalgOp->getNumRegions() != 1) {
4646cfdfb5SNicolas Vasilache     assert(linalgOp->getNumRegions() == 0 && "op with multiple regions");
4746cfdfb5SNicolas Vasilache     // TOD: Otherwise it needs to be built explicitly from the region builder.
48e826db62STobias Gysi     return failure();
4946cfdfb5SNicolas Vasilache   }
50e826db62STobias Gysi   return success();
51e826db62STobias Gysi }
52e826db62STobias Gysi 
539a7d111fSNicolas Vasilache FailureOr<GenericOp> mlir::linalg::generalizeNamedOp(RewriterBase &rewriter,
54c05db638SNicolas Vasilache                                                      LinalgOp linalgOp) {
55c05db638SNicolas Vasilache   if (failed(generalizeNamedOpPrecondition(linalgOp)))
56c05db638SNicolas Vasilache     return rewriter.notifyMatchFailure(linalgOp, "preconditions not met");
579a7d111fSNicolas Vasilache 
580b2197b0SMatthias Springer   SmallVector<Value> inputs = linalgOp.getDpsInputs();
590b2197b0SMatthias Springer   ValueRange outputs = linalgOp.getDpsInits();
60d2c0572bSJacques Pienaar   SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
61e6598b05SOleg Shyshkov   SmallVector<utils::IteratorType> iterators = linalgOp.getIteratorTypesArray();
620a8e3dd4SMatthias Springer   SmallVector<Type> resultTypes = linalgOp.hasPureTensorSemantics()
63a7cccb9cSAlexander Belyaev                                       ? TypeRange(ValueRange(outputs))
64a7cccb9cSAlexander Belyaev                                       : TypeRange{};
655a451e48STobias Gysi 
66eaa52750STobias Gysi   // All named ops have a region attached that can be inlined.
67c05db638SNicolas Vasilache   assert(linalgOp->getNumRegions() == 1 &&
68eaa52750STobias Gysi          "expect named op to have one region attached");
69a7cccb9cSAlexander Belyaev   GenericOp genericOp = rewriter.create<GenericOp>(
70a7cccb9cSAlexander Belyaev       linalgOp.getLoc(), resultTypes, inputs, outputs, indexingMaps, iterators);
71d3b3f765SJacques Pienaar   rewriter.inlineRegionBefore(linalgOp->getRegion(0), genericOp.getRegion(),
72d3b3f765SJacques Pienaar                               genericOp.getRegion().begin());
73c05db638SNicolas Vasilache   rewriter.replaceOp(linalgOp, genericOp->getResults());
745a451e48STobias Gysi   return genericOp;
755a451e48STobias Gysi }
765a451e48STobias Gysi 
779e39a5d9SLei Zhang namespace {
789e39a5d9SLei Zhang 
791e98d488SQuinn Dawkins struct LinalgGeneralizeNamedOpsPass
801e98d488SQuinn Dawkins     : public impl::LinalgGeneralizeNamedOpsPassBase<
811e98d488SQuinn Dawkins           LinalgGeneralizeNamedOpsPass> {
821e98d488SQuinn Dawkins   using impl::LinalgGeneralizeNamedOpsPassBase<
831e98d488SQuinn Dawkins       LinalgGeneralizeNamedOpsPass>::LinalgGeneralizeNamedOpsPassBase;
8441574554SRiver Riddle   void runOnOperation() override;
859e39a5d9SLei Zhang };
869e39a5d9SLei Zhang 
879e39a5d9SLei Zhang } // namespace
889e39a5d9SLei Zhang 
891e98d488SQuinn Dawkins void LinalgGeneralizeNamedOpsPass::runOnOperation() {
90dc4e913bSChris Lattner   RewritePatternSet patterns(&getContext());
915a451e48STobias Gysi   populateLinalgNamedOpsGeneralizationPatterns(patterns);
92*09dfc571SJacques Pienaar   (void)applyPatternsGreedily(getOperation(), std::move(patterns));
939e39a5d9SLei Zhang }
949e39a5d9SLei Zhang 
959e39a5d9SLei Zhang void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
96bcfbf8ccSNicolas Vasilache     RewritePatternSet &patterns) {
97bcfbf8ccSNicolas Vasilache   patterns.add<LinalgGeneralizationPattern>(patterns.getContext());
989e39a5d9SLei Zhang }
99