1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 13 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/Dialect/Vector/VectorOps.h" 17 #include "mlir/IR/AffineMap.h" 18 #include "mlir/IR/Attributes.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/MLIRContext.h" 21 #include "mlir/IR/Module.h" 22 #include "mlir/IR/Operation.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/IR/StandardTypes.h" 25 #include "mlir/IR/Types.h" 26 #include "mlir/Transforms/DialectConversion.h" 27 #include "mlir/Transforms/Passes.h" 28 #include "llvm/IR/DerivedTypes.h" 29 #include "llvm/IR/Module.h" 30 #include "llvm/IR/Type.h" 31 #include "llvm/Support/Allocator.h" 32 #include "llvm/Support/ErrorHandling.h" 33 34 using namespace mlir; 35 using namespace mlir::vector; 36 37 template <typename T> 38 static LLVM::LLVMType getPtrToElementType(T containerType, 39 LLVMTypeConverter &typeConverter) { 40 return typeConverter.convertType(containerType.getElementType()) 41 .template cast<LLVM::LLVMType>() 42 .getPointerTo(); 43 } 44 45 // Helper to reduce vector type by one rank at front. 46 static VectorType reducedVectorTypeFront(VectorType tp) { 47 assert((tp.getRank() > 1) && "unlowerable vector type"); 48 return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); 49 } 50 51 // Helper to reduce vector type by *all* but one rank at back. 52 static VectorType reducedVectorTypeBack(VectorType tp) { 53 assert((tp.getRank() > 1) && "unlowerable vector type"); 54 return VectorType::get(tp.getShape().take_back(), tp.getElementType()); 55 } 56 57 // Helper that picks the proper sequence for inserting. 58 static Value insertOne(ConversionPatternRewriter &rewriter, 59 LLVMTypeConverter &typeConverter, Location loc, 60 Value val1, Value val2, Type llvmType, int64_t rank, 61 int64_t pos) { 62 if (rank == 1) { 63 auto idxType = rewriter.getIndexType(); 64 auto constant = rewriter.create<LLVM::ConstantOp>( 65 loc, typeConverter.convertType(idxType), 66 rewriter.getIntegerAttr(idxType, pos)); 67 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 68 constant); 69 } 70 return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, 71 rewriter.getI64ArrayAttr(pos)); 72 } 73 74 // Helper that picks the proper sequence for inserting. 75 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, 76 Value into, int64_t offset) { 77 auto vectorType = into.getType().cast<VectorType>(); 78 if (vectorType.getRank() > 1) 79 return rewriter.create<InsertOp>(loc, from, into, offset); 80 return rewriter.create<vector::InsertElementOp>( 81 loc, vectorType, from, into, 82 rewriter.create<ConstantIndexOp>(loc, offset)); 83 } 84 85 // Helper that picks the proper sequence for extracting. 86 static Value extractOne(ConversionPatternRewriter &rewriter, 87 LLVMTypeConverter &typeConverter, Location loc, 88 Value val, Type llvmType, int64_t rank, int64_t pos) { 89 if (rank == 1) { 90 auto idxType = rewriter.getIndexType(); 91 auto constant = rewriter.create<LLVM::ConstantOp>( 92 loc, typeConverter.convertType(idxType), 93 rewriter.getIntegerAttr(idxType, pos)); 94 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 95 constant); 96 } 97 return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val, 98 rewriter.getI64ArrayAttr(pos)); 99 } 100 101 // Helper that picks the proper sequence for extracting. 102 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, 103 int64_t offset) { 104 auto vectorType = vector.getType().cast<VectorType>(); 105 if (vectorType.getRank() > 1) 106 return rewriter.create<ExtractOp>(loc, vector, offset); 107 return rewriter.create<vector::ExtractElementOp>( 108 loc, vectorType.getElementType(), vector, 109 rewriter.create<ConstantIndexOp>(loc, offset)); 110 } 111 112 // Helper that returns a subset of `arrayAttr` as a vector of int64_t. 113 // TODO(rriddle): Better support for attribute subtype forwarding + slicing. 114 static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr, 115 unsigned dropFront = 0, 116 unsigned dropBack = 0) { 117 assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds"); 118 auto range = arrayAttr.getAsRange<IntegerAttr>(); 119 SmallVector<int64_t, 4> res; 120 res.reserve(arrayAttr.size() - dropFront - dropBack); 121 for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; 122 it != eit; ++it) 123 res.push_back((*it).getValue().getSExtValue()); 124 return res; 125 } 126 127 namespace { 128 129 /// Conversion pattern for a vector.matrix_multiply. 130 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 131 class VectorMatmulOpConversion : public ConvertToLLVMPattern { 132 public: 133 explicit VectorMatmulOpConversion(MLIRContext *context, 134 LLVMTypeConverter &typeConverter) 135 : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context, 136 typeConverter) {} 137 138 LogicalResult 139 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 140 ConversionPatternRewriter &rewriter) const override { 141 auto matmulOp = cast<vector::MatmulOp>(op); 142 auto adaptor = vector::MatmulOpOperandAdaptor(operands); 143 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 144 op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(), 145 adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(), 146 matmulOp.rhs_columns()); 147 return success(); 148 } 149 }; 150 151 /// Conversion pattern for a vector.flat_transpose. 152 /// This is lowered directly to the proper llvm.intr.matrix.transpose. 153 class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern { 154 public: 155 explicit VectorFlatTransposeOpConversion(MLIRContext *context, 156 LLVMTypeConverter &typeConverter) 157 : ConvertToLLVMPattern(vector::FlatTransposeOp::getOperationName(), 158 context, typeConverter) {} 159 160 LogicalResult 161 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 162 ConversionPatternRewriter &rewriter) const override { 163 auto transOp = cast<vector::FlatTransposeOp>(op); 164 auto adaptor = vector::FlatTransposeOpOperandAdaptor(operands); 165 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 166 transOp, typeConverter.convertType(transOp.res().getType()), 167 adaptor.matrix(), transOp.rows(), transOp.columns()); 168 return success(); 169 } 170 }; 171 172 class VectorReductionOpConversion : public ConvertToLLVMPattern { 173 public: 174 explicit VectorReductionOpConversion(MLIRContext *context, 175 LLVMTypeConverter &typeConverter) 176 : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context, 177 typeConverter) {} 178 179 LogicalResult 180 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 181 ConversionPatternRewriter &rewriter) const override { 182 auto reductionOp = cast<vector::ReductionOp>(op); 183 auto kind = reductionOp.kind(); 184 Type eltType = reductionOp.dest().getType(); 185 Type llvmType = typeConverter.convertType(eltType); 186 if (eltType.isSignlessInteger(32) || eltType.isSignlessInteger(64)) { 187 // Integer reductions: add/mul/min/max/and/or/xor. 188 if (kind == "add") 189 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_add>( 190 op, llvmType, operands[0]); 191 else if (kind == "mul") 192 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_mul>( 193 op, llvmType, operands[0]); 194 else if (kind == "min") 195 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smin>( 196 op, llvmType, operands[0]); 197 else if (kind == "max") 198 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smax>( 199 op, llvmType, operands[0]); 200 else if (kind == "and") 201 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_and>( 202 op, llvmType, operands[0]); 203 else if (kind == "or") 204 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_or>( 205 op, llvmType, operands[0]); 206 else if (kind == "xor") 207 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_xor>( 208 op, llvmType, operands[0]); 209 else 210 return failure(); 211 return success(); 212 213 } else if (eltType.isF32() || eltType.isF64()) { 214 // Floating-point reductions: add/mul/min/max 215 if (kind == "add") { 216 // Optional accumulator (or zero). 217 Value acc = operands.size() > 1 ? operands[1] 218 : rewriter.create<LLVM::ConstantOp>( 219 op->getLoc(), llvmType, 220 rewriter.getZeroAttr(eltType)); 221 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fadd>( 222 op, llvmType, acc, operands[0]); 223 } else if (kind == "mul") { 224 // Optional accumulator (or one). 225 Value acc = operands.size() > 1 226 ? operands[1] 227 : rewriter.create<LLVM::ConstantOp>( 228 op->getLoc(), llvmType, 229 rewriter.getFloatAttr(eltType, 1.0)); 230 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fmul>( 231 op, llvmType, acc, operands[0]); 232 } else if (kind == "min") 233 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmin>( 234 op, llvmType, operands[0]); 235 else if (kind == "max") 236 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmax>( 237 op, llvmType, operands[0]); 238 else 239 return failure(); 240 return success(); 241 } 242 return failure(); 243 } 244 }; 245 246 class VectorShuffleOpConversion : public ConvertToLLVMPattern { 247 public: 248 explicit VectorShuffleOpConversion(MLIRContext *context, 249 LLVMTypeConverter &typeConverter) 250 : ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context, 251 typeConverter) {} 252 253 LogicalResult 254 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 255 ConversionPatternRewriter &rewriter) const override { 256 auto loc = op->getLoc(); 257 auto adaptor = vector::ShuffleOpOperandAdaptor(operands); 258 auto shuffleOp = cast<vector::ShuffleOp>(op); 259 auto v1Type = shuffleOp.getV1VectorType(); 260 auto v2Type = shuffleOp.getV2VectorType(); 261 auto vectorType = shuffleOp.getVectorType(); 262 Type llvmType = typeConverter.convertType(vectorType); 263 auto maskArrayAttr = shuffleOp.mask(); 264 265 // Bail if result type cannot be lowered. 266 if (!llvmType) 267 return failure(); 268 269 // Get rank and dimension sizes. 270 int64_t rank = vectorType.getRank(); 271 assert(v1Type.getRank() == rank); 272 assert(v2Type.getRank() == rank); 273 int64_t v1Dim = v1Type.getDimSize(0); 274 275 // For rank 1, where both operands have *exactly* the same vector type, 276 // there is direct shuffle support in LLVM. Use it! 277 if (rank == 1 && v1Type == v2Type) { 278 Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>( 279 loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); 280 rewriter.replaceOp(op, shuffle); 281 return success(); 282 } 283 284 // For all other cases, insert the individual values individually. 285 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 286 int64_t insPos = 0; 287 for (auto en : llvm::enumerate(maskArrayAttr)) { 288 int64_t extPos = en.value().cast<IntegerAttr>().getInt(); 289 Value value = adaptor.v1(); 290 if (extPos >= v1Dim) { 291 extPos -= v1Dim; 292 value = adaptor.v2(); 293 } 294 Value extract = extractOne(rewriter, typeConverter, loc, value, llvmType, 295 rank, extPos); 296 insert = insertOne(rewriter, typeConverter, loc, insert, extract, 297 llvmType, rank, insPos++); 298 } 299 rewriter.replaceOp(op, insert); 300 return success(); 301 } 302 }; 303 304 class VectorExtractElementOpConversion : public ConvertToLLVMPattern { 305 public: 306 explicit VectorExtractElementOpConversion(MLIRContext *context, 307 LLVMTypeConverter &typeConverter) 308 : ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(), 309 context, typeConverter) {} 310 311 LogicalResult 312 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 313 ConversionPatternRewriter &rewriter) const override { 314 auto adaptor = vector::ExtractElementOpOperandAdaptor(operands); 315 auto extractEltOp = cast<vector::ExtractElementOp>(op); 316 auto vectorType = extractEltOp.getVectorType(); 317 auto llvmType = typeConverter.convertType(vectorType.getElementType()); 318 319 // Bail if result type cannot be lowered. 320 if (!llvmType) 321 return failure(); 322 323 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 324 op, llvmType, adaptor.vector(), adaptor.position()); 325 return success(); 326 } 327 }; 328 329 class VectorExtractOpConversion : public ConvertToLLVMPattern { 330 public: 331 explicit VectorExtractOpConversion(MLIRContext *context, 332 LLVMTypeConverter &typeConverter) 333 : ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context, 334 typeConverter) {} 335 336 LogicalResult 337 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 338 ConversionPatternRewriter &rewriter) const override { 339 auto loc = op->getLoc(); 340 auto adaptor = vector::ExtractOpOperandAdaptor(operands); 341 auto extractOp = cast<vector::ExtractOp>(op); 342 auto vectorType = extractOp.getVectorType(); 343 auto resultType = extractOp.getResult().getType(); 344 auto llvmResultType = typeConverter.convertType(resultType); 345 auto positionArrayAttr = extractOp.position(); 346 347 // Bail if result type cannot be lowered. 348 if (!llvmResultType) 349 return failure(); 350 351 // One-shot extraction of vector from array (only requires extractvalue). 352 if (resultType.isa<VectorType>()) { 353 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 354 loc, llvmResultType, adaptor.vector(), positionArrayAttr); 355 rewriter.replaceOp(op, extracted); 356 return success(); 357 } 358 359 // Potential extraction of 1-D vector from array. 360 auto *context = op->getContext(); 361 Value extracted = adaptor.vector(); 362 auto positionAttrs = positionArrayAttr.getValue(); 363 if (positionAttrs.size() > 1) { 364 auto oneDVectorType = reducedVectorTypeBack(vectorType); 365 auto nMinusOnePositionAttrs = 366 ArrayAttr::get(positionAttrs.drop_back(), context); 367 extracted = rewriter.create<LLVM::ExtractValueOp>( 368 loc, typeConverter.convertType(oneDVectorType), extracted, 369 nMinusOnePositionAttrs); 370 } 371 372 // Remaining extraction of element from 1-D LLVM vector 373 auto position = positionAttrs.back().cast<IntegerAttr>(); 374 auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect()); 375 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 376 extracted = 377 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 378 rewriter.replaceOp(op, extracted); 379 380 return success(); 381 } 382 }; 383 384 /// Conversion pattern that turns a vector.fma on a 1-D vector 385 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 386 /// This does not match vectors of n >= 2 rank. 387 /// 388 /// Example: 389 /// ``` 390 /// vector.fma %a, %a, %a : vector<8xf32> 391 /// ``` 392 /// is converted to: 393 /// ``` 394 /// llvm.intr.fma %va, %va, %va: 395 /// (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">) 396 /// -> !llvm<"<8 x float>"> 397 /// ``` 398 class VectorFMAOp1DConversion : public ConvertToLLVMPattern { 399 public: 400 explicit VectorFMAOp1DConversion(MLIRContext *context, 401 LLVMTypeConverter &typeConverter) 402 : ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context, 403 typeConverter) {} 404 405 LogicalResult 406 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 407 ConversionPatternRewriter &rewriter) const override { 408 auto adaptor = vector::FMAOpOperandAdaptor(operands); 409 vector::FMAOp fmaOp = cast<vector::FMAOp>(op); 410 VectorType vType = fmaOp.getVectorType(); 411 if (vType.getRank() != 1) 412 return failure(); 413 rewriter.replaceOpWithNewOp<LLVM::FMAOp>(op, adaptor.lhs(), adaptor.rhs(), 414 adaptor.acc()); 415 return success(); 416 } 417 }; 418 419 class VectorInsertElementOpConversion : public ConvertToLLVMPattern { 420 public: 421 explicit VectorInsertElementOpConversion(MLIRContext *context, 422 LLVMTypeConverter &typeConverter) 423 : ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(), 424 context, typeConverter) {} 425 426 LogicalResult 427 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 428 ConversionPatternRewriter &rewriter) const override { 429 auto adaptor = vector::InsertElementOpOperandAdaptor(operands); 430 auto insertEltOp = cast<vector::InsertElementOp>(op); 431 auto vectorType = insertEltOp.getDestVectorType(); 432 auto llvmType = typeConverter.convertType(vectorType); 433 434 // Bail if result type cannot be lowered. 435 if (!llvmType) 436 return failure(); 437 438 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 439 op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position()); 440 return success(); 441 } 442 }; 443 444 class VectorInsertOpConversion : public ConvertToLLVMPattern { 445 public: 446 explicit VectorInsertOpConversion(MLIRContext *context, 447 LLVMTypeConverter &typeConverter) 448 : ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context, 449 typeConverter) {} 450 451 LogicalResult 452 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 453 ConversionPatternRewriter &rewriter) const override { 454 auto loc = op->getLoc(); 455 auto adaptor = vector::InsertOpOperandAdaptor(operands); 456 auto insertOp = cast<vector::InsertOp>(op); 457 auto sourceType = insertOp.getSourceType(); 458 auto destVectorType = insertOp.getDestVectorType(); 459 auto llvmResultType = typeConverter.convertType(destVectorType); 460 auto positionArrayAttr = insertOp.position(); 461 462 // Bail if result type cannot be lowered. 463 if (!llvmResultType) 464 return failure(); 465 466 // One-shot insertion of a vector into an array (only requires insertvalue). 467 if (sourceType.isa<VectorType>()) { 468 Value inserted = rewriter.create<LLVM::InsertValueOp>( 469 loc, llvmResultType, adaptor.dest(), adaptor.source(), 470 positionArrayAttr); 471 rewriter.replaceOp(op, inserted); 472 return success(); 473 } 474 475 // Potential extraction of 1-D vector from array. 476 auto *context = op->getContext(); 477 Value extracted = adaptor.dest(); 478 auto positionAttrs = positionArrayAttr.getValue(); 479 auto position = positionAttrs.back().cast<IntegerAttr>(); 480 auto oneDVectorType = destVectorType; 481 if (positionAttrs.size() > 1) { 482 oneDVectorType = reducedVectorTypeBack(destVectorType); 483 auto nMinusOnePositionAttrs = 484 ArrayAttr::get(positionAttrs.drop_back(), context); 485 extracted = rewriter.create<LLVM::ExtractValueOp>( 486 loc, typeConverter.convertType(oneDVectorType), extracted, 487 nMinusOnePositionAttrs); 488 } 489 490 // Insertion of an element into a 1-D LLVM vector. 491 auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect()); 492 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 493 Value inserted = rewriter.create<LLVM::InsertElementOp>( 494 loc, typeConverter.convertType(oneDVectorType), extracted, 495 adaptor.source(), constant); 496 497 // Potential insertion of resulting 1-D vector into array. 498 if (positionAttrs.size() > 1) { 499 auto nMinusOnePositionAttrs = 500 ArrayAttr::get(positionAttrs.drop_back(), context); 501 inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType, 502 adaptor.dest(), inserted, 503 nMinusOnePositionAttrs); 504 } 505 506 rewriter.replaceOp(op, inserted); 507 return success(); 508 } 509 }; 510 511 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 512 /// 513 /// Example: 514 /// ``` 515 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 516 /// ``` 517 /// is rewritten into: 518 /// ``` 519 /// %r = splat %f0: vector<2x4xf32> 520 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 521 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 522 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 523 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 524 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 525 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 526 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 527 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 528 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 529 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 530 /// // %r3 holds the final value. 531 /// ``` 532 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 533 public: 534 using OpRewritePattern<FMAOp>::OpRewritePattern; 535 536 LogicalResult matchAndRewrite(FMAOp op, 537 PatternRewriter &rewriter) const override { 538 auto vType = op.getVectorType(); 539 if (vType.getRank() < 2) 540 return failure(); 541 542 auto loc = op.getLoc(); 543 auto elemType = vType.getElementType(); 544 Value zero = rewriter.create<ConstantOp>(loc, elemType, 545 rewriter.getZeroAttr(elemType)); 546 Value desc = rewriter.create<SplatOp>(loc, vType, zero); 547 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 548 Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i); 549 Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i); 550 Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i); 551 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 552 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 553 } 554 rewriter.replaceOp(op, desc); 555 return success(); 556 } 557 }; 558 559 // When ranks are different, InsertStridedSlice needs to extract a properly 560 // ranked vector from the destination vector into which to insert. This pattern 561 // only takes care of this part and forwards the rest of the conversion to 562 // another pattern that converts InsertStridedSlice for operands of the same 563 // rank. 564 // 565 // RewritePattern for InsertStridedSliceOp where source and destination vectors 566 // have different ranks. In this case: 567 // 1. the proper subvector is extracted from the destination vector 568 // 2. a new InsertStridedSlice op is created to insert the source in the 569 // destination subvector 570 // 3. the destination subvector is inserted back in the proper place 571 // 4. the op is replaced by the result of step 3. 572 // The new InsertStridedSlice from step 2. will be picked up by a 573 // `VectorInsertStridedSliceOpSameRankRewritePattern`. 574 class VectorInsertStridedSliceOpDifferentRankRewritePattern 575 : public OpRewritePattern<InsertStridedSliceOp> { 576 public: 577 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 578 579 LogicalResult matchAndRewrite(InsertStridedSliceOp op, 580 PatternRewriter &rewriter) const override { 581 auto srcType = op.getSourceVectorType(); 582 auto dstType = op.getDestVectorType(); 583 584 if (op.offsets().getValue().empty()) 585 return failure(); 586 587 auto loc = op.getLoc(); 588 int64_t rankDiff = dstType.getRank() - srcType.getRank(); 589 assert(rankDiff >= 0); 590 if (rankDiff == 0) 591 return failure(); 592 593 int64_t rankRest = dstType.getRank() - rankDiff; 594 // Extract / insert the subvector of matching rank and InsertStridedSlice 595 // on it. 596 Value extracted = 597 rewriter.create<ExtractOp>(loc, op.dest(), 598 getI64SubArray(op.offsets(), /*dropFront=*/0, 599 /*dropFront=*/rankRest)); 600 // A different pattern will kick in for InsertStridedSlice with matching 601 // ranks. 602 auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>( 603 loc, op.source(), extracted, 604 getI64SubArray(op.offsets(), /*dropFront=*/rankDiff), 605 getI64SubArray(op.strides(), /*dropFront=*/0)); 606 rewriter.replaceOpWithNewOp<InsertOp>( 607 op, stridedSliceInnerOp.getResult(), op.dest(), 608 getI64SubArray(op.offsets(), /*dropFront=*/0, 609 /*dropFront=*/rankRest)); 610 return success(); 611 } 612 }; 613 614 // RewritePattern for InsertStridedSliceOp where source and destination vectors 615 // have the same rank. In this case, we reduce 616 // 1. the proper subvector is extracted from the destination vector 617 // 2. a new InsertStridedSlice op is created to insert the source in the 618 // destination subvector 619 // 3. the destination subvector is inserted back in the proper place 620 // 4. the op is replaced by the result of step 3. 621 // The new InsertStridedSlice from step 2. will be picked up by a 622 // `VectorInsertStridedSliceOpSameRankRewritePattern`. 623 class VectorInsertStridedSliceOpSameRankRewritePattern 624 : public OpRewritePattern<InsertStridedSliceOp> { 625 public: 626 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 627 628 LogicalResult matchAndRewrite(InsertStridedSliceOp op, 629 PatternRewriter &rewriter) const override { 630 auto srcType = op.getSourceVectorType(); 631 auto dstType = op.getDestVectorType(); 632 633 if (op.offsets().getValue().empty()) 634 return failure(); 635 636 int64_t rankDiff = dstType.getRank() - srcType.getRank(); 637 assert(rankDiff >= 0); 638 if (rankDiff != 0) 639 return failure(); 640 641 if (srcType == dstType) { 642 rewriter.replaceOp(op, op.source()); 643 return success(); 644 } 645 646 int64_t offset = 647 op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 648 int64_t size = srcType.getShape().front(); 649 int64_t stride = 650 op.strides().getValue().front().cast<IntegerAttr>().getInt(); 651 652 auto loc = op.getLoc(); 653 Value res = op.dest(); 654 // For each slice of the source vector along the most major dimension. 655 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 656 off += stride, ++idx) { 657 // 1. extract the proper subvector (or element) from source 658 Value extractedSource = extractOne(rewriter, loc, op.source(), idx); 659 if (extractedSource.getType().isa<VectorType>()) { 660 // 2. If we have a vector, extract the proper subvector from destination 661 // Otherwise we are at the element level and no need to recurse. 662 Value extractedDest = extractOne(rewriter, loc, op.dest(), off); 663 // 3. Reduce the problem to lowering a new InsertStridedSlice op with 664 // smaller rank. 665 extractedSource = rewriter.create<InsertStridedSliceOp>( 666 loc, extractedSource, extractedDest, 667 getI64SubArray(op.offsets(), /* dropFront=*/1), 668 getI64SubArray(op.strides(), /* dropFront=*/1)); 669 } 670 // 4. Insert the extractedSource into the res vector. 671 res = insertOne(rewriter, loc, extractedSource, res, off); 672 } 673 674 rewriter.replaceOp(op, res); 675 return success(); 676 } 677 /// This pattern creates recursive InsertStridedSliceOp, but the recursion is 678 /// bounded as the rank is strictly decreasing. 679 bool hasBoundedRewriteRecursion() const final { return true; } 680 }; 681 682 class VectorTypeCastOpConversion : public ConvertToLLVMPattern { 683 public: 684 explicit VectorTypeCastOpConversion(MLIRContext *context, 685 LLVMTypeConverter &typeConverter) 686 : ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context, 687 typeConverter) {} 688 689 LogicalResult 690 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 691 ConversionPatternRewriter &rewriter) const override { 692 auto loc = op->getLoc(); 693 vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op); 694 MemRefType sourceMemRefType = 695 castOp.getOperand().getType().cast<MemRefType>(); 696 MemRefType targetMemRefType = 697 castOp.getResult().getType().cast<MemRefType>(); 698 699 // Only static shape casts supported atm. 700 if (!sourceMemRefType.hasStaticShape() || 701 !targetMemRefType.hasStaticShape()) 702 return failure(); 703 704 auto llvmSourceDescriptorTy = 705 operands[0].getType().dyn_cast<LLVM::LLVMType>(); 706 if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy()) 707 return failure(); 708 MemRefDescriptor sourceMemRef(operands[0]); 709 710 auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType) 711 .dyn_cast_or_null<LLVM::LLVMType>(); 712 if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) 713 return failure(); 714 715 int64_t offset; 716 SmallVector<int64_t, 4> strides; 717 auto successStrides = 718 getStridesAndOffset(sourceMemRefType, strides, offset); 719 bool isContiguous = (strides.back() == 1); 720 if (isContiguous) { 721 auto sizes = sourceMemRefType.getShape(); 722 for (int index = 0, e = strides.size() - 2; index < e; ++index) { 723 if (strides[index] != strides[index + 1] * sizes[index + 1]) { 724 isContiguous = false; 725 break; 726 } 727 } 728 } 729 // Only contiguous source tensors supported atm. 730 if (failed(successStrides) || !isContiguous) 731 return failure(); 732 733 auto int64Ty = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect()); 734 735 // Create descriptor. 736 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 737 Type llvmTargetElementTy = desc.getElementType(); 738 // Set allocated ptr. 739 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 740 allocated = 741 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 742 desc.setAllocatedPtr(rewriter, loc, allocated); 743 // Set aligned ptr. 744 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 745 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 746 desc.setAlignedPtr(rewriter, loc, ptr); 747 // Fill offset 0. 748 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 749 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 750 desc.setOffset(rewriter, loc, zero); 751 752 // Fill size and stride descriptors in memref. 753 for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { 754 int64_t index = indexedSize.index(); 755 auto sizeAttr = 756 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 757 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 758 desc.setSize(rewriter, loc, index, size); 759 auto strideAttr = 760 rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]); 761 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 762 desc.setStride(rewriter, loc, index, stride); 763 } 764 765 rewriter.replaceOp(op, {desc}); 766 return success(); 767 } 768 }; 769 770 LogicalResult getLLVMTypeAndAlignment(LLVMTypeConverter &typeConverter, 771 Type type, LLVM::LLVMType &llvmType, 772 unsigned &align) { 773 auto convertedType = typeConverter.convertType(type); 774 if (!convertedType) 775 return failure(); 776 777 llvmType = convertedType.template cast<LLVM::LLVMType>(); 778 auto dataLayout = typeConverter.getDialect()->getLLVMModule().getDataLayout(); 779 align = dataLayout.getPrefTypeAlignment(llvmType.getUnderlyingType()); 780 return success(); 781 } 782 783 LogicalResult 784 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, 785 LLVMTypeConverter &typeConverter, Location loc, 786 TransferReadOp xferOp, 787 ArrayRef<Value> operands, Value dataPtr) { 788 LLVM::LLVMType vecTy; 789 unsigned align; 790 if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(), 791 vecTy, align))) 792 return failure(); 793 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr); 794 return success(); 795 } 796 797 LogicalResult replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, 798 LLVMTypeConverter &typeConverter, 799 Location loc, TransferReadOp xferOp, 800 ArrayRef<Value> operands, 801 Value dataPtr, Value mask) { 802 auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; 803 VectorType fillType = xferOp.getVectorType(); 804 Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding()); 805 fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill); 806 807 LLVM::LLVMType vecTy; 808 unsigned align; 809 if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(), 810 vecTy, align))) 811 return failure(); 812 813 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 814 xferOp, vecTy, dataPtr, mask, ValueRange{fill}, 815 rewriter.getI32IntegerAttr(align)); 816 return success(); 817 } 818 819 LogicalResult 820 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, 821 LLVMTypeConverter &typeConverter, Location loc, 822 TransferWriteOp xferOp, 823 ArrayRef<Value> operands, Value dataPtr) { 824 auto adaptor = TransferWriteOpOperandAdaptor(operands); 825 LLVM::LLVMType vecTy; 826 unsigned align; 827 if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(), 828 vecTy, align))) 829 return failure(); 830 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr); 831 return success(); 832 } 833 834 LogicalResult replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, 835 LLVMTypeConverter &typeConverter, 836 Location loc, TransferWriteOp xferOp, 837 ArrayRef<Value> operands, 838 Value dataPtr, Value mask) { 839 auto adaptor = TransferWriteOpOperandAdaptor(operands); 840 LLVM::LLVMType vecTy; 841 unsigned align; 842 if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(), 843 vecTy, align))) 844 return failure(); 845 846 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 847 xferOp, adaptor.vector(), dataPtr, mask, 848 rewriter.getI32IntegerAttr(align)); 849 return success(); 850 } 851 852 static TransferReadOpOperandAdaptor 853 getTransferOpAdapter(TransferReadOp xferOp, ArrayRef<Value> operands) { 854 return TransferReadOpOperandAdaptor(operands); 855 } 856 857 static TransferWriteOpOperandAdaptor 858 getTransferOpAdapter(TransferWriteOp xferOp, ArrayRef<Value> operands) { 859 return TransferWriteOpOperandAdaptor(operands); 860 } 861 862 bool isMinorIdentity(AffineMap map, unsigned rank) { 863 if (map.getNumResults() < rank) 864 return false; 865 unsigned startDim = map.getNumDims() - rank; 866 for (unsigned i = 0; i < rank; ++i) 867 if (map.getResult(i) != getAffineDimExpr(startDim + i, map.getContext())) 868 return false; 869 return true; 870 } 871 872 /// Conversion pattern that converts a 1-D vector transfer read/write op in a 873 /// sequence of: 874 /// 1. Bitcast or addrspacecast to vector form. 875 /// 2. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 876 /// 3. Create a mask where offsetVector is compared against memref upper bound. 877 /// 4. Rewrite op as a masked read or write. 878 template <typename ConcreteOp> 879 class VectorTransferConversion : public ConvertToLLVMPattern { 880 public: 881 explicit VectorTransferConversion(MLIRContext *context, 882 LLVMTypeConverter &typeConv) 883 : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context, 884 typeConv) {} 885 886 LogicalResult 887 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 888 ConversionPatternRewriter &rewriter) const override { 889 auto xferOp = cast<ConcreteOp>(op); 890 auto adaptor = getTransferOpAdapter(xferOp, operands); 891 892 if (xferOp.getVectorType().getRank() > 1 || 893 llvm::size(xferOp.indices()) == 0) 894 return failure(); 895 if (!isMinorIdentity(xferOp.permutation_map(), 896 xferOp.getVectorType().getRank())) 897 return failure(); 898 899 auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; 900 901 Location loc = op->getLoc(); 902 Type i64Type = rewriter.getIntegerType(64); 903 MemRefType memRefType = xferOp.getMemRefType(); 904 905 // 1. Get the source/dst address as an LLVM vector pointer. 906 // The vector pointer would always be on address space 0, therefore 907 // addrspacecast shall be used when source/dst memrefs are not on 908 // address space 0. 909 // TODO: support alignment when possible. 910 Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(), 911 adaptor.indices(), rewriter, getModule()); 912 auto vecTy = 913 toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>(); 914 Value vectorDataPtr; 915 if (memRefType.getMemorySpace() == 0) 916 vectorDataPtr = 917 rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr); 918 else 919 vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>( 920 loc, vecTy.getPointerTo(), dataPtr); 921 922 if (!xferOp.isMaskedDim(0)) 923 return replaceTransferOpWithLoadOrStore(rewriter, typeConverter, loc, 924 xferOp, operands, vectorDataPtr); 925 926 // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. 927 unsigned vecWidth = vecTy.getVectorNumElements(); 928 VectorType vectorCmpType = VectorType::get(vecWidth, i64Type); 929 SmallVector<int64_t, 8> indices; 930 indices.reserve(vecWidth); 931 for (unsigned i = 0; i < vecWidth; ++i) 932 indices.push_back(i); 933 Value linearIndices = rewriter.create<ConstantOp>( 934 loc, vectorCmpType, 935 DenseElementsAttr::get(vectorCmpType, ArrayRef<int64_t>(indices))); 936 linearIndices = rewriter.create<LLVM::DialectCastOp>( 937 loc, toLLVMTy(vectorCmpType), linearIndices); 938 939 // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 940 // TODO(ntv, ajcbik): when the leaf transfer rank is k > 1 we need the last 941 // `k` dimensions here. 942 unsigned lastIndex = llvm::size(xferOp.indices()) - 1; 943 Value offsetIndex = *(xferOp.indices().begin() + lastIndex); 944 offsetIndex = rewriter.create<IndexCastOp>(loc, i64Type, offsetIndex); 945 Value base = rewriter.create<SplatOp>(loc, vectorCmpType, offsetIndex); 946 Value offsetVector = rewriter.create<AddIOp>(loc, base, linearIndices); 947 948 // 4. Let dim the memref dimension, compute the vector comparison mask: 949 // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] 950 Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex); 951 dim = rewriter.create<IndexCastOp>(loc, i64Type, dim); 952 dim = rewriter.create<SplatOp>(loc, vectorCmpType, dim); 953 Value mask = 954 rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, offsetVector, dim); 955 mask = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(mask.getType()), 956 mask); 957 958 // 5. Rewrite as a masked read / write. 959 return replaceTransferOpWithMasked(rewriter, typeConverter, loc, xferOp, 960 operands, vectorDataPtr, mask); 961 } 962 }; 963 964 class VectorPrintOpConversion : public ConvertToLLVMPattern { 965 public: 966 explicit VectorPrintOpConversion(MLIRContext *context, 967 LLVMTypeConverter &typeConverter) 968 : ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context, 969 typeConverter) {} 970 971 // Proof-of-concept lowering implementation that relies on a small 972 // runtime support library, which only needs to provide a few 973 // printing methods (single value for all data types, opening/closing 974 // bracket, comma, newline). The lowering fully unrolls a vector 975 // in terms of these elementary printing operations. The advantage 976 // of this approach is that the library can remain unaware of all 977 // low-level implementation details of vectors while still supporting 978 // output of any shaped and dimensioned vector. Due to full unrolling, 979 // this approach is less suited for very large vectors though. 980 // 981 // TODO(ajcbik): rely solely on libc in future? something else? 982 // 983 LogicalResult 984 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 985 ConversionPatternRewriter &rewriter) const override { 986 auto printOp = cast<vector::PrintOp>(op); 987 auto adaptor = vector::PrintOpOperandAdaptor(operands); 988 Type printType = printOp.getPrintType(); 989 990 if (typeConverter.convertType(printType) == nullptr) 991 return failure(); 992 993 // Make sure element type has runtime support (currently just Float/Double). 994 VectorType vectorType = printType.dyn_cast<VectorType>(); 995 Type eltType = vectorType ? vectorType.getElementType() : printType; 996 int64_t rank = vectorType ? vectorType.getRank() : 0; 997 Operation *printer; 998 if (eltType.isSignlessInteger(1)) 999 printer = getPrintI1(op); 1000 else if (eltType.isSignlessInteger(32)) 1001 printer = getPrintI32(op); 1002 else if (eltType.isSignlessInteger(64)) 1003 printer = getPrintI64(op); 1004 else if (eltType.isF32()) 1005 printer = getPrintFloat(op); 1006 else if (eltType.isF64()) 1007 printer = getPrintDouble(op); 1008 else 1009 return failure(); 1010 1011 // Unroll vector into elementary print calls. 1012 emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank); 1013 emitCall(rewriter, op->getLoc(), getPrintNewline(op)); 1014 rewriter.eraseOp(op); 1015 return success(); 1016 } 1017 1018 private: 1019 void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 1020 Value value, VectorType vectorType, Operation *printer, 1021 int64_t rank) const { 1022 Location loc = op->getLoc(); 1023 if (rank == 0) { 1024 emitCall(rewriter, loc, printer, value); 1025 return; 1026 } 1027 1028 emitCall(rewriter, loc, getPrintOpen(op)); 1029 Operation *printComma = getPrintComma(op); 1030 int64_t dim = vectorType.getDimSize(0); 1031 for (int64_t d = 0; d < dim; ++d) { 1032 auto reducedType = 1033 rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr; 1034 auto llvmType = typeConverter.convertType( 1035 rank > 1 ? reducedType : vectorType.getElementType()); 1036 Value nestedVal = 1037 extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d); 1038 emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1); 1039 if (d != dim - 1) 1040 emitCall(rewriter, loc, printComma); 1041 } 1042 emitCall(rewriter, loc, getPrintClose(op)); 1043 } 1044 1045 // Helper to emit a call. 1046 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1047 Operation *ref, ValueRange params = ValueRange()) { 1048 rewriter.create<LLVM::CallOp>(loc, ArrayRef<Type>{}, 1049 rewriter.getSymbolRefAttr(ref), params); 1050 } 1051 1052 // Helper for printer method declaration (first hit) and lookup. 1053 static Operation *getPrint(Operation *op, LLVM::LLVMDialect *dialect, 1054 StringRef name, ArrayRef<LLVM::LLVMType> params) { 1055 auto module = op->getParentOfType<ModuleOp>(); 1056 auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name); 1057 if (func) 1058 return func; 1059 OpBuilder moduleBuilder(module.getBodyRegion()); 1060 return moduleBuilder.create<LLVM::LLVMFuncOp>( 1061 op->getLoc(), name, 1062 LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(dialect), 1063 params, /*isVarArg=*/false)); 1064 } 1065 1066 // Helpers for method names. 1067 Operation *getPrintI1(Operation *op) const { 1068 LLVM::LLVMDialect *dialect = typeConverter.getDialect(); 1069 return getPrint(op, dialect, "print_i1", 1070 LLVM::LLVMType::getInt1Ty(dialect)); 1071 } 1072 Operation *getPrintI32(Operation *op) const { 1073 LLVM::LLVMDialect *dialect = typeConverter.getDialect(); 1074 return getPrint(op, dialect, "print_i32", 1075 LLVM::LLVMType::getInt32Ty(dialect)); 1076 } 1077 Operation *getPrintI64(Operation *op) const { 1078 LLVM::LLVMDialect *dialect = typeConverter.getDialect(); 1079 return getPrint(op, dialect, "print_i64", 1080 LLVM::LLVMType::getInt64Ty(dialect)); 1081 } 1082 Operation *getPrintFloat(Operation *op) const { 1083 LLVM::LLVMDialect *dialect = typeConverter.getDialect(); 1084 return getPrint(op, dialect, "print_f32", 1085 LLVM::LLVMType::getFloatTy(dialect)); 1086 } 1087 Operation *getPrintDouble(Operation *op) const { 1088 LLVM::LLVMDialect *dialect = typeConverter.getDialect(); 1089 return getPrint(op, dialect, "print_f64", 1090 LLVM::LLVMType::getDoubleTy(dialect)); 1091 } 1092 Operation *getPrintOpen(Operation *op) const { 1093 return getPrint(op, typeConverter.getDialect(), "print_open", {}); 1094 } 1095 Operation *getPrintClose(Operation *op) const { 1096 return getPrint(op, typeConverter.getDialect(), "print_close", {}); 1097 } 1098 Operation *getPrintComma(Operation *op) const { 1099 return getPrint(op, typeConverter.getDialect(), "print_comma", {}); 1100 } 1101 Operation *getPrintNewline(Operation *op) const { 1102 return getPrint(op, typeConverter.getDialect(), "print_newline", {}); 1103 } 1104 }; 1105 1106 /// Progressive lowering of ExtractStridedSliceOp to either: 1107 /// 1. extractelement + insertelement for the 1-D case 1108 /// 2. extract + optional strided_slice + insert for the n-D case. 1109 class VectorStridedSliceOpConversion 1110 : public OpRewritePattern<ExtractStridedSliceOp> { 1111 public: 1112 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; 1113 1114 LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 1115 PatternRewriter &rewriter) const override { 1116 auto dstType = op.getResult().getType().cast<VectorType>(); 1117 1118 assert(!op.offsets().getValue().empty() && "Unexpected empty offsets"); 1119 1120 int64_t offset = 1121 op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 1122 int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt(); 1123 int64_t stride = 1124 op.strides().getValue().front().cast<IntegerAttr>().getInt(); 1125 1126 auto loc = op.getLoc(); 1127 auto elemType = dstType.getElementType(); 1128 assert(elemType.isSignlessIntOrIndexOrFloat()); 1129 Value zero = rewriter.create<ConstantOp>(loc, elemType, 1130 rewriter.getZeroAttr(elemType)); 1131 Value res = rewriter.create<SplatOp>(loc, dstType, zero); 1132 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 1133 off += stride, ++idx) { 1134 Value extracted = extractOne(rewriter, loc, op.vector(), off); 1135 if (op.offsets().getValue().size() > 1) { 1136 extracted = rewriter.create<ExtractStridedSliceOp>( 1137 loc, extracted, getI64SubArray(op.offsets(), /* dropFront=*/1), 1138 getI64SubArray(op.sizes(), /* dropFront=*/1), 1139 getI64SubArray(op.strides(), /* dropFront=*/1)); 1140 } 1141 res = insertOne(rewriter, loc, extracted, res, idx); 1142 } 1143 rewriter.replaceOp(op, {res}); 1144 return success(); 1145 } 1146 /// This pattern creates recursive ExtractStridedSliceOp, but the recursion is 1147 /// bounded as the rank is strictly decreasing. 1148 bool hasBoundedRewriteRecursion() const final { return true; } 1149 }; 1150 1151 } // namespace 1152 1153 /// Populate the given list with patterns that convert from Vector to LLVM. 1154 void mlir::populateVectorToLLVMConversionPatterns( 1155 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 1156 MLIRContext *ctx = converter.getDialect()->getContext(); 1157 // clang-format off 1158 patterns.insert<VectorFMAOpNDRewritePattern, 1159 VectorInsertStridedSliceOpDifferentRankRewritePattern, 1160 VectorInsertStridedSliceOpSameRankRewritePattern, 1161 VectorStridedSliceOpConversion>(ctx); 1162 patterns 1163 .insert<VectorReductionOpConversion, 1164 VectorShuffleOpConversion, 1165 VectorExtractElementOpConversion, 1166 VectorExtractOpConversion, 1167 VectorFMAOp1DConversion, 1168 VectorInsertElementOpConversion, 1169 VectorInsertOpConversion, 1170 VectorPrintOpConversion, 1171 VectorTransferConversion<TransferReadOp>, 1172 VectorTransferConversion<TransferWriteOp>, 1173 VectorTypeCastOpConversion>(ctx, converter); 1174 // clang-format on 1175 } 1176 1177 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1178 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 1179 MLIRContext *ctx = converter.getDialect()->getContext(); 1180 patterns.insert<VectorMatmulOpConversion>(ctx, converter); 1181 patterns.insert<VectorFlatTransposeOpConversion>(ctx, converter); 1182 } 1183 1184 namespace { 1185 struct LowerVectorToLLVMPass 1186 : public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> { 1187 void runOnOperation() override; 1188 }; 1189 } // namespace 1190 1191 void LowerVectorToLLVMPass::runOnOperation() { 1192 // Perform progressive lowering of operations on slices and 1193 // all contraction operations. Also applies folding and DCE. 1194 { 1195 OwningRewritePatternList patterns; 1196 populateVectorToVectorCanonicalizationPatterns(patterns, &getContext()); 1197 populateVectorSlicesLoweringPatterns(patterns, &getContext()); 1198 populateVectorContractLoweringPatterns(patterns, &getContext()); 1199 applyPatternsAndFoldGreedily(getOperation(), patterns); 1200 } 1201 1202 // Convert to the LLVM IR dialect. 1203 LLVMTypeConverter converter(&getContext()); 1204 OwningRewritePatternList patterns; 1205 populateVectorToLLVMMatrixConversionPatterns(converter, patterns); 1206 populateVectorToLLVMConversionPatterns(converter, patterns); 1207 populateVectorToLLVMMatrixConversionPatterns(converter, patterns); 1208 populateStdToLLVMConversionPatterns(converter, patterns); 1209 1210 LLVMConversionTarget target(getContext()); 1211 target.addDynamicallyLegalOp<FuncOp>( 1212 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 1213 if (failed(applyPartialConversion(getOperation(), target, patterns, 1214 &converter))) { 1215 signalPassFailure(); 1216 } 1217 } 1218 1219 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertVectorToLLVMPass() { 1220 return std::make_unique<LowerVectorToLLVMPass>(); 1221 } 1222