xref: /llvm-project/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (revision 9cbc1f29cabc01c02a523c11d098c00650f6955c)
1 //===- TosaToLinalgNamed.cpp - Lowering Tosa to Linalg Named Ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Linalg named ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/Linalg/IR/Linalg.h"
16 #include "mlir/Dialect/Math/IR/Math.h"
17 #include "mlir/Dialect/SCF/IR/SCF.h"
18 #include "mlir/Dialect/Tensor/IR/Tensor.h"
19 #include "mlir/Dialect/Tensor/Utils/Utils.h"
20 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
21 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
22 #include "mlir/Dialect/Utils/IndexingUtils.h"
23 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/Transforms/DialectConversion.h"
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28 
29 #include "mlir/Interfaces/InferTypeOpInterface.h"
30 
31 #include <numeric>
32 #include <type_traits>
33 
34 using namespace mlir;
35 using namespace mlir::tosa;
36 
37 static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
38                             TypedAttr padAttr, OpBuilder &rewriter) {
39   // Input should be padded only if necessary.
40   if (llvm::all_of(pad, [](int64_t p) { return p == 0; }))
41     return input;
42 
43   ShapedType inputTy = cast<ShapedType>(input.getType());
44   Type inputETy = inputTy.getElementType();
45   auto inputShape = inputTy.getShape();
46 
47   assert((inputShape.size() * 2) == pad.size());
48 
49   SmallVector<int64_t, 4> paddedShape;
50   SmallVector<OpFoldResult, 8> lowIndices;
51   SmallVector<OpFoldResult, 8> highIndices;
52   for (size_t i : llvm::seq(inputShape.size())) {
53     auto lowPad = pad[i * 2];
54     auto highPad = pad[i * 2 + 1];
55     if (ShapedType::isDynamic(inputShape[i]))
56       paddedShape.push_back(inputShape[i]);
57     else
58       paddedShape.push_back(inputShape[i] + highPad + lowPad);
59     lowIndices.push_back(rewriter.getIndexAttr(lowPad));
60     highIndices.push_back(rewriter.getIndexAttr(highPad));
61   }
62 
63   Value padValue = rewriter.create<arith::ConstantOp>(loc, padAttr);
64 
65   return rewriter.create<tensor::PadOp>(
66       loc, RankedTensorType::get(paddedShape, inputETy), input, lowIndices,
67       highIndices, padValue);
68 }
69 
70 static mlir::Value
71 linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias,
72                            Value conv, Value result,
73                            ArrayRef<AffineMap> indexingMaps) {
74   ShapedType resultTy = cast<ShapedType>(conv.getType());
75   return rewriter
76       .create<linalg::GenericOp>(
77           loc, resultTy, ValueRange({bias, conv}), result, indexingMaps,
78           getNParallelLoopsAttrs(resultTy.getRank()),
79           [](OpBuilder &builder, Location loc, ValueRange args) {
80             Value biasVal = args[0];
81             Type resType = args[1].getType();
82             if (resType != biasVal.getType()) {
83               biasVal = builder.create<arith::ExtSIOp>(loc, resType, biasVal);
84             }
85             Value added = builder.create<arith::AddIOp>(loc, biasVal, args[1]);
86             builder.create<linalg::YieldOp>(loc, added);
87           })
88       .getResult(0);
89 }
90 
91 // Construct the affine map that a linalg generic would use to broadcast the
92 // source tensor into the shape of the result tensor.
93 static AffineMap getBroadcastingMap(PatternRewriter &rewriter, Value source,
94                                     Value result) {
95   ShapedType resultTy = cast<ShapedType>(result.getType());
96   ShapedType sourceTy = cast<ShapedType>(source.getType());
97   const int64_t resultRank = resultTy.getRank();
98   const int64_t sourceRank = sourceTy.getRank();
99 
100   // The source tensor is broadcast to all the outer dimensions of the
101   // result tensor.
102   SmallVector<AffineExpr> sourceDims;
103   // In the case of a rank one source tensor with a single element TOSA
104   // specifies that the value be broadcast meaning we need an edge case for a
105   // constant map.
106   assert(sourceTy.hasStaticShape() &&
107          "Dynamic broadcasting shapes not supported!");
108   if (sourceRank == 1 && sourceTy.getDimSize(0) == 1) {
109     sourceDims.push_back(rewriter.getAffineConstantExpr(0));
110   } else {
111     for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
112       auto expr = rewriter.getAffineDimExpr(dim + resultRank - sourceRank);
113       sourceDims.push_back(expr);
114     }
115   }
116 
117   return AffineMap::get(/*dimCount=*/resultRank,
118                         /*symbolCount=*/0, sourceDims, rewriter.getContext());
119 }
120 
121 // Broadcast the source value to all the outer dimensions of the result value.
122 // If required, the element type is expanded using an arith.extsi operation.
123 static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,
124                                                 Location loc, Value source,
125                                                 Value result) {
126   ShapedType resultTy = cast<ShapedType>(result.getType());
127   const int64_t resultRank = resultTy.getRank();
128   // Creating maps for the input and output of the broacast-like generic op.
129   SmallVector<AffineMap, 2> indexingMaps;
130   indexingMaps.push_back(getBroadcastingMap(rewriter, source, result));
131   indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
132 
133   // Build the broadcast-like operation as a linalg.generic.
134   return rewriter
135       .create<linalg::GenericOp>(
136           loc, resultTy, ValueRange({source}), result, indexingMaps,
137           getNParallelLoopsAttrs(resultTy.getRank()),
138           [](OpBuilder &builder, Location loc, ValueRange args) {
139             Value biasVal = args[0];
140             Type resType = args[1].getType();
141             if (resType != biasVal.getType()) {
142               biasVal = builder.create<arith::ExtSIOp>(loc, resType, biasVal);
143             }
144             builder.create<linalg::YieldOp>(loc, biasVal);
145           })
146       .getResult(0);
147 }
148 
149 static mlir::Value reifyConstantDim(int64_t attr,
150                                     ImplicitLocOpBuilder &builder) {
151   return builder.create<arith::ConstantIndexOp>(attr);
152 }
153 
154 // Calculating the output width/height using the formula:
155 // H = ((IH+pad_top+pad_bottom-(dilation_y*(KH-1)+1))/stride_y)+1
156 // W = ((IW+pad_left+pad_right-(dilation_x*(KW-1)+1))/stride_x)+1
157 
158 static mlir::Value getConvOrPoolOutputDim(Location loc, Value inputDim,
159                                           int64_t padBeforeAttr,
160                                           int64_t padAfterAttr, Value kernelDim,
161                                           int64_t strideAttr,
162                                           int64_t dilationAttr,
163                                           OpBuilder &rewriter) {
164   ImplicitLocOpBuilder builder(loc, rewriter);
165   auto one = rewriter.create<arith::ConstantOp>(
166       loc, IntegerAttr::get(inputDim.getType(), 1));
167   Value padBefore = reifyConstantDim(padBeforeAttr, builder);
168   Value paddedBefore = builder.create<arith::AddIOp>(inputDim, padBefore);
169   Value padAfter = reifyConstantDim(padAfterAttr, builder);
170   Value paddedAfter = builder.create<arith::AddIOp>(paddedBefore, padAfter);
171 
172   Value subOne = builder.create<arith::SubIOp>(kernelDim, one);
173   Value dilation = reifyConstantDim(dilationAttr, builder);
174   Value dilated = builder.create<arith::MulIOp>(dilation, subOne);
175   Value addOne = builder.create<arith::AddIOp>(dilated, one);
176 
177   Value subtract = builder.create<arith::SubIOp>(paddedAfter, addOne);
178   Value stride = reifyConstantDim(strideAttr, builder);
179   Value divide = builder.create<arith::DivUIOp>(subtract, stride);
180   return builder.create<arith::AddIOp>(divide, one);
181 }
182 
183 // Creates a vector of the dynamic output dims for Conv2D and Depthwise_Conv2D
184 static SmallVector<Value> inferDynamicDimsForConv(
185     Location loc, Value input, Value weight, ShapedType resultTy,
186     ArrayRef<int64_t> padAttr, ArrayRef<int64_t> strideAttr,
187     ArrayRef<int64_t> dilationAttr, ArrayRef<int64_t> inputSizeDims,
188     ArrayRef<int64_t> kernelSizeDims, OpBuilder &rewriter) {
189   ShapedType inputTy = cast<ShapedType>(input.getType());
190   int64_t inputRank = inputTy.getRank();
191 
192   SmallVector<Value> dynDims;
193   dynDims.resize(resultTy.getRank());
194 
195   for (uint32_t i = 0, s = inputSizeDims.size(); i < s; ++i) {
196     int64_t inputDim = inputSizeDims[i];
197     int64_t kernelDim = kernelSizeDims[i];
198     if (resultTy.isDynamicDim(inputDim)) {
199       auto padTop = padAttr[i * 2];
200       auto padBottom = padAttr[i * 2 + 1];
201       auto stride = strideAttr[i];
202       auto dilation = dilationAttr[i];
203       Value initDynDim = rewriter.create<tensor::DimOp>(loc, input, inputDim);
204       Value kernelDynDim =
205           rewriter.create<tensor::DimOp>(loc, weight, kernelDim);
206       // H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y)
207       dynDims[inputDim] =
208           getConvOrPoolOutputDim(loc, initDynDim, padTop, padBottom,
209                                  kernelDynDim, stride, dilation, rewriter);
210     }
211   }
212 
213   // Get the batch/channels dimensions.
214   for (int i = 0; i < inputRank; i++) {
215     if (resultTy.isDynamicDim(i) && !dynDims[i])
216       dynDims[i] = rewriter.create<tensor::DimOp>(loc, input, i);
217   }
218 
219   SmallVector<Value> filteredDims = condenseValues(dynDims);
220   return filteredDims;
221 }
222 
223 // Creates a map to collapse the last dimension of the Depthwise convolution op
224 // due to a shape mismatch
225 static void createDepthwiseConvCollapseMap(
226     int64_t outputRank, SmallVector<ReassociationExprs, 4> &reassociationMap,
227     OpBuilder &rewriter) {
228   reassociationMap.resize(outputRank);
229   for (int i = 0; i < outputRank; i++) {
230     reassociationMap[i].push_back(rewriter.getAffineDimExpr(i));
231   }
232   reassociationMap[outputRank - 1].push_back(
233       rewriter.getAffineDimExpr(outputRank));
234 }
235 
236 namespace {
237 
238 template <typename TosaConvOp, typename LinalgConvOp, typename LinalgConvQOp>
239 class ConvConverter : public OpConversionPattern<TosaConvOp> {
240 public:
241   using OpConversionPattern<TosaConvOp>::OpConversionPattern;
242   LogicalResult
243   matchAndRewrite(TosaConvOp op, typename TosaConvOp::Adaptor adaptor,
244                   ConversionPatternRewriter &rewriter) const final {
245     Location loc = op->getLoc();
246     Value input = op->getOperand(0);
247     Value weight = op->getOperand(1);
248     Value bias = op->getOperand(2);
249 
250     ShapedType inputTy = cast<ShapedType>(input.getType());
251     ShapedType weightTy = cast<ShapedType>(weight.getType());
252     ShapedType biasTy = cast<ShapedType>(bias.getType());
253     ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
254 
255     Type inputETy = inputTy.getElementType();
256     Type resultETy = resultTy.getElementType();
257 
258     DenseI64ArrayAttr padAttr = op.getPadAttr();
259     DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
260     DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
261     bool isQuantized = op.getQuantizationInfo().has_value();
262 
263     if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
264       return rewriter.notifyMatchFailure(
265           op, "tosa.conv ops require static shapes for weight and bias");
266 
267     if (inputETy.isUnsignedInteger())
268       return rewriter.notifyMatchFailure(
269           op, "tosa.conv ops does not support unsigned integer input");
270 
271     llvm::SmallVector<int64_t> inputSizeDims;
272     llvm::SmallVector<int64_t> kernelSizeDims;
273     for (int i = 1; i < resultTy.getRank() - 1; i++) {
274       inputSizeDims.push_back(i);
275       kernelSizeDims.push_back(i);
276     }
277 
278     SmallVector<Value> filteredDims = inferDynamicDimsForConv(
279         loc, input, weight, resultTy, padAttr.asArrayRef(),
280         strideTosaAttr.asArrayRef(), dilationTosaAttr.asArrayRef(),
281         inputSizeDims, kernelSizeDims, rewriter);
282 
283     auto weightShape = weightTy.getShape();
284 
285     // Apply padding as necessary.
286     TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
287     if (isQuantized) {
288       auto quantizationInfo = *op.getQuantizationInfo();
289       int64_t iZp = quantizationInfo.getInputZp();
290 
291       int64_t intMin =
292           APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
293               .getSExtValue();
294       int64_t intMax =
295           APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
296               .getSExtValue();
297 
298       if (iZp < intMin || iZp > intMax)
299         return rewriter.notifyMatchFailure(
300             op, "tosa.conv op quantization has zp outside of input range");
301 
302       zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
303     }
304 
305     llvm::SmallVector<int64_t> pad;
306     pad.resize(2, 0);
307     llvm::append_range(pad, padAttr.asArrayRef());
308     pad.resize(pad.size() + 2, 0);
309     input = applyPad(loc, input, pad, zeroAttr, rewriter);
310 
311     if (4 == inputTy.getRank()) {
312       // For 2D convolutions, we need to check if the target convolution op
313       // wants a HWCF kernel layout.
314       bool wantHwcf =
315           isQuantized ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
316                       : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
317       if (wantHwcf) {
318         // Transpose the kernel to match dimension ordering of the linalg
319         // convolution operation.
320         // TODO(suderman): See if this can be efficiently folded - check whether
321         // the input is used anywhere else, if not fold the constant.
322         SmallVector<int32_t> weightPerm;
323         for (int i = 1; i < resultTy.getRank(); i++)
324           weightPerm.push_back(i);
325         weightPerm.push_back(0);
326 
327         SmallVector<int64_t> newWeightShape;
328         for (auto dim : weightPerm)
329           newWeightShape.push_back(weightShape[dim]);
330         auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
331         Value weightPermValue =
332             rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
333         Type newWeightTy =
334             RankedTensorType::get(newWeightShape, weightTy.getElementType());
335         weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
336                                                     weightPermValue);
337       }
338     }
339 
340     // For Conv3D transpose the kernel to match dimension ordering of the linalg
341     // convolution operation. Conv2D has a 1-1 mapping in linalg so better to
342     // map directly and then transpose later if desired.
343     if (5 == inputTy.getRank()) {
344       // TODO(suderman): See if this can be efficiently folded - check whether
345       // the input is used anywhere else, if not fold the constant.
346       SmallVector<int32_t> weightPerm;
347       for (int i = 1; i < resultTy.getRank(); i++)
348         weightPerm.push_back(i);
349       weightPerm.push_back(0);
350 
351       SmallVector<int64_t> newWeightShape;
352       for (auto dim : weightPerm)
353         newWeightShape.push_back(weightShape[dim]);
354       auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
355       Value weightPermValue =
356           rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
357       Type newWeightTy =
358           RankedTensorType::get(newWeightShape, weightTy.getElementType());
359       weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
360                                                   weightPermValue);
361     }
362 
363     // Extract the attributes for convolution.
364     ArrayRef<int64_t> stride = strideTosaAttr;
365     ArrayRef<int64_t> dilation = dilationTosaAttr;
366 
367     // Create the convolution op.
368     auto strideAttr = rewriter.getI64TensorAttr(stride);
369     auto dilationAttr = rewriter.getI64TensorAttr(dilation);
370 
371     Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
372         loc, resultTy.getShape(), resultETy, filteredDims);
373 
374     Value broadcastBias =
375         linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
376 
377     if (isQuantized) {
378       auto quantizationInfo = *op.getQuantizationInfo();
379       auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
380       auto kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
381 
382       auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
383       auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
384 
385       Value conv =
386           rewriter
387               .create<LinalgConvQOp>(
388                   loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
389                   ValueRange{broadcastBias}, strideAttr, dilationAttr)
390               ->getResult(0);
391 
392       rewriter.replaceOp(op, conv);
393       return success();
394     }
395 
396     Value conv = rewriter
397                      .create<LinalgConvOp>(
398                          loc, resultTy, ValueRange{input, weight},
399                          ValueRange{broadcastBias}, strideAttr, dilationAttr)
400                      ->getResult(0);
401 
402     rewriter.replaceOp(op, conv);
403     return success();
404   }
405 };
406 
407 class DepthwiseConvConverter
408     : public OpConversionPattern<tosa::DepthwiseConv2DOp> {
409 public:
410   using OpConversionPattern<tosa::DepthwiseConv2DOp>::OpConversionPattern;
411   LogicalResult
412   matchAndRewrite(tosa::DepthwiseConv2DOp op, OpAdaptor adaptor,
413                   ConversionPatternRewriter &rewriter) const final {
414     Location loc = op->getLoc();
415     Value input = op->getOperand(0);
416     Value weight = op->getOperand(1);
417     Value bias = op->getOperand(2);
418 
419     ShapedType inputTy = cast<ShapedType>(input.getType());
420     ShapedType weightTy = cast<ShapedType>(weight.getType());
421     ShapedType biasTy = cast<ShapedType>(bias.getType());
422     ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
423     int64_t resultRank = resultTy.getRank();
424 
425     Type inputETy = inputTy.getElementType();
426     Type resultETy = resultTy.getElementType();
427 
428     auto padAttr = cast<DenseI64ArrayAttr>(op->getAttr("pad"));
429     auto strideTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("stride"));
430     auto dilationTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("dilation"));
431 
432     if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
433       return rewriter.notifyMatchFailure(
434           op, "tosa.depthwise_conv ops require static shapes");
435 
436     // Compute output dynamic dims
437     SmallVector<Value> filteredDims = inferDynamicDimsForConv(
438         loc, input, weight, resultTy, padAttr.asArrayRef(),
439         strideTosaAttr.asArrayRef(), dilationTosaAttr.asArrayRef(),
440         /*inputSizeDims=*/{1, 2},
441         /*kernelSizeDims=*/{0, 1}, rewriter);
442 
443     bool isQuantized = op->hasAttr("quantization_info");
444     IntegerAttr iZp;
445     IntegerAttr kZp;
446     if (isQuantized) {
447       auto quantizationInfo =
448           cast<tosa::ConvOpQuantizationAttr>(op->getAttr("quantization_info"));
449       iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
450       kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
451     }
452 
453     auto weightShape = weightTy.getShape();
454     auto resultShape = resultTy.getShape();
455 
456     // Apply padding as necessary.
457     TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
458     if (isQuantized) {
459       auto quantizationInfo =
460           cast<tosa::ConvOpQuantizationAttr>(op->getAttr("quantization_info"));
461       int64_t iZp = quantizationInfo.getInputZp();
462 
463       int64_t intMin =
464           APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
465               .getSExtValue();
466       int64_t intMax =
467           APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
468               .getSExtValue();
469 
470       if (iZp < intMin || iZp > intMax)
471         return rewriter.notifyMatchFailure(
472             op, "tosa.depthwise_conv op quantization has zp outside of input "
473                 "range");
474 
475       zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
476     }
477 
478     llvm::SmallVector<int64_t> pad;
479     pad.resize(2, 0);
480     llvm::append_range(pad, padAttr.asArrayRef());
481     pad.resize(pad.size() + 2, 0);
482 
483     input = applyPad(loc, input, pad, zeroAttr, rewriter);
484 
485     // Extract the attributes for convolution.
486     ArrayRef<int64_t> stride = strideTosaAttr;
487     ArrayRef<int64_t> dilation = dilationTosaAttr;
488 
489     // Create the convolution op.
490     auto strideAttr = rewriter.getI64TensorAttr(stride);
491     auto dilationAttr = rewriter.getI64TensorAttr(dilation);
492     ShapedType linalgConvTy =
493         RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2],
494                                weightShape[2], weightShape[3]},
495                               resultETy);
496 
497     auto resultZeroAttr = rewriter.getZeroAttr(resultETy);
498     Value emptyTensor = rewriter.create<tensor::EmptyOp>(
499         loc, linalgConvTy.getShape(), resultETy, filteredDims);
500     Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
501     Value zeroTensor = rewriter
502                            .create<linalg::FillOp>(loc, ValueRange{zero},
503                                                    ValueRange{emptyTensor})
504                            .result();
505 
506     Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
507         loc, resultTy.getShape(), resultETy, filteredDims);
508 
509     // Broadcast the initial value to the output tensor before convolving.
510     SmallVector<AffineMap, 4> indexingMaps;
511     indexingMaps.push_back(getBroadcastingMap(rewriter, bias, biasEmptyTensor));
512     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
513     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
514 
515     if (!isQuantized) {
516       Value conv = rewriter
517                        .create<linalg::DepthwiseConv2DNhwcHwcmOp>(
518                            loc, linalgConvTy, ValueRange{input, weight},
519                            ValueRange{zeroTensor}, strideAttr, dilationAttr)
520                        .getResult(0);
521 
522       SmallVector<ReassociationExprs, 4> reassociationMap;
523       createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
524       Value convReshape = rewriter.create<tensor::CollapseShapeOp>(
525           loc, resultTy, conv, reassociationMap);
526 
527       Value result =
528           rewriter
529               .create<linalg::GenericOp>(
530                   loc, resultTy, ValueRange({bias, convReshape}),
531                   biasEmptyTensor, indexingMaps,
532                   getNParallelLoopsAttrs(resultRank),
533                   [&](OpBuilder &nestedBuilder, Location nestedLoc,
534                       ValueRange args) {
535                     Value added = nestedBuilder.create<arith::AddFOp>(
536                         loc, args[0], args[1]);
537                     nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
538                   })
539               .getResult(0);
540       rewriter.replaceOp(op, result);
541     } else {
542       auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
543       auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
544       Value conv =
545           rewriter
546               .create<linalg::DepthwiseConv2DNhwcHwcmQOp>(
547                   loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal},
548                   ValueRange{zeroTensor}, strideAttr, dilationAttr)
549               .getResult(0);
550       SmallVector<ReassociationExprs, 4> reassociationMap;
551       createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
552       Value convReshape = rewriter.create<tensor::CollapseShapeOp>(
553           loc, resultTy, conv, reassociationMap);
554       Value result = linalgIntBroadcastExtSIAdd(
555           rewriter, loc, bias, convReshape, biasEmptyTensor, indexingMaps);
556       rewriter.replaceOp(op, result);
557     }
558     return success();
559   }
560 };
561 
562 class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
563 public:
564   using OpConversionPattern<tosa::MatMulOp>::OpConversionPattern;
565   LogicalResult
566   matchAndRewrite(tosa::MatMulOp op, OpAdaptor adaptor,
567                   ConversionPatternRewriter &rewriter) const final {
568     Location loc = op.getLoc();
569 
570     auto outputTy = cast<ShapedType>(op.getType());
571     auto outputElementTy = outputTy.getElementType();
572 
573     SmallVector<Value> dynDims;
574     dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
575 
576     if (!outputTy.hasRank() || outputTy.isDynamicDim(0)) {
577       dynDims[0] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 0);
578     }
579 
580     if (!outputTy.hasRank() || outputTy.isDynamicDim(1)) {
581       dynDims[1] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 1);
582     }
583 
584     if (!outputTy.hasRank() || outputTy.isDynamicDim(2)) {
585       dynDims[2] = rewriter.create<tensor::DimOp>(loc, op->getOperand(1), 2);
586     }
587 
588     SmallVector<Value> filteredDims = condenseValues(dynDims);
589 
590     auto zeroAttr = rewriter.getZeroAttr(outputElementTy);
591     Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
592     auto emptyTensor = rewriter.create<tensor::EmptyOp>(
593         loc, outputTy.getShape(), outputTy.getElementType(), filteredDims);
594     Value zeroTensor = rewriter
595                            .create<linalg::FillOp>(loc, ValueRange{zero},
596                                                    ValueRange{emptyTensor})
597                            .result();
598     if (!op.getQuantizationInfo()) {
599       rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
600           op, TypeRange{op.getType()},
601           ValueRange{adaptor.getA(), adaptor.getB()}, ValueRange{zeroTensor});
602       return success();
603     }
604 
605     auto quantizationInfo = *op.getQuantizationInfo();
606     auto aZp = rewriter.create<arith::ConstantOp>(
607         loc, rewriter.getI32IntegerAttr(quantizationInfo.getAZp()));
608     auto bZp = rewriter.create<arith::ConstantOp>(
609         loc, rewriter.getI32IntegerAttr(quantizationInfo.getBZp()));
610     rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
611         op, TypeRange{op.getType()},
612         ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor);
613 
614     return success();
615   }
616 };
617 
618 class FullyConnectedConverter
619     : public OpConversionPattern<tosa::FullyConnectedOp> {
620 public:
621   using OpConversionPattern<tosa::FullyConnectedOp>::OpConversionPattern;
622   LogicalResult
623   matchAndRewrite(tosa::FullyConnectedOp op, OpAdaptor adaptor,
624                   ConversionPatternRewriter &rewriter) const final {
625     Location loc = op.getLoc();
626     auto outputTy = cast<ShapedType>(op.getType());
627     auto input = op.getInput();
628     auto inputTy = cast<ShapedType>(input.getType());
629 
630     auto bias = op.getBias();
631 
632     auto weight = op.getWeight();
633     auto weightTy = cast<ShapedType>(weight.getType());
634     auto weightShape = weightTy.getShape();
635 
636     auto outputETy = outputTy.getElementType();
637 
638     SmallVector<Value> dynDims;
639     dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
640 
641     if (!inputTy.hasRank() || inputTy.isDynamicDim(0)) {
642       dynDims[0] = rewriter.create<tensor::DimOp>(loc, input, 0);
643     }
644 
645     if (!weightTy.hasRank() || weightTy.isDynamicDim(0)) {
646       dynDims[1] = rewriter.create<tensor::DimOp>(loc, weight, 0);
647     }
648 
649     SmallVector<Value> filteredDims = condenseValues(dynDims);
650 
651     SmallVector<int64_t> permutation = {1, 0};
652     auto permutationAttr = rewriter.getI64TensorAttr(permutation);
653     Value permutationValue =
654         rewriter.create<arith::ConstantOp>(loc, permutationAttr);
655 
656     SmallVector<int64_t> newWeightShape = {weightShape[1], weightShape[0]};
657     Type newWeightTy =
658         RankedTensorType::get(newWeightShape, weightTy.getElementType());
659 
660     Value transposedWeight = rewriter.create<tosa::TransposeOp>(
661         loc, newWeightTy, weight, permutationValue);
662 
663     Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
664         loc, outputTy.getShape(), outputETy, filteredDims);
665 
666     Value broadcastBias =
667         linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
668 
669     if (!op.getQuantizationInfo()) {
670       Value matmul = rewriter
671                          .create<linalg::MatmulOp>(
672                              loc, TypeRange{op.getType()},
673                              ValueRange{input, transposedWeight}, broadcastBias)
674                          ->getResult(0);
675 
676       rewriter.replaceOp(op, matmul);
677       return success();
678     }
679 
680     auto quantizationInfo = *op.getQuantizationInfo();
681     auto inputZp = rewriter.create<arith::ConstantOp>(
682         loc, rewriter.getI32IntegerAttr(quantizationInfo.getInputZp()));
683     auto outputZp = rewriter.create<arith::ConstantOp>(
684         loc, rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp()));
685     Value matmul =
686         rewriter
687             .create<linalg::QuantizedMatmulOp>(
688                 loc, TypeRange{op.getType()},
689                 ValueRange{input, transposedWeight, inputZp, outputZp},
690                 broadcastBias)
691             ->getResult(0);
692 
693     rewriter.replaceOp(op, matmul);
694     return success();
695   }
696 };
697 
698 class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
699 public:
700   using OpConversionPattern::OpConversionPattern;
701 
702   // Compute the dynamic output sizes of the maxpool operation.
703   static SmallVector<Value>
704   computeDynamicOutputSizes(tosa::MaxPool2dOp op, OpAdaptor adaptor,
705                             ConversionPatternRewriter &rewriter) {
706     TensorType resultTy = op.getType();
707     Location loc = op.getLoc();
708 
709     Value input = adaptor.getInput();
710     ArrayRef<int64_t> kernel = op.getKernel();
711     ArrayRef<int64_t> pad = op.getPad();
712     ArrayRef<int64_t> stride = op.getStride();
713 
714     SmallVector<Value> dynamicDims;
715 
716     // Batch dimension
717     if (resultTy.isDynamicDim(0))
718       dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
719 
720     // Height/width dimensions
721     for (int64_t dim : {1, 2}) {
722       if (!resultTy.isDynamicDim(dim))
723         continue;
724 
725       // Index into the attribute arrays
726       int64_t index = dim - 1;
727 
728       // Input height/width
729       Value ihw = rewriter.create<tensor::DimOp>(loc, input, dim);
730 
731       // Kernel height/width
732       Value khw = rewriter.create<arith::ConstantIndexOp>(loc, kernel[index]);
733 
734       // Output height/width
735       Value ohw = getConvOrPoolOutputDim(loc, ihw, pad[index * 2],
736                                          pad[index * 2 + 1], khw, stride[index],
737                                          /*dilationAttr=*/1, rewriter);
738       dynamicDims.push_back(ohw);
739     }
740 
741     // Channel dimension
742     if (resultTy.isDynamicDim(3))
743       dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 3));
744 
745     return dynamicDims;
746   }
747 
748   LogicalResult
749   matchAndRewrite(tosa::MaxPool2dOp op, OpAdaptor adaptor,
750                   ConversionPatternRewriter &rewriter) const final {
751     Location loc = op.getLoc();
752     Value input = adaptor.getInput();
753     ShapedType inputTy = cast<ShapedType>(input.getType());
754 
755     bool isUnsigned = op.getType().getElementType().isUnsignedInteger();
756     ShapedType resultTy =
757         cast<ShapedType>(getTypeConverter()->convertType(op.getType()));
758     if (!resultTy)
759       return rewriter.notifyMatchFailure(op, "failed to convert type");
760     Type resultETy = inputTy.getElementType();
761 
762     SmallVector<Value> dynamicDims =
763         computeDynamicOutputSizes(op, adaptor, rewriter);
764 
765     // Determine what the initial value needs to be for the max pool op.
766     TypedAttr initialAttr;
767     if (resultETy.isF32() || resultETy.isBF16() || resultETy.isF16())
768       initialAttr = rewriter.getFloatAttr(
769           resultETy, APFloat::getLargest(
770                          cast<FloatType>(resultETy).getFloatSemantics(), true));
771 
772     else if (isUnsigned)
773       initialAttr = rewriter.getIntegerAttr(
774           resultETy, APInt::getZero(resultETy.getIntOrFloatBitWidth()));
775     else if (isa<IntegerType>(resultETy))
776       initialAttr = rewriter.getIntegerAttr(
777           resultETy,
778           APInt::getSignedMinValue(resultETy.getIntOrFloatBitWidth()));
779 
780     if (!initialAttr)
781       return rewriter.notifyMatchFailure(
782           op, "Unsupported initial value for tosa.maxpool_2d op");
783 
784     // Apply padding as necessary.
785     llvm::SmallVector<int64_t> pad;
786     pad.resize(2, 0);
787     llvm::append_range(pad, op.getPad());
788     pad.resize(pad.size() + 2, 0);
789 
790     Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter);
791 
792     Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr);
793 
794     ArrayRef<int64_t> kernel = op.getKernel();
795     ArrayRef<int64_t> stride = op.getStride();
796 
797     Attribute strideAttr = rewriter.getI64VectorAttr(stride);
798     Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
799 
800     // Create the linalg op that performs pooling.
801     Value emptyTensor = rewriter.create<tensor::EmptyOp>(
802         loc, resultTy.getShape(), resultTy.getElementType(), dynamicDims);
803 
804     Value filledEmptyTensor =
805         rewriter.create<linalg::FillOp>(loc, initialValue, emptyTensor)
806             .result();
807 
808     Value fakeWindowDims =
809         rewriter.create<tensor::EmptyOp>(loc, kernel, resultETy);
810 
811     if (isUnsigned) {
812       rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxUnsignedOp>(
813           op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
814           filledEmptyTensor, strideAttr, dilationAttr);
815     } else {
816       rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
817           op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
818           filledEmptyTensor, strideAttr, dilationAttr);
819     }
820     return success();
821   }
822 };
823 
824 class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
825 public:
826   using OpRewritePattern<tosa::AvgPool2dOp>::OpRewritePattern;
827 
828   LogicalResult matchAndRewrite(tosa::AvgPool2dOp op,
829                                 PatternRewriter &rewriter) const final {
830     Location loc = op.getLoc();
831     Value input = op.getInput();
832     ShapedType inputTy = cast<ShapedType>(input.getType());
833     Type inElementTy = inputTy.getElementType();
834 
835     ShapedType resultTy = cast<ShapedType>(op.getType());
836     Type resultETy = cast<ShapedType>(op.getType()).getElementType();
837 
838     Type accETy = op.getAccType();
839     ShapedType accTy = resultTy.clone(accETy);
840 
841     auto dynamicDimsOr =
842         checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
843     if (!dynamicDimsOr.has_value())
844       return failure();
845     SmallVector<Value> dynamicDims = *dynamicDimsOr;
846 
847     // Apply padding as necessary.
848     llvm::SmallVector<int64_t> pad;
849     pad.resize(2, 0);
850     llvm::append_range(pad, op.getPad());
851     pad.resize(pad.size() + 2, 0);
852     TypedAttr padAttr = rewriter.getZeroAttr(inElementTy);
853     // Unsupported element type
854     if (!padAttr)
855       return failure();
856     Value paddedInput = applyPad(loc, input, pad, padAttr, rewriter);
857 
858     auto initialAttr = rewriter.getZeroAttr(accETy);
859     Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr);
860 
861     ArrayRef<int64_t> kernel = op.getKernel();
862     ArrayRef<int64_t> stride = op.getStride();
863 
864     Attribute strideAttr = rewriter.getI64VectorAttr(stride);
865     Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
866 
867     // Create the linalg op that performs pooling.
868     Value poolEmptyTensor = rewriter.create<tensor::EmptyOp>(
869         loc, accTy.getShape(), accETy, dynamicDims);
870 
871     Value filledEmptyTensor =
872         rewriter
873             .create<linalg::FillOp>(loc, ValueRange{initialValue},
874                                     ValueRange{poolEmptyTensor})
875             .result();
876 
877     Value fakeWindowDims =
878         rewriter.create<tensor::EmptyOp>(loc, kernel, accETy);
879 
880     // Sum across the pooled region.
881     Value poolingOp = rewriter
882                           .create<linalg::PoolingNhwcSumOp>(
883                               loc, ArrayRef<Type>{accTy},
884                               ValueRange{paddedInput, fakeWindowDims},
885                               filledEmptyTensor, strideAttr, dilationAttr)
886                           .getResult(0);
887 
888     // Normalize the summed value by the number of elements grouped in each
889     // pool.
890     Value iH = rewriter.create<tensor::DimOp>(loc, poolingOp, 1);
891     Value iW = rewriter.create<tensor::DimOp>(loc, poolingOp, 2);
892 
893     auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
894     iH = rewriter.create<arith::SubIOp>(loc, iH, one);
895     iW = rewriter.create<arith::SubIOp>(loc, iW, one);
896 
897     Value genericEmptyTensor = rewriter.create<tensor::EmptyOp>(
898         loc, resultTy.getShape(), resultETy, dynamicDims);
899 
900     auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
901     auto genericOp = rewriter.create<linalg::GenericOp>(
902         loc, ArrayRef<Type>({resultTy}), ValueRange{poolingOp},
903         ValueRange{genericEmptyTensor},
904         ArrayRef<AffineMap>({affineMap, affineMap}),
905         getNParallelLoopsAttrs(resultTy.getRank()),
906         [&](OpBuilder &b, Location loc, ValueRange args) {
907           auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
908 
909           // Determines what the portion of valid input is covered by the
910           // kernel.
911           auto padFn = [&](Value valid, Value pos, int64_t pad) -> Value {
912             if (pad == 0)
913               return valid;
914 
915             auto padVal = rewriter.create<arith::ConstantIndexOp>(loc, pad);
916             Value dpos = rewriter.create<arith::SubIOp>(loc, pos, padVal);
917 
918             Value offset = rewriter.create<arith::MinSIOp>(loc, dpos, zero);
919             return rewriter.create<arith::AddIOp>(loc, valid, offset)
920                 ->getResult(0);
921           };
922 
923           auto coverageFn = [&](int64_t i, Value isize) -> Value {
924             Value strideVal =
925                 rewriter.create<arith::ConstantIndexOp>(loc, stride[i - 1]);
926             Value val =
927                 rewriter.create<arith::ConstantIndexOp>(loc, kernel[i - 1]);
928 
929             // Find the position relative to the input tensor's ends.
930             Value left = rewriter.create<linalg::IndexOp>(loc, i);
931             Value right = rewriter.create<arith::SubIOp>(loc, isize, left);
932             left = rewriter.create<arith::MulIOp>(loc, left, strideVal);
933             right = rewriter.create<arith::MulIOp>(loc, right, strideVal);
934 
935             // Determine how much padding was included.
936             val = padFn(val, left, pad[i * 2]);
937             val = padFn(val, right, pad[i * 2 + 1]);
938             return rewriter.create<arith::MaxSIOp>(loc, one, val);
939           };
940 
941           // Compute the indices from either end.
942           Value kH3 = coverageFn(1, iH);
943           Value kW3 = coverageFn(2, iW);
944 
945           // Compute the total number of elements and normalize.
946           auto count = rewriter.create<arith::IndexCastOp>(
947               loc, rewriter.getI32Type(),
948               rewriter.create<arith::MulIOp>(loc, kH3, kW3));
949 
950           // Divide by the number of summed values. For floats this is just
951           // a div however for quantized values input normalization had
952           // to be applied.
953           Value poolVal = args[0];
954           if (isa<FloatType>(accETy)) {
955             auto countF = rewriter.create<arith::SIToFPOp>(loc, accETy, count);
956             poolVal = rewriter.create<arith::DivFOp>(loc, poolVal, countF)
957                           ->getResult(0);
958             if (accETy.getIntOrFloatBitWidth() >
959                 resultETy.getIntOrFloatBitWidth())
960               poolVal =
961                   rewriter.create<arith::TruncFOp>(loc, resultETy, poolVal);
962           } else {
963 
964             // If we have quantization information we need to apply an offset
965             // for the input zp value.
966             if (op.getQuantizationInfo()) {
967               auto quantizationInfo = *op.getQuantizationInfo();
968               auto inputZp = rewriter.create<arith::ConstantOp>(
969                   loc, b.getIntegerAttr(accETy, quantizationInfo.getInputZp()));
970               Value offset =
971                   rewriter.create<arith::MulIOp>(loc, accETy, count, inputZp);
972               poolVal =
973                   rewriter.create<arith::SubIOp>(loc, accETy, poolVal, offset);
974             }
975 
976             // Compute: k = 32 - count_leading_zeros(value - 1)
977             Value one32 = rewriter.create<arith::ConstantOp>(
978                 loc, rewriter.getI32IntegerAttr(1));
979             Value thirtyTwo32 = rewriter.create<arith::ConstantOp>(
980                 loc, rewriter.getI32IntegerAttr(32));
981 
982             Value countSubOne =
983                 rewriter.create<arith::SubIOp>(loc, count, one32);
984             Value leadingZeros =
985                 rewriter.create<math::CountLeadingZerosOp>(loc, countSubOne);
986             Value k =
987                 rewriter.create<arith::SubIOp>(loc, thirtyTwo32, leadingZeros);
988 
989             // Compute: numerator = ((1 << 30) + 1) << k
990             Value k64 =
991                 rewriter.create<arith::ExtUIOp>(loc, rewriter.getI64Type(), k);
992             Value thirtyShiftPlusOne = rewriter.create<arith::ConstantOp>(
993                 loc, rewriter.getI64IntegerAttr((1 << 30) + 1));
994             Value numerator =
995                 rewriter.create<arith::ShLIOp>(loc, thirtyShiftPlusOne, k64);
996 
997             // Compute: scale.multiplier = numerator / value;
998             Value count64 = rewriter.create<arith::ExtUIOp>(
999                 loc, rewriter.getI64Type(), count);
1000             Value multiplier =
1001                 rewriter.create<arith::DivUIOp>(loc, numerator, count64);
1002             multiplier = rewriter.create<arith::TruncIOp>(
1003                 loc, rewriter.getI32Type(), multiplier);
1004 
1005             // Compute: scale.shift = 30 + k
1006             Value k8 =
1007                 rewriter.create<arith::TruncIOp>(loc, rewriter.getI8Type(), k);
1008             Value thirty8 = rewriter.create<arith::ConstantOp>(
1009                 loc, rewriter.getI8IntegerAttr(30));
1010             Value shift = rewriter.create<arith::AddIOp>(loc, k8, thirty8);
1011 
1012             auto scaled =
1013                 rewriter
1014                     .create<tosa::ApplyScaleOp>(loc, rewriter.getI32Type(),
1015                                                 poolVal, multiplier, shift,
1016                                                 rewriter.getBoolAttr(false))
1017                     .getResult();
1018 
1019             // If we have quantization information we need to apply output
1020             // zeropoint.
1021             if (op.getQuantizationInfo()) {
1022               auto quantizationInfo = *op.getQuantizationInfo();
1023               auto outputZp = rewriter.create<arith::ConstantOp>(
1024                   loc, b.getIntegerAttr(scaled.getType(),
1025                                         quantizationInfo.getOutputZp()));
1026               scaled = rewriter.create<arith::AddIOp>(loc, scaled, outputZp)
1027                            .getResult();
1028             }
1029 
1030             // Apply Clip.
1031             int64_t outBitwidth = resultETy.getIntOrFloatBitWidth();
1032 
1033             auto min = rewriter.create<arith::ConstantIntOp>(
1034                 loc, APInt::getSignedMinValue(outBitwidth).getSExtValue(),
1035                 accETy);
1036             auto max = rewriter.create<arith::ConstantIntOp>(
1037                 loc, APInt::getSignedMaxValue(outBitwidth).getSExtValue(),
1038                 accETy);
1039             auto clamp = clampIntHelper(loc, scaled, min, max, rewriter,
1040                                         /*isUnsigned=*/false);
1041 
1042             poolVal = clamp;
1043             // Convert type.
1044             if (resultETy != clamp.getType()) {
1045               poolVal =
1046                   rewriter.create<arith::TruncIOp>(loc, resultETy, poolVal);
1047             }
1048           }
1049 
1050           rewriter.create<linalg::YieldOp>(loc, poolVal);
1051         });
1052 
1053     rewriter.replaceOp(op, genericOp.getResult(0));
1054     return success();
1055   }
1056 };
1057 
1058 class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
1059 public:
1060   using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
1061 
1062   LogicalResult matchAndRewrite(tosa::TransposeOp op,
1063                                 PatternRewriter &rewriter) const final {
1064     SmallVector<int32_t> constantPerms;
1065     if (failed(op.getConstantPerms(constantPerms)))
1066       return failure();
1067 
1068     Location loc = op.getLoc();
1069     // The verifier should have made sure we have a valid TOSA permutation
1070     // tensor. isPermutationVector doesn't actually check the TOSA perms we
1071     // expect.
1072     SmallVector<OpFoldResult> inputSizes =
1073         tensor::getMixedSizes(rewriter, loc, op.getInput1());
1074     auto permutedSizes =
1075         applyTOSAPermutation<OpFoldResult>(inputSizes, constantPerms);
1076 
1077     auto permutedInit = rewriter.create<tensor::EmptyOp>(
1078         loc, permutedSizes, op.getInput1().getType().getElementType());
1079     rewriter.replaceOpWithNewOp<linalg::TransposeOp>(
1080         op, op.getInput1(), permutedInit,
1081         llvm::to_vector(llvm::map_range(
1082             constantPerms, [](int32_t v) -> int64_t { return v; })));
1083     return success();
1084   }
1085 };
1086 } // namespace
1087 
1088 void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
1089     const TypeConverter &converter, RewritePatternSet *patterns,
1090     const TosaToLinalgNamedOptions &options) {
1091   if (options.preferConv2DKernelLayoutHWCF) {
1092     patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp,
1093                                 linalg::Conv2DNhwcHwcfQOp>>(
1094         patterns->getContext());
1095   } else {
1096     patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp,
1097                                 linalg::Conv2DNhwcFhwcQOp>>(
1098         patterns->getContext());
1099   }
1100   patterns->add<
1101       // clang-format off
1102       ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
1103       DepthwiseConvConverter,
1104       MatMulConverter,
1105       AvgPool2dConverter,
1106       FullyConnectedConverter,
1107       TransposeConverter
1108   >(patterns->getContext());
1109 
1110   patterns->add<
1111       MaxPool2dConverter
1112     >(converter, patterns->getContext());
1113   // clang-format on
1114 }
1115