Lines Matching defs:rewriter
45 static Value insertOne(ConversionPatternRewriter &rewriter,
51 auto idxType = rewriter.getIndexType();
52 auto constant = rewriter.create<LLVM::ConstantOp>(
54 rewriter.getIntegerAttr(idxType, pos));
55 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
58 return rewriter.create<LLVM::InsertValueOp>(loc, val1, val2, pos);
62 static Value extractOne(ConversionPatternRewriter &rewriter,
66 auto idxType = rewriter.getIndexType();
67 auto constant = rewriter.create<LLVM::ConstantOp>(
69 rewriter.getIntegerAttr(idxType, pos));
70 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
73 return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos);
102 static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
113 return rewriter.create<LLVM::GEPOp>(
144 ConversionPatternRewriter &rewriter) const override {
150 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
165 ConversionPatternRewriter &rewriter) const override {
166 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
183 ConversionPatternRewriter &rewriter) const override {
184 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
197 ConversionPatternRewriter &rewriter) {
198 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align,
206 ConversionPatternRewriter &rewriter) {
207 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
214 ConversionPatternRewriter &rewriter) {
215 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(),
223 ConversionPatternRewriter &rewriter) {
224 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
238 ConversionPatternRewriter &rewriter) const override {
256 adaptor.getIndices(), rewriter);
258 rewriter);
271 ConversionPatternRewriter &rewriter) const override {
286 adaptor.getIndices(), rewriter);
295 getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
298 rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
300 adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
305 auto callback = [align, memRefType, base, ptr, loc, &rewriter,
310 rewriter, loc, typeConverter, memRefType, base, ptr,
313 return rewriter.create<LLVM::masked_gather>(
315 /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align));
320 gather, vectorOperands, *getTypeConverter(), callback, rewriter);
332 ConversionPatternRewriter &rewriter) const override {
347 adaptor.getIndices(), rewriter);
349 getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
353 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
355 rewriter.getI32IntegerAttr(align));
368 ConversionPatternRewriter &rewriter) const override {
375 adaptor.getIndices(), rewriter);
377 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
391 ConversionPatternRewriter &rewriter) const override {
397 adaptor.getIndices(), rewriter);
399 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
419 ConversionPatternRewriter &rewriter,
421 return rewriter.create<LLVM::ConstantOp>(loc, llvmType,
422 rewriter.getZeroAttr(llvmType));
427 ConversionPatternRewriter &rewriter,
429 return rewriter.create<LLVM::ConstantOp>(
430 loc, llvmType, rewriter.getIntegerAttr(llvmType, 1));
435 ConversionPatternRewriter &rewriter,
437 return rewriter.create<LLVM::ConstantOp>(
438 loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0));
443 ConversionPatternRewriter &rewriter,
445 return rewriter.create<LLVM::ConstantOp>(
447 rewriter.getIntegerAttr(
453 ConversionPatternRewriter &rewriter,
455 return rewriter.create<LLVM::ConstantOp>(
457 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue(
463 ConversionPatternRewriter &rewriter,
465 return rewriter.create<LLVM::ConstantOp>(
467 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue(
473 ConversionPatternRewriter &rewriter,
475 return rewriter.create<LLVM::ConstantOp>(
477 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue(
483 ConversionPatternRewriter &rewriter,
485 return rewriter.create<LLVM::ConstantOp>(
487 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue(
493 ConversionPatternRewriter &rewriter,
496 return rewriter.create<LLVM::ConstantOp>(
498 rewriter.getFloatAttr(
505 ConversionPatternRewriter &rewriter,
508 return rewriter.create<LLVM::ConstantOp>(
510 rewriter.getFloatAttr(
518 static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
524 return createReductionNeutralValue(ReductionNeutral(), rewriter, loc,
531 static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
537 Value baseVecLength = rewriter.create<LLVM::ConstantOp>(
538 loc, rewriter.getI32Type(),
539 rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));
545 Value vScale = rewriter.create<vector::VectorScaleOp>(loc);
547 rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), vScale);
549 rewriter.create<arith::MulIOp>(loc, baseVecLength, vScale);
559 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
562 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
565 result = rewriter.create<ScalarOp>(loc, accumulator, result);
575 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
577 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
580 rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result);
581 result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result);
609 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
612 rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand, fmf);
616 rewriter.create<typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>(
642 static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter,
648 return rewriter.create<LLVM::ConstantOp>(loc, vectorType, denseValue);
657 lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter,
662 rewriter, loc, llvmType, vectorOperand.getType());
663 const Value selectedVectorByMask = rewriter.create<LLVM::SelectOp>(
666 rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf);
671 lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc,
674 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
676 return rewriter.create<LLVMRedIntrinOp>(loc, llvmType,
686 lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter,
689 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
691 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
698 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
700 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
703 createVectorLengthValue(rewriter, loc, vectorOperand.getType());
704 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
712 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
717 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
722 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
736 ConversionPatternRewriter &rewriter) const override {
752 rewriter, loc, llvmType, operand, acc);
758 rewriter, loc, llvmType, operand, acc);
762 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
767 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
772 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
777 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
784 rewriter, loc, llvmType, operand, acc);
790 rewriter, loc, llvmType, operand, acc);
796 rewriter, loc, llvmType, operand, acc);
801 rewriter.replaceOp(reductionOp, result);
823 rewriter, loc, llvmType, operand, acc, fmf);
827 rewriter, loc, llvmType, operand, acc, fmf);
831 rewriter, loc, llvmType, operand, acc, fmf);
835 rewriter, loc, llvmType, operand, acc, fmf);
838 rewriter, loc, llvmType, operand, acc, fmf);
841 rewriter, loc, llvmType, operand, acc, fmf);
845 rewriter.replaceOp(reductionOp, result);
869 ConversionPatternRewriter &rewriter) const final {
874 return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter);
881 ConversionPatternRewriter &rewriter) const = 0;
893 ConversionPatternRewriter &rewriter) const override {
912 ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc,
918 ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc,
924 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
929 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
934 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
939 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
944 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
949 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
954 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
959 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
964 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
969 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
974 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
979 rewriter.replaceOp(maskOp, result);
991 ConversionPatternRewriter &rewriter) const override {
1016 Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
1019 rewriter.replaceOp(shuffleOp, llvmShuffleOp);
1030 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
1038 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
1040 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
1043 rewriter.replaceOp(shuffleOp, insert);
1056 ConversionPatternRewriter &rewriter) const override {
1066 auto idxType = rewriter.getIndexType();
1067 auto zero = rewriter.create<LLVM::ConstantOp>(
1069 rewriter.getIntegerAttr(idxType, 0));
1070 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
1075 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
1088 ConversionPatternRewriter &rewriter) const override {
1097 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1121 Type idxType = typeConverter->convertType(rewriter.getIndexType());
1122 positionVec.push_back(rewriter.getZeroAttr(idxType));
1138 extracted = rewriter.create<LLVM::ExtractValueOp>(
1143 extracted = rewriter.create<LLVM::ExtractElementOp>(
1144 loc, extracted, getAsLLVMValue(rewriter, loc, positionVec.back()));
1147 rewriter.replaceOp(extractOp, extracted);
1172 ConversionPatternRewriter &rewriter) const override {
1177 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
1190 ConversionPatternRewriter &rewriter) const override {
1200 auto idxType = rewriter.getIndexType();
1201 auto zero = rewriter.create<LLVM::ConstantOp>(
1203 rewriter.getIntegerAttr(idxType, 0));
1204 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1209 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1223 ConversionPatternRewriter &rewriter) const override {
1233 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1239 rewriter.replaceOp(insertOp, adaptor.getSource());
1248 Value inserted = rewriter.create<LLVM::InsertValueOp>(
1250 rewriter.replaceOp(insertOp, inserted);
1262 extracted = rewriter.create<LLVM::ExtractValueOp>(
1267 Value inserted = rewriter.create<LLVM::InsertElementOp>(
1269 adaptor.getSource(), getAsLLVMValue(rewriter, loc, position.back()));
1276 inserted = rewriter.create<LLVM::InsertValueOp>(
1281 rewriter.replaceOp(insertOp, inserted);
1294 ConversionPatternRewriter &rewriter) const override {
1295 rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
1309 ConversionPatternRewriter &rewriter) const override {
1310 rewriter.replaceOpWithNewOp<LLVM::vector_extract>(
1349 PatternRewriter &rewriter) const override {
1356 Value zero = rewriter.create<arith::ConstantOp>(
1357 loc, elemType, rewriter.getZeroAttr(elemType));
1358 Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero);
1360 Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i);
1361 Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i);
1362 Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i);
1363 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
1364 desc = rewriter.create<InsertOp>(loc, fma, desc, i);
1366 rewriter.replaceOp(op, desc);
1408 ConversionPatternRewriter &rewriter) const override {
1441 auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
1444 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1446 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1447 desc.setAllocatedPtr(rewriter, loc, allocated);
1450 Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1451 desc.setAlignedPtr(rewriter, loc, ptr);
1453 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
1454 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
1455 desc.setOffset(rewriter, loc, zero);
1462 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
1463 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
1464 desc.setSize(rewriter, loc, index, size);
1465 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
1467 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
1468 desc.setStride(rewriter, loc, index, stride);
1471 rewriter.replaceOp(castOp, {desc});
1488 ConversionPatternRewriter &rewriter) const override {
1493 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1495 Value indices = rewriter.create<LLVM::StepVectorOp>(
1498 auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
1500 Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
1501 Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
1503 rewriter.replaceOp(op, comp);
1530 ConversionPatternRewriter &rewriter) const override {
1543 if (failed(emitScalarPrint(rewriter, parent, loc, printType, value)))
1550 LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
1573 emitCall(rewriter, printOp->getLoc(), op.value());
1576 rewriter.eraseOp(printOp);
1590 LogicalResult emitScalarPrint(ConversionPatternRewriter &rewriter,
1646 value = rewriter.create<arith::ExtUIOp>(
1647 loc, IntegerType::get(rewriter.getContext(), 64), value);
1650 value = rewriter.create<arith::ExtSIOp>(
1651 loc, IntegerType::get(rewriter.getContext(), 64), value);
1654 value = rewriter.create<LLVM::BitcastOp>(
1655 loc, IntegerType::get(rewriter.getContext(), 16), value);
1660 emitCall(rewriter, loc, printer.value(), value);
1665 static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1667 rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref),
1679 ConversionPatternRewriter &rewriter) const override {
1686 Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
1687 auto zero = rewriter.create<LLVM::ConstantOp>(
1689 typeConverter->convertType(rewriter.getIntegerType(32)),
1690 rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1694 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1700 auto v = rewriter.create<LLVM::InsertElementOp>(
1707 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef,
1721 ConversionPatternRewriter &rewriter) const override {
1736 Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy);
1740 Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy);
1741 auto zero = rewriter.create<LLVM::ConstantOp>(
1742 loc, typeConverter->convertType(rewriter.getIntegerType(32)),
1743 rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1744 Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
1750 v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);
1754 nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
1755 desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position);
1757 rewriter.replaceOp(splatOp, desc);
1770 ConversionPatternRewriter &rewriter) const override {
1774 return rewriter.notifyMatchFailure(interleaveOp,
1778 rewriter.replaceOpWithNewOp<LLVM::vector_interleave2>(
1794 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1809 ConversionPatternRewriter &rewriter) const override {
1817 return rewriter.notifyMatchFailure(deinterleaveOp,
1825 auto intrinsic = rewriter.create<LLVM::vector_deinterleave2>(
1828 auto evenResult = rewriter.create<LLVM::ExtractValueOp>(
1830 auto oddResult = rewriter.create<LLVM::ExtractValueOp>(
1833 rewriter.replaceOp(deinterleaveOp, ValueRange{evenResult, oddResult});
1854 auto poison = rewriter.create<LLVM::PoisonOp>(loc, sourceType);
1855 auto evenShuffle = rewriter.create<LLVM::ShuffleVectorOp>(
1857 auto oddShuffle = rewriter.create<LLVM::ShuffleVectorOp>(
1860 rewriter.replaceOp(deinterleaveOp, ValueRange{evenShuffle, oddShuffle});
1872 ConversionPatternRewriter &rewriter) const override {
1878 return rewriter.notifyMatchFailure(fromElementsOp,
1881 Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
1883 result = rewriter.create<vector::InsertOp>(loc, val, result, idx);
1884 rewriter.replaceOp(fromElementsOp, result);
1896 ConversionPatternRewriter &rewriter) const override {
1902 rewriter.replaceOpWithNewOp<LLVM::StepVectorOp>(stepOp, llvmType);