xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp (revision f3a8af07fa9e9dbc3bfa495b34846e3a5962cc27)
1 //===- StageSparseOperations.cpp - stage sparse ops rewriting rules -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
10 #include "mlir/Dialect/Func/IR/FuncOps.h"
11 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
12 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
13 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
14 
15 using namespace mlir;
16 using namespace mlir::sparse_tensor;
17 
18 namespace {
19 
20 struct GuardSparseAlloc
21     : public OpRewritePattern<bufferization::AllocTensorOp> {
22   using OpRewritePattern<bufferization::AllocTensorOp>::OpRewritePattern;
23 
matchAndRewrite__anon27c8d91d0111::GuardSparseAlloc24   LogicalResult matchAndRewrite(bufferization::AllocTensorOp op,
25                                 PatternRewriter &rewriter) const override {
26     // Only rewrite sparse allocations.
27     if (!getSparseTensorEncoding(op.getResult().getType()))
28       return failure();
29 
30     // Only rewrite sparse allocations that escape the method
31     // without any chance of a finalizing operation in between.
32     // Here we assume that sparse tensor setup never crosses
33     // method boundaries. The current rewriting only repairs
34     // the most obvious allocate-call/return cases.
35     if (!llvm::all_of(op->getUses(), [](OpOperand &use) {
36           return isa<func::ReturnOp, func::CallOp, func::CallIndirectOp>(
37               use.getOwner());
38         }))
39       return failure();
40 
41     // Guard escaping empty sparse tensor allocations with a finalizing
42     // operation that leaves the underlying storage in a proper state
43     // before the tensor escapes across the method boundary.
44     rewriter.setInsertionPointAfter(op);
45     auto load = rewriter.create<LoadOp>(op.getLoc(), op.getResult(), true);
46     rewriter.replaceAllUsesExcept(op, load, load);
47     return success();
48   }
49 };
50 
51 template <typename StageWithSortOp>
52 struct StageUnorderedSparseOps : public OpRewritePattern<StageWithSortOp> {
53   using OpRewritePattern<StageWithSortOp>::OpRewritePattern;
54 
matchAndRewrite__anon27c8d91d0111::StageUnorderedSparseOps55   LogicalResult matchAndRewrite(StageWithSortOp op,
56                                 PatternRewriter &rewriter) const override {
57     Location loc = op.getLoc();
58     Value tmpBuf = nullptr;
59     auto itOp = llvm::cast<StageWithSortSparseOp>(op.getOperation());
60     LogicalResult stageResult = itOp.stageWithSort(rewriter, tmpBuf);
61     // Deallocate tmpBuf.
62     // TODO: Delegate to buffer deallocation pass in the future.
63     if (succeeded(stageResult) && tmpBuf)
64       rewriter.create<bufferization::DeallocTensorOp>(loc, tmpBuf);
65 
66     return stageResult;
67   }
68 };
69 } // namespace
70 
populateStageSparseOperationsPatterns(RewritePatternSet & patterns)71 void mlir::populateStageSparseOperationsPatterns(RewritePatternSet &patterns) {
72   patterns.add<GuardSparseAlloc, StageUnorderedSparseOps<ConvertOp>,
73                StageUnorderedSparseOps<ConcatenateOp>>(patterns.getContext());
74 }
75