1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 18 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 19 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 20 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 21 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 22 #include "mlir/Dialect/VectorOps/VectorOps.h" 23 #include "mlir/IR/Attributes.h" 24 #include "mlir/IR/Builders.h" 25 #include "mlir/IR/MLIRContext.h" 26 #include "mlir/IR/Module.h" 27 #include "mlir/IR/Operation.h" 28 #include "mlir/IR/PatternMatch.h" 29 #include "mlir/IR/StandardTypes.h" 30 #include "mlir/IR/Types.h" 31 #include "mlir/Pass/Pass.h" 32 #include "mlir/Pass/PassManager.h" 33 #include "mlir/Transforms/DialectConversion.h" 34 #include "mlir/Transforms/Passes.h" 35 36 #include "llvm/IR/DerivedTypes.h" 37 #include "llvm/IR/Module.h" 38 #include "llvm/IR/Type.h" 39 #include "llvm/Support/Allocator.h" 40 #include "llvm/Support/ErrorHandling.h" 41 42 using namespace mlir; 43 44 template <typename T> 45 static LLVM::LLVMType getPtrToElementType(T containerType, 46 LLVMTypeConverter &lowering) { 47 return lowering.convertType(containerType.getElementType()) 48 .template cast<LLVM::LLVMType>() 49 .getPointerTo(); 50 } 51 52 class VectorBroadcastOpConversion : public LLVMOpLowering { 53 public: 54 explicit VectorBroadcastOpConversion(MLIRContext *context, 55 LLVMTypeConverter &typeConverter) 56 : LLVMOpLowering(vector::BroadcastOp::getOperationName(), context, 57 typeConverter) {} 58 59 PatternMatchResult 60 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 61 ConversionPatternRewriter &rewriter) const override { 62 auto broadcastOp = cast<vector::BroadcastOp>(op); 63 VectorType dstVectorType = broadcastOp.getVectorType(); 64 if (lowering.convertType(dstVectorType) == nullptr) 65 return matchFailure(); 66 // Rewrite when the full vector type can be lowered (which 67 // implies all 'reduced' types can be lowered too). 68 VectorType srcVectorType = 69 broadcastOp.getSourceType().dyn_cast<VectorType>(); 70 rewriter.replaceOp( 71 op, expandRanks(operands[0], // source value to be expanded 72 op->getLoc(), // location of original broadcast 73 srcVectorType, dstVectorType, rewriter)); 74 return matchSuccess(); 75 } 76 77 private: 78 // Expands the given source value over all the ranks, as defined 79 // by the source and destination type (a null source type denotes 80 // expansion from a scalar value into a vector). 81 // 82 // TODO(ajcbik): consider replacing this one-pattern lowering 83 // with a two-pattern lowering using other vector 84 // ops once all insert/extract/shuffle operations 85 // are available with lowering implemention. 86 // 87 Value *expandRanks(Value *value, Location loc, VectorType srcVectorType, 88 VectorType dstVectorType, 89 ConversionPatternRewriter &rewriter) const { 90 assert((dstVectorType != nullptr) && "invalid result type in broadcast"); 91 // Determine rank of source and destination. 92 int64_t srcRank = srcVectorType ? srcVectorType.getRank() : 0; 93 int64_t dstRank = dstVectorType.getRank(); 94 int64_t curDim = dstVectorType.getDimSize(0); 95 if (srcRank < dstRank) 96 // Duplicate this rank. 97 return duplicateOneRank(value, loc, srcVectorType, dstVectorType, dstRank, 98 curDim, rewriter); 99 // If all trailing dimensions are the same, the broadcast consists of 100 // simply passing through the source value and we are done. Otherwise, 101 // any non-matching dimension forces a stretch along this rank. 102 assert((srcVectorType != nullptr) && (srcRank > 0) && 103 (srcRank == dstRank) && "invalid rank in broadcast"); 104 for (int64_t r = 0; r < dstRank; r++) { 105 if (srcVectorType.getDimSize(r) != dstVectorType.getDimSize(r)) { 106 return stretchOneRank(value, loc, srcVectorType, dstVectorType, dstRank, 107 curDim, rewriter); 108 } 109 } 110 return value; 111 } 112 113 // Picks the best way to duplicate a single rank. For the 1-D case, a 114 // single insert-elt/shuffle is the most efficient expansion. For higher 115 // dimensions, however, we need dim x insert-values on a new broadcast 116 // with one less leading dimension, which will be lowered "recursively" 117 // to matching LLVM IR. 118 // For example: 119 // v = broadcast s : f32 to vector<4x2xf32> 120 // becomes: 121 // x = broadcast s : f32 to vector<2xf32> 122 // v = [x,x,x,x] 123 // becomes: 124 // x = [s,s] 125 // v = [x,x,x,x] 126 Value *duplicateOneRank(Value *value, Location loc, VectorType srcVectorType, 127 VectorType dstVectorType, int64_t rank, int64_t dim, 128 ConversionPatternRewriter &rewriter) const { 129 Type llvmType = lowering.convertType(dstVectorType); 130 assert((llvmType != nullptr) && "unlowerable vector type"); 131 if (rank == 1) { 132 Value *undef = rewriter.create<LLVM::UndefOp>(loc, llvmType); 133 Value *expand = insertOne(undef, value, loc, llvmType, rank, 0, rewriter); 134 SmallVector<int32_t, 4> zeroValues(dim, 0); 135 return rewriter.create<LLVM::ShuffleVectorOp>( 136 loc, expand, undef, rewriter.getI32ArrayAttr(zeroValues)); 137 } 138 Value *expand = expandRanks(value, loc, srcVectorType, 139 reducedVectorType(dstVectorType), rewriter); 140 Value *result = rewriter.create<LLVM::UndefOp>(loc, llvmType); 141 for (int64_t d = 0; d < dim; ++d) { 142 result = insertOne(result, expand, loc, llvmType, rank, d, rewriter); 143 } 144 return result; 145 } 146 147 // Picks the best way to stretch a single rank. For the 1-D case, a 148 // single insert-elt/shuffle is the most efficient expansion when at 149 // a stretch. Otherwise, every dimension needs to be expanded 150 // individually and individually inserted in the resulting vector. 151 // For example: 152 // v = broadcast w : vector<4x1x2xf32> to vector<4x2x2xf32> 153 // becomes: 154 // a = broadcast w[0] : vector<1x2xf32> to vector<2x2xf32> 155 // b = broadcast w[1] : vector<1x2xf32> to vector<2x2xf32> 156 // c = broadcast w[2] : vector<1x2xf32> to vector<2x2xf32> 157 // d = broadcast w[3] : vector<1x2xf32> to vector<2x2xf32> 158 // v = [a,b,c,d] 159 // becomes: 160 // x = broadcast w[0][0] : vector<2xf32> to vector <2x2xf32> 161 // y = broadcast w[1][0] : vector<2xf32> to vector <2x2xf32> 162 // a = [x, y] 163 // etc. 164 Value *stretchOneRank(Value *value, Location loc, VectorType srcVectorType, 165 VectorType dstVectorType, int64_t rank, int64_t dim, 166 ConversionPatternRewriter &rewriter) const { 167 Type llvmType = lowering.convertType(dstVectorType); 168 assert((llvmType != nullptr) && "unlowerable vector type"); 169 Value *result = rewriter.create<LLVM::UndefOp>(loc, llvmType); 170 bool atStretch = dim != srcVectorType.getDimSize(0); 171 if (rank == 1) { 172 Type redLlvmType = lowering.convertType(dstVectorType.getElementType()); 173 if (atStretch) { 174 Value *one = extractOne(value, loc, redLlvmType, rank, 0, rewriter); 175 Value *expand = 176 insertOne(result, one, loc, llvmType, rank, 0, rewriter); 177 SmallVector<int32_t, 4> zeroValues(dim, 0); 178 return rewriter.create<LLVM::ShuffleVectorOp>( 179 loc, expand, result, rewriter.getI32ArrayAttr(zeroValues)); 180 } 181 for (int64_t d = 0; d < dim; ++d) { 182 Value *one = extractOne(value, loc, redLlvmType, rank, d, rewriter); 183 result = insertOne(result, one, loc, llvmType, rank, d, rewriter); 184 } 185 } else { 186 VectorType redSrcType = reducedVectorType(srcVectorType); 187 VectorType redDstType = reducedVectorType(dstVectorType); 188 Type redLlvmType = lowering.convertType(redSrcType); 189 for (int64_t d = 0; d < dim; ++d) { 190 int64_t pos = atStretch ? 0 : d; 191 Value *one = extractOne(value, loc, redLlvmType, rank, pos, rewriter); 192 Value *expand = expandRanks(one, loc, redSrcType, redDstType, rewriter); 193 result = insertOne(result, expand, loc, llvmType, rank, d, rewriter); 194 } 195 } 196 return result; 197 } 198 199 // Picks the proper sequence for inserting. 200 Value *insertOne(Value *val1, Value *val2, Location loc, Type llvmType, 201 int64_t rank, int64_t pos, 202 ConversionPatternRewriter &rewriter) const { 203 if (rank == 1) { 204 auto idxType = rewriter.getIndexType(); 205 auto constant = rewriter.create<LLVM::ConstantOp>( 206 loc, lowering.convertType(idxType), 207 rewriter.getIntegerAttr(idxType, pos)); 208 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 209 constant); 210 } 211 return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, 212 rewriter.getI64ArrayAttr(pos)); 213 } 214 215 // Picks the proper sequence for extracting. 216 Value *extractOne(Value *value, Location loc, Type llvmType, int64_t rank, 217 int64_t pos, ConversionPatternRewriter &rewriter) const { 218 if (rank == 1) { 219 auto idxType = rewriter.getIndexType(); 220 auto constant = rewriter.create<LLVM::ConstantOp>( 221 loc, lowering.convertType(idxType), 222 rewriter.getIntegerAttr(idxType, pos)); 223 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, value, 224 constant); 225 } 226 return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, value, 227 rewriter.getI64ArrayAttr(pos)); 228 } 229 230 // Helper to reduce vector type by one rank. 231 static VectorType reducedVectorType(VectorType tp) { 232 assert((tp.getRank() > 1) && "unlowerable vector type"); 233 return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); 234 } 235 }; 236 237 class VectorExtractElementOpConversion : public LLVMOpLowering { 238 public: 239 explicit VectorExtractElementOpConversion(MLIRContext *context, 240 LLVMTypeConverter &typeConverter) 241 : LLVMOpLowering(vector::ExtractOp::getOperationName(), context, 242 typeConverter) {} 243 244 PatternMatchResult 245 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 246 ConversionPatternRewriter &rewriter) const override { 247 auto loc = op->getLoc(); 248 auto adaptor = vector::ExtractOpOperandAdaptor(operands); 249 auto extractOp = cast<vector::ExtractOp>(op); 250 auto vectorType = extractOp.vector()->getType().cast<VectorType>(); 251 auto resultType = extractOp.getResult()->getType(); 252 auto llvmResultType = lowering.convertType(resultType); 253 254 auto positionArrayAttr = extractOp.position(); 255 // One-shot extraction of vector from array (only requires extractvalue). 256 if (resultType.isa<VectorType>()) { 257 Value *extracted = rewriter.create<LLVM::ExtractValueOp>( 258 loc, llvmResultType, adaptor.vector(), positionArrayAttr); 259 rewriter.replaceOp(op, extracted); 260 return matchSuccess(); 261 } 262 263 // Potential extraction of 1-D vector from struct. 264 auto *context = op->getContext(); 265 Value *extracted = adaptor.vector(); 266 auto positionAttrs = positionArrayAttr.getValue(); 267 auto i32Type = rewriter.getIntegerType(32); 268 if (positionAttrs.size() > 1) { 269 auto nDVectorType = vectorType; 270 auto oneDVectorType = VectorType::get(nDVectorType.getShape().take_back(), 271 nDVectorType.getElementType()); 272 auto nMinusOnePositionAttrs = 273 ArrayAttr::get(positionAttrs.drop_back(), context); 274 extracted = rewriter.create<LLVM::ExtractValueOp>( 275 loc, lowering.convertType(oneDVectorType), extracted, 276 nMinusOnePositionAttrs); 277 } 278 279 // Remaining extraction of element from 1-D LLVM vector 280 auto position = positionAttrs.back().cast<IntegerAttr>(); 281 auto constant = rewriter.create<LLVM::ConstantOp>( 282 loc, lowering.convertType(i32Type), position); 283 extracted = 284 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 285 rewriter.replaceOp(op, extracted); 286 287 return matchSuccess(); 288 } 289 }; 290 291 class VectorOuterProductOpConversion : public LLVMOpLowering { 292 public: 293 explicit VectorOuterProductOpConversion(MLIRContext *context, 294 LLVMTypeConverter &typeConverter) 295 : LLVMOpLowering(vector::OuterProductOp::getOperationName(), context, 296 typeConverter) {} 297 298 PatternMatchResult 299 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 300 ConversionPatternRewriter &rewriter) const override { 301 auto loc = op->getLoc(); 302 auto adaptor = vector::OuterProductOpOperandAdaptor(operands); 303 auto *ctx = op->getContext(); 304 auto vLHS = adaptor.lhs()->getType().cast<LLVM::LLVMType>(); 305 auto vRHS = adaptor.rhs()->getType().cast<LLVM::LLVMType>(); 306 auto rankLHS = vLHS.getUnderlyingType()->getVectorNumElements(); 307 auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements(); 308 auto llvmArrayOfVectType = lowering.convertType( 309 cast<vector::OuterProductOp>(op).getResult()->getType()); 310 Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType); 311 Value *a = adaptor.lhs(), *b = adaptor.rhs(); 312 Value *acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front(); 313 SmallVector<Value *, 8> lhs, accs; 314 lhs.reserve(rankLHS); 315 accs.reserve(rankLHS); 316 for (unsigned d = 0, e = rankLHS; d < e; ++d) { 317 // shufflevector explicitly requires i32. 318 auto attr = rewriter.getI32IntegerAttr(d); 319 SmallVector<Attribute, 4> bcastAttr(rankRHS, attr); 320 auto bcastArrayAttr = ArrayAttr::get(bcastAttr, ctx); 321 Value *aD = nullptr, *accD = nullptr; 322 // 1. Broadcast the element a[d] into vector aD. 323 aD = rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, bcastArrayAttr); 324 // 2. If acc is present, extract 1-d vector acc[d] into accD. 325 if (acc) 326 accD = rewriter.create<LLVM::ExtractValueOp>( 327 loc, vRHS, acc, rewriter.getI64ArrayAttr(d)); 328 // 3. Compute aD outer b (plus accD, if relevant). 329 Value *aOuterbD = 330 accD ? rewriter.create<LLVM::FMulAddOp>(loc, vRHS, aD, b, accD) 331 .getResult() 332 : rewriter.create<LLVM::FMulOp>(loc, aD, b).getResult(); 333 // 4. Insert as value `d` in the descriptor. 334 desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayOfVectType, 335 desc, aOuterbD, 336 rewriter.getI64ArrayAttr(d)); 337 } 338 rewriter.replaceOp(op, desc); 339 return matchSuccess(); 340 } 341 }; 342 343 class VectorTypeCastOpConversion : public LLVMOpLowering { 344 public: 345 explicit VectorTypeCastOpConversion(MLIRContext *context, 346 LLVMTypeConverter &typeConverter) 347 : LLVMOpLowering(vector::TypeCastOp::getOperationName(), context, 348 typeConverter) {} 349 350 PatternMatchResult 351 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 352 ConversionPatternRewriter &rewriter) const override { 353 auto loc = op->getLoc(); 354 vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op); 355 MemRefType sourceMemRefType = 356 castOp.getOperand()->getType().cast<MemRefType>(); 357 MemRefType targetMemRefType = 358 castOp.getResult()->getType().cast<MemRefType>(); 359 360 // Only static shape casts supported atm. 361 if (!sourceMemRefType.hasStaticShape() || 362 !targetMemRefType.hasStaticShape()) 363 return matchFailure(); 364 365 auto llvmSourceDescriptorTy = 366 operands[0]->getType().dyn_cast<LLVM::LLVMType>(); 367 if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy()) 368 return matchFailure(); 369 MemRefDescriptor sourceMemRef(operands[0]); 370 371 auto llvmTargetDescriptorTy = lowering.convertType(targetMemRefType) 372 .dyn_cast_or_null<LLVM::LLVMType>(); 373 if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) 374 return matchFailure(); 375 376 int64_t offset; 377 SmallVector<int64_t, 4> strides; 378 auto successStrides = 379 getStridesAndOffset(sourceMemRefType, strides, offset); 380 bool isContiguous = (strides.back() == 1); 381 if (isContiguous) { 382 auto sizes = sourceMemRefType.getShape(); 383 for (int index = 0, e = strides.size() - 2; index < e; ++index) { 384 if (strides[index] != strides[index + 1] * sizes[index + 1]) { 385 isContiguous = false; 386 break; 387 } 388 } 389 } 390 // Only contiguous source tensors supported atm. 391 if (failed(successStrides) || !isContiguous) 392 return matchFailure(); 393 394 auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect()); 395 396 // Create descriptor. 397 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 398 Type llvmTargetElementTy = desc.getElementType(); 399 // Set allocated ptr. 400 Value *allocated = sourceMemRef.allocatedPtr(rewriter, loc); 401 allocated = 402 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 403 desc.setAllocatedPtr(rewriter, loc, allocated); 404 // Set aligned ptr. 405 Value *ptr = sourceMemRef.alignedPtr(rewriter, loc); 406 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 407 desc.setAlignedPtr(rewriter, loc, ptr); 408 // Fill offset 0. 409 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 410 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 411 desc.setOffset(rewriter, loc, zero); 412 413 // Fill size and stride descriptors in memref. 414 for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { 415 int64_t index = indexedSize.index(); 416 auto sizeAttr = 417 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 418 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 419 desc.setSize(rewriter, loc, index, size); 420 auto strideAttr = 421 rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]); 422 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 423 desc.setStride(rewriter, loc, index, stride); 424 } 425 426 rewriter.replaceOp(op, {desc}); 427 return matchSuccess(); 428 } 429 }; 430 431 /// Populate the given list with patterns that convert from Vector to LLVM. 432 void mlir::populateVectorToLLVMConversionPatterns( 433 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 434 patterns.insert<VectorBroadcastOpConversion, VectorExtractElementOpConversion, 435 VectorOuterProductOpConversion, VectorTypeCastOpConversion>( 436 converter.getDialect()->getContext(), converter); 437 } 438 439 namespace { 440 struct LowerVectorToLLVMPass : public ModulePass<LowerVectorToLLVMPass> { 441 void runOnModule() override; 442 }; 443 } // namespace 444 445 void LowerVectorToLLVMPass::runOnModule() { 446 // Convert to the LLVM IR dialect using the converter defined above. 447 OwningRewritePatternList patterns; 448 LLVMTypeConverter converter(&getContext()); 449 populateVectorToLLVMConversionPatterns(converter, patterns); 450 populateStdToLLVMConversionPatterns(converter, patterns); 451 452 ConversionTarget target(getContext()); 453 target.addLegalDialect<LLVM::LLVMDialect>(); 454 target.addDynamicallyLegalOp<FuncOp>( 455 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 456 if (failed( 457 applyPartialConversion(getModule(), target, patterns, &converter))) { 458 signalPassFailure(); 459 } 460 } 461 462 OpPassBase<ModuleOp> *mlir::createLowerVectorToLLVMPass() { 463 return new LowerVectorToLLVMPass(); 464 } 465 466 static PassRegistration<LowerVectorToLLVMPass> 467 pass("convert-vector-to-llvm", 468 "Lower the operations from the vector dialect into the LLVM dialect"); 469