1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 18 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 19 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 20 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 21 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 22 #include "mlir/Dialect/VectorOps/VectorOps.h" 23 #include "mlir/IR/Attributes.h" 24 #include "mlir/IR/Builders.h" 25 #include "mlir/IR/MLIRContext.h" 26 #include "mlir/IR/Module.h" 27 #include "mlir/IR/Operation.h" 28 #include "mlir/IR/PatternMatch.h" 29 #include "mlir/IR/StandardTypes.h" 30 #include "mlir/IR/Types.h" 31 #include "mlir/Pass/Pass.h" 32 #include "mlir/Pass/PassManager.h" 33 #include "mlir/Transforms/DialectConversion.h" 34 #include "mlir/Transforms/Passes.h" 35 36 #include "llvm/IR/DerivedTypes.h" 37 #include "llvm/IR/Module.h" 38 #include "llvm/IR/Type.h" 39 #include "llvm/Support/Allocator.h" 40 #include "llvm/Support/ErrorHandling.h" 41 42 using namespace mlir; 43 44 template <typename T> 45 static LLVM::LLVMType getPtrToElementType(T containerType, 46 LLVMTypeConverter &lowering) { 47 return lowering.convertType(containerType.getElementType()) 48 .template cast<LLVM::LLVMType>() 49 .getPointerTo(); 50 } 51 52 class VectorExtractElementOpConversion : public LLVMOpLowering { 53 public: 54 explicit VectorExtractElementOpConversion(MLIRContext *context, 55 LLVMTypeConverter &typeConverter) 56 : LLVMOpLowering(vector::ExtractElementOp::getOperationName(), context, 57 typeConverter) {} 58 59 PatternMatchResult 60 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 61 ConversionPatternRewriter &rewriter) const override { 62 auto loc = op->getLoc(); 63 auto adaptor = vector::ExtractElementOpOperandAdaptor(operands); 64 auto extractOp = cast<vector::ExtractElementOp>(op); 65 auto vectorType = extractOp.vector()->getType().cast<VectorType>(); 66 auto resultType = extractOp.getResult()->getType(); 67 auto llvmResultType = lowering.convertType(resultType); 68 69 auto positionArrayAttr = extractOp.position(); 70 // One-shot extraction of vector from array (only requires extractvalue). 71 if (resultType.isa<VectorType>()) { 72 Value *extracted = rewriter.create<LLVM::ExtractValueOp>( 73 loc, llvmResultType, adaptor.vector(), positionArrayAttr); 74 rewriter.replaceOp(op, extracted); 75 return matchSuccess(); 76 } 77 78 // Potential extraction of 1-D vector from struct. 79 auto *context = op->getContext(); 80 Value *extracted = adaptor.vector(); 81 auto positionAttrs = positionArrayAttr.getValue(); 82 auto i32Type = rewriter.getIntegerType(32); 83 if (positionAttrs.size() > 1) { 84 auto nDVectorType = vectorType; 85 auto oneDVectorType = VectorType::get(nDVectorType.getShape().take_back(), 86 nDVectorType.getElementType()); 87 auto nMinusOnePositionAttrs = 88 ArrayAttr::get(positionAttrs.drop_back(), context); 89 extracted = rewriter.create<LLVM::ExtractValueOp>( 90 loc, lowering.convertType(oneDVectorType), extracted, 91 nMinusOnePositionAttrs); 92 } 93 94 // Remaining extraction of element from 1-D LLVM vector 95 auto position = positionAttrs.back().cast<IntegerAttr>(); 96 auto constant = rewriter.create<LLVM::ConstantOp>( 97 loc, lowering.convertType(i32Type), position); 98 extracted = 99 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 100 rewriter.replaceOp(op, extracted); 101 102 return matchSuccess(); 103 } 104 }; 105 106 class VectorOuterProductOpConversion : public LLVMOpLowering { 107 public: 108 explicit VectorOuterProductOpConversion(MLIRContext *context, 109 LLVMTypeConverter &typeConverter) 110 : LLVMOpLowering(vector::OuterProductOp::getOperationName(), context, 111 typeConverter) {} 112 113 PatternMatchResult 114 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 115 ConversionPatternRewriter &rewriter) const override { 116 auto loc = op->getLoc(); 117 auto adaptor = vector::OuterProductOpOperandAdaptor(operands); 118 auto *ctx = op->getContext(); 119 auto vLHS = adaptor.lhs()->getType().cast<LLVM::LLVMType>(); 120 auto vRHS = adaptor.rhs()->getType().cast<LLVM::LLVMType>(); 121 auto rankLHS = vLHS.getUnderlyingType()->getVectorNumElements(); 122 auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements(); 123 auto llvmArrayOfVectType = lowering.convertType( 124 cast<vector::OuterProductOp>(op).getResult()->getType()); 125 Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType); 126 Value *a = adaptor.lhs(), *b = adaptor.rhs(); 127 Value *acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front(); 128 SmallVector<Value *, 8> lhs, accs; 129 lhs.reserve(rankLHS); 130 accs.reserve(rankLHS); 131 for (unsigned d = 0, e = rankLHS; d < e; ++d) { 132 // shufflevector explicitly requires i32. 133 auto attr = rewriter.getI32IntegerAttr(d); 134 SmallVector<Attribute, 4> bcastAttr(rankRHS, attr); 135 auto bcastArrayAttr = ArrayAttr::get(bcastAttr, ctx); 136 Value *aD = nullptr, *accD = nullptr; 137 // 1. Broadcast the element a[d] into vector aD. 138 aD = rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, bcastArrayAttr); 139 // 2. If acc is present, extract 1-d vector acc[d] into accD. 140 if (acc) 141 accD = rewriter.create<LLVM::ExtractValueOp>( 142 loc, vRHS, acc, rewriter.getI64ArrayAttr(d)); 143 // 3. Compute aD outer b (plus accD, if relevant). 144 Value *aOuterbD = 145 accD ? rewriter.create<LLVM::FMulAddOp>(loc, vRHS, aD, b, accD) 146 .getResult() 147 : rewriter.create<LLVM::FMulOp>(loc, aD, b).getResult(); 148 // 4. Insert as value `d` in the descriptor. 149 desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayOfVectType, 150 desc, aOuterbD, 151 rewriter.getI64ArrayAttr(d)); 152 } 153 rewriter.replaceOp(op, desc); 154 return matchSuccess(); 155 } 156 }; 157 158 class VectorTypeCastOpConversion : public LLVMOpLowering { 159 public: 160 explicit VectorTypeCastOpConversion(MLIRContext *context, 161 LLVMTypeConverter &typeConverter) 162 : LLVMOpLowering(vector::TypeCastOp::getOperationName(), context, 163 typeConverter) {} 164 165 PatternMatchResult 166 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 167 ConversionPatternRewriter &rewriter) const override { 168 auto loc = op->getLoc(); 169 vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op); 170 MemRefType sourceMemRefType = 171 castOp.getOperand()->getType().cast<MemRefType>(); 172 MemRefType targetMemRefType = 173 castOp.getResult()->getType().cast<MemRefType>(); 174 175 // Only static shape casts supported atm. 176 if (!sourceMemRefType.hasStaticShape() || 177 !targetMemRefType.hasStaticShape()) 178 return matchFailure(); 179 180 auto llvmSourceDescriptorTy = 181 operands[0]->getType().dyn_cast<LLVM::LLVMType>(); 182 if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy()) 183 return matchFailure(); 184 MemRefDescriptor sourceMemRef(operands[0]); 185 186 auto llvmTargetDescriptorTy = lowering.convertType(targetMemRefType) 187 .dyn_cast_or_null<LLVM::LLVMType>(); 188 if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) 189 return matchFailure(); 190 191 int64_t offset; 192 SmallVector<int64_t, 4> strides; 193 auto successStrides = 194 getStridesAndOffset(sourceMemRefType, strides, offset); 195 bool isContiguous = (strides.back() == 1); 196 if (isContiguous) { 197 auto sizes = sourceMemRefType.getShape(); 198 for (int index = 0, e = strides.size() - 2; index < e; ++index) { 199 if (strides[index] != strides[index + 1] * sizes[index + 1]) { 200 isContiguous = false; 201 break; 202 } 203 } 204 } 205 // Only contiguous source tensors supported atm. 206 if (failed(successStrides) || !isContiguous) 207 return matchFailure(); 208 209 auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect()); 210 211 // Create descriptor. 212 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 213 Type llvmTargetElementTy = desc.getElementType(); 214 // Set allocated ptr. 215 Value *allocated = sourceMemRef.allocatedPtr(rewriter, loc); 216 allocated = 217 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 218 desc.setAllocatedPtr(rewriter, loc, allocated); 219 // Set aligned ptr. 220 Value *ptr = sourceMemRef.alignedPtr(rewriter, loc); 221 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 222 desc.setAlignedPtr(rewriter, loc, ptr); 223 // Fill offset 0. 224 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 225 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 226 desc.setOffset(rewriter, loc, zero); 227 228 // Fill size and stride descriptors in memref. 229 for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { 230 int64_t index = indexedSize.index(); 231 auto sizeAttr = 232 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 233 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 234 desc.setSize(rewriter, loc, index, size); 235 auto strideAttr = 236 rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]); 237 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 238 desc.setStride(rewriter, loc, index, stride); 239 } 240 241 rewriter.replaceOp(op, {desc}); 242 return matchSuccess(); 243 } 244 }; 245 246 /// Populate the given list with patterns that convert from Vector to LLVM. 247 void mlir::populateVectorToLLVMConversionPatterns( 248 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 249 patterns.insert<VectorExtractElementOpConversion, 250 VectorOuterProductOpConversion, VectorTypeCastOpConversion>( 251 converter.getDialect()->getContext(), converter); 252 } 253 254 namespace { 255 struct LowerVectorToLLVMPass : public ModulePass<LowerVectorToLLVMPass> { 256 void runOnModule() override; 257 }; 258 } // namespace 259 260 void LowerVectorToLLVMPass::runOnModule() { 261 // Convert to the LLVM IR dialect using the converter defined above. 262 OwningRewritePatternList patterns; 263 LLVMTypeConverter converter(&getContext()); 264 populateVectorToLLVMConversionPatterns(converter, patterns); 265 populateStdToLLVMConversionPatterns(converter, patterns); 266 267 ConversionTarget target(getContext()); 268 target.addLegalDialect<LLVM::LLVMDialect>(); 269 target.addDynamicallyLegalOp<FuncOp>( 270 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 271 if (failed( 272 applyPartialConversion(getModule(), target, patterns, &converter))) { 273 signalPassFailure(); 274 } 275 } 276 277 OpPassBase<ModuleOp> *mlir::createLowerVectorToLLVMPass() { 278 return new LowerVectorToLLVMPass(); 279 } 280 281 static PassRegistration<LowerVectorToLLVMPass> 282 pass("convert-vector-to-llvm", 283 "Lower the operations from the vector dialect into the LLVM dialect"); 284