1 //===- Generalization.cpp - linalg named ops to generic ops --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the Linalg generalization pass. It converts named 10 // Linalg ops to linalg.generic ops. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Passes.h" 15 16 #include "mlir/Dialect/Func/IR/FuncOps.h" 17 #include "mlir/Dialect/Linalg/IR/Linalg.h" 18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 19 #include "mlir/IR/AffineMap.h" 20 #include "mlir/IR/Attributes.h" 21 #include "mlir/IR/Builders.h" 22 #include "mlir/IR/ImplicitLocOpBuilder.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 25 #include "llvm/ADT/SmallVector.h" 26 #include "llvm/Support/Debug.h" 27 28 namespace mlir { 29 #define GEN_PASS_DEF_LINALGGENERALIZENAMEDOPSPASS 30 #include "mlir/Dialect/Linalg/Passes.h.inc" 31 } // namespace mlir 32 33 #define DEBUG_TYPE "linalg-generalization" 34 35 using namespace mlir; 36 using namespace mlir::linalg; 37 38 static LogicalResult generalizeNamedOpPrecondition(LinalgOp linalgOp) { 39 // Bailout if `linalgOp` is already a generic or a linalg.map. We cannot 40 // trivially generalize a `linalg.map`, as it does not use the output as 41 // region arguments in the block. 42 if (isa<GenericOp>(linalgOp) || isa<MapOp>(linalgOp)) 43 return failure(); 44 // Check if the operation has exactly one region. 45 if (linalgOp->getNumRegions() != 1) { 46 assert(linalgOp->getNumRegions() == 0 && "op with multiple regions"); 47 // TOD: Otherwise it needs to be built explicitly from the region builder. 48 return failure(); 49 } 50 return success(); 51 } 52 53 FailureOr<GenericOp> mlir::linalg::generalizeNamedOp(RewriterBase &rewriter, 54 LinalgOp linalgOp) { 55 if (failed(generalizeNamedOpPrecondition(linalgOp))) 56 return rewriter.notifyMatchFailure(linalgOp, "preconditions not met"); 57 58 SmallVector<Value> inputs = linalgOp.getDpsInputs(); 59 ValueRange outputs = linalgOp.getDpsInits(); 60 SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray(); 61 SmallVector<utils::IteratorType> iterators = linalgOp.getIteratorTypesArray(); 62 SmallVector<Type> resultTypes = linalgOp.hasPureTensorSemantics() 63 ? TypeRange(ValueRange(outputs)) 64 : TypeRange{}; 65 66 // All named ops have a region attached that can be inlined. 67 assert(linalgOp->getNumRegions() == 1 && 68 "expect named op to have one region attached"); 69 GenericOp genericOp = rewriter.create<GenericOp>( 70 linalgOp.getLoc(), resultTypes, inputs, outputs, indexingMaps, iterators); 71 rewriter.inlineRegionBefore(linalgOp->getRegion(0), genericOp.getRegion(), 72 genericOp.getRegion().begin()); 73 rewriter.replaceOp(linalgOp, genericOp->getResults()); 74 return genericOp; 75 } 76 77 namespace { 78 79 struct LinalgGeneralizeNamedOpsPass 80 : public impl::LinalgGeneralizeNamedOpsPassBase< 81 LinalgGeneralizeNamedOpsPass> { 82 using impl::LinalgGeneralizeNamedOpsPassBase< 83 LinalgGeneralizeNamedOpsPass>::LinalgGeneralizeNamedOpsPassBase; 84 void runOnOperation() override; 85 }; 86 87 } // namespace 88 89 void LinalgGeneralizeNamedOpsPass::runOnOperation() { 90 RewritePatternSet patterns(&getContext()); 91 populateLinalgNamedOpsGeneralizationPatterns(patterns); 92 (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 93 } 94 95 void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns( 96 RewritePatternSet &patterns) { 97 patterns.add<LinalgGeneralizationPattern>(patterns.getContext()); 98 } 99