xref: /llvm-project/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp (revision 956c0707d9098499a2682297b71f46b0a562eed9)
1126e7eafSRob Suderman //===- TosaToTensor.cpp - Lowering Tosa to Tensor Dialect -------------===//
2126e7eafSRob Suderman //
3126e7eafSRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4126e7eafSRob Suderman // See https://llvm.org/LICENSE.txt for license information.
5126e7eafSRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6126e7eafSRob Suderman //
7126e7eafSRob Suderman //===----------------------------------------------------------------------===//
8126e7eafSRob Suderman //
9126e7eafSRob Suderman // These rewriters lower from the Tosa to the Tensor dialect.
10126e7eafSRob Suderman //
11126e7eafSRob Suderman //===----------------------------------------------------------------------===//
12126e7eafSRob Suderman 
13126e7eafSRob Suderman #include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
14abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
1586c4972fSSpenser Bauman #include "mlir/Dialect/Arith/Utils/Utils.h"
16126e7eafSRob Suderman #include "mlir/Dialect/Tensor/IR/Tensor.h"
1786c4972fSSpenser Bauman #include "mlir/Dialect/Tensor/Utils/Utils.h"
18126e7eafSRob Suderman #include "mlir/Dialect/Tosa/IR/TosaOps.h"
19126e7eafSRob Suderman #include "mlir/IR/PatternMatch.h"
20723979efSKrzysztof Drewniak #include "mlir/Transforms/DialectConversion.h"
21126e7eafSRob Suderman 
2226d896f3SRafael Ubal #include <numeric>
2326d896f3SRafael Ubal 
24126e7eafSRob Suderman using namespace mlir;
25126e7eafSRob Suderman using namespace tosa;
26126e7eafSRob Suderman 
2726d896f3SRafael Ubal namespace {
2826d896f3SRafael Ubal 
2926d896f3SRafael Ubal // Infer the type to which the input of a 'tosa.reshape' op must be cast when
3026d896f3SRafael Ubal // lowered.
3126d896f3SRafael Ubal TensorType inferReshapeInputType(TypedValue<TensorType> input,
3226d896f3SRafael Ubal                                  ArrayRef<int64_t> newShape) {
3326d896f3SRafael Ubal   // No need to cast input for non-empty target shape
3426d896f3SRafael Ubal   if (!newShape.empty())
3526d896f3SRafael Ubal     return input.getType();
3626d896f3SRafael Ubal 
3726d896f3SRafael Ubal   // The input type must be cast into a tensor with the same rank and all static
3826d896f3SRafael Ubal   // dimensions set to 1. This prevents the generation of a tensor.collapse_shape
3926d896f3SRafael Ubal   // op that converts a dynamically shaped tensor into a 0D tensor. While such
4026d896f3SRafael Ubal   // construct is not incorrect on its own, bufferization cannot properly handle
4126d896f3SRafael Ubal   // it at the moment, so we avoid it.
4226d896f3SRafael Ubal   SmallVector<int64_t> shape(input.getType().getRank(), 1);
4326d896f3SRafael Ubal   return input.getType().clone(shape);
44723979efSKrzysztof Drewniak }
45723979efSKrzysztof Drewniak 
4626d896f3SRafael Ubal // Infer the result type of 'tensor.expand_shape' in the collapse-expand
4726d896f3SRafael Ubal // pair emitted for a 'tosa.reshape' op.
4826d896f3SRafael Ubal TensorType inferReshapeExpandedType(TensorType inputType,
4926d896f3SRafael Ubal                                     ArrayRef<int64_t> newShape) {
5026d896f3SRafael Ubal   // Special case for 0D output tensor. Note: Watch out when using Type::clone()
5126d896f3SRafael Ubal   // with just '{}', as it will invoke the incorrect overload.
5226d896f3SRafael Ubal   if (newShape.empty())
5326d896f3SRafael Ubal     return inputType.clone(ArrayRef<int64_t>{});
5426d896f3SRafael Ubal 
5526d896f3SRafael Ubal   // Check if the input is static, and if so, get its total size
5626d896f3SRafael Ubal   bool inputIsStatic = inputType.hasStaticShape();
5726d896f3SRafael Ubal   int64_t totalSize = inputIsStatic ? inputType.getNumElements() : -1;
5826d896f3SRafael Ubal 
5926d896f3SRafael Ubal   // Compute result shape
6026d896f3SRafael Ubal   auto resultShape = llvm::map_to_vector(newShape, [&](int64_t size) -> int64_t {
6126d896f3SRafael Ubal     // If this is not a placeholder, do not change it
6226d896f3SRafael Ubal     if (size >= 0)
6326d896f3SRafael Ubal       return size;
6426d896f3SRafael Ubal 
6526d896f3SRafael Ubal     // If we do not know the total size of the tensor, keep this dimension
6626d896f3SRafael Ubal     // dynamic in the result shape.
679d66dcafSSpenser Bauman     if (!inputIsStatic)
6826d896f3SRafael Ubal       return ShapedType::kDynamic;
69723979efSKrzysztof Drewniak 
7026d896f3SRafael Ubal     // Calculate the product of all elements in 'newShape' except for the -1
7126d896f3SRafael Ubal     // placeholder, which we discard by negating the result.
7226d896f3SRafael Ubal     int64_t totalSizeNoPlaceholder = -std::accumulate(
731eaef445SKazu Hirata         newShape.begin(), newShape.end(), 1, std::multiplies<int64_t>());
7426d896f3SRafael Ubal 
7526d896f3SRafael Ubal     // If there is a 0 component in 'newShape', resolve the placeholder as 0.
7626d896f3SRafael Ubal     if (totalSizeNoPlaceholder == 0)
7726d896f3SRafael Ubal       return 0;
7826d896f3SRafael Ubal 
7926d896f3SRafael Ubal     // Resolve the placeholder as the quotient between the total tensor size and
8026d896f3SRafael Ubal     // the product of all other sizes.
8126d896f3SRafael Ubal     return totalSize / totalSizeNoPlaceholder;
8226d896f3SRafael Ubal   });
8326d896f3SRafael Ubal 
849d66dcafSSpenser Bauman   bool resultIsStatic = !ShapedType::isDynamicShape(resultShape);
859d66dcafSSpenser Bauman 
8626d896f3SRafael Ubal   // A syntactic restriction in 'tensor.expand_shape' forbids a dynamically
8726d896f3SRafael Ubal   // shaped input from being reshaped into a statically shaped result. We may
8826d896f3SRafael Ubal   // simply turn the first result dimension dynamic to address this.
8926d896f3SRafael Ubal   if (!inputIsStatic && resultIsStatic)
9026d896f3SRafael Ubal     resultShape[0] = ShapedType::kDynamic;
9126d896f3SRafael Ubal 
9226d896f3SRafael Ubal   // The 'tensor.expand_shape' op also forbids a statically shaped input from
9326d896f3SRafael Ubal   // being reshaped into a dynamically shaped result, but the placeholder
9426d896f3SRafael Ubal   // inference algorithm above guarantees that this will never be the case.
9526d896f3SRafael Ubal   assert(!inputIsStatic || resultIsStatic);
9626d896f3SRafael Ubal 
9726d896f3SRafael Ubal   // Create result type
9826d896f3SRafael Ubal   return inputType.clone(resultShape);
9926d896f3SRafael Ubal }
10026d896f3SRafael Ubal 
10126d896f3SRafael Ubal // Infer the result type of 'tensor.collapse_shape' in the collapse-expand
10226d896f3SRafael Ubal // pair emitted for a 'tosa.reshape' op.
10326d896f3SRafael Ubal TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) {
10426d896f3SRafael Ubal   auto lhsShape = lhsType.getShape();
10526d896f3SRafael Ubal   auto rhsShape = rhsType.getShape();
10626d896f3SRafael Ubal 
10726d896f3SRafael Ubal   if (lhsShape.empty() || rhsShape.empty())
10826d896f3SRafael Ubal     return lhsType.clone(ArrayRef<int64_t>{});
10926d896f3SRafael Ubal 
11026d896f3SRafael Ubal   if (ShapedType::isDynamicShape(lhsShape) || ShapedType::isDynamicShape(rhsShape))
11126d896f3SRafael Ubal     return lhsType.clone({ShapedType::kDynamic});
11226d896f3SRafael Ubal 
11326d896f3SRafael Ubal   SmallVector<int64_t> intermediateShape;
114723979efSKrzysztof Drewniak   unsigned currLhsDim = 0, currRhsDim = 0;
115723979efSKrzysztof Drewniak   while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
116723979efSKrzysztof Drewniak     int64_t rhsSize = rhsShape[currRhsDim];
117723979efSKrzysztof Drewniak     int64_t lhsSize = lhsShape[currLhsDim];
118723979efSKrzysztof Drewniak     while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
119723979efSKrzysztof Drewniak            currRhsDim < rhsShape.size()) {
120723979efSKrzysztof Drewniak       if (lhsSize < rhsSize) {
121723979efSKrzysztof Drewniak         currLhsDim++;
122723979efSKrzysztof Drewniak         if (currLhsDim < lhsShape.size()) {
123723979efSKrzysztof Drewniak           lhsSize *= lhsShape[currLhsDim];
124723979efSKrzysztof Drewniak         }
125723979efSKrzysztof Drewniak       } else {
126723979efSKrzysztof Drewniak         currRhsDim++;
127723979efSKrzysztof Drewniak         if (currRhsDim < rhsShape.size()) {
128723979efSKrzysztof Drewniak           rhsSize *= rhsShape[currRhsDim];
129723979efSKrzysztof Drewniak         }
130723979efSKrzysztof Drewniak       }
131723979efSKrzysztof Drewniak     }
132723979efSKrzysztof Drewniak     if (lhsSize == rhsSize) {
133723979efSKrzysztof Drewniak       intermediateShape.push_back(lhsSize);
134723979efSKrzysztof Drewniak     }
135723979efSKrzysztof Drewniak     currRhsDim++;
136723979efSKrzysztof Drewniak     currLhsDim++;
137723979efSKrzysztof Drewniak   }
138723979efSKrzysztof Drewniak 
13926d896f3SRafael Ubal   // Static shapes are guaranteed to be compatible by the op verifier, so all
14026d896f3SRafael Ubal   // leftover dimensions should be 1.
14126d896f3SRafael Ubal   for (; currLhsDim < lhsShape.size(); currLhsDim++) {
14226d896f3SRafael Ubal     assert(lhsShape[currLhsDim] == 1);
143723979efSKrzysztof Drewniak   }
14426d896f3SRafael Ubal   for (; currRhsDim < rhsShape.size(); currRhsDim++) {
14526d896f3SRafael Ubal     assert(rhsShape[currRhsDim] == 1);
146723979efSKrzysztof Drewniak   }
147723979efSKrzysztof Drewniak 
14826d896f3SRafael Ubal   return lhsType.clone(intermediateShape);
149723979efSKrzysztof Drewniak }
150723979efSKrzysztof Drewniak 
15126d896f3SRafael Ubal SmallVector<ReassociationExprs>
15226d896f3SRafael Ubal createReassociationMapForCollapse(OpBuilder &builder, Type srcType, Type dstType) {
15326d896f3SRafael Ubal   auto srcShape = cast<TensorType>(srcType).getShape();
15426d896f3SRafael Ubal   auto dstShape = cast<TensorType>(dstType).getShape();
155723979efSKrzysztof Drewniak 
15626d896f3SRafael Ubal   if (srcShape.empty() || dstShape.empty())
15726d896f3SRafael Ubal     return {};
158723979efSKrzysztof Drewniak 
15926d896f3SRafael Ubal   if (ShapedType::isDynamicShape(srcShape) || ShapedType::isDynamicShape(dstShape)) {
16026d896f3SRafael Ubal     assert(dstShape.size() == 1);
161723979efSKrzysztof Drewniak     SmallVector<AffineExpr, 2> exprs;
16226d896f3SRafael Ubal     for (auto i : llvm::seq<int64_t>(srcShape.size()))
16326d896f3SRafael Ubal       exprs.push_back(builder.getAffineDimExpr(i));
16426d896f3SRafael Ubal     return {exprs};
165723979efSKrzysztof Drewniak   }
166723979efSKrzysztof Drewniak 
16726d896f3SRafael Ubal   SmallVector<ReassociationExprs> reassociationMap(dstShape.size());
168723979efSKrzysztof Drewniak   unsigned currSrcDim = 0, currDstDim = 0;
169723979efSKrzysztof Drewniak   while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
170723979efSKrzysztof Drewniak     int64_t dstSize = dstShape[currDstDim];
171723979efSKrzysztof Drewniak     int64_t srcSize = srcShape[currSrcDim];
172723979efSKrzysztof Drewniak     while (srcSize < dstSize && currSrcDim < srcShape.size()) {
173723979efSKrzysztof Drewniak       reassociationMap[currDstDim].push_back(
17426d896f3SRafael Ubal           builder.getAffineDimExpr(currSrcDim++));
175723979efSKrzysztof Drewniak       srcSize *= srcShape[currSrcDim];
176723979efSKrzysztof Drewniak     }
177723979efSKrzysztof Drewniak     if (srcSize == dstSize) {
178723979efSKrzysztof Drewniak       reassociationMap[currDstDim].push_back(
17926d896f3SRafael Ubal           builder.getAffineDimExpr(currSrcDim++));
180723979efSKrzysztof Drewniak       // If the next dim in collapsedShape is not 1, treat subsequent dims in
181723979efSKrzysztof Drewniak       // expandedShape which are 1 to be collapsed.
182723979efSKrzysztof Drewniak       if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
183723979efSKrzysztof Drewniak         while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
184723979efSKrzysztof Drewniak           reassociationMap[currDstDim].push_back(
18526d896f3SRafael Ubal               builder.getAffineDimExpr(currSrcDim++));
186723979efSKrzysztof Drewniak         }
187723979efSKrzysztof Drewniak       }
188723979efSKrzysztof Drewniak     }
189723979efSKrzysztof Drewniak     currDstDim++;
190723979efSKrzysztof Drewniak   }
191723979efSKrzysztof Drewniak 
19226d896f3SRafael Ubal   // If the source and target shapes are compatible, both iterators must have
19326d896f3SRafael Ubal   // reached the end. This condition is guaranteed by the op verifier for
19426d896f3SRafael Ubal   // static shapes.
19526d896f3SRafael Ubal   assert(currSrcDim == srcShape.size() && currDstDim == dstShape.size());
19626d896f3SRafael Ubal   return reassociationMap;
197723979efSKrzysztof Drewniak }
198723979efSKrzysztof Drewniak 
19926d896f3SRafael Ubal // Create a tensor.collapse_shape op that reshapes the input into the given
20026d896f3SRafael Ubal // result type.
20126d896f3SRafael Ubal Value createCollapse(OpBuilder &builder, Location loc, TensorType resultType,
20226d896f3SRafael Ubal                      Value input) {
20326d896f3SRafael Ubal   auto reassociationMap =
20426d896f3SRafael Ubal       createReassociationMapForCollapse(builder, input.getType(), resultType);
20526d896f3SRafael Ubal   return builder.createOrFold<tensor::CollapseShapeOp>(loc, resultType, input,
2060ebb0503SMatthias Gehre                                                        reassociationMap);
207723979efSKrzysztof Drewniak }
208723979efSKrzysztof Drewniak 
20926d896f3SRafael Ubal // Create a tensor.expand_shape op that reshapes the input into the given result
21026d896f3SRafael Ubal // type.
21126d896f3SRafael Ubal Value createExpand(OpBuilder &builder, Location loc, TensorType resultType,
21226d896f3SRafael Ubal                    Value input) {
21326d896f3SRafael Ubal   auto reassociationMap =
21426d896f3SRafael Ubal       createReassociationMapForCollapse(builder, resultType, input.getType());
21526d896f3SRafael Ubal   return builder.createOrFold<tensor::ExpandShapeOp>(loc, resultType, input,
2160ebb0503SMatthias Gehre                                                      reassociationMap);
217723979efSKrzysztof Drewniak }
218723979efSKrzysztof Drewniak 
21926d896f3SRafael Ubal class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
220723979efSKrzysztof Drewniak public:
221723979efSKrzysztof Drewniak   using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
222723979efSKrzysztof Drewniak 
223723979efSKrzysztof Drewniak   LogicalResult
224723979efSKrzysztof Drewniak   matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
225723979efSKrzysztof Drewniak                   ConversionPatternRewriter &rewriter) const final {
22626d896f3SRafael Ubal     auto loc = reshape.getLoc();
227af22e274SMatthias Gehre     auto resultType = cast_if_present<ShapedType>(
228af22e274SMatthias Gehre         getTypeConverter()->convertType(reshape.getType()));
229af22e274SMatthias Gehre     if (!resultType) {
230af22e274SMatthias Gehre       return rewriter.notifyMatchFailure(reshape.getLoc(),
231af22e274SMatthias Gehre                                          "could not convert result type");
232af22e274SMatthias Gehre     }
233af22e274SMatthias Gehre     auto input = dyn_cast<TypedValue<TensorType>>(adaptor.getInput1());
234af22e274SMatthias Gehre     if (!input) {
235af22e274SMatthias Gehre       return rewriter.notifyMatchFailure(reshape.getLoc(),
236af22e274SMatthias Gehre                                          "expected input type to be tensor");
237af22e274SMatthias Gehre     }
23826d896f3SRafael Ubal     auto newShape = reshape.getNewShape();
239723979efSKrzysztof Drewniak 
24026d896f3SRafael Ubal     // Infer all intermediate types
24126d896f3SRafael Ubal     auto inputType = inferReshapeInputType(input, newShape);
24226d896f3SRafael Ubal     auto expandedType = inferReshapeExpandedType(inputType, newShape);
24326d896f3SRafael Ubal     auto collapsedType = inferReshapeCollapsedType(inputType, expandedType);
244723979efSKrzysztof Drewniak 
24526d896f3SRafael Ubal     // Cast input if needed
24626d896f3SRafael Ubal     auto castInput = rewriter.createOrFold<tensor::CastOp>(loc, inputType, input);
2470ebb0503SMatthias Gehre 
24826d896f3SRafael Ubal     // Emit collaspe-expand pair
24926d896f3SRafael Ubal     auto collapsed = createCollapse(rewriter, loc, collapsedType, castInput);
25026d896f3SRafael Ubal     auto expanded = createExpand(rewriter, loc, expandedType, collapsed);
2510ebb0503SMatthias Gehre 
25226d896f3SRafael Ubal     // Cast to final result type if needed
25326d896f3SRafael Ubal     auto result = rewriter.createOrFold<tensor::CastOp>(loc, resultType, expanded);
25426d896f3SRafael Ubal     rewriter.replaceOp(reshape, result);
255723979efSKrzysztof Drewniak     return success();
256723979efSKrzysztof Drewniak   }
257723979efSKrzysztof Drewniak };
258723979efSKrzysztof Drewniak 
259723979efSKrzysztof Drewniak class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
260723979efSKrzysztof Drewniak public:
261723979efSKrzysztof Drewniak   using OpConversionPattern<tosa::SliceOp>::OpConversionPattern;
262723979efSKrzysztof Drewniak 
263723979efSKrzysztof Drewniak   LogicalResult
264723979efSKrzysztof Drewniak   matchAndRewrite(tosa::SliceOp sliceOp, OpAdaptor adaptor,
265723979efSKrzysztof Drewniak                   ConversionPatternRewriter &rewriter) const final {
266640973f2SRob Suderman     Location loc = sliceOp.getLoc();
267c6876b4eSJerry-Ge     Value input = adaptor.getInput1();
2689ab732f6SLiqinWeng     ShapedType resultType = cast<ShapedType>(sliceOp.getType());
269d37056c6SLiqinWeng     if (llvm::isa<UnrankedTensorType>(resultType))
2709ab732f6SLiqinWeng       return failure();
271*956c0707SJerry-Ge 
272*956c0707SJerry-Ge     ElementsAttr startElems;
273*956c0707SJerry-Ge     ElementsAttr sizeElems;
274*956c0707SJerry-Ge 
275*956c0707SJerry-Ge     if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
276*956c0707SJerry-Ge       return rewriter.notifyMatchFailure(
277*956c0707SJerry-Ge           sliceOp, "start of slice must be a static ranked shape");
278*956c0707SJerry-Ge 
279*956c0707SJerry-Ge     if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
280*956c0707SJerry-Ge       return rewriter.notifyMatchFailure(
281*956c0707SJerry-Ge           sliceOp, "size of slice must be a static ranked shape");
282*956c0707SJerry-Ge 
283*956c0707SJerry-Ge     llvm::SmallVector<int64_t> sliceStarts =
284*956c0707SJerry-Ge         llvm::to_vector(startElems.getValues<int64_t>());
285*956c0707SJerry-Ge     llvm::SmallVector<int64_t> sliceSizes =
286*956c0707SJerry-Ge         llvm::to_vector(sizeElems.getValues<int64_t>());
287*956c0707SJerry-Ge 
2889e1a3441SAlexander Shaposhnikov     SmallVector<int64_t> strides, sizes;
2895550c821STres Popp     strides.resize(cast<ShapedType>(sliceOp.getType()).getRank(), 1);
290126e7eafSRob Suderman 
291640973f2SRob Suderman     SmallVector<Value> dynSizes;
292*956c0707SJerry-Ge     for (const auto &i : llvm::enumerate(sliceSizes)) {
2939e1a3441SAlexander Shaposhnikov       int64_t size = i.value();
294640973f2SRob Suderman       size_t index = i.index();
295399638f9SAliia Khasanova       sizes.push_back(size == -1 ? ShapedType::kDynamic : size);
296fb4cedccSAliia Khasanova       if (!ShapedType::isDynamic(sizes.back()))
297640973f2SRob Suderman         continue;
298640973f2SRob Suderman 
299640973f2SRob Suderman       auto dim = rewriter.create<tensor::DimOp>(loc, input, index);
300640973f2SRob Suderman       auto offset = rewriter.create<arith::ConstantOp>(
301*956c0707SJerry-Ge           loc, rewriter.getIndexAttr(sliceStarts[index]));
302640973f2SRob Suderman       dynSizes.push_back(rewriter.create<arith::SubIOp>(loc, dim, offset));
303640973f2SRob Suderman     }
304640973f2SRob Suderman 
305640973f2SRob Suderman     auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
306640973f2SRob Suderman         sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes,
307*956c0707SJerry-Ge         ValueRange({}), rewriter.getDenseI64ArrayAttr(sliceStarts),
308a9733b8aSLorenzo Chelini         rewriter.getDenseI64ArrayAttr(sizes),
309a9733b8aSLorenzo Chelini         rewriter.getDenseI64ArrayAttr(strides));
310640973f2SRob Suderman 
311640973f2SRob Suderman     rewriter.replaceOp(sliceOp, newSliceOp.getResult());
312*956c0707SJerry-Ge 
313*956c0707SJerry-Ge     // Remove const_shape ops when it no longer has use point.
314*956c0707SJerry-Ge     Operation *startConstShape = sliceOp.getStart().getDefiningOp();
315*956c0707SJerry-Ge     if (startConstShape->getResult(0).hasOneUse())
316*956c0707SJerry-Ge       rewriter.eraseOp(startConstShape);
317*956c0707SJerry-Ge 
318*956c0707SJerry-Ge     Operation *sizeConstShape = sliceOp.getSize().getDefiningOp();
319*956c0707SJerry-Ge     if (sizeConstShape->getResult(0).hasOneUse())
320*956c0707SJerry-Ge       rewriter.eraseOp(sizeConstShape);
321*956c0707SJerry-Ge 
322126e7eafSRob Suderman     return success();
323126e7eafSRob Suderman   }
324126e7eafSRob Suderman };
325126e7eafSRob Suderman 
326af22e274SMatthias Gehre class PadConverter : public OpConversionPattern<tosa::PadOp> {
3272a196254SRamkumar Ramachandra public:
328af22e274SMatthias Gehre   using OpConversionPattern::OpConversionPattern;
3292a196254SRamkumar Ramachandra 
330af22e274SMatthias Gehre   LogicalResult
331af22e274SMatthias Gehre   matchAndRewrite(tosa::PadOp padOp, OpAdaptor adaptor,
332af22e274SMatthias Gehre                   ConversionPatternRewriter &rewriter) const final {
3332a196254SRamkumar Ramachandra     auto loc = padOp.getLoc();
3342a196254SRamkumar Ramachandra     auto input = padOp.getInput1();
3357e622b61SJerry-Ge 
3367e622b61SJerry-Ge     ElementsAttr paddingElems;
3377e622b61SJerry-Ge     if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
3387e622b61SJerry-Ge       return rewriter.notifyMatchFailure(
3397e622b61SJerry-Ge           padOp, "padding must be a static shape value");
3407e622b61SJerry-Ge     }
3417e622b61SJerry-Ge     llvm::SmallVector<int64_t> paddingVals;
3427e622b61SJerry-Ge     for (auto idx : paddingElems.getValues<IntegerAttr>()) {
3437e622b61SJerry-Ge       paddingVals.push_back(static_cast<int64_t>(idx.getInt()));
3447e622b61SJerry-Ge     }
3452a196254SRamkumar Ramachandra 
3465550c821STres Popp     ShapedType inputTy = cast<ShapedType>(input.getType());
3472a196254SRamkumar Ramachandra     Type elementTy = inputTy.getElementType();
3482a196254SRamkumar Ramachandra     int64_t rank = inputTy.getRank();
3492a196254SRamkumar Ramachandra 
3502a196254SRamkumar Ramachandra     // Setup the default constantAttr.
3512a196254SRamkumar Ramachandra 
3522a196254SRamkumar Ramachandra     Value padConstant;
3532a196254SRamkumar Ramachandra 
3542a196254SRamkumar Ramachandra     if (padOp.getPadConst()) {
3552a196254SRamkumar Ramachandra       padConstant = rewriter.createOrFold<tensor::ExtractOp>(
3562a196254SRamkumar Ramachandra           loc, padOp.getPadConst(), ValueRange({}));
3572a196254SRamkumar Ramachandra     } else {
3586089d612SRahul Kayaith       TypedAttr constantAttr;
3595550c821STres Popp       if (isa<FloatType>(elementTy)) {
3602a196254SRamkumar Ramachandra         constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
3615550c821STres Popp       } else if (isa<IntegerType>(elementTy) && !padOp.getQuantizationInfo()) {
3622a196254SRamkumar Ramachandra         constantAttr = rewriter.getIntegerAttr(elementTy, 0);
3635550c821STres Popp       } else if (isa<IntegerType>(elementTy) && padOp.getQuantizationInfo()) {
3642a196254SRamkumar Ramachandra         int64_t value = padOp.getQuantizationInfo()->getInputZp();
3652a196254SRamkumar Ramachandra         constantAttr = rewriter.getIntegerAttr(elementTy, value);
3662a196254SRamkumar Ramachandra       }
3672a196254SRamkumar Ramachandra       if (constantAttr)
3682a196254SRamkumar Ramachandra         padConstant = rewriter.create<arith::ConstantOp>(loc, constantAttr);
3692a196254SRamkumar Ramachandra     }
3702a196254SRamkumar Ramachandra 
3712a196254SRamkumar Ramachandra     if (!padConstant) {
3722a196254SRamkumar Ramachandra       return rewriter.notifyMatchFailure(
3732a196254SRamkumar Ramachandra           padOp, "tosa.pad was unable to determine the pad constant value.");
3742a196254SRamkumar Ramachandra     }
3752a196254SRamkumar Ramachandra 
3762a196254SRamkumar Ramachandra     SmallVector<OpFoldResult, 3> lowValues;
3772a196254SRamkumar Ramachandra     SmallVector<OpFoldResult, 3> highValues;
3782a196254SRamkumar Ramachandra 
3792a196254SRamkumar Ramachandra     lowValues.reserve(rank);
3802a196254SRamkumar Ramachandra     highValues.reserve(rank);
3812a196254SRamkumar Ramachandra 
3822a196254SRamkumar Ramachandra     for (int i = 0; i < rank; i++) {
3837e622b61SJerry-Ge       Value lowVal = rewriter.create<arith::ConstantOp>(
3847e622b61SJerry-Ge           loc, rewriter.getIndexAttr(paddingVals[2 * i]));
3857e622b61SJerry-Ge       Value highVal = rewriter.create<arith::ConstantOp>(
3867e622b61SJerry-Ge           loc, rewriter.getIndexAttr(paddingVals[2 * i + 1]));
3872a196254SRamkumar Ramachandra       lowValues.push_back(lowVal);
3882a196254SRamkumar Ramachandra       highValues.push_back(highVal);
3892a196254SRamkumar Ramachandra     }
3902a196254SRamkumar Ramachandra 
3912a196254SRamkumar Ramachandra     auto newPadOp = rewriter.create<tensor::PadOp>(
3922a196254SRamkumar Ramachandra         loc, padOp.getType(), input, lowValues, highValues, padConstant);
3932a196254SRamkumar Ramachandra 
3942a196254SRamkumar Ramachandra     rewriter.replaceOp(padOp, newPadOp.getResult());
3952a196254SRamkumar Ramachandra     return success();
3962a196254SRamkumar Ramachandra   }
3972a196254SRamkumar Ramachandra };
3982a196254SRamkumar Ramachandra 
399e377520aSMaya Amrami struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
400e377520aSMaya Amrami   using OpConversionPattern<tosa::ConcatOp>::OpConversionPattern;
401e377520aSMaya Amrami 
402e377520aSMaya Amrami   LogicalResult
403e377520aSMaya Amrami   matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor,
404e377520aSMaya Amrami                   ConversionPatternRewriter &rewriter) const override {
4055550c821STres Popp     auto resultType = dyn_cast<RankedTensorType>(op.getType());
406e377520aSMaya Amrami 
407e377520aSMaya Amrami     Location loc = op.getLoc();
408e377520aSMaya Amrami     int axis = op.getAxis();
40965066c02SHugo Trachino     Value axisValue =
41065066c02SHugo Trachino         rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(axis));
41186c4972fSSpenser Bauman     int64_t rank = resultType.getRank();
412e377520aSMaya Amrami 
41386c4972fSSpenser Bauman     SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
41486c4972fSSpenser Bauman     SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
4156596b0ddSMatthias Springer     SmallVector<OpFoldResult> sizes =
4166596b0ddSMatthias Springer         tensor::getMixedSizes(rewriter, op.getLoc(), adaptor.getOperands()[0]);
417e377520aSMaya Amrami 
41886c4972fSSpenser Bauman     // Pre-compute the offsets along the axis dimension.
41986c4972fSSpenser Bauman     // The axisOffsets will be of size rank + 1, where the last value
42086c4972fSSpenser Bauman     // will hold the total size of the tensor along the 'axis' dimension.
42186c4972fSSpenser Bauman     SmallVector<OpFoldResult> axisOffsets;
42286c4972fSSpenser Bauman     axisOffsets.push_back(rewriter.getIndexAttr(0));
42386c4972fSSpenser Bauman     axisOffsets.push_back(sizes[axis]);
42486c4972fSSpenser Bauman 
425e377520aSMaya Amrami     for (auto arg : adaptor.getOperands().drop_front()) {
426e377520aSMaya Amrami       auto size = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
42786c4972fSSpenser Bauman       auto currentOffset =
42886c4972fSSpenser Bauman           getValueOrCreateConstantIndexOp(rewriter, loc, axisOffsets.back());
42986c4972fSSpenser Bauman       auto total =
43086c4972fSSpenser Bauman           rewriter.createOrFold<arith::AddIOp>(loc, currentOffset, size);
43186c4972fSSpenser Bauman       axisOffsets.push_back(getAsOpFoldResult(total));
432e377520aSMaya Amrami     }
43386c4972fSSpenser Bauman     sizes[axis] = axisOffsets.back();
434e377520aSMaya Amrami 
43586c4972fSSpenser Bauman     // Compute the dynamic sizes of the tensor.empty operation.
43686c4972fSSpenser Bauman     // This is based off of the specified result type of the tosa.concat
43786c4972fSSpenser Bauman     // operation, since we don't want to change the result type of the operation
43886c4972fSSpenser Bauman     // during the conversion.
43986c4972fSSpenser Bauman     SmallVector<Value> dynDims;
44086c4972fSSpenser Bauman     for (int64_t i = 0; i < rank; ++i) {
44186c4972fSSpenser Bauman       if (resultType.isDynamicDim(i)) {
44286c4972fSSpenser Bauman         dynDims.push_back(
44386c4972fSSpenser Bauman             getValueOrCreateConstantIndexOp(rewriter, loc, sizes[i]));
44486c4972fSSpenser Bauman       }
44586c4972fSSpenser Bauman     }
44686c4972fSSpenser Bauman 
44786c4972fSSpenser Bauman     Value result = rewriter.create<tensor::EmptyOp>(
448e377520aSMaya Amrami         loc, resultType.getShape(), resultType.getElementType(), dynDims);
449e377520aSMaya Amrami 
45086c4972fSSpenser Bauman     for (auto [arg, offset] : llvm::zip(adaptor.getOperands(), axisOffsets)) {
4516596b0ddSMatthias Springer       auto sizes = tensor::getMixedSizes(rewriter, op.getLoc(), arg);
45286c4972fSSpenser Bauman       offsets[axis] = offset;
453e377520aSMaya Amrami       result = rewriter.createOrFold<tensor::InsertSliceOp>(
45486c4972fSSpenser Bauman           loc, arg, result, offsets, sizes, strides);
455e377520aSMaya Amrami     }
456e377520aSMaya Amrami     rewriter.replaceOp(op, result);
457e377520aSMaya Amrami     return success();
458e377520aSMaya Amrami   }
459e377520aSMaya Amrami };
460e377520aSMaya Amrami 
461126e7eafSRob Suderman } // namespace
462126e7eafSRob Suderman 
463126e7eafSRob Suderman void mlir::tosa::populateTosaToTensorConversionPatterns(
464206fad0eSMatthias Springer     const TypeConverter &converter, RewritePatternSet *patterns) {
465af22e274SMatthias Gehre   patterns
466af22e274SMatthias Gehre       ->add<ConcatConverter, PadConverter, ReshapeConverter, SliceConverter>(
467af22e274SMatthias Gehre           converter, patterns->getContext());
468126e7eafSRob Suderman }
469