xref: /llvm-project/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp (revision a5757c5b65f1894de16f549212b1c37793312703)
1 //===- SparseTensorInterfaces.cpp - SparseTensor interfaces impl ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h"
10 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
11 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
12 #include "mlir/IR/PatternMatch.h"
13 
14 using namespace mlir;
15 using namespace mlir::sparse_tensor;
16 
17 #include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp.inc"
18 
19 /// Stage the operations into a sequence of simple operations as follow:
20 /// op -> unsorted_coo +
21 /// unsorted_coo -> sorted_coo +
22 /// sorted_coo -> dstTp.
23 ///
24 /// return `tmpBuf` if a intermediate memory is allocated.
stageWithSortImpl(StageWithSortSparseOp op,PatternRewriter & rewriter,Value & tmpBufs)25 LogicalResult sparse_tensor::detail::stageWithSortImpl(
26     StageWithSortSparseOp op, PatternRewriter &rewriter, Value &tmpBufs) {
27   if (!op.needsExtraSort())
28     return failure();
29 
30   Location loc = op.getLoc();
31   Type finalTp = op->getOpResult(0).getType();
32   SparseTensorType dstStt(cast<RankedTensorType>(finalTp));
33   Type srcCOOTp = dstStt.getCOOType(/*ordered=*/false);
34 
35   // Clones the original operation but changing the output to an unordered COO.
36   Operation *cloned = rewriter.clone(*op.getOperation());
37   rewriter.modifyOpInPlace(cloned, [cloned, srcCOOTp]() {
38     cloned->getOpResult(0).setType(srcCOOTp);
39   });
40   Value srcCOO = cloned->getOpResult(0);
41 
42   // -> sort
43   Type dstCOOTp = dstStt.getCOOType(/*ordered=*/true);
44   Value dstCOO = rewriter.create<ReorderCOOOp>(
45       loc, dstCOOTp, srcCOO, SparseTensorSortKind::HybridQuickSort);
46 
47   // -> dest.
48   if (dstCOO.getType() == finalTp) {
49     rewriter.replaceOp(op, dstCOO);
50   } else {
51     // Need an extra conversion if the target type is not COO.
52     auto c = rewriter.replaceOpWithNewOp<ConvertOp>(op, finalTp, dstCOO);
53     rewriter.setInsertionPointAfter(c);
54     // Informs the caller about the intermediate buffer we allocated. We can not
55     // create a bufferization::DeallocateTensorOp here because it would
56     // introduce cyclic dependency between the SparseTensorDialect and the
57     // BufferizationDialect. Besides, whether the buffer need to be deallocated
58     // by SparseTensorDialect or by BufferDeallocationPass is still TBD.
59     tmpBufs = dstCOO;
60   }
61 
62   return success();
63 }
64