1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Dialect/Vector/VectorOps.h" 16 #include "mlir/IR/BuiltinTypes.h" 17 #include "mlir/Target/LLVMIR/TypeTranslation.h" 18 #include "mlir/Transforms/DialectConversion.h" 19 20 using namespace mlir; 21 using namespace mlir::vector; 22 23 // Helper to reduce vector type by one rank at front. 24 static VectorType reducedVectorTypeFront(VectorType tp) { 25 assert((tp.getRank() > 1) && "unlowerable vector type"); 26 return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); 27 } 28 29 // Helper to reduce vector type by *all* but one rank at back. 30 static VectorType reducedVectorTypeBack(VectorType tp) { 31 assert((tp.getRank() > 1) && "unlowerable vector type"); 32 return VectorType::get(tp.getShape().take_back(), tp.getElementType()); 33 } 34 35 // Helper that picks the proper sequence for inserting. 36 static Value insertOne(ConversionPatternRewriter &rewriter, 37 LLVMTypeConverter &typeConverter, Location loc, 38 Value val1, Value val2, Type llvmType, int64_t rank, 39 int64_t pos) { 40 if (rank == 1) { 41 auto idxType = rewriter.getIndexType(); 42 auto constant = rewriter.create<LLVM::ConstantOp>( 43 loc, typeConverter.convertType(idxType), 44 rewriter.getIntegerAttr(idxType, pos)); 45 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 46 constant); 47 } 48 return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, 49 rewriter.getI64ArrayAttr(pos)); 50 } 51 52 // Helper that picks the proper sequence for inserting. 53 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, 54 Value into, int64_t offset) { 55 auto vectorType = into.getType().cast<VectorType>(); 56 if (vectorType.getRank() > 1) 57 return rewriter.create<InsertOp>(loc, from, into, offset); 58 return rewriter.create<vector::InsertElementOp>( 59 loc, vectorType, from, into, 60 rewriter.create<ConstantIndexOp>(loc, offset)); 61 } 62 63 // Helper that picks the proper sequence for extracting. 64 static Value extractOne(ConversionPatternRewriter &rewriter, 65 LLVMTypeConverter &typeConverter, Location loc, 66 Value val, Type llvmType, int64_t rank, int64_t pos) { 67 if (rank == 1) { 68 auto idxType = rewriter.getIndexType(); 69 auto constant = rewriter.create<LLVM::ConstantOp>( 70 loc, typeConverter.convertType(idxType), 71 rewriter.getIntegerAttr(idxType, pos)); 72 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 73 constant); 74 } 75 return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val, 76 rewriter.getI64ArrayAttr(pos)); 77 } 78 79 // Helper that picks the proper sequence for extracting. 80 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, 81 int64_t offset) { 82 auto vectorType = vector.getType().cast<VectorType>(); 83 if (vectorType.getRank() > 1) 84 return rewriter.create<ExtractOp>(loc, vector, offset); 85 return rewriter.create<vector::ExtractElementOp>( 86 loc, vectorType.getElementType(), vector, 87 rewriter.create<ConstantIndexOp>(loc, offset)); 88 } 89 90 // Helper that returns a subset of `arrayAttr` as a vector of int64_t. 91 // TODO: Better support for attribute subtype forwarding + slicing. 92 static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr, 93 unsigned dropFront = 0, 94 unsigned dropBack = 0) { 95 assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds"); 96 auto range = arrayAttr.getAsRange<IntegerAttr>(); 97 SmallVector<int64_t, 4> res; 98 res.reserve(arrayAttr.size() - dropFront - dropBack); 99 for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; 100 it != eit; ++it) 101 res.push_back((*it).getValue().getSExtValue()); 102 return res; 103 } 104 105 static Value createCastToIndexLike(ConversionPatternRewriter &rewriter, 106 Location loc, Type targetType, Value value) { 107 if (targetType == value.getType()) 108 return value; 109 110 bool targetIsIndex = targetType.isIndex(); 111 bool valueIsIndex = value.getType().isIndex(); 112 if (targetIsIndex ^ valueIsIndex) 113 return rewriter.create<IndexCastOp>(loc, targetType, value); 114 115 auto targetIntegerType = targetType.dyn_cast<IntegerType>(); 116 auto valueIntegerType = value.getType().dyn_cast<IntegerType>(); 117 assert(targetIntegerType && valueIntegerType && 118 "unexpected cast between types other than integers and index"); 119 assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness()); 120 121 if (targetIntegerType.getWidth() > valueIntegerType.getWidth()) 122 return rewriter.create<SignExtendIOp>(loc, targetIntegerType, value); 123 return rewriter.create<TruncateIOp>(loc, targetIntegerType, value); 124 } 125 126 // Helper that returns a vector comparison that constructs a mask: 127 // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] 128 // 129 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative, 130 // much more compact, IR for this operation, but LLVM eventually 131 // generates more elaborate instructions for this intrinsic since it 132 // is very conservative on the boundary conditions. 133 static Value buildVectorComparison(ConversionPatternRewriter &rewriter, 134 Operation *op, bool enableIndexOptimizations, 135 int64_t dim, Value b, Value *off = nullptr) { 136 auto loc = op->getLoc(); 137 // If we can assume all indices fit in 32-bit, we perform the vector 138 // comparison in 32-bit to get a higher degree of SIMD parallelism. 139 // Otherwise we perform the vector comparison using 64-bit indices. 140 Value indices; 141 Type idxType; 142 if (enableIndexOptimizations) { 143 indices = rewriter.create<ConstantOp>( 144 loc, rewriter.getI32VectorAttr( 145 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)))); 146 idxType = rewriter.getI32Type(); 147 } else { 148 indices = rewriter.create<ConstantOp>( 149 loc, rewriter.getI64VectorAttr( 150 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)))); 151 idxType = rewriter.getI64Type(); 152 } 153 // Add in an offset if requested. 154 if (off) { 155 Value o = createCastToIndexLike(rewriter, loc, idxType, *off); 156 Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o); 157 indices = rewriter.create<AddIOp>(loc, ov, indices); 158 } 159 // Construct the vector comparison. 160 Value bound = createCastToIndexLike(rewriter, loc, idxType, b); 161 Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound); 162 return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds); 163 } 164 165 // Helper that returns data layout alignment of a memref. 166 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, 167 MemRefType memrefType, unsigned &align) { 168 Type elementTy = typeConverter.convertType(memrefType.getElementType()); 169 if (!elementTy) 170 return failure(); 171 172 // TODO: this should use the MLIR data layout when it becomes available and 173 // stop depending on translation. 174 llvm::LLVMContext llvmContext; 175 align = LLVM::TypeToLLVMIRTranslator(llvmContext) 176 .getPreferredAlignment(elementTy, typeConverter.getDataLayout()); 177 return success(); 178 } 179 180 // Helper that returns the base address of a memref. 181 static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc, 182 Value memref, MemRefType memRefType, Value &base) { 183 // Inspect stride and offset structure. 184 // 185 // TODO: flat memory only for now, generalize 186 // 187 int64_t offset; 188 SmallVector<int64_t, 4> strides; 189 auto successStrides = getStridesAndOffset(memRefType, strides, offset); 190 if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 || 191 offset != 0 || memRefType.getMemorySpace() != 0) 192 return failure(); 193 base = MemRefDescriptor(memref).alignedPtr(rewriter, loc); 194 return success(); 195 } 196 197 // Helper that returns vector of pointers given a memref base with index vector. 198 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, 199 Location loc, Value memref, Value indices, 200 MemRefType memRefType, VectorType vType, 201 Type iType, Value &ptrs) { 202 Value base; 203 if (failed(getBase(rewriter, loc, memref, memRefType, base))) 204 return failure(); 205 auto pType = MemRefDescriptor(memref).getElementPtrType(); 206 auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0)); 207 ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices); 208 return success(); 209 } 210 211 // Casts a strided element pointer to a vector pointer. The vector pointer 212 // would always be on address space 0, therefore addrspacecast shall be 213 // used when source/dst memrefs are not on address space 0. 214 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, 215 Value ptr, MemRefType memRefType, Type vt) { 216 auto pType = LLVM::LLVMPointerType::get(vt); 217 if (memRefType.getMemorySpace() == 0) 218 return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr); 219 return rewriter.create<LLVM::AddrSpaceCastOp>(loc, pType, ptr); 220 } 221 222 static LogicalResult 223 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, 224 LLVMTypeConverter &typeConverter, Location loc, 225 TransferReadOp xferOp, 226 ArrayRef<Value> operands, Value dataPtr) { 227 unsigned align; 228 if (failed(getMemRefAlignment( 229 typeConverter, xferOp.getShapedType().cast<MemRefType>(), align))) 230 return failure(); 231 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align); 232 return success(); 233 } 234 235 static LogicalResult 236 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, 237 LLVMTypeConverter &typeConverter, Location loc, 238 TransferReadOp xferOp, ArrayRef<Value> operands, 239 Value dataPtr, Value mask) { 240 VectorType fillType = xferOp.getVectorType(); 241 Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding()); 242 243 Type vecTy = typeConverter.convertType(xferOp.getVectorType()); 244 if (!vecTy) 245 return failure(); 246 247 unsigned align; 248 if (failed(getMemRefAlignment( 249 typeConverter, xferOp.getShapedType().cast<MemRefType>(), align))) 250 return failure(); 251 252 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 253 xferOp, vecTy, dataPtr, mask, ValueRange{fill}, 254 rewriter.getI32IntegerAttr(align)); 255 return success(); 256 } 257 258 static LogicalResult 259 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, 260 LLVMTypeConverter &typeConverter, Location loc, 261 TransferWriteOp xferOp, 262 ArrayRef<Value> operands, Value dataPtr) { 263 unsigned align; 264 if (failed(getMemRefAlignment( 265 typeConverter, xferOp.getShapedType().cast<MemRefType>(), align))) 266 return failure(); 267 auto adaptor = TransferWriteOpAdaptor(operands); 268 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr, 269 align); 270 return success(); 271 } 272 273 static LogicalResult 274 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, 275 LLVMTypeConverter &typeConverter, Location loc, 276 TransferWriteOp xferOp, ArrayRef<Value> operands, 277 Value dataPtr, Value mask) { 278 unsigned align; 279 if (failed(getMemRefAlignment( 280 typeConverter, xferOp.getShapedType().cast<MemRefType>(), align))) 281 return failure(); 282 283 auto adaptor = TransferWriteOpAdaptor(operands); 284 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 285 xferOp, adaptor.vector(), dataPtr, mask, 286 rewriter.getI32IntegerAttr(align)); 287 return success(); 288 } 289 290 static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp, 291 ArrayRef<Value> operands) { 292 return TransferReadOpAdaptor(operands); 293 } 294 295 static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp, 296 ArrayRef<Value> operands) { 297 return TransferWriteOpAdaptor(operands); 298 } 299 300 namespace { 301 302 /// Conversion pattern for a vector.bitcast. 303 class VectorBitCastOpConversion 304 : public ConvertOpToLLVMPattern<vector::BitCastOp> { 305 public: 306 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern; 307 308 LogicalResult 309 matchAndRewrite(vector::BitCastOp bitCastOp, ArrayRef<Value> operands, 310 ConversionPatternRewriter &rewriter) const override { 311 // Only 1-D vectors can be lowered to LLVM. 312 VectorType resultTy = bitCastOp.getType(); 313 if (resultTy.getRank() != 1) 314 return failure(); 315 Type newResultTy = typeConverter->convertType(resultTy); 316 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy, 317 operands[0]); 318 return success(); 319 } 320 }; 321 322 /// Conversion pattern for a vector.matrix_multiply. 323 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 324 class VectorMatmulOpConversion 325 : public ConvertOpToLLVMPattern<vector::MatmulOp> { 326 public: 327 using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern; 328 329 LogicalResult 330 matchAndRewrite(vector::MatmulOp matmulOp, ArrayRef<Value> operands, 331 ConversionPatternRewriter &rewriter) const override { 332 auto adaptor = vector::MatmulOpAdaptor(operands); 333 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 334 matmulOp, typeConverter->convertType(matmulOp.res().getType()), 335 adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(), 336 matmulOp.lhs_columns(), matmulOp.rhs_columns()); 337 return success(); 338 } 339 }; 340 341 /// Conversion pattern for a vector.flat_transpose. 342 /// This is lowered directly to the proper llvm.intr.matrix.transpose. 343 class VectorFlatTransposeOpConversion 344 : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> { 345 public: 346 using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern; 347 348 LogicalResult 349 matchAndRewrite(vector::FlatTransposeOp transOp, ArrayRef<Value> operands, 350 ConversionPatternRewriter &rewriter) const override { 351 auto adaptor = vector::FlatTransposeOpAdaptor(operands); 352 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 353 transOp, typeConverter->convertType(transOp.res().getType()), 354 adaptor.matrix(), transOp.rows(), transOp.columns()); 355 return success(); 356 } 357 }; 358 359 /// Conversion pattern for a vector.maskedload. 360 class VectorMaskedLoadOpConversion 361 : public ConvertOpToLLVMPattern<vector::MaskedLoadOp> { 362 public: 363 using ConvertOpToLLVMPattern<vector::MaskedLoadOp>::ConvertOpToLLVMPattern; 364 365 LogicalResult 366 matchAndRewrite(vector::MaskedLoadOp load, ArrayRef<Value> operands, 367 ConversionPatternRewriter &rewriter) const override { 368 auto loc = load->getLoc(); 369 auto adaptor = vector::MaskedLoadOpAdaptor(operands); 370 MemRefType memRefType = load.getMemRefType(); 371 372 // Resolve alignment. 373 unsigned align; 374 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 375 return failure(); 376 377 // Resolve address. 378 auto vtype = typeConverter->convertType(load.getResultVectorType()); 379 Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(), 380 adaptor.indices(), rewriter); 381 Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype); 382 383 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 384 load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(), 385 rewriter.getI32IntegerAttr(align)); 386 return success(); 387 } 388 }; 389 390 /// Conversion pattern for a vector.maskedstore. 391 class VectorMaskedStoreOpConversion 392 : public ConvertOpToLLVMPattern<vector::MaskedStoreOp> { 393 public: 394 using ConvertOpToLLVMPattern<vector::MaskedStoreOp>::ConvertOpToLLVMPattern; 395 396 LogicalResult 397 matchAndRewrite(vector::MaskedStoreOp store, ArrayRef<Value> operands, 398 ConversionPatternRewriter &rewriter) const override { 399 auto loc = store->getLoc(); 400 auto adaptor = vector::MaskedStoreOpAdaptor(operands); 401 MemRefType memRefType = store.getMemRefType(); 402 403 // Resolve alignment. 404 unsigned align; 405 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 406 return failure(); 407 408 // Resolve address. 409 auto vtype = typeConverter->convertType(store.getValueVectorType()); 410 Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(), 411 adaptor.indices(), rewriter); 412 Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype); 413 414 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 415 store, adaptor.value(), ptr, adaptor.mask(), 416 rewriter.getI32IntegerAttr(align)); 417 return success(); 418 } 419 }; 420 421 /// Conversion pattern for a vector.gather. 422 class VectorGatherOpConversion 423 : public ConvertOpToLLVMPattern<vector::GatherOp> { 424 public: 425 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern; 426 427 LogicalResult 428 matchAndRewrite(vector::GatherOp gather, ArrayRef<Value> operands, 429 ConversionPatternRewriter &rewriter) const override { 430 auto loc = gather->getLoc(); 431 auto adaptor = vector::GatherOpAdaptor(operands); 432 433 // Resolve alignment. 434 unsigned align; 435 if (failed(getMemRefAlignment(*getTypeConverter(), gather.getMemRefType(), 436 align))) 437 return failure(); 438 439 // Get index ptrs. 440 VectorType vType = gather.getResultVectorType(); 441 Type iType = gather.getIndicesVectorType().getElementType(); 442 Value ptrs; 443 if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(), 444 gather.getMemRefType(), vType, iType, ptrs))) 445 return failure(); 446 447 // Replace with the gather intrinsic. 448 rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 449 gather, typeConverter->convertType(vType), ptrs, adaptor.mask(), 450 adaptor.pass_thru(), rewriter.getI32IntegerAttr(align)); 451 return success(); 452 } 453 }; 454 455 /// Conversion pattern for a vector.scatter. 456 class VectorScatterOpConversion 457 : public ConvertOpToLLVMPattern<vector::ScatterOp> { 458 public: 459 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern; 460 461 LogicalResult 462 matchAndRewrite(vector::ScatterOp scatter, ArrayRef<Value> operands, 463 ConversionPatternRewriter &rewriter) const override { 464 auto loc = scatter->getLoc(); 465 auto adaptor = vector::ScatterOpAdaptor(operands); 466 467 // Resolve alignment. 468 unsigned align; 469 if (failed(getMemRefAlignment(*getTypeConverter(), scatter.getMemRefType(), 470 align))) 471 return failure(); 472 473 // Get index ptrs. 474 VectorType vType = scatter.getValueVectorType(); 475 Type iType = scatter.getIndicesVectorType().getElementType(); 476 Value ptrs; 477 if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(), 478 scatter.getMemRefType(), vType, iType, ptrs))) 479 return failure(); 480 481 // Replace with the scatter intrinsic. 482 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 483 scatter, adaptor.value(), ptrs, adaptor.mask(), 484 rewriter.getI32IntegerAttr(align)); 485 return success(); 486 } 487 }; 488 489 /// Conversion pattern for a vector.expandload. 490 class VectorExpandLoadOpConversion 491 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> { 492 public: 493 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern; 494 495 LogicalResult 496 matchAndRewrite(vector::ExpandLoadOp expand, ArrayRef<Value> operands, 497 ConversionPatternRewriter &rewriter) const override { 498 auto loc = expand->getLoc(); 499 auto adaptor = vector::ExpandLoadOpAdaptor(operands); 500 MemRefType memRefType = expand.getMemRefType(); 501 502 // Resolve address. 503 auto vtype = typeConverter->convertType(expand.getResultVectorType()); 504 Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(), 505 adaptor.indices(), rewriter); 506 507 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 508 expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru()); 509 return success(); 510 } 511 }; 512 513 /// Conversion pattern for a vector.compressstore. 514 class VectorCompressStoreOpConversion 515 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> { 516 public: 517 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern; 518 519 LogicalResult 520 matchAndRewrite(vector::CompressStoreOp compress, ArrayRef<Value> operands, 521 ConversionPatternRewriter &rewriter) const override { 522 auto loc = compress->getLoc(); 523 auto adaptor = vector::CompressStoreOpAdaptor(operands); 524 MemRefType memRefType = compress.getMemRefType(); 525 526 // Resolve address. 527 Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(), 528 adaptor.indices(), rewriter); 529 530 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 531 compress, adaptor.value(), ptr, adaptor.mask()); 532 return success(); 533 } 534 }; 535 536 /// Conversion pattern for all vector reductions. 537 class VectorReductionOpConversion 538 : public ConvertOpToLLVMPattern<vector::ReductionOp> { 539 public: 540 explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv, 541 bool reassociateFPRed) 542 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv), 543 reassociateFPReductions(reassociateFPRed) {} 544 545 LogicalResult 546 matchAndRewrite(vector::ReductionOp reductionOp, ArrayRef<Value> operands, 547 ConversionPatternRewriter &rewriter) const override { 548 auto kind = reductionOp.kind(); 549 Type eltType = reductionOp.dest().getType(); 550 Type llvmType = typeConverter->convertType(eltType); 551 if (eltType.isIntOrIndex()) { 552 // Integer reductions: add/mul/min/max/and/or/xor. 553 if (kind == "add") 554 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>( 555 reductionOp, llvmType, operands[0]); 556 else if (kind == "mul") 557 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>( 558 reductionOp, llvmType, operands[0]); 559 else if (kind == "min" && 560 (eltType.isIndex() || eltType.isUnsignedInteger())) 561 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>( 562 reductionOp, llvmType, operands[0]); 563 else if (kind == "min") 564 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>( 565 reductionOp, llvmType, operands[0]); 566 else if (kind == "max" && 567 (eltType.isIndex() || eltType.isUnsignedInteger())) 568 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>( 569 reductionOp, llvmType, operands[0]); 570 else if (kind == "max") 571 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>( 572 reductionOp, llvmType, operands[0]); 573 else if (kind == "and") 574 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>( 575 reductionOp, llvmType, operands[0]); 576 else if (kind == "or") 577 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>( 578 reductionOp, llvmType, operands[0]); 579 else if (kind == "xor") 580 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>( 581 reductionOp, llvmType, operands[0]); 582 else 583 return failure(); 584 return success(); 585 } 586 587 if (!eltType.isa<FloatType>()) 588 return failure(); 589 590 // Floating-point reductions: add/mul/min/max 591 if (kind == "add") { 592 // Optional accumulator (or zero). 593 Value acc = operands.size() > 1 ? operands[1] 594 : rewriter.create<LLVM::ConstantOp>( 595 reductionOp->getLoc(), llvmType, 596 rewriter.getZeroAttr(eltType)); 597 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>( 598 reductionOp, llvmType, acc, operands[0], 599 rewriter.getBoolAttr(reassociateFPReductions)); 600 } else if (kind == "mul") { 601 // Optional accumulator (or one). 602 Value acc = operands.size() > 1 603 ? operands[1] 604 : rewriter.create<LLVM::ConstantOp>( 605 reductionOp->getLoc(), llvmType, 606 rewriter.getFloatAttr(eltType, 1.0)); 607 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>( 608 reductionOp, llvmType, acc, operands[0], 609 rewriter.getBoolAttr(reassociateFPReductions)); 610 } else if (kind == "min") 611 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>( 612 reductionOp, llvmType, operands[0]); 613 else if (kind == "max") 614 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>( 615 reductionOp, llvmType, operands[0]); 616 else 617 return failure(); 618 return success(); 619 } 620 621 private: 622 const bool reassociateFPReductions; 623 }; 624 625 /// Conversion pattern for a vector.create_mask (1-D only). 626 class VectorCreateMaskOpConversion 627 : public ConvertOpToLLVMPattern<vector::CreateMaskOp> { 628 public: 629 explicit VectorCreateMaskOpConversion(LLVMTypeConverter &typeConv, 630 bool enableIndexOpt) 631 : ConvertOpToLLVMPattern<vector::CreateMaskOp>(typeConv), 632 enableIndexOptimizations(enableIndexOpt) {} 633 634 LogicalResult 635 matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands, 636 ConversionPatternRewriter &rewriter) const override { 637 auto dstType = op.getType(); 638 int64_t rank = dstType.getRank(); 639 if (rank == 1) { 640 rewriter.replaceOp( 641 op, buildVectorComparison(rewriter, op, enableIndexOptimizations, 642 dstType.getDimSize(0), operands[0])); 643 return success(); 644 } 645 return failure(); 646 } 647 648 private: 649 const bool enableIndexOptimizations; 650 }; 651 652 class VectorShuffleOpConversion 653 : public ConvertOpToLLVMPattern<vector::ShuffleOp> { 654 public: 655 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern; 656 657 LogicalResult 658 matchAndRewrite(vector::ShuffleOp shuffleOp, ArrayRef<Value> operands, 659 ConversionPatternRewriter &rewriter) const override { 660 auto loc = shuffleOp->getLoc(); 661 auto adaptor = vector::ShuffleOpAdaptor(operands); 662 auto v1Type = shuffleOp.getV1VectorType(); 663 auto v2Type = shuffleOp.getV2VectorType(); 664 auto vectorType = shuffleOp.getVectorType(); 665 Type llvmType = typeConverter->convertType(vectorType); 666 auto maskArrayAttr = shuffleOp.mask(); 667 668 // Bail if result type cannot be lowered. 669 if (!llvmType) 670 return failure(); 671 672 // Get rank and dimension sizes. 673 int64_t rank = vectorType.getRank(); 674 assert(v1Type.getRank() == rank); 675 assert(v2Type.getRank() == rank); 676 int64_t v1Dim = v1Type.getDimSize(0); 677 678 // For rank 1, where both operands have *exactly* the same vector type, 679 // there is direct shuffle support in LLVM. Use it! 680 if (rank == 1 && v1Type == v2Type) { 681 Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>( 682 loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); 683 rewriter.replaceOp(shuffleOp, llvmShuffleOp); 684 return success(); 685 } 686 687 // For all other cases, insert the individual values individually. 688 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 689 int64_t insPos = 0; 690 for (auto en : llvm::enumerate(maskArrayAttr)) { 691 int64_t extPos = en.value().cast<IntegerAttr>().getInt(); 692 Value value = adaptor.v1(); 693 if (extPos >= v1Dim) { 694 extPos -= v1Dim; 695 value = adaptor.v2(); 696 } 697 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, 698 llvmType, rank, extPos); 699 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, 700 llvmType, rank, insPos++); 701 } 702 rewriter.replaceOp(shuffleOp, insert); 703 return success(); 704 } 705 }; 706 707 class VectorExtractElementOpConversion 708 : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { 709 public: 710 using ConvertOpToLLVMPattern< 711 vector::ExtractElementOp>::ConvertOpToLLVMPattern; 712 713 LogicalResult 714 matchAndRewrite(vector::ExtractElementOp extractEltOp, 715 ArrayRef<Value> operands, 716 ConversionPatternRewriter &rewriter) const override { 717 auto adaptor = vector::ExtractElementOpAdaptor(operands); 718 auto vectorType = extractEltOp.getVectorType(); 719 auto llvmType = typeConverter->convertType(vectorType.getElementType()); 720 721 // Bail if result type cannot be lowered. 722 if (!llvmType) 723 return failure(); 724 725 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 726 extractEltOp, llvmType, adaptor.vector(), adaptor.position()); 727 return success(); 728 } 729 }; 730 731 class VectorExtractOpConversion 732 : public ConvertOpToLLVMPattern<vector::ExtractOp> { 733 public: 734 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern; 735 736 LogicalResult 737 matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands, 738 ConversionPatternRewriter &rewriter) const override { 739 auto loc = extractOp->getLoc(); 740 auto adaptor = vector::ExtractOpAdaptor(operands); 741 auto vectorType = extractOp.getVectorType(); 742 auto resultType = extractOp.getResult().getType(); 743 auto llvmResultType = typeConverter->convertType(resultType); 744 auto positionArrayAttr = extractOp.position(); 745 746 // Bail if result type cannot be lowered. 747 if (!llvmResultType) 748 return failure(); 749 750 // One-shot extraction of vector from array (only requires extractvalue). 751 if (resultType.isa<VectorType>()) { 752 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 753 loc, llvmResultType, adaptor.vector(), positionArrayAttr); 754 rewriter.replaceOp(extractOp, extracted); 755 return success(); 756 } 757 758 // Potential extraction of 1-D vector from array. 759 auto *context = extractOp->getContext(); 760 Value extracted = adaptor.vector(); 761 auto positionAttrs = positionArrayAttr.getValue(); 762 if (positionAttrs.size() > 1) { 763 auto oneDVectorType = reducedVectorTypeBack(vectorType); 764 auto nMinusOnePositionAttrs = 765 ArrayAttr::get(positionAttrs.drop_back(), context); 766 extracted = rewriter.create<LLVM::ExtractValueOp>( 767 loc, typeConverter->convertType(oneDVectorType), extracted, 768 nMinusOnePositionAttrs); 769 } 770 771 // Remaining extraction of element from 1-D LLVM vector 772 auto position = positionAttrs.back().cast<IntegerAttr>(); 773 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 774 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 775 extracted = 776 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 777 rewriter.replaceOp(extractOp, extracted); 778 779 return success(); 780 } 781 }; 782 783 /// Conversion pattern that turns a vector.fma on a 1-D vector 784 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 785 /// This does not match vectors of n >= 2 rank. 786 /// 787 /// Example: 788 /// ``` 789 /// vector.fma %a, %a, %a : vector<8xf32> 790 /// ``` 791 /// is converted to: 792 /// ``` 793 /// llvm.intr.fmuladd %va, %va, %va: 794 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">) 795 /// -> !llvm."<8 x f32>"> 796 /// ``` 797 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> { 798 public: 799 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern; 800 801 LogicalResult 802 matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands, 803 ConversionPatternRewriter &rewriter) const override { 804 auto adaptor = vector::FMAOpAdaptor(operands); 805 VectorType vType = fmaOp.getVectorType(); 806 if (vType.getRank() != 1) 807 return failure(); 808 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(), 809 adaptor.rhs(), adaptor.acc()); 810 return success(); 811 } 812 }; 813 814 class VectorInsertElementOpConversion 815 : public ConvertOpToLLVMPattern<vector::InsertElementOp> { 816 public: 817 using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; 818 819 LogicalResult 820 matchAndRewrite(vector::InsertElementOp insertEltOp, ArrayRef<Value> operands, 821 ConversionPatternRewriter &rewriter) const override { 822 auto adaptor = vector::InsertElementOpAdaptor(operands); 823 auto vectorType = insertEltOp.getDestVectorType(); 824 auto llvmType = typeConverter->convertType(vectorType); 825 826 // Bail if result type cannot be lowered. 827 if (!llvmType) 828 return failure(); 829 830 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 831 insertEltOp, llvmType, adaptor.dest(), adaptor.source(), 832 adaptor.position()); 833 return success(); 834 } 835 }; 836 837 class VectorInsertOpConversion 838 : public ConvertOpToLLVMPattern<vector::InsertOp> { 839 public: 840 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern; 841 842 LogicalResult 843 matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands, 844 ConversionPatternRewriter &rewriter) const override { 845 auto loc = insertOp->getLoc(); 846 auto adaptor = vector::InsertOpAdaptor(operands); 847 auto sourceType = insertOp.getSourceType(); 848 auto destVectorType = insertOp.getDestVectorType(); 849 auto llvmResultType = typeConverter->convertType(destVectorType); 850 auto positionArrayAttr = insertOp.position(); 851 852 // Bail if result type cannot be lowered. 853 if (!llvmResultType) 854 return failure(); 855 856 // One-shot insertion of a vector into an array (only requires insertvalue). 857 if (sourceType.isa<VectorType>()) { 858 Value inserted = rewriter.create<LLVM::InsertValueOp>( 859 loc, llvmResultType, adaptor.dest(), adaptor.source(), 860 positionArrayAttr); 861 rewriter.replaceOp(insertOp, inserted); 862 return success(); 863 } 864 865 // Potential extraction of 1-D vector from array. 866 auto *context = insertOp->getContext(); 867 Value extracted = adaptor.dest(); 868 auto positionAttrs = positionArrayAttr.getValue(); 869 auto position = positionAttrs.back().cast<IntegerAttr>(); 870 auto oneDVectorType = destVectorType; 871 if (positionAttrs.size() > 1) { 872 oneDVectorType = reducedVectorTypeBack(destVectorType); 873 auto nMinusOnePositionAttrs = 874 ArrayAttr::get(positionAttrs.drop_back(), context); 875 extracted = rewriter.create<LLVM::ExtractValueOp>( 876 loc, typeConverter->convertType(oneDVectorType), extracted, 877 nMinusOnePositionAttrs); 878 } 879 880 // Insertion of an element into a 1-D LLVM vector. 881 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 882 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 883 Value inserted = rewriter.create<LLVM::InsertElementOp>( 884 loc, typeConverter->convertType(oneDVectorType), extracted, 885 adaptor.source(), constant); 886 887 // Potential insertion of resulting 1-D vector into array. 888 if (positionAttrs.size() > 1) { 889 auto nMinusOnePositionAttrs = 890 ArrayAttr::get(positionAttrs.drop_back(), context); 891 inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType, 892 adaptor.dest(), inserted, 893 nMinusOnePositionAttrs); 894 } 895 896 rewriter.replaceOp(insertOp, inserted); 897 return success(); 898 } 899 }; 900 901 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 902 /// 903 /// Example: 904 /// ``` 905 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 906 /// ``` 907 /// is rewritten into: 908 /// ``` 909 /// %r = splat %f0: vector<2x4xf32> 910 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 911 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 912 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 913 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 914 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 915 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 916 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 917 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 918 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 919 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 920 /// // %r3 holds the final value. 921 /// ``` 922 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 923 public: 924 using OpRewritePattern<FMAOp>::OpRewritePattern; 925 926 LogicalResult matchAndRewrite(FMAOp op, 927 PatternRewriter &rewriter) const override { 928 auto vType = op.getVectorType(); 929 if (vType.getRank() < 2) 930 return failure(); 931 932 auto loc = op.getLoc(); 933 auto elemType = vType.getElementType(); 934 Value zero = rewriter.create<ConstantOp>(loc, elemType, 935 rewriter.getZeroAttr(elemType)); 936 Value desc = rewriter.create<SplatOp>(loc, vType, zero); 937 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 938 Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i); 939 Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i); 940 Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i); 941 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 942 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 943 } 944 rewriter.replaceOp(op, desc); 945 return success(); 946 } 947 }; 948 949 // When ranks are different, InsertStridedSlice needs to extract a properly 950 // ranked vector from the destination vector into which to insert. This pattern 951 // only takes care of this part and forwards the rest of the conversion to 952 // another pattern that converts InsertStridedSlice for operands of the same 953 // rank. 954 // 955 // RewritePattern for InsertStridedSliceOp where source and destination vectors 956 // have different ranks. In this case: 957 // 1. the proper subvector is extracted from the destination vector 958 // 2. a new InsertStridedSlice op is created to insert the source in the 959 // destination subvector 960 // 3. the destination subvector is inserted back in the proper place 961 // 4. the op is replaced by the result of step 3. 962 // The new InsertStridedSlice from step 2. will be picked up by a 963 // `VectorInsertStridedSliceOpSameRankRewritePattern`. 964 class VectorInsertStridedSliceOpDifferentRankRewritePattern 965 : public OpRewritePattern<InsertStridedSliceOp> { 966 public: 967 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 968 969 LogicalResult matchAndRewrite(InsertStridedSliceOp op, 970 PatternRewriter &rewriter) const override { 971 auto srcType = op.getSourceVectorType(); 972 auto dstType = op.getDestVectorType(); 973 974 if (op.offsets().getValue().empty()) 975 return failure(); 976 977 auto loc = op.getLoc(); 978 int64_t rankDiff = dstType.getRank() - srcType.getRank(); 979 assert(rankDiff >= 0); 980 if (rankDiff == 0) 981 return failure(); 982 983 int64_t rankRest = dstType.getRank() - rankDiff; 984 // Extract / insert the subvector of matching rank and InsertStridedSlice 985 // on it. 986 Value extracted = 987 rewriter.create<ExtractOp>(loc, op.dest(), 988 getI64SubArray(op.offsets(), /*dropFront=*/0, 989 /*dropBack=*/rankRest)); 990 // A different pattern will kick in for InsertStridedSlice with matching 991 // ranks. 992 auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>( 993 loc, op.source(), extracted, 994 getI64SubArray(op.offsets(), /*dropFront=*/rankDiff), 995 getI64SubArray(op.strides(), /*dropFront=*/0)); 996 rewriter.replaceOpWithNewOp<InsertOp>( 997 op, stridedSliceInnerOp.getResult(), op.dest(), 998 getI64SubArray(op.offsets(), /*dropFront=*/0, 999 /*dropBack=*/rankRest)); 1000 return success(); 1001 } 1002 }; 1003 1004 // RewritePattern for InsertStridedSliceOp where source and destination vectors 1005 // have the same rank. In this case, we reduce 1006 // 1. the proper subvector is extracted from the destination vector 1007 // 2. a new InsertStridedSlice op is created to insert the source in the 1008 // destination subvector 1009 // 3. the destination subvector is inserted back in the proper place 1010 // 4. the op is replaced by the result of step 3. 1011 // The new InsertStridedSlice from step 2. will be picked up by a 1012 // `VectorInsertStridedSliceOpSameRankRewritePattern`. 1013 class VectorInsertStridedSliceOpSameRankRewritePattern 1014 : public OpRewritePattern<InsertStridedSliceOp> { 1015 public: 1016 VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext *ctx) 1017 : OpRewritePattern<InsertStridedSliceOp>(ctx) { 1018 // This pattern creates recursive InsertStridedSliceOp, but the recursion is 1019 // bounded as the rank is strictly decreasing. 1020 setHasBoundedRewriteRecursion(); 1021 } 1022 1023 LogicalResult matchAndRewrite(InsertStridedSliceOp op, 1024 PatternRewriter &rewriter) const override { 1025 auto srcType = op.getSourceVectorType(); 1026 auto dstType = op.getDestVectorType(); 1027 1028 if (op.offsets().getValue().empty()) 1029 return failure(); 1030 1031 int64_t rankDiff = dstType.getRank() - srcType.getRank(); 1032 assert(rankDiff >= 0); 1033 if (rankDiff != 0) 1034 return failure(); 1035 1036 if (srcType == dstType) { 1037 rewriter.replaceOp(op, op.source()); 1038 return success(); 1039 } 1040 1041 int64_t offset = 1042 op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 1043 int64_t size = srcType.getShape().front(); 1044 int64_t stride = 1045 op.strides().getValue().front().cast<IntegerAttr>().getInt(); 1046 1047 auto loc = op.getLoc(); 1048 Value res = op.dest(); 1049 // For each slice of the source vector along the most major dimension. 1050 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 1051 off += stride, ++idx) { 1052 // 1. extract the proper subvector (or element) from source 1053 Value extractedSource = extractOne(rewriter, loc, op.source(), idx); 1054 if (extractedSource.getType().isa<VectorType>()) { 1055 // 2. If we have a vector, extract the proper subvector from destination 1056 // Otherwise we are at the element level and no need to recurse. 1057 Value extractedDest = extractOne(rewriter, loc, op.dest(), off); 1058 // 3. Reduce the problem to lowering a new InsertStridedSlice op with 1059 // smaller rank. 1060 extractedSource = rewriter.create<InsertStridedSliceOp>( 1061 loc, extractedSource, extractedDest, 1062 getI64SubArray(op.offsets(), /* dropFront=*/1), 1063 getI64SubArray(op.strides(), /* dropFront=*/1)); 1064 } 1065 // 4. Insert the extractedSource into the res vector. 1066 res = insertOne(rewriter, loc, extractedSource, res, off); 1067 } 1068 1069 rewriter.replaceOp(op, res); 1070 return success(); 1071 } 1072 }; 1073 1074 /// Returns the strides if the memory underlying `memRefType` has a contiguous 1075 /// static layout. 1076 static llvm::Optional<SmallVector<int64_t, 4>> 1077 computeContiguousStrides(MemRefType memRefType) { 1078 int64_t offset; 1079 SmallVector<int64_t, 4> strides; 1080 if (failed(getStridesAndOffset(memRefType, strides, offset))) 1081 return None; 1082 if (!strides.empty() && strides.back() != 1) 1083 return None; 1084 // If no layout or identity layout, this is contiguous by definition. 1085 if (memRefType.getAffineMaps().empty() || 1086 memRefType.getAffineMaps().front().isIdentity()) 1087 return strides; 1088 1089 // Otherwise, we must determine contiguity form shapes. This can only ever 1090 // work in static cases because MemRefType is underspecified to represent 1091 // contiguous dynamic shapes in other ways than with just empty/identity 1092 // layout. 1093 auto sizes = memRefType.getShape(); 1094 for (int index = 0, e = strides.size() - 2; index < e; ++index) { 1095 if (ShapedType::isDynamic(sizes[index + 1]) || 1096 ShapedType::isDynamicStrideOrOffset(strides[index]) || 1097 ShapedType::isDynamicStrideOrOffset(strides[index + 1])) 1098 return None; 1099 if (strides[index] != strides[index + 1] * sizes[index + 1]) 1100 return None; 1101 } 1102 return strides; 1103 } 1104 1105 class VectorTypeCastOpConversion 1106 : public ConvertOpToLLVMPattern<vector::TypeCastOp> { 1107 public: 1108 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern; 1109 1110 LogicalResult 1111 matchAndRewrite(vector::TypeCastOp castOp, ArrayRef<Value> operands, 1112 ConversionPatternRewriter &rewriter) const override { 1113 auto loc = castOp->getLoc(); 1114 MemRefType sourceMemRefType = 1115 castOp.getOperand().getType().cast<MemRefType>(); 1116 MemRefType targetMemRefType = castOp.getType(); 1117 1118 // Only static shape casts supported atm. 1119 if (!sourceMemRefType.hasStaticShape() || 1120 !targetMemRefType.hasStaticShape()) 1121 return failure(); 1122 1123 auto llvmSourceDescriptorTy = 1124 operands[0].getType().dyn_cast<LLVM::LLVMStructType>(); 1125 if (!llvmSourceDescriptorTy) 1126 return failure(); 1127 MemRefDescriptor sourceMemRef(operands[0]); 1128 1129 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 1130 .dyn_cast_or_null<LLVM::LLVMStructType>(); 1131 if (!llvmTargetDescriptorTy) 1132 return failure(); 1133 1134 // Only contiguous source buffers supported atm. 1135 auto sourceStrides = computeContiguousStrides(sourceMemRefType); 1136 if (!sourceStrides) 1137 return failure(); 1138 auto targetStrides = computeContiguousStrides(targetMemRefType); 1139 if (!targetStrides) 1140 return failure(); 1141 // Only support static strides for now, regardless of contiguity. 1142 if (llvm::any_of(*targetStrides, [](int64_t stride) { 1143 return ShapedType::isDynamicStrideOrOffset(stride); 1144 })) 1145 return failure(); 1146 1147 auto int64Ty = IntegerType::get(rewriter.getContext(), 64); 1148 1149 // Create descriptor. 1150 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1151 Type llvmTargetElementTy = desc.getElementPtrType(); 1152 // Set allocated ptr. 1153 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 1154 allocated = 1155 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 1156 desc.setAllocatedPtr(rewriter, loc, allocated); 1157 // Set aligned ptr. 1158 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 1159 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 1160 desc.setAlignedPtr(rewriter, loc, ptr); 1161 // Fill offset 0. 1162 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 1163 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 1164 desc.setOffset(rewriter, loc, zero); 1165 1166 // Fill size and stride descriptors in memref. 1167 for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { 1168 int64_t index = indexedSize.index(); 1169 auto sizeAttr = 1170 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 1171 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 1172 desc.setSize(rewriter, loc, index, size); 1173 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 1174 (*targetStrides)[index]); 1175 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 1176 desc.setStride(rewriter, loc, index, stride); 1177 } 1178 1179 rewriter.replaceOp(castOp, {desc}); 1180 return success(); 1181 } 1182 }; 1183 1184 /// Conversion pattern that converts a 1-D vector transfer read/write op in a 1185 /// sequence of: 1186 /// 1. Get the source/dst address as an LLVM vector pointer. 1187 /// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. 1188 /// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 1189 /// 4. Create a mask where offsetVector is compared against memref upper bound. 1190 /// 5. Rewrite op as a masked read or write. 1191 template <typename ConcreteOp> 1192 class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> { 1193 public: 1194 explicit VectorTransferConversion(LLVMTypeConverter &typeConv, 1195 bool enableIndexOpt) 1196 : ConvertOpToLLVMPattern<ConcreteOp>(typeConv), 1197 enableIndexOptimizations(enableIndexOpt) {} 1198 1199 LogicalResult 1200 matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands, 1201 ConversionPatternRewriter &rewriter) const override { 1202 auto adaptor = getTransferOpAdapter(xferOp, operands); 1203 1204 if (xferOp.getVectorType().getRank() > 1 || 1205 llvm::size(xferOp.indices()) == 0) 1206 return failure(); 1207 if (xferOp.permutation_map() != 1208 AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(), 1209 xferOp.getVectorType().getRank(), 1210 xferOp->getContext())) 1211 return failure(); 1212 auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>(); 1213 if (!memRefType) 1214 return failure(); 1215 // Only contiguous source tensors supported atm. 1216 auto strides = computeContiguousStrides(memRefType); 1217 if (!strides) 1218 return failure(); 1219 1220 auto toLLVMTy = [&](Type t) { 1221 return this->getTypeConverter()->convertType(t); 1222 }; 1223 1224 Location loc = xferOp->getLoc(); 1225 1226 if (auto memrefVectorElementType = 1227 memRefType.getElementType().template dyn_cast<VectorType>()) { 1228 // Memref has vector element type. 1229 if (memrefVectorElementType.getElementType() != 1230 xferOp.getVectorType().getElementType()) 1231 return failure(); 1232 #ifndef NDEBUG 1233 // Check that memref vector type is a suffix of 'vectorType. 1234 unsigned memrefVecEltRank = memrefVectorElementType.getRank(); 1235 unsigned resultVecRank = xferOp.getVectorType().getRank(); 1236 assert(memrefVecEltRank <= resultVecRank); 1237 // TODO: Move this to isSuffix in Vector/Utils.h. 1238 unsigned rankOffset = resultVecRank - memrefVecEltRank; 1239 auto memrefVecEltShape = memrefVectorElementType.getShape(); 1240 auto resultVecShape = xferOp.getVectorType().getShape(); 1241 for (unsigned i = 0; i < memrefVecEltRank; ++i) 1242 assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] && 1243 "memref vector element shape should match suffix of vector " 1244 "result shape."); 1245 #endif // ifndef NDEBUG 1246 } 1247 1248 // 1. Get the source/dst address as an LLVM vector pointer. 1249 VectorType vtp = xferOp.getVectorType(); 1250 Value dataPtr = this->getStridedElementPtr( 1251 loc, memRefType, adaptor.source(), adaptor.indices(), rewriter); 1252 Value vectorDataPtr = 1253 castDataPtr(rewriter, loc, dataPtr, memRefType, toLLVMTy(vtp)); 1254 1255 if (!xferOp.isMaskedDim(0)) 1256 return replaceTransferOpWithLoadOrStore(rewriter, 1257 *this->getTypeConverter(), loc, 1258 xferOp, operands, vectorDataPtr); 1259 1260 // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. 1261 // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 1262 // 4. Let dim the memref dimension, compute the vector comparison mask: 1263 // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] 1264 // 1265 // TODO: when the leaf transfer rank is k > 1, we need the last `k` 1266 // dimensions here. 1267 unsigned vecWidth = LLVM::getVectorNumElements(vtp).getFixedValue(); 1268 unsigned lastIndex = llvm::size(xferOp.indices()) - 1; 1269 Value off = xferOp.indices()[lastIndex]; 1270 Value dim = rewriter.create<DimOp>(loc, xferOp.source(), lastIndex); 1271 Value mask = buildVectorComparison( 1272 rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off); 1273 1274 // 5. Rewrite as a masked read / write. 1275 return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc, 1276 xferOp, operands, vectorDataPtr, mask); 1277 } 1278 1279 private: 1280 const bool enableIndexOptimizations; 1281 }; 1282 1283 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> { 1284 public: 1285 using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern; 1286 1287 // Proof-of-concept lowering implementation that relies on a small 1288 // runtime support library, which only needs to provide a few 1289 // printing methods (single value for all data types, opening/closing 1290 // bracket, comma, newline). The lowering fully unrolls a vector 1291 // in terms of these elementary printing operations. The advantage 1292 // of this approach is that the library can remain unaware of all 1293 // low-level implementation details of vectors while still supporting 1294 // output of any shaped and dimensioned vector. Due to full unrolling, 1295 // this approach is less suited for very large vectors though. 1296 // 1297 // TODO: rely solely on libc in future? something else? 1298 // 1299 LogicalResult 1300 matchAndRewrite(vector::PrintOp printOp, ArrayRef<Value> operands, 1301 ConversionPatternRewriter &rewriter) const override { 1302 auto adaptor = vector::PrintOpAdaptor(operands); 1303 Type printType = printOp.getPrintType(); 1304 1305 if (typeConverter->convertType(printType) == nullptr) 1306 return failure(); 1307 1308 // Make sure element type has runtime support. 1309 PrintConversion conversion = PrintConversion::None; 1310 VectorType vectorType = printType.dyn_cast<VectorType>(); 1311 Type eltType = vectorType ? vectorType.getElementType() : printType; 1312 Operation *printer; 1313 if (eltType.isF32()) { 1314 printer = getPrintFloat(printOp); 1315 } else if (eltType.isF64()) { 1316 printer = getPrintDouble(printOp); 1317 } else if (eltType.isIndex()) { 1318 printer = getPrintU64(printOp); 1319 } else if (auto intTy = eltType.dyn_cast<IntegerType>()) { 1320 // Integers need a zero or sign extension on the operand 1321 // (depending on the source type) as well as a signed or 1322 // unsigned print method. Up to 64-bit is supported. 1323 unsigned width = intTy.getWidth(); 1324 if (intTy.isUnsigned()) { 1325 if (width <= 64) { 1326 if (width < 64) 1327 conversion = PrintConversion::ZeroExt64; 1328 printer = getPrintU64(printOp); 1329 } else { 1330 return failure(); 1331 } 1332 } else { 1333 assert(intTy.isSignless() || intTy.isSigned()); 1334 if (width <= 64) { 1335 // Note that we *always* zero extend booleans (1-bit integers), 1336 // so that true/false is printed as 1/0 rather than -1/0. 1337 if (width == 1) 1338 conversion = PrintConversion::ZeroExt64; 1339 else if (width < 64) 1340 conversion = PrintConversion::SignExt64; 1341 printer = getPrintI64(printOp); 1342 } else { 1343 return failure(); 1344 } 1345 } 1346 } else { 1347 return failure(); 1348 } 1349 1350 // Unroll vector into elementary print calls. 1351 int64_t rank = vectorType ? vectorType.getRank() : 0; 1352 emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank, 1353 conversion); 1354 emitCall(rewriter, printOp->getLoc(), getPrintNewline(printOp)); 1355 rewriter.eraseOp(printOp); 1356 return success(); 1357 } 1358 1359 private: 1360 enum class PrintConversion { 1361 // clang-format off 1362 None, 1363 ZeroExt64, 1364 SignExt64 1365 // clang-format on 1366 }; 1367 1368 void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 1369 Value value, VectorType vectorType, Operation *printer, 1370 int64_t rank, PrintConversion conversion) const { 1371 Location loc = op->getLoc(); 1372 if (rank == 0) { 1373 switch (conversion) { 1374 case PrintConversion::ZeroExt64: 1375 value = rewriter.create<ZeroExtendIOp>( 1376 loc, value, IntegerType::get(rewriter.getContext(), 64)); 1377 break; 1378 case PrintConversion::SignExt64: 1379 value = rewriter.create<SignExtendIOp>( 1380 loc, value, IntegerType::get(rewriter.getContext(), 64)); 1381 break; 1382 case PrintConversion::None: 1383 break; 1384 } 1385 emitCall(rewriter, loc, printer, value); 1386 return; 1387 } 1388 1389 emitCall(rewriter, loc, getPrintOpen(op)); 1390 Operation *printComma = getPrintComma(op); 1391 int64_t dim = vectorType.getDimSize(0); 1392 for (int64_t d = 0; d < dim; ++d) { 1393 auto reducedType = 1394 rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr; 1395 auto llvmType = typeConverter->convertType( 1396 rank > 1 ? reducedType : vectorType.getElementType()); 1397 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1398 llvmType, rank, d); 1399 emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1, 1400 conversion); 1401 if (d != dim - 1) 1402 emitCall(rewriter, loc, printComma); 1403 } 1404 emitCall(rewriter, loc, getPrintClose(op)); 1405 } 1406 1407 // Helper to emit a call. 1408 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1409 Operation *ref, ValueRange params = ValueRange()) { 1410 rewriter.create<LLVM::CallOp>(loc, TypeRange(), 1411 rewriter.getSymbolRefAttr(ref), params); 1412 } 1413 1414 // Helper for printer method declaration (first hit) and lookup. 1415 static Operation *getPrint(Operation *op, StringRef name, 1416 ArrayRef<Type> params) { 1417 auto module = op->getParentOfType<ModuleOp>(); 1418 auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name); 1419 if (func) 1420 return func; 1421 OpBuilder moduleBuilder(module.getBodyRegion()); 1422 return moduleBuilder.create<LLVM::LLVMFuncOp>( 1423 op->getLoc(), name, 1424 LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(op->getContext()), 1425 params)); 1426 } 1427 1428 // Helpers for method names. 1429 Operation *getPrintI64(Operation *op) const { 1430 return getPrint(op, "printI64", IntegerType::get(op->getContext(), 64)); 1431 } 1432 Operation *getPrintU64(Operation *op) const { 1433 return getPrint(op, "printU64", IntegerType::get(op->getContext(), 64)); 1434 } 1435 Operation *getPrintFloat(Operation *op) const { 1436 return getPrint(op, "printF32", Float32Type::get(op->getContext())); 1437 } 1438 Operation *getPrintDouble(Operation *op) const { 1439 return getPrint(op, "printF64", Float64Type::get(op->getContext())); 1440 } 1441 Operation *getPrintOpen(Operation *op) const { 1442 return getPrint(op, "printOpen", {}); 1443 } 1444 Operation *getPrintClose(Operation *op) const { 1445 return getPrint(op, "printClose", {}); 1446 } 1447 Operation *getPrintComma(Operation *op) const { 1448 return getPrint(op, "printComma", {}); 1449 } 1450 Operation *getPrintNewline(Operation *op) const { 1451 return getPrint(op, "printNewline", {}); 1452 } 1453 }; 1454 1455 /// Progressive lowering of ExtractStridedSliceOp to either: 1456 /// 1. express single offset extract as a direct shuffle. 1457 /// 2. extract + lower rank strided_slice + insert for the n-D case. 1458 class VectorExtractStridedSliceOpConversion 1459 : public OpRewritePattern<ExtractStridedSliceOp> { 1460 public: 1461 VectorExtractStridedSliceOpConversion(MLIRContext *ctx) 1462 : OpRewritePattern<ExtractStridedSliceOp>(ctx) { 1463 // This pattern creates recursive ExtractStridedSliceOp, but the recursion 1464 // is bounded as the rank is strictly decreasing. 1465 setHasBoundedRewriteRecursion(); 1466 } 1467 1468 LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 1469 PatternRewriter &rewriter) const override { 1470 auto dstType = op.getType(); 1471 1472 assert(!op.offsets().getValue().empty() && "Unexpected empty offsets"); 1473 1474 int64_t offset = 1475 op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 1476 int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt(); 1477 int64_t stride = 1478 op.strides().getValue().front().cast<IntegerAttr>().getInt(); 1479 1480 auto loc = op.getLoc(); 1481 auto elemType = dstType.getElementType(); 1482 assert(elemType.isSignlessIntOrIndexOrFloat()); 1483 1484 // Single offset can be more efficiently shuffled. 1485 if (op.offsets().getValue().size() == 1) { 1486 SmallVector<int64_t, 4> offsets; 1487 offsets.reserve(size); 1488 for (int64_t off = offset, e = offset + size * stride; off < e; 1489 off += stride) 1490 offsets.push_back(off); 1491 rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(), 1492 op.vector(), 1493 rewriter.getI64ArrayAttr(offsets)); 1494 return success(); 1495 } 1496 1497 // Extract/insert on a lower ranked extract strided slice op. 1498 Value zero = rewriter.create<ConstantOp>(loc, elemType, 1499 rewriter.getZeroAttr(elemType)); 1500 Value res = rewriter.create<SplatOp>(loc, dstType, zero); 1501 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 1502 off += stride, ++idx) { 1503 Value one = extractOne(rewriter, loc, op.vector(), off); 1504 Value extracted = rewriter.create<ExtractStridedSliceOp>( 1505 loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1), 1506 getI64SubArray(op.sizes(), /* dropFront=*/1), 1507 getI64SubArray(op.strides(), /* dropFront=*/1)); 1508 res = insertOne(rewriter, loc, extracted, res, idx); 1509 } 1510 rewriter.replaceOp(op, res); 1511 return success(); 1512 } 1513 }; 1514 1515 } // namespace 1516 1517 /// Populate the given list with patterns that convert from Vector to LLVM. 1518 void mlir::populateVectorToLLVMConversionPatterns( 1519 LLVMTypeConverter &converter, OwningRewritePatternList &patterns, 1520 bool reassociateFPReductions, bool enableIndexOptimizations) { 1521 MLIRContext *ctx = converter.getDialect()->getContext(); 1522 // clang-format off 1523 patterns.insert<VectorFMAOpNDRewritePattern, 1524 VectorInsertStridedSliceOpDifferentRankRewritePattern, 1525 VectorInsertStridedSliceOpSameRankRewritePattern, 1526 VectorExtractStridedSliceOpConversion>(ctx); 1527 patterns.insert<VectorReductionOpConversion>( 1528 converter, reassociateFPReductions); 1529 patterns.insert<VectorCreateMaskOpConversion, 1530 VectorTransferConversion<TransferReadOp>, 1531 VectorTransferConversion<TransferWriteOp>>( 1532 converter, enableIndexOptimizations); 1533 patterns 1534 .insert<VectorBitCastOpConversion, 1535 VectorShuffleOpConversion, 1536 VectorExtractElementOpConversion, 1537 VectorExtractOpConversion, 1538 VectorFMAOp1DConversion, 1539 VectorInsertElementOpConversion, 1540 VectorInsertOpConversion, 1541 VectorPrintOpConversion, 1542 VectorTypeCastOpConversion, 1543 VectorMaskedLoadOpConversion, 1544 VectorMaskedStoreOpConversion, 1545 VectorGatherOpConversion, 1546 VectorScatterOpConversion, 1547 VectorExpandLoadOpConversion, 1548 VectorCompressStoreOpConversion>(converter); 1549 // clang-format on 1550 } 1551 1552 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1553 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 1554 patterns.insert<VectorMatmulOpConversion>(converter); 1555 patterns.insert<VectorFlatTransposeOpConversion>(converter); 1556 } 1557