Lines Matching defs:rewriter

57 static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp,
68 AffineExpr d0 = rewriter.getAffineDimExpr(offsetMap.getNumDims());
70 rewriter, loc, d0 + offsetMap.getResult(offsetsIdx++), dims);
391 PatternRewriter &rewriter) const override {
401 bindDims(rewriter.getContext(), m, n, k);
408 return rewriter.notifyMatchFailure(op, "not a gemm contraction");
414 return rewriter.notifyMatchFailure(op, "contraction already prepared");
416 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
418 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
420 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
421 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
424 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
425 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
428 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
431 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
436 return rewriter.notifyMatchFailure(op, "unexpected contraction case");
438 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
440 rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})),
455 PatternRewriter &rewriter) const override {
471 return rewriter.notifyMatchFailure(op, "no transfer read");
475 return rewriter.notifyMatchFailure(op, "0-D transfer read");
478 return rewriter.notifyMatchFailure(op, "not inbounds transfer read");
487 rewriter
498 result = rewriter.create<arith::ExtSIOp>(loc, op.getType(), result)
501 result = rewriter.create<arith::ExtUIOp>(loc, op.getType(), result)
504 result = rewriter.create<arith::ExtFOp>(loc, op.getType(), result)
508 rewriter.replaceOp(op, result);
542 convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
544 OpBuilder::InsertionGuard g(rewriter);
545 rewriter.setInsertionPoint(op);
555 return rewriter.notifyMatchFailure(op, "no stride");
583 Value load = rewriter.create<gpu::SubgroupMmaLoadMatrixOp>(
585 rewriter.getIndexAttr(*stride),
586 isTranspose ? rewriter.getUnitAttr() : UnitAttr());
594 convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
596 OpBuilder::InsertionGuard g(rewriter);
597 rewriter.setInsertionPoint(op);
604 return rewriter.notifyMatchFailure(op, "no stride");
610 return rewriter.notifyMatchFailure(op, "no mapping");
614 auto store = rewriter.create<gpu::SubgroupMmaStoreMatrixOp>(
616 rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr());
622 rewriter.eraseOp(op);
639 convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
641 OpBuilder::InsertionGuard g(rewriter);
642 rewriter.setInsertionPoint(op);
648 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
655 return rewriter.notifyMatchFailure(op, "not mma sync reg info");
662 return rewriter.notifyMatchFailure(op, "not a splat");
665 Value result = rewriter.create<arith::ConstantOp>(
706 creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
708 OpBuilder::InsertionGuard g(rewriter);
709 rewriter.setInsertionPoint(op);
716 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
723 return rewriter.notifyMatchFailure(op, "not mma sync reg info");
729 return rewriter.notifyMatchFailure(
741 return rewriter.notifyMatchFailure(
747 auto laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr);
749 nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params);
752 return rewriter.notifyMatchFailure(op, "no offsets");
758 getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId},
761 nvgpu::LdMatrixOp newOp = rewriter.create<nvgpu::LdMatrixOp>(
768 createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
770 OpBuilder::InsertionGuard g(rewriter);
771 rewriter.setInsertionPoint(op);
777 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
781 return rewriter.notifyMatchFailure(
786 Value laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr);
793 Value fill = rewriter.create<arith::ConstantOp>(
795 rewriter.getZeroAttr(vectorType.getElementType()));
797 rewriter.create<vector::SplatOp>(op.getLoc(), fill, vectorType);
810 rewriter, op.getLoc(), *warpMatrixInfo);
812 return rewriter.notifyMatchFailure(op, "no coords");
814 Value logicalValueId = rewriter.create<arith::ConstantOp>(
815 loc, rewriter.getIndexType(),
816 rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
819 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
821 Value el = rewriter.create<vector::LoadOp>(loc, loadedElType,
823 result = rewriter.create<vector::InsertOp>(loc, el, result, i);
833 Value logicalValueId = rewriter.create<arith::ConstantOp>(
834 loc, rewriter.getIndexType(),
835 rewriter.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
837 rewriter, op.getLoc(), *warpMatrixInfo);
839 return rewriter.notifyMatchFailure(op, "no coords");
843 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
844 Value el = rewriter.create<memref::LoadOp>(op.getLoc(), loadedElType,
846 result = rewriter.create<vector::InsertOp>(
868 convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op,
870 OpBuilder::InsertionGuard g(rewriter);
871 rewriter.setInsertionPoint(op);
876 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
894 return createNonLdMatrixLoads(rewriter, op, valueMapping);
896 return creatLdMatrixCompatibleLoads(rewriter, op, valueMapping);
900 convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
902 OpBuilder::InsertionGuard g(rewriter);
903 rewriter.setInsertionPoint(op);
908 return rewriter.notifyMatchFailure(op, "no mapping");
914 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
918 return rewriter.notifyMatchFailure(op, "not mma sync reg info");
921 Value laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr);
924 Value logicalValueId = rewriter.create<arith::ConstantOp>(
925 loc, rewriter.getIndexType(),
926 rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
928 rewriter, op.getLoc(), *warpMatrixInfo);
930 return rewriter.notifyMatchFailure(op, "no coords");
933 rewriter.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i});
936 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
937 rewriter.create<vector::StoreOp>(loc, el, op.getSource(), newIndices);
941 rewriter.eraseOp(op);
952 convertExtractStridedSlice(RewriterBase &rewriter,
955 OpBuilder::InsertionGuard g(rewriter);
956 rewriter.setInsertionPoint(op);
963 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
968 return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo");
973 return rewriter.notifyMatchFailure(op, "no transfer read");
977 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
982 return rewriter.notifyMatchFailure(op, "no ldFragmentInfo");
997 return rewriter.notifyMatchFailure(op, "no mapping");
1020 Value newOp = rewriter.create<vector::ExtractStridedSliceOp>(
1028 convertContractOp(RewriterBase &rewriter, vector::ContractionOp op,
1030 OpBuilder::InsertionGuard g(rewriter);
1031 rewriter.setInsertionPoint(op);
1038 return rewriter.notifyMatchFailure(op, "no mapping");
1040 Value matmul = rewriter.create<gpu::SubgroupMmaComputeOp>(
1048 convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op,
1050 OpBuilder::InsertionGuard g(rewriter);
1051 rewriter.setInsertionPoint(op);
1058 return rewriter.notifyMatchFailure(op, "no mapping");
1063 Value matmul = rewriter.create<nvgpu::MmaSyncOp>(
1064 op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k}));
1071 convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op,
1073 OpBuilder::InsertionGuard g(rewriter);
1074 rewriter.setInsertionPoint(op);
1081 rewriter.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat);
1086 auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
1094 convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op,
1096 OpBuilder::InsertionGuard g(rewriter);
1097 rewriter.setInsertionPoint(op);
1105 auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
1113 static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
1116 OpBuilder::InsertionGuard g(rewriter);
1117 rewriter.setInsertionPoint(loop);
1120 rewriter.setInsertionPoint(loop);
1123 scf::ForOp newLoop = rewriter.create<scf::ForOp>(
1126 rewriter.eraseBlock(newLoop.getBody());
1135 rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
1141 rewriter.eraseOp(loop);
1145 static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
1147 OpBuilder::InsertionGuard g(rewriter);
1148 rewriter.setInsertionPoint(op);
1163 scf::ForOp newForOp = replaceForOpWithNewSignature(rewriter, op, newOperands);
1178 convertYieldOp(RewriterBase &rewriter, scf::YieldOp op,
1180 OpBuilder::InsertionGuard g(rewriter);
1181 rewriter.setInsertionPoint(op);
1194 rewriter.create<scf::YieldOp>(op.getLoc(), yieldOperands);
1197 rewriter.eraseOp(op);
1203 convertElementwiseOp(RewriterBase &rewriter, Operation *op,
1206 OpBuilder::InsertionGuard g(rewriter);
1207 rewriter.setInsertionPoint(op);
1213 return rewriter.notifyMatchFailure(op, "no mapping");
1225 Value newOp = rewriter.create<gpu::SubgroupMmaElementwiseOp>(
1242 LogicalResult mlir::convertVectorToMMAOps(RewriterBase &rewriter,
1253 res = convertTransferReadOp(rewriter, transferRead, valueMapping);
1255 res = convertTransferWriteOp(rewriter, transferWrite, valueMapping);
1257 res = convertContractOp(rewriter, contractOp, valueMapping);
1259 res = convertConstantOp(rewriter, constantOp, valueMapping);
1261 res = convertBroadcastOp(rewriter, broadcastOp, valueMapping);
1263 res = convertForOp(rewriter, forOp, valueMapping);
1265 res = convertYieldOp(rewriter, yieldOp, valueMapping);
1267 res = convertElementwiseOp(rewriter, op, *elementwiseType, valueMapping);
1275 LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter,
1282 return convertTransferReadToLoads(rewriter, transferReadOp,
1286 return convertTransferWriteToStores(rewriter, transferWriteOp,
1290 return convertExtractStridedSlice(rewriter, extractStridedSliceOp,
1294 return convertContractOpToMmaSync(rewriter, contractionOp,
1298 return convertForOp(rewriter, forOp, valueMapping);
1301 return convertYieldOp(rewriter, yieldOp, valueMapping);
1304 return convertConstantOpMmaSync(rewriter, constOp, valueMapping);
1332 IRRewriter rewriter(&getContext());
1335 convertVectorToNVVMCompatibleMMASync(rewriter, getOperation())))
1339 (void)convertVectorToMMAOps(rewriter, getOperation());