1 //===- StageSparseOperations.cpp - stage sparse ops rewriting rules -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
10 #include "mlir/Dialect/Func/IR/FuncOps.h"
11 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
12 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
13 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
14
15 using namespace mlir;
16 using namespace mlir::sparse_tensor;
17
18 namespace {
19
20 struct GuardSparseAlloc
21 : public OpRewritePattern<bufferization::AllocTensorOp> {
22 using OpRewritePattern<bufferization::AllocTensorOp>::OpRewritePattern;
23
matchAndRewrite__anon27c8d91d0111::GuardSparseAlloc24 LogicalResult matchAndRewrite(bufferization::AllocTensorOp op,
25 PatternRewriter &rewriter) const override {
26 // Only rewrite sparse allocations.
27 if (!getSparseTensorEncoding(op.getResult().getType()))
28 return failure();
29
30 // Only rewrite sparse allocations that escape the method
31 // without any chance of a finalizing operation in between.
32 // Here we assume that sparse tensor setup never crosses
33 // method boundaries. The current rewriting only repairs
34 // the most obvious allocate-call/return cases.
35 if (!llvm::all_of(op->getUses(), [](OpOperand &use) {
36 return isa<func::ReturnOp, func::CallOp, func::CallIndirectOp>(
37 use.getOwner());
38 }))
39 return failure();
40
41 // Guard escaping empty sparse tensor allocations with a finalizing
42 // operation that leaves the underlying storage in a proper state
43 // before the tensor escapes across the method boundary.
44 rewriter.setInsertionPointAfter(op);
45 auto load = rewriter.create<LoadOp>(op.getLoc(), op.getResult(), true);
46 rewriter.replaceAllUsesExcept(op, load, load);
47 return success();
48 }
49 };
50
51 template <typename StageWithSortOp>
52 struct StageUnorderedSparseOps : public OpRewritePattern<StageWithSortOp> {
53 using OpRewritePattern<StageWithSortOp>::OpRewritePattern;
54
matchAndRewrite__anon27c8d91d0111::StageUnorderedSparseOps55 LogicalResult matchAndRewrite(StageWithSortOp op,
56 PatternRewriter &rewriter) const override {
57 Location loc = op.getLoc();
58 Value tmpBuf = nullptr;
59 auto itOp = llvm::cast<StageWithSortSparseOp>(op.getOperation());
60 LogicalResult stageResult = itOp.stageWithSort(rewriter, tmpBuf);
61 // Deallocate tmpBuf.
62 // TODO: Delegate to buffer deallocation pass in the future.
63 if (succeeded(stageResult) && tmpBuf)
64 rewriter.create<bufferization::DeallocTensorOp>(loc, tmpBuf);
65
66 return stageResult;
67 }
68 };
69 } // namespace
70
populateStageSparseOperationsPatterns(RewritePatternSet & patterns)71 void mlir::populateStageSparseOperationsPatterns(RewritePatternSet &patterns) {
72 patterns.add<GuardSparseAlloc, StageUnorderedSparseOps<ConvertOp>,
73 StageUnorderedSparseOps<ConcatenateOp>>(patterns.getContext());
74 }
75