1 //===- Generalization.cpp - linalg named ops to generic ops --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the Linalg generalization pass. It converts named 10 // Linalg ops to linalg.generic ops. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Dialect/Linalg/IR/Linalg.h" 16 #include "mlir/Dialect/Linalg/Passes.h" 17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 18 #include "mlir/IR/AffineMap.h" 19 #include "mlir/IR/Attributes.h" 20 #include "mlir/IR/Builders.h" 21 #include "mlir/IR/ImplicitLocOpBuilder.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 24 #include "llvm/ADT/SmallVector.h" 25 #include "llvm/Support/Debug.h" 26 27 #define DEBUG_TYPE "linalg-generalization" 28 29 using namespace mlir; 30 using namespace mlir::linalg; 31 32 static LogicalResult generalizeNamedOpPrecondition(Operation *op) { 33 LinalgOp namedOp = dyn_cast<LinalgOp>(op); 34 // Check if the operation is a LinalgOp but not a GenericOp. 35 if (!namedOp || isa<GenericOp>(op)) 36 return failure(); 37 // Check if the operation has a region builder. 38 if (!namedOp.getRegionBuilder()) 39 return failure(); 40 return success(); 41 } 42 43 FailureOr<GenericOp> mlir::linalg::generalizeNamedOp(RewriterBase &rewriter, 44 LinalgOp namedOp) { 45 if (failed(generalizeNamedOpPrecondition(namedOp))) 46 return rewriter.notifyMatchFailure(namedOp, "preconditions not met"); 47 48 SmallVector<Value> inputOperands = namedOp.getInputOperands(); 49 SmallVector<Value> outputOperands = namedOp.getOutputOperands(); 50 SmallVector<AffineMap> indexingMaps = namedOp.getIndexingMaps(); 51 SmallVector<StringRef> iterators = llvm::to_vector<4>( 52 namedOp.iterator_types().getAsValueRange<StringAttr>()); 53 SmallVector<RankedTensorType> resultTypes = namedOp.getOutputTensorTypes(); 54 SmallVector<Type> types(resultTypes.begin(), resultTypes.end()); 55 56 // All named ops have a region attached that can be inlined. 57 assert(namedOp->getNumRegions() == 1 && 58 "expect named op to have one region attached"); 59 GenericOp genericOp = 60 rewriter.create<GenericOp>(namedOp.getLoc(), types, inputOperands, 61 outputOperands, indexingMaps, iterators); 62 rewriter.inlineRegionBefore(namedOp->getRegion(0), genericOp.region(), 63 genericOp.region().begin()); 64 rewriter.replaceOp(namedOp, genericOp->getResults()); 65 return genericOp; 66 } 67 68 namespace { 69 70 struct LinalgGeneralizationPass 71 : public LinalgGeneralizationBase<LinalgGeneralizationPass> { 72 void runOnFunction() override; 73 }; 74 75 } // namespace 76 77 void LinalgGeneralizationPass::runOnFunction() { 78 FuncOp func = getFunction(); 79 RewritePatternSet patterns(&getContext()); 80 populateLinalgNamedOpsGeneralizationPatterns(patterns); 81 (void)applyPatternsAndFoldGreedily(func.getBody(), std::move(patterns)); 82 } 83 84 void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns( 85 RewritePatternSet &patterns, const LinalgTransformationFilter &marker) { 86 patterns.add<LinalgGeneralizationPattern>(patterns.getContext(), marker); 87 } 88 89 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgGeneralizationPass() { 90 return std::make_unique<LinalgGeneralizationPass>(); 91 } 92