1 //===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/MemRef/IR/MemRef.h" 16 #include "mlir/Dialect/SCF/SCF.h" 17 #include "mlir/Dialect/StandardOps/IR/Ops.h" 18 19 using namespace mlir; 20 using namespace mlir::linalg; 21 22 /// Helper function to extract the operand types that are passed to the 23 /// generated CallOp. MemRefTypes have their layout canonicalized since the 24 /// information is not used in signature generation. 25 /// Note that static size information is not modified. 26 static SmallVector<Type, 4> extractOperandTypes(Operation *op) { 27 SmallVector<Type, 4> result; 28 result.reserve(op->getNumOperands()); 29 for (auto type : op->getOperandTypes()) { 30 // The underlying descriptor type (e.g. LLVM) does not have layout 31 // information. Canonicalizing the type at the level of std when going into 32 // a library call avoids needing to introduce DialectCastOp. 33 if (auto memrefType = type.dyn_cast<MemRefType>()) 34 result.push_back(eraseStridedLayout(memrefType)); 35 else 36 result.push_back(type); 37 } 38 return result; 39 } 40 41 // Get a SymbolRefAttr containing the library function name for the LinalgOp. 42 // If the library function does not exist, insert a declaration. 43 static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op, 44 PatternRewriter &rewriter) { 45 auto linalgOp = cast<LinalgOp>(op); 46 auto fnName = linalgOp.getLibraryCallName(); 47 if (fnName.empty()) { 48 op->emitWarning("No library call defined for: ") << *op; 49 return {}; 50 } 51 52 // fnName is a dynamic std::string, unique it via a SymbolRefAttr. 53 FlatSymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName); 54 auto module = op->getParentOfType<ModuleOp>(); 55 if (module.lookupSymbol(fnName)) { 56 return fnNameAttr; 57 } 58 59 SmallVector<Type, 4> inputTypes(extractOperandTypes(op)); 60 assert(op->getNumResults() == 0 && 61 "Library call for linalg operation can be generated only for ops that " 62 "have void return types"); 63 auto libFnType = rewriter.getFunctionType(inputTypes, {}); 64 65 OpBuilder::InsertionGuard guard(rewriter); 66 // Insert before module terminator. 67 rewriter.setInsertionPoint(module.getBody(), 68 std::prev(module.getBody()->end())); 69 FuncOp funcOp = 70 rewriter.create<FuncOp>(op->getLoc(), fnNameAttr.getValue(), libFnType); 71 // Insert a function attribute that will trigger the emission of the 72 // corresponding `_mlir_ciface_xxx` interface so that external libraries see 73 // a normalized ABI. This interface is added during std to llvm conversion. 74 funcOp->setAttr("llvm.emit_c_interface", UnitAttr::get(op->getContext())); 75 funcOp.setPrivate(); 76 return fnNameAttr; 77 } 78 79 static SmallVector<Value, 4> 80 createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc, 81 ValueRange operands) { 82 SmallVector<Value, 4> res; 83 res.reserve(operands.size()); 84 for (auto op : operands) { 85 auto memrefType = op.getType().dyn_cast<MemRefType>(); 86 if (!memrefType) { 87 res.push_back(op); 88 continue; 89 } 90 Value cast = 91 b.create<memref::CastOp>(loc, eraseStridedLayout(memrefType), op); 92 res.push_back(cast); 93 } 94 return res; 95 } 96 97 LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite( 98 LinalgOp op, PatternRewriter &rewriter) const { 99 // Only LinalgOp for which there is no specialized pattern go through this. 100 if (isa<CopyOp>(op)) 101 return failure(); 102 103 auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); 104 if (!libraryCallName) 105 return failure(); 106 107 // TODO: Add support for more complex library call signatures that include 108 // indices or captured values. 109 rewriter.replaceOpWithNewOp<mlir::CallOp>( 110 op, libraryCallName.getValue(), TypeRange(), 111 createTypeCanonicalizedMemRefOperands(rewriter, op->getLoc(), 112 op->getOperands())); 113 return success(); 114 } 115 116 LogicalResult mlir::linalg::CopyOpToLibraryCallRewrite::matchAndRewrite( 117 CopyOp op, PatternRewriter &rewriter) const { 118 auto inputPerm = op.inputPermutation(); 119 if (inputPerm.hasValue() && !inputPerm->isIdentity()) 120 return failure(); 121 auto outputPerm = op.outputPermutation(); 122 if (outputPerm.hasValue() && !outputPerm->isIdentity()) 123 return failure(); 124 125 auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); 126 if (!libraryCallName) 127 return failure(); 128 129 rewriter.replaceOpWithNewOp<mlir::CallOp>( 130 op, libraryCallName.getValue(), TypeRange(), 131 createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), 132 op.getOperands())); 133 return success(); 134 } 135 136 LogicalResult mlir::linalg::CopyTransposeRewrite::matchAndRewrite( 137 CopyOp op, PatternRewriter &rewriter) const { 138 Value in = op.input(), out = op.output(); 139 140 // If either inputPerm or outputPerm are non-identities, insert transposes. 141 auto inputPerm = op.inputPermutation(); 142 if (inputPerm.hasValue() && !inputPerm->isIdentity()) 143 in = rewriter.create<memref::TransposeOp>(op.getLoc(), in, 144 AffineMapAttr::get(*inputPerm)); 145 auto outputPerm = op.outputPermutation(); 146 if (outputPerm.hasValue() && !outputPerm->isIdentity()) 147 out = rewriter.create<memref::TransposeOp>(op.getLoc(), out, 148 AffineMapAttr::get(*outputPerm)); 149 150 // If nothing was transposed, fail and let the conversion kick in. 151 if (in == op.input() && out == op.output()) 152 return failure(); 153 154 auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); 155 if (!libraryCallName) 156 return failure(); 157 158 rewriter.replaceOpWithNewOp<mlir::CallOp>( 159 op, libraryCallName.getValue(), TypeRange(), 160 createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), {in, out})); 161 return success(); 162 } 163 164 /// Populate the given list with patterns that convert from Linalg to Standard. 165 void mlir::linalg::populateLinalgToStandardConversionPatterns( 166 RewritePatternSet &patterns) { 167 // TODO: ConvOp conversion needs to export a descriptor with relevant 168 // attribute values such as kernel striding and dilation. 169 // clang-format off 170 patterns.add< 171 CopyOpToLibraryCallRewrite, 172 CopyTransposeRewrite, 173 LinalgOpToLibraryCallRewrite>(patterns.getContext()); 174 // clang-format on 175 } 176 177 namespace { 178 struct ConvertLinalgToStandardPass 179 : public ConvertLinalgToStandardBase<ConvertLinalgToStandardPass> { 180 void runOnOperation() override; 181 }; 182 } // namespace 183 184 void ConvertLinalgToStandardPass::runOnOperation() { 185 auto module = getOperation(); 186 ConversionTarget target(getContext()); 187 target.addLegalDialect<AffineDialect, memref::MemRefDialect, scf::SCFDialect, 188 StandardOpsDialect>(); 189 target.addLegalOp<ModuleOp, FuncOp, ReturnOp>(); 190 target.addLegalOp<linalg::ExpandShapeOp, linalg::CollapseShapeOp, 191 linalg::RangeOp>(); 192 RewritePatternSet patterns(&getContext()); 193 populateLinalgToStandardConversionPatterns(patterns); 194 if (failed(applyFullConversion(module, target, std::move(patterns)))) 195 signalPassFailure(); 196 } 197 198 std::unique_ptr<OperationPass<ModuleOp>> 199 mlir::createConvertLinalgToStandardPass() { 200 return std::make_unique<ConvertLinalgToStandardPass>(); 201 } 202