xref: /llvm-project/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp (revision 5550c821897ab77e664977121a0e90ad5be1ff59)
1 //===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
10 
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/Func/IR/FuncOps.h"
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/Dialect/Linalg/IR/Linalg.h"
15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/SCF/IR/SCF.h"
18 #include "mlir/Pass/Pass.h"
19 
20 namespace mlir {
21 #define GEN_PASS_DEF_CONVERTLINALGTOSTANDARD
22 #include "mlir/Conversion/Passes.h.inc"
23 } // namespace mlir
24 
25 using namespace mlir;
26 using namespace mlir::linalg;
27 
makeStridedLayoutDynamic(MemRefType type)28 static MemRefType makeStridedLayoutDynamic(MemRefType type) {
29   return MemRefType::Builder(type).setLayout(StridedLayoutAttr::get(
30       type.getContext(), ShapedType::kDynamic,
31       SmallVector<int64_t>(type.getRank(), ShapedType::kDynamic)));
32 }
33 
34 /// Helper function to extract the operand types that are passed to the
35 /// generated CallOp. MemRefTypes have their layout canonicalized since the
36 /// information is not used in signature generation.
37 /// Note that static size information is not modified.
extractOperandTypes(Operation * op)38 static SmallVector<Type, 4> extractOperandTypes(Operation *op) {
39   SmallVector<Type, 4> result;
40   result.reserve(op->getNumOperands());
41   for (auto type : op->getOperandTypes()) {
42     // The underlying descriptor type (e.g. LLVM) does not have layout
43     // information. Canonicalizing the type at the level of std when going into
44     // a library call avoids needing to introduce DialectCastOp.
45     if (auto memrefType = dyn_cast<MemRefType>(type))
46       result.push_back(makeStridedLayoutDynamic(memrefType));
47     else
48       result.push_back(type);
49   }
50   return result;
51 }
52 
53 // Get a SymbolRefAttr containing the library function name for the LinalgOp.
54 // If the library function does not exist, insert a declaration.
55 static FailureOr<FlatSymbolRefAttr>
getLibraryCallSymbolRef(Operation * op,PatternRewriter & rewriter)56 getLibraryCallSymbolRef(Operation *op, PatternRewriter &rewriter) {
57   auto linalgOp = cast<LinalgOp>(op);
58   auto fnName = linalgOp.getLibraryCallName();
59   if (fnName.empty())
60     return rewriter.notifyMatchFailure(op, "No library call defined for: ");
61 
62   // fnName is a dynamic std::string, unique it via a SymbolRefAttr.
63   FlatSymbolRefAttr fnNameAttr =
64       SymbolRefAttr::get(rewriter.getContext(), fnName);
65   auto module = op->getParentOfType<ModuleOp>();
66   if (module.lookupSymbol(fnNameAttr.getAttr()))
67     return fnNameAttr;
68 
69   SmallVector<Type, 4> inputTypes(extractOperandTypes(op));
70   if (op->getNumResults() != 0) {
71     return rewriter.notifyMatchFailure(
72         op,
73         "Library call for linalg operation can be generated only for ops that "
74         "have void return types");
75   }
76   auto libFnType = rewriter.getFunctionType(inputTypes, {});
77 
78   OpBuilder::InsertionGuard guard(rewriter);
79   // Insert before module terminator.
80   rewriter.setInsertionPoint(module.getBody(),
81                              std::prev(module.getBody()->end()));
82   func::FuncOp funcOp = rewriter.create<func::FuncOp>(
83       op->getLoc(), fnNameAttr.getValue(), libFnType);
84   // Insert a function attribute that will trigger the emission of the
85   // corresponding `_mlir_ciface_xxx` interface so that external libraries see
86   // a normalized ABI. This interface is added during std to llvm conversion.
87   funcOp->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
88                   UnitAttr::get(op->getContext()));
89   funcOp.setPrivate();
90   return fnNameAttr;
91 }
92 
93 static SmallVector<Value, 4>
createTypeCanonicalizedMemRefOperands(OpBuilder & b,Location loc,ValueRange operands)94 createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc,
95                                       ValueRange operands) {
96   SmallVector<Value, 4> res;
97   res.reserve(operands.size());
98   for (auto op : operands) {
99     auto memrefType = dyn_cast<MemRefType>(op.getType());
100     if (!memrefType) {
101       res.push_back(op);
102       continue;
103     }
104     Value cast =
105         b.create<memref::CastOp>(loc, makeStridedLayoutDynamic(memrefType), op);
106     res.push_back(cast);
107   }
108   return res;
109 }
110 
matchAndRewrite(LinalgOp op,PatternRewriter & rewriter) const111 LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
112     LinalgOp op, PatternRewriter &rewriter) const {
113   auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
114   if (failed(libraryCallName))
115     return failure();
116 
117   // TODO: Add support for more complex library call signatures that include
118   // indices or captured values.
119   rewriter.replaceOpWithNewOp<func::CallOp>(
120       op, libraryCallName->getValue(), TypeRange(),
121       createTypeCanonicalizedMemRefOperands(rewriter, op->getLoc(),
122                                             op->getOperands()));
123   return success();
124 }
125 
126 /// Populate the given list with patterns that convert from Linalg to Standard.
populateLinalgToStandardConversionPatterns(RewritePatternSet & patterns)127 void mlir::linalg::populateLinalgToStandardConversionPatterns(
128     RewritePatternSet &patterns) {
129   // TODO: ConvOp conversion needs to export a descriptor with relevant
130   // attribute values such as kernel striding and dilation.
131   patterns.add<LinalgOpToLibraryCallRewrite>(patterns.getContext());
132 }
133 
134 namespace {
135 struct ConvertLinalgToStandardPass
136     : public impl::ConvertLinalgToStandardBase<ConvertLinalgToStandardPass> {
137   void runOnOperation() override;
138 };
139 } // namespace
140 
runOnOperation()141 void ConvertLinalgToStandardPass::runOnOperation() {
142   auto module = getOperation();
143   ConversionTarget target(getContext());
144   target.addLegalDialect<affine::AffineDialect, arith::ArithDialect,
145                          func::FuncDialect, memref::MemRefDialect,
146                          scf::SCFDialect>();
147   target.addLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>();
148   RewritePatternSet patterns(&getContext());
149   populateLinalgToStandardConversionPatterns(patterns);
150   if (failed(applyFullConversion(module, target, std::move(patterns))))
151     signalPassFailure();
152 }
153 
154 std::unique_ptr<OperationPass<ModuleOp>>
createConvertLinalgToStandardPass()155 mlir::createConvertLinalgToStandardPass() {
156   return std::make_unique<ConvertLinalgToStandardPass>();
157 }
158