xref: /llvm-project/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp (revision 6b4b63a832f105039442fc983d0b309abe5261d5)
1 //===- TosaToSCF.cpp - Lowering Tosa to SCF Dialect -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the SCF dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/TosaToSCF/TosaToSCF.h"
14 #include "mlir/Dialect/SCF/IR/SCF.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
17 #include "mlir/IR/IRMapping.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 
21 using namespace mlir;
22 using namespace tosa;
23 
inlineIfCase(Region & srcRegion,Region & dstRegion,OperandRange operands,PatternRewriter & rewriter)24 static void inlineIfCase(Region &srcRegion, Region &dstRegion,
25                          OperandRange operands, PatternRewriter &rewriter) {
26   rewriter.cloneRegionBefore(srcRegion, &dstRegion.front());
27   rewriter.eraseBlock(&dstRegion.back());
28 
29   Block *headBlock = &dstRegion.front();
30   for (auto it : llvm::zip(headBlock->getArguments(), operands))
31     std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
32 
33   auto yield = cast<YieldOp>(headBlock->getTerminator());
34   rewriter.setInsertionPoint(yield);
35   rewriter.create<scf::YieldOp>(yield.getLoc(), yield.getInputs());
36   rewriter.eraseOp(yield);
37 
38   headBlock->eraseArguments(0, headBlock->getNumArguments());
39 }
40 
inlineWhileCase(Region & srcRegion,Region & dstRegion,PatternRewriter & rewriter,bool isCond)41 static void inlineWhileCase(Region &srcRegion, Region &dstRegion,
42                             PatternRewriter &rewriter, bool isCond) {
43   rewriter.cloneRegionBefore(srcRegion, &dstRegion.back());
44   rewriter.eraseBlock(&dstRegion.back());
45 
46   Block *headBlock = &dstRegion.front();
47 
48   auto yield = cast<YieldOp>(headBlock->getTerminator());
49   rewriter.setInsertionPoint(yield);
50   if (isCond) {
51     auto condition =
52         rewriter.create<tensor::ExtractOp>(yield.getLoc(), yield.getOperand(0));
53     rewriter.create<scf::ConditionOp>(yield.getLoc(), condition,
54                                       headBlock->getArguments());
55   } else {
56     rewriter.setInsertionPoint(yield);
57     rewriter.create<scf::YieldOp>(yield.getLoc(), yield.getInputs());
58   }
59   rewriter.eraseOp(yield);
60 }
61 
62 namespace {
63 
64 class IfOpConverter : public OpRewritePattern<tosa::IfOp> {
65 public:
66   using OpRewritePattern<tosa::IfOp>::OpRewritePattern;
67 
matchAndRewrite(tosa::IfOp op,PatternRewriter & rewriter) const68   LogicalResult matchAndRewrite(tosa::IfOp op,
69                                 PatternRewriter &rewriter) const final {
70     auto condition =
71         rewriter.create<tensor::ExtractOp>(op.getLoc(), op.getCond());
72     auto newIf = rewriter.create<scf::IfOp>(op.getLoc(), op.getResultTypes(),
73                                             condition, true);
74 
75     inlineIfCase(op.getThenBranch(), newIf.getThenRegion(), op.getInputs(),
76                  rewriter);
77     inlineIfCase(op.getElseBranch(), newIf.getElseRegion(), op.getInputs(),
78                  rewriter);
79 
80     rewriter.replaceOp(op, newIf.getResults());
81     return success();
82   }
83 };
84 
85 class ScatterOpConverter : public OpRewritePattern<tosa::ScatterOp> {
createTensorDim(OpBuilder & builder,Location loc,Value tensor,int64_t dim)86   static Value createTensorDim(OpBuilder &builder, Location loc, Value tensor,
87                                int64_t dim) {
88     return builder.createOrFold<tensor::DimOp>(loc, tensor, dim);
89   }
90 
createIndexConst(OpBuilder & builder,Location loc,int64_t value)91   static Value createIndexConst(OpBuilder &builder, Location loc,
92                                 int64_t value) {
93     return builder.create<arith::ConstantIndexOp>(loc, value);
94   }
95 
96 public:
97   using OpRewritePattern<tosa::ScatterOp>::OpRewritePattern;
98 
matchAndRewrite(tosa::ScatterOp scatter,PatternRewriter & rewriter) const99   LogicalResult matchAndRewrite(tosa::ScatterOp scatter,
100                                 PatternRewriter &rewriter) const final {
101     auto valuesIn = scatter.getValuesIn();
102     auto indices = scatter.getIndices();
103     auto input = scatter.getInput();
104     auto loc = scatter.getLoc();
105 
106     // N, W, C are chosen to match the TOSA spec
107     auto dimN = createTensorDim(rewriter, loc, input, 0);
108     auto dimW = createTensorDim(rewriter, loc, input, 1);
109     auto dimC = createTensorDim(rewriter, loc, input, 2);
110 
111     auto zero = createIndexConst(rewriter, loc, 0);
112     auto one = createIndexConst(rewriter, loc, 1);
113 
114     // Loop bounds
115     auto lbs = llvm::SmallVector<Value>(2, zero);
116     auto steps = llvm::SmallVector<Value>(2, one);
117     auto ubs = llvm::SmallVector<Value>{{dimN, dimW}};
118 
119     auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
120                          ValueRange args) -> scf::ValueVector {
121       auto n = ivs[0];
122 
123       // Read the index and cast it to index type
124       auto index = builder.create<tensor::ExtractOp>(loc, indices, ivs);
125       auto castIndex = builder.create<arith::IndexCastOp>(
126           loc, builder.getIndexType(), index);
127 
128       // Offset, sizes, and strides for the input tensor
129       auto inputOffset = llvm::to_vector(ivs);
130       inputOffset.push_back(zero);
131 
132       llvm::SmallVector<Value> sizes = {one, one, dimC};
133       llvm::SmallVector<Value> strides = {one, one, one};
134 
135       auto slice = builder.create<tensor::ExtractSliceOp>(
136           loc, input, inputOffset, sizes, strides);
137 
138       // Insert the slice into the output accumulator tensor.
139       llvm::SmallVector<Value> outputOffset = {n, castIndex, zero};
140       auto updated = builder.create<tensor::InsertSliceOp>(
141           loc, slice, args[0], outputOffset, sizes, strides);
142 
143       return {updated};
144     };
145 
146     auto loops = scf::buildLoopNest(rewriter, loc, lbs, ubs, steps,
147                                     ValueRange{valuesIn}, buildBody);
148     rewriter.replaceOp(scatter, loops.results);
149 
150     return success();
151   }
152 };
153 
154 class WhileOpConverter : public OpRewritePattern<tosa::WhileOp> {
155 public:
156   using OpRewritePattern<tosa::WhileOp>::OpRewritePattern;
157 
matchAndRewrite(tosa::WhileOp op,PatternRewriter & rewriter) const158   LogicalResult matchAndRewrite(tosa::WhileOp op,
159                                 PatternRewriter &rewriter) const final {
160     auto newWhile = rewriter.create<scf::WhileOp>(
161         op.getLoc(), op.getResultTypes(), op.getInputs());
162     rewriter.createBlock(&newWhile.getBefore());
163     rewriter.createBlock(&newWhile.getAfter());
164 
165     inlineWhileCase(op.getCond(), newWhile.getBefore(), rewriter, true);
166     inlineWhileCase(op.getBody(), newWhile.getAfter(), rewriter, false);
167 
168     rewriter.replaceOp(op, newWhile.getResults());
169 
170     return success();
171   }
172 };
173 
174 } // namespace
175 
populateTosaToSCFConversionPatterns(RewritePatternSet * patterns)176 void mlir::tosa::populateTosaToSCFConversionPatterns(
177     RewritePatternSet *patterns) {
178   patterns->add<IfOpConverter, ScatterOpConverter, WhileOpConverter>(
179       patterns->getContext());
180 }
181