Lines Matching full:op
53 /// applied by taking into account the permutation map of the transfer op. If
76 // Return true if the contract op can be convert to MMA matmul.
144 // Return true if the transfer op can be converted to a MMA matrix load.
168 // Return true if the transfer op can be converted to a MMA matrix store.
187 /// converted to a MMA constant matrix op.
200 /// Return true if this integer extend op can be folded into a contract op.
212 /// Return the MMA elementwise enum associated with `op` if it is supported.
215 convertElementwiseOpToMMA(Operation *op) {
216 if (isa<arith::AddFOp>(op))
218 if (isa<arith::MulFOp>(op))
220 if (isa<arith::SubFOp>(op))
222 if (isa<arith::MaximumFOp>(op))
224 if (isa<arith::MinimumFOp>(op))
226 if (isa<arith::DivFOp>(op))
228 if (isa<arith::AddIOp>(op))
230 if (isa<arith::MulIOp>(op))
232 if (isa<arith::SubIOp>(op))
234 if (isa<arith::DivSIOp>(op))
236 if (isa<arith::DivUIOp>(op))
238 if (isa<arith::NegFOp>(op))
240 if (isa<arith::ExtFOp>(op))
245 /// Return true if the op is supported as elementwise op on MMAMatrix type.
246 static bool elementwiseSupportsMMAMatrixType(Operation *op) {
247 return convertElementwiseOpToMMA(op).has_value();
250 /// Returns true if the extract strided slice op is supported with `mma.sync`
253 extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) {
256 nvgpu::getWarpMatrixInfo(op);
260 FailureOr<vector::ContractionOp> contractOp = nvgpu::getUserContract(op);
265 // matrixB and matrixC operands. vector.extract_strided_slice op
268 return (cast<VectorType>(op->getResult(0).getType()) ==
271 return (cast<VectorType>(op->getResult(0).getType()) ==
277 static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
278 if (isa<scf::ForOp, scf::YieldOp>(op))
280 if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
283 if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
286 if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
289 if (auto contract = dyn_cast<vector::ContractionOp>(op))
291 if (auto constant = dyn_cast<arith::ConstantOp>(op))
293 if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
295 if (auto signedExtend = dyn_cast<arith::ExtSIOp>(op))
297 if (auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op))
299 if (auto fpExtend = dyn_cast<arith::ExtFOp>(op))
301 return elementwiseSupportsMMAMatrixType(op);
308 getSliceContract(Operation *op,
312 slice.insert(op);
343 // Analyze slice of operations based on convert op to figure out if the whole
345 static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
347 auto hasVectorDest = [](Operation *op) {
348 return llvm::any_of(op->getResultTypes(), llvm::IsaPred<VectorType>);
353 auto hasVectorSrc = [](Operation *op) {
354 return llvm::any_of(op->getOperandTypes(), llvm::IsaPred<VectorType>);
360 op->walk([&](vector::ContractionOp contract) {
368 if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
369 if (!supportsMMaMatrixType(op, useNvGpu)) {
370 LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n");
390 LogicalResult matchAndRewrite(vector::ContractionOp op,
392 Location loc = op.getLoc();
393 Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc();
398 return AffineMap::inferFromExprList(m, op.getContext());
403 auto iteratorTypes = op.getIteratorTypes().getValue();
404 SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
408 return rewriter.notifyMatchFailure(op, "not a gemm contraction");
414 return rewriter.notifyMatchFailure(op, "contraction already prepared");
436 return rewriter.notifyMatchFailure(op, "unexpected contraction case");
439 op, lhs, rhs, res,
441 op.getIteratorTypes());
446 // Fold transpose op into the transfer read op. NVGPU mma.sync op only supports
454 LogicalResult matchAndRewrite(vector::TransposeOp op,
457 Value source = op.getVector();
458 Type resultType = op.getType();
471 return rewriter.notifyMatchFailure(op, "no transfer read");
475 return rewriter.notifyMatchFailure(op, "0-D transfer read");
478 return rewriter.notifyMatchFailure(op, "not inbounds transfer read");
481 AffineMap::getPermutationMap(op.getPermutation(), op.getContext());
485 auto loc = op.getLoc();
495 // Fuse through the integer extend op.
498 result = rewriter.create<arith::ExtSIOp>(loc, op.getType(), result)
501 result = rewriter.create<arith::ExtUIOp>(loc, op.getType(), result)
504 result = rewriter.create<arith::ExtFOp>(loc, op.getType(), result)
508 rewriter.replaceOp(op, result);
516 // Figure the right layout to use by looking at op uses.
519 static const char *inferFragType(Operation *op) {
522 if (op->hasOneUse()) {
523 Operation *userOp = *op->user_begin();
528 for (Operation *users : op->getUsers()) {
532 assert(op->getNumResults() == 1);
533 if (contract.getLhs() == op->getResult(0))
535 if (contract.getRhs() == op->getResult(0))
542 convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
545 rewriter.setInsertionPoint(op);
547 assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
548 assert(transferReadSupportsMMAMatrixType(op) &&
552 getStaticallyKnownRowStride(op.getShapedType());
555 return rewriter.notifyMatchFailure(op, "no stride");
558 AffineMap map = op.getPermutationMap();
567 Value mappingResult = op.getResult();
568 auto elType = op.getVectorType().getElementType();
569 const char *fragType = inferFragType(op);
570 if (op->hasOneUse()) {
571 auto *user = *op->user_begin();
575 op.getContext(), cast<IntegerType>(elType).getWidth(),
582 gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType);
584 op.getLoc(), type, op.getSource(), op.getIndices(),
594 convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
597 rewriter.setInsertionPoint(op);
599 assert(transferWriteSupportsMMAMatrixType(op));
601 getStaticallyKnownRowStride(op.getShapedType());
604 return rewriter.notifyMatchFailure(op, "no stride");
607 auto it = valueMapping.find(op.getVector());
610 return rewriter.notifyMatchFailure(op, "no mapping");
615 op.getLoc(), matrix, op.getSource(), op.getIndices(),
621 LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
622 rewriter.eraseOp(op);
637 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
639 convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
642 rewriter.setInsertionPoint(op);
645 nvgpu::getWarpMatrixInfo(op);
648 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
655 return rewriter.notifyMatchFailure(op, "not mma sync reg info");
659 auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
662 return rewriter.notifyMatchFailure(op, "not a splat");
666 op.getLoc(), vectorType,
668 valueMapping[op.getResult()] = result;
679 static FailureOr<bool> isTransposed(vector::TransferReadOp op) {
680 mlir::AffineMap map = op.getPermutationMap();
706 creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
709 rewriter.setInsertionPoint(op);
710 Location loc = op->getLoc();
713 nvgpu::getWarpMatrixInfo(op);
716 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
723 return rewriter.notifyMatchFailure(op, "not mma sync reg info");
726 FailureOr<bool> transpose = isTransposed(op);
730 op, "Op should likely not be converted to a nvgpu.ldmatrix call.");
740 << "Op should likely not be converted to a nvgpu.ldmatrix call.\n");
742 op, "failed to convert vector.transfer_read to ldmatrix; this op "
752 return rewriter.notifyMatchFailure(op, "no offsets");
758 getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId},
762 loc, vectorType, op.getSource(), indices, *transpose, params->numTiles);
763 valueMapping[op] = newOp->getResult(0);
768 createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
771 rewriter.setInsertionPoint(op);
773 Location loc = op.getLoc();
775 nvgpu::getWarpMatrixInfo(op);
777 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
782 op, "Failed to deduce register fragment type during "
794 op.getLoc(), vectorType.getElementType(),
797 rewriter.create<vector::SplatOp>(op.getLoc(), fill, vectorType);
799 bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
810 rewriter, op.getLoc(), *warpMatrixInfo);
812 return rewriter.notifyMatchFailure(op, "no coords");
819 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
822 op.getSource(), newIndices);
837 rewriter, op.getLoc(), *warpMatrixInfo);
839 return rewriter.notifyMatchFailure(op, "no coords");
843 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
844 Value el = rewriter.create<memref::LoadOp>(op.getLoc(), loadedElType,
845 op.getSource(), newIndices);
847 op.getLoc(), el, result, ArrayRef<int64_t>{i, innerIdx});
852 valueMapping[op.getResult()] = result;
868 convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op,
871 rewriter.setInsertionPoint(op);
874 nvgpu::getWarpMatrixInfo(op);
876 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
879 isSharedMemory(cast<MemRefType>(op.getSource().getType())) &&
882 VectorType vecTy = op.getVectorType();
888 if (!op.getPermutationMap().isMinorIdentity() &&
894 return createNonLdMatrixLoads(rewriter, op, valueMapping);
896 return creatLdMatrixCompatibleLoads(rewriter, op, valueMapping);
900 convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
903 rewriter.setInsertionPoint(op);
905 Location loc = op->getLoc();
906 auto it = valueMapping.find(op.getVector());
908 return rewriter.notifyMatchFailure(op, "no mapping");
912 nvgpu::getWarpMatrixInfo(op);
914 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
918 return rewriter.notifyMatchFailure(op, "not mma sync reg info");
928 rewriter, op.getLoc(), *warpMatrixInfo);
930 return rewriter.notifyMatchFailure(op, "no coords");
936 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
937 rewriter.create<vector::StoreOp>(loc, el, op.getSource(), newIndices);
940 LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
941 rewriter.eraseOp(op);
953 vector::ExtractStridedSliceOp op,
956 rewriter.setInsertionPoint(op);
958 Location loc = op->getLoc();
961 nvgpu::getWarpMatrixInfo(op);
963 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
968 return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo");
971 auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
973 return rewriter.notifyMatchFailure(op, "no transfer read");
977 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
982 return rewriter.notifyMatchFailure(op, "no ldFragmentInfo");
989 // Create vector.extract_strided_slice op for thread-owned fragments.
997 return rewriter.notifyMatchFailure(op, "no mapping");
1002 populateFromInt64AttrArray(op.getOffsets(), offsets);
1005 populateFromInt64AttrArray(op.getSizes(), sizes);
1006 ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape();
1014 return op->emitError() << "Slicing fragments in 2D is not supported. ";
1023 valueMapping[op] = newOp;
1028 convertContractOp(RewriterBase &rewriter, vector::ContractionOp op,
1031 rewriter.setInsertionPoint(op);
1033 auto itA = valueMapping.find(op.getLhs());
1034 auto itB = valueMapping.find(op.getRhs());
1035 auto itC = valueMapping.find(op.getAcc());
1038 return rewriter.notifyMatchFailure(op, "no mapping");
1041 op.getLoc(), opC.getType(), opA, opB, opC, /*a_transpose=*/UnitAttr(),
1043 valueMapping[op.getResult()] = matmul;
1048 convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op,
1051 rewriter.setInsertionPoint(op);
1053 auto itA = valueMapping.find(op.getLhs());
1054 auto itB = valueMapping.find(op.getRhs());
1055 auto itC = valueMapping.find(op.getAcc());
1058 return rewriter.notifyMatchFailure(op, "no mapping");
1060 int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0];
1061 int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0];
1062 int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1];
1064 op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k}));
1065 valueMapping[op.getResult()] = matmul;
1069 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
1071 convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op,
1074 rewriter.setInsertionPoint(op);
1076 assert(constantSupportsMMAMatrixType(op));
1079 cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>();
1081 rewriter.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat);
1082 const char *fragType = inferFragType(op);
1083 auto vecType = cast<VectorType>(op.getType());
1087 op.getLoc(), type, scalarConstant);
1088 valueMapping[op.getResult()] = matrix;
1092 /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
1094 convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op,
1097 rewriter.setInsertionPoint(op);
1099 assert(broadcastSupportsMMAMatrixType(op));
1101 const char *fragType = inferFragType(op);
1102 auto vecType = op.getResultVectorType();
1106 op.getLoc(), type, op.getSource());
1107 valueMapping[op.getResult()] = matrix;
1145 static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
1148 rewriter.setInsertionPoint(op);
1152 for (const auto &operand : llvm::enumerate(op.getInitArgs())) {
1159 operand.index(), op.getInitArgs().size() + newOperands.size()));
1163 scf::ForOp newForOp = replaceForOpWithNewSignature(rewriter, op, newOperands);
1178 convertYieldOp(RewriterBase &rewriter, scf::YieldOp op,
1181 rewriter.setInsertionPoint(op);
1183 auto loop = cast<scf::ForOp>(op->getParentOp());
1184 auto yieldOperands = llvm::to_vector<4>(op.getOperands());
1185 for (const auto &operand : llvm::enumerate(op.getOperands())) {
1189 // Replace the yield of old value with the for op argument to make it easier
1194 rewriter.create<scf::YieldOp>(op.getLoc(), yieldOperands);
1196 LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
1197 rewriter.eraseOp(op);
1201 /// Convert an elementwise op to the equivalent elementwise op on MMA matrix.
1203 convertElementwiseOp(RewriterBase &rewriter, Operation *op,
1207 rewriter.setInsertionPoint(op);
1210 for (Value operand : op->getOperands()) {
1213 return rewriter.notifyMatchFailure(op, "no mapping");
1219 auto vectorType = cast<VectorType>(op->getResultTypes()[0]);
1226 op->getLoc(), resultType, matrixOperands, opType);
1227 valueMapping[op->getResult(0)] = newOp;
1248 for (Operation *op : ops) {
1249 LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n");
1252 if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
1254 } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
1256 } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
1258 } else if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
1260 } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
1262 } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
1264 } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
1266 } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) {
1267 res = convertElementwiseOp(rewriter, op, *elementwiseType, valueMapping);
1279 for (Operation *op : ops) {
1280 if (llvm::TypeSwitch<Operation *, LogicalResult>(op)
1306 .Default([&](Operation *op) {
1307 return op->emitError() << "unhandled vector to mma type: " << *op;
1310 return op->emitOpError()
1311 << "failed to convert op during vector-to-nvgpu conversion";