xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp (revision 9a7d111f4fb65ad7343dcbd4f35ee608100634e8)
19e39a5d9SLei Zhang //===- Generalization.cpp - linalg named ops to generic ops  --------------===//
29e39a5d9SLei Zhang //
39e39a5d9SLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
49e39a5d9SLei Zhang // See https://llvm.org/LICENSE.txt for license information.
59e39a5d9SLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
69e39a5d9SLei Zhang //
79e39a5d9SLei Zhang //===----------------------------------------------------------------------===//
89e39a5d9SLei Zhang //
99e39a5d9SLei Zhang // This file implements the Linalg generalization pass. It converts named
109e39a5d9SLei Zhang // Linalg ops to linalg.generic ops.
119e39a5d9SLei Zhang //
129e39a5d9SLei Zhang //===----------------------------------------------------------------------===//
139e39a5d9SLei Zhang 
149e39a5d9SLei Zhang #include "PassDetail.h"
15b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h"
169e39a5d9SLei Zhang #include "mlir/Dialect/Linalg/Passes.h"
179e39a5d9SLei Zhang #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
189e39a5d9SLei Zhang #include "mlir/IR/AffineMap.h"
199e39a5d9SLei Zhang #include "mlir/IR/Attributes.h"
209e39a5d9SLei Zhang #include "mlir/IR/Builders.h"
214519ca3dSNicolas Vasilache #include "mlir/IR/ImplicitLocOpBuilder.h"
229e39a5d9SLei Zhang #include "mlir/IR/PatternMatch.h"
239e39a5d9SLei Zhang #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
249e39a5d9SLei Zhang #include "llvm/ADT/SmallVector.h"
259e39a5d9SLei Zhang #include "llvm/Support/Debug.h"
269e39a5d9SLei Zhang 
279e39a5d9SLei Zhang #define DEBUG_TYPE "linalg-generalization"
289e39a5d9SLei Zhang 
299e39a5d9SLei Zhang using namespace mlir;
305a451e48STobias Gysi using namespace mlir::linalg;
319e39a5d9SLei Zhang 
32*9a7d111fSNicolas Vasilache static LogicalResult generalizeNamedOpPrecondition(Operation *op) {
33e826db62STobias Gysi   LinalgOp namedOp = dyn_cast<LinalgOp>(op);
34e826db62STobias Gysi   // Check if the operation is a LinalgOp but not a GenericOp.
35e826db62STobias Gysi   if (!namedOp || isa<GenericOp>(op))
36e826db62STobias Gysi     return failure();
37e826db62STobias Gysi   // Check if the operation has a region builder.
38e826db62STobias Gysi   if (!namedOp.getRegionBuilder())
39e826db62STobias Gysi     return failure();
40e826db62STobias Gysi   return success();
41e826db62STobias Gysi }
42e826db62STobias Gysi 
43*9a7d111fSNicolas Vasilache FailureOr<GenericOp> mlir::linalg::generalizeNamedOp(RewriterBase &rewriter,
44e826db62STobias Gysi                                                      LinalgOp namedOp) {
45*9a7d111fSNicolas Vasilache   if (failed(generalizeNamedOpPrecondition(namedOp)))
46*9a7d111fSNicolas Vasilache     return rewriter.notifyMatchFailure(namedOp, "preconditions not met");
47*9a7d111fSNicolas Vasilache 
48ad10d965STobias Gysi   SmallVector<Value> inputOperands = namedOp.getInputOperands();
49ad10d965STobias Gysi   SmallVector<Value> outputOperands = namedOp.getOutputOperands();
505a451e48STobias Gysi   SmallVector<AffineMap> indexingMaps = namedOp.getIndexingMaps();
515a451e48STobias Gysi   SmallVector<StringRef> iterators = llvm::to_vector<4>(
525a451e48STobias Gysi       namedOp.iterator_types().getAsValueRange<StringAttr>());
535a451e48STobias Gysi   SmallVector<RankedTensorType> resultTypes = namedOp.getOutputTensorTypes();
545a451e48STobias Gysi   SmallVector<Type> types(resultTypes.begin(), resultTypes.end());
555a451e48STobias Gysi 
56eaa52750STobias Gysi   // All named ops have a region attached that can be inlined.
57eaa52750STobias Gysi   assert(namedOp->getNumRegions() == 1 &&
58eaa52750STobias Gysi          "expect named op to have one region attached");
59ad10d965STobias Gysi   GenericOp genericOp =
60ad10d965STobias Gysi       rewriter.create<GenericOp>(namedOp.getLoc(), types, inputOperands,
61ad10d965STobias Gysi                                  outputOperands, indexingMaps, iterators);
625a451e48STobias Gysi   rewriter.inlineRegionBefore(namedOp->getRegion(0), genericOp.region(),
635a451e48STobias Gysi                               genericOp.region().begin());
64*9a7d111fSNicolas Vasilache   rewriter.replaceOp(namedOp, genericOp->getResults());
655a451e48STobias Gysi   return genericOp;
665a451e48STobias Gysi }
675a451e48STobias Gysi 
689e39a5d9SLei Zhang namespace {
699e39a5d9SLei Zhang 
709e39a5d9SLei Zhang struct LinalgGeneralizationPass
719e39a5d9SLei Zhang     : public LinalgGeneralizationBase<LinalgGeneralizationPass> {
729e39a5d9SLei Zhang   void runOnFunction() override;
739e39a5d9SLei Zhang };
749e39a5d9SLei Zhang 
759e39a5d9SLei Zhang } // namespace
769e39a5d9SLei Zhang 
779e39a5d9SLei Zhang void LinalgGeneralizationPass::runOnFunction() {
789e39a5d9SLei Zhang   FuncOp func = getFunction();
79dc4e913bSChris Lattner   RewritePatternSet patterns(&getContext());
805a451e48STobias Gysi   populateLinalgNamedOpsGeneralizationPatterns(patterns);
81e21adfa3SRiver Riddle   (void)applyPatternsAndFoldGreedily(func.getBody(), std::move(patterns));
829e39a5d9SLei Zhang }
839e39a5d9SLei Zhang 
849e39a5d9SLei Zhang void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
851fc096afSMehdi Amini     RewritePatternSet &patterns, const LinalgTransformationFilter &marker) {
86e826db62STobias Gysi   patterns.add<LinalgGeneralizationPattern>(patterns.getContext(), marker);
879e39a5d9SLei Zhang }
889e39a5d9SLei Zhang 
899e39a5d9SLei Zhang std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgGeneralizationPass() {
909e39a5d9SLei Zhang   return std::make_unique<LinalgGeneralizationPass>();
919e39a5d9SLei Zhang }
92