1 //===- Generalization.cpp - linalg named ops to generic ops --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the Linalg generalization pass. It converts named 10 // Linalg ops to linalg.generic ops. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Dialect/Linalg/IR/Linalg.h" 16 #include "mlir/Dialect/Linalg/Passes.h" 17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 18 #include "mlir/IR/AffineMap.h" 19 #include "mlir/IR/Attributes.h" 20 #include "mlir/IR/Builders.h" 21 #include "mlir/IR/ImplicitLocOpBuilder.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 24 #include "llvm/ADT/SmallVector.h" 25 #include "llvm/Support/Debug.h" 26 27 #define DEBUG_TYPE "linalg-generalization" 28 29 using namespace mlir; 30 using namespace mlir::linalg; 31 32 static LogicalResult generalizeNamedOpPrecondition(LinalgOp linalgOp) { 33 // Check if the operation is a LinalgOp but not a GenericOp. 34 if (isa<GenericOp>(linalgOp)) 35 return failure(); 36 // Check if the operation has a region builder. 37 if (!linalgOp.getRegionBuilder()) 38 return failure(); 39 return success(); 40 } 41 42 FailureOr<GenericOp> mlir::linalg::generalizeNamedOp(RewriterBase &rewriter, 43 LinalgOp linalgOp) { 44 if (failed(generalizeNamedOpPrecondition(linalgOp))) 45 return rewriter.notifyMatchFailure(linalgOp, "preconditions not met"); 46 47 SmallVector<Value> inputOperands = linalgOp.getInputOperands(); 48 SmallVector<Value> outputOperands = linalgOp.getOutputOperands(); 49 SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMaps(); 50 SmallVector<StringRef> iterators = llvm::to_vector<4>( 51 linalgOp.iterator_types().getAsValueRange<StringAttr>()); 52 SmallVector<RankedTensorType> resultTypes = linalgOp.getOutputTensorTypes(); 53 SmallVector<Type> types(resultTypes.begin(), resultTypes.end()); 54 55 // All named ops have a region attached that can be inlined. 56 assert(linalgOp->getNumRegions() == 1 && 57 "expect named op to have one region attached"); 58 GenericOp genericOp = 59 rewriter.create<GenericOp>(linalgOp.getLoc(), types, inputOperands, 60 outputOperands, indexingMaps, iterators); 61 rewriter.inlineRegionBefore(linalgOp->getRegion(0), genericOp.region(), 62 genericOp.region().begin()); 63 rewriter.replaceOp(linalgOp, genericOp->getResults()); 64 return genericOp; 65 } 66 67 namespace { 68 69 struct LinalgGeneralizationPass 70 : public LinalgGeneralizationBase<LinalgGeneralizationPass> { 71 void runOnFunction() override; 72 }; 73 74 } // namespace 75 76 void LinalgGeneralizationPass::runOnFunction() { 77 FuncOp func = getFunction(); 78 RewritePatternSet patterns(&getContext()); 79 populateLinalgNamedOpsGeneralizationPatterns(patterns); 80 (void)applyPatternsAndFoldGreedily(func.getBody(), std::move(patterns)); 81 } 82 83 void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns( 84 RewritePatternSet &patterns, const LinalgTransformationFilter &marker) { 85 patterns.add<LinalgGeneralizationPattern>(patterns.getContext(), marker); 86 } 87 88 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgGeneralizationPass() { 89 return std::make_unique<LinalgGeneralizationPass>(); 90 } 91