xref: /llvm-project/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp (revision f1844f15c1ad54b78f2d84087df4b51fe5f703f6)
1 //===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/SCF/SCF.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 
19 using namespace mlir;
20 using namespace mlir::linalg;
21 
22 /// Helper function to extract the operand types that are passed to the
23 /// generated CallOp. MemRefTypes have their layout canonicalized since the
24 /// information is not used in signature generation.
25 /// Note that static size information is not modified.
26 static SmallVector<Type, 4> extractOperandTypes(Operation *op) {
27   SmallVector<Type, 4> result;
28   result.reserve(op->getNumOperands());
29   for (auto type : op->getOperandTypes()) {
30     // The underlying descriptor type (e.g. LLVM) does not have layout
31     // information. Canonicalizing the type at the level of std when going into
32     // a library call avoids needing to introduce DialectCastOp.
33     if (auto memrefType = type.dyn_cast<MemRefType>())
34       result.push_back(eraseStridedLayout(memrefType));
35     else
36       result.push_back(type);
37   }
38   return result;
39 }
40 
41 // Get a SymbolRefAttr containing the library function name for the LinalgOp.
42 // If the library function does not exist, insert a declaration.
43 static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
44                                                  PatternRewriter &rewriter) {
45   auto linalgOp = cast<LinalgOp>(op);
46   auto fnName = linalgOp.getLibraryCallName();
47   if (fnName.empty()) {
48     op->emitWarning("No library call defined for: ") << *op;
49     return {};
50   }
51 
52   // fnName is a dynamic std::string, unique it via a SymbolRefAttr.
53   FlatSymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName);
54   auto module = op->getParentOfType<ModuleOp>();
55   if (module.lookupSymbol(fnName)) {
56     return fnNameAttr;
57   }
58 
59   SmallVector<Type, 4> inputTypes(extractOperandTypes(op));
60   assert(op->getNumResults() == 0 &&
61          "Library call for linalg operation can be generated only for ops that "
62          "have void return types");
63   auto libFnType = rewriter.getFunctionType(inputTypes, {});
64 
65   OpBuilder::InsertionGuard guard(rewriter);
66   // Insert before module terminator.
67   rewriter.setInsertionPoint(module.getBody(),
68                              std::prev(module.getBody()->end()));
69   FuncOp funcOp =
70       rewriter.create<FuncOp>(op->getLoc(), fnNameAttr.getValue(), libFnType);
71   // Insert a function attribute that will trigger the emission of the
72   // corresponding `_mlir_ciface_xxx` interface so that external libraries see
73   // a normalized ABI. This interface is added during std to llvm conversion.
74   funcOp->setAttr("llvm.emit_c_interface", UnitAttr::get(op->getContext()));
75   funcOp.setPrivate();
76   return fnNameAttr;
77 }
78 
79 static SmallVector<Value, 4>
80 createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc,
81                                       ValueRange operands) {
82   SmallVector<Value, 4> res;
83   res.reserve(operands.size());
84   for (auto op : operands) {
85     auto memrefType = op.getType().dyn_cast<MemRefType>();
86     if (!memrefType) {
87       res.push_back(op);
88       continue;
89     }
90     Value cast =
91         b.create<memref::CastOp>(loc, eraseStridedLayout(memrefType), op);
92     res.push_back(cast);
93   }
94   return res;
95 }
96 
97 LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
98     LinalgOp op, PatternRewriter &rewriter) const {
99   // Only LinalgOp for which there is no specialized pattern go through this.
100   if (isa<CopyOp>(op))
101     return failure();
102 
103   auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
104   if (!libraryCallName)
105     return failure();
106 
107   // TODO: Add support for more complex library call signatures that include
108   // indices or captured values.
109   rewriter.replaceOpWithNewOp<mlir::CallOp>(
110       op, libraryCallName.getValue(), TypeRange(),
111       createTypeCanonicalizedMemRefOperands(rewriter, op->getLoc(),
112                                             op->getOperands()));
113   return success();
114 }
115 
116 LogicalResult mlir::linalg::CopyOpToLibraryCallRewrite::matchAndRewrite(
117     CopyOp op, PatternRewriter &rewriter) const {
118   auto inputPerm = op.inputPermutation();
119   if (inputPerm.hasValue() && !inputPerm->isIdentity())
120     return failure();
121   auto outputPerm = op.outputPermutation();
122   if (outputPerm.hasValue() && !outputPerm->isIdentity())
123     return failure();
124 
125   auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
126   if (!libraryCallName)
127     return failure();
128 
129   rewriter.replaceOpWithNewOp<mlir::CallOp>(
130       op, libraryCallName.getValue(), TypeRange(),
131       createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(),
132                                             op.getOperands()));
133   return success();
134 }
135 
136 LogicalResult mlir::linalg::CopyTransposeRewrite::matchAndRewrite(
137     CopyOp op, PatternRewriter &rewriter) const {
138   Value in = op.input(), out = op.output();
139 
140   // If either inputPerm or outputPerm are non-identities, insert transposes.
141   auto inputPerm = op.inputPermutation();
142   if (inputPerm.hasValue() && !inputPerm->isIdentity())
143     in = rewriter.create<memref::TransposeOp>(op.getLoc(), in,
144                                               AffineMapAttr::get(*inputPerm));
145   auto outputPerm = op.outputPermutation();
146   if (outputPerm.hasValue() && !outputPerm->isIdentity())
147     out = rewriter.create<memref::TransposeOp>(op.getLoc(), out,
148                                                AffineMapAttr::get(*outputPerm));
149 
150   // If nothing was transposed, fail and let the conversion kick in.
151   if (in == op.input() && out == op.output())
152     return failure();
153 
154   auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
155   if (!libraryCallName)
156     return failure();
157 
158   rewriter.replaceOpWithNewOp<mlir::CallOp>(
159       op, libraryCallName.getValue(), TypeRange(),
160       createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), {in, out}));
161   return success();
162 }
163 
164 /// Populate the given list with patterns that convert from Linalg to Standard.
165 void mlir::linalg::populateLinalgToStandardConversionPatterns(
166     RewritePatternSet &patterns) {
167   // TODO: ConvOp conversion needs to export a descriptor with relevant
168   // attribute values such as kernel striding and dilation.
169   // clang-format off
170   patterns.add<
171       CopyOpToLibraryCallRewrite,
172       CopyTransposeRewrite,
173       LinalgOpToLibraryCallRewrite>(patterns.getContext());
174   // clang-format on
175 }
176 
177 namespace {
178 struct ConvertLinalgToStandardPass
179     : public ConvertLinalgToStandardBase<ConvertLinalgToStandardPass> {
180   void runOnOperation() override;
181 };
182 } // namespace
183 
184 void ConvertLinalgToStandardPass::runOnOperation() {
185   auto module = getOperation();
186   ConversionTarget target(getContext());
187   target.addLegalDialect<AffineDialect, memref::MemRefDialect, scf::SCFDialect,
188                          StandardOpsDialect>();
189   target.addLegalOp<ModuleOp, FuncOp, ReturnOp>();
190   target.addLegalOp<linalg::ExpandShapeOp, linalg::CollapseShapeOp,
191                     linalg::RangeOp>();
192   RewritePatternSet patterns(&getContext());
193   populateLinalgToStandardConversionPatterns(patterns);
194   if (failed(applyFullConversion(module, target, std::move(patterns))))
195     signalPassFailure();
196 }
197 
198 std::unique_ptr<OperationPass<ModuleOp>>
199 mlir::createConvertLinalgToStandardPass() {
200   return std::make_unique<ConvertLinalgToStandardPass>();
201 }
202