xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- Generalization.cpp - linalg named ops to generic ops  --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg generalization pass. It converts named
10 // Linalg ops to linalg.generic ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Passes.h"
15 
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/Dialect/Linalg/IR/Linalg.h"
18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
19 #include "mlir/IR/AffineMap.h"
20 #include "mlir/IR/Attributes.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/ImplicitLocOpBuilder.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/Support/Debug.h"
27 
28 namespace mlir {
29 #define GEN_PASS_DEF_LINALGGENERALIZENAMEDOPSPASS
30 #include "mlir/Dialect/Linalg/Passes.h.inc"
31 } // namespace mlir
32 
33 #define DEBUG_TYPE "linalg-generalization"
34 
35 using namespace mlir;
36 using namespace mlir::linalg;
37 
38 static LogicalResult generalizeNamedOpPrecondition(LinalgOp linalgOp) {
39   // Bailout if `linalgOp` is already a generic or a linalg.map. We cannot
40   // trivially generalize a `linalg.map`, as it does not use the output as
41   // region arguments in the block.
42   if (isa<GenericOp>(linalgOp) || isa<MapOp>(linalgOp))
43     return failure();
44   // Check if the operation has exactly one region.
45   if (linalgOp->getNumRegions() != 1) {
46     assert(linalgOp->getNumRegions() == 0 && "op with multiple regions");
47     // TOD: Otherwise it needs to be built explicitly from the region builder.
48     return failure();
49   }
50   return success();
51 }
52 
53 FailureOr<GenericOp> mlir::linalg::generalizeNamedOp(RewriterBase &rewriter,
54                                                      LinalgOp linalgOp) {
55   if (failed(generalizeNamedOpPrecondition(linalgOp)))
56     return rewriter.notifyMatchFailure(linalgOp, "preconditions not met");
57 
58   SmallVector<Value> inputs = linalgOp.getDpsInputs();
59   ValueRange outputs = linalgOp.getDpsInits();
60   SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
61   SmallVector<utils::IteratorType> iterators = linalgOp.getIteratorTypesArray();
62   SmallVector<Type> resultTypes = linalgOp.hasPureTensorSemantics()
63                                       ? TypeRange(ValueRange(outputs))
64                                       : TypeRange{};
65 
66   // All named ops have a region attached that can be inlined.
67   assert(linalgOp->getNumRegions() == 1 &&
68          "expect named op to have one region attached");
69   GenericOp genericOp = rewriter.create<GenericOp>(
70       linalgOp.getLoc(), resultTypes, inputs, outputs, indexingMaps, iterators);
71   rewriter.inlineRegionBefore(linalgOp->getRegion(0), genericOp.getRegion(),
72                               genericOp.getRegion().begin());
73   rewriter.replaceOp(linalgOp, genericOp->getResults());
74   return genericOp;
75 }
76 
77 namespace {
78 
79 struct LinalgGeneralizeNamedOpsPass
80     : public impl::LinalgGeneralizeNamedOpsPassBase<
81           LinalgGeneralizeNamedOpsPass> {
82   using impl::LinalgGeneralizeNamedOpsPassBase<
83       LinalgGeneralizeNamedOpsPass>::LinalgGeneralizeNamedOpsPassBase;
84   void runOnOperation() override;
85 };
86 
87 } // namespace
88 
89 void LinalgGeneralizeNamedOpsPass::runOnOperation() {
90   RewritePatternSet patterns(&getContext());
91   populateLinalgNamedOpsGeneralizationPatterns(patterns);
92   (void)applyPatternsGreedily(getOperation(), std::move(patterns));
93 }
94 
95 void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
96     RewritePatternSet &patterns) {
97   patterns.add<LinalgGeneralizationPattern>(patterns.getContext());
98 }
99