1 //===- TosaToLinalgNamed.cpp - Lowering Tosa to Linalg Named Ops ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // These rewriters lower from the Tosa to the Linalg named ops. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" 14 #include "mlir/Dialect/Arith/IR/Arith.h" 15 #include "mlir/Dialect/Linalg/IR/Linalg.h" 16 #include "mlir/Dialect/Math/IR/Math.h" 17 #include "mlir/Dialect/SCF/IR/SCF.h" 18 #include "mlir/Dialect/Tensor/IR/Tensor.h" 19 #include "mlir/Dialect/Tensor/Utils/Utils.h" 20 #include "mlir/Dialect/Tosa/IR/TosaOps.h" 21 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" 22 #include "mlir/Dialect/Utils/IndexingUtils.h" 23 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" 24 #include "mlir/IR/Matchers.h" 25 #include "mlir/IR/PatternMatch.h" 26 #include "mlir/Transforms/DialectConversion.h" 27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 28 29 #include "mlir/Interfaces/InferTypeOpInterface.h" 30 31 #include <numeric> 32 #include <type_traits> 33 34 using namespace mlir; 35 using namespace mlir::tosa; 36 37 static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad, 38 TypedAttr padAttr, OpBuilder &rewriter) { 39 // Input should be padded only if necessary. 40 if (llvm::all_of(pad, [](int64_t p) { return p == 0; })) 41 return input; 42 43 ShapedType inputTy = cast<ShapedType>(input.getType()); 44 Type inputETy = inputTy.getElementType(); 45 auto inputShape = inputTy.getShape(); 46 47 assert((inputShape.size() * 2) == pad.size()); 48 49 SmallVector<int64_t, 4> paddedShape; 50 SmallVector<OpFoldResult, 8> lowIndices; 51 SmallVector<OpFoldResult, 8> highIndices; 52 for (size_t i : llvm::seq(inputShape.size())) { 53 auto lowPad = pad[i * 2]; 54 auto highPad = pad[i * 2 + 1]; 55 if (ShapedType::isDynamic(inputShape[i])) 56 paddedShape.push_back(inputShape[i]); 57 else 58 paddedShape.push_back(inputShape[i] + highPad + lowPad); 59 lowIndices.push_back(rewriter.getIndexAttr(lowPad)); 60 highIndices.push_back(rewriter.getIndexAttr(highPad)); 61 } 62 63 Value padValue = rewriter.create<arith::ConstantOp>(loc, padAttr); 64 65 return rewriter.create<tensor::PadOp>( 66 loc, RankedTensorType::get(paddedShape, inputETy), input, lowIndices, 67 highIndices, padValue); 68 } 69 70 static mlir::Value 71 linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias, 72 Value conv, Value result, 73 ArrayRef<AffineMap> indexingMaps) { 74 ShapedType resultTy = cast<ShapedType>(conv.getType()); 75 return rewriter 76 .create<linalg::GenericOp>( 77 loc, resultTy, ValueRange({bias, conv}), result, indexingMaps, 78 getNParallelLoopsAttrs(resultTy.getRank()), 79 [](OpBuilder &builder, Location loc, ValueRange args) { 80 Value biasVal = args[0]; 81 Type resType = args[1].getType(); 82 if (resType != biasVal.getType()) { 83 biasVal = builder.create<arith::ExtSIOp>(loc, resType, biasVal); 84 } 85 Value added = builder.create<arith::AddIOp>(loc, biasVal, args[1]); 86 builder.create<linalg::YieldOp>(loc, added); 87 }) 88 .getResult(0); 89 } 90 91 // Construct the affine map that a linalg generic would use to broadcast the 92 // source tensor into the shape of the result tensor. 93 static AffineMap getBroadcastingMap(PatternRewriter &rewriter, Value source, 94 Value result) { 95 ShapedType resultTy = cast<ShapedType>(result.getType()); 96 ShapedType sourceTy = cast<ShapedType>(source.getType()); 97 const int64_t resultRank = resultTy.getRank(); 98 const int64_t sourceRank = sourceTy.getRank(); 99 100 // The source tensor is broadcast to all the outer dimensions of the 101 // result tensor. 102 SmallVector<AffineExpr> sourceDims; 103 // In the case of a rank one source tensor with a single element TOSA 104 // specifies that the value be broadcast meaning we need an edge case for a 105 // constant map. 106 assert(sourceTy.hasStaticShape() && 107 "Dynamic broadcasting shapes not supported!"); 108 if (sourceRank == 1 && sourceTy.getDimSize(0) == 1) { 109 sourceDims.push_back(rewriter.getAffineConstantExpr(0)); 110 } else { 111 for (auto dim : llvm::seq<int64_t>(0, sourceRank)) { 112 auto expr = rewriter.getAffineDimExpr(dim + resultRank - sourceRank); 113 sourceDims.push_back(expr); 114 } 115 } 116 117 return AffineMap::get(/*dimCount=*/resultRank, 118 /*symbolCount=*/0, sourceDims, rewriter.getContext()); 119 } 120 121 // Broadcast the source value to all the outer dimensions of the result value. 122 // If required, the element type is expanded using an arith.extsi operation. 123 static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter, 124 Location loc, Value source, 125 Value result) { 126 ShapedType resultTy = cast<ShapedType>(result.getType()); 127 const int64_t resultRank = resultTy.getRank(); 128 // Creating maps for the input and output of the broacast-like generic op. 129 SmallVector<AffineMap, 2> indexingMaps; 130 indexingMaps.push_back(getBroadcastingMap(rewriter, source, result)); 131 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank)); 132 133 // Build the broadcast-like operation as a linalg.generic. 134 return rewriter 135 .create<linalg::GenericOp>( 136 loc, resultTy, ValueRange({source}), result, indexingMaps, 137 getNParallelLoopsAttrs(resultTy.getRank()), 138 [](OpBuilder &builder, Location loc, ValueRange args) { 139 Value biasVal = args[0]; 140 Type resType = args[1].getType(); 141 if (resType != biasVal.getType()) { 142 biasVal = builder.create<arith::ExtSIOp>(loc, resType, biasVal); 143 } 144 builder.create<linalg::YieldOp>(loc, biasVal); 145 }) 146 .getResult(0); 147 } 148 149 static mlir::Value reifyConstantDim(int64_t attr, 150 ImplicitLocOpBuilder &builder) { 151 return builder.create<arith::ConstantIndexOp>(attr); 152 } 153 154 // Calculating the output width/height using the formula: 155 // H = ((IH+pad_top+pad_bottom-(dilation_y*(KH-1)+1))/stride_y)+1 156 // W = ((IW+pad_left+pad_right-(dilation_x*(KW-1)+1))/stride_x)+1 157 158 static mlir::Value getConvOrPoolOutputDim(Location loc, Value inputDim, 159 int64_t padBeforeAttr, 160 int64_t padAfterAttr, Value kernelDim, 161 int64_t strideAttr, 162 int64_t dilationAttr, 163 OpBuilder &rewriter) { 164 ImplicitLocOpBuilder builder(loc, rewriter); 165 auto one = rewriter.create<arith::ConstantOp>( 166 loc, IntegerAttr::get(inputDim.getType(), 1)); 167 Value padBefore = reifyConstantDim(padBeforeAttr, builder); 168 Value paddedBefore = builder.create<arith::AddIOp>(inputDim, padBefore); 169 Value padAfter = reifyConstantDim(padAfterAttr, builder); 170 Value paddedAfter = builder.create<arith::AddIOp>(paddedBefore, padAfter); 171 172 Value subOne = builder.create<arith::SubIOp>(kernelDim, one); 173 Value dilation = reifyConstantDim(dilationAttr, builder); 174 Value dilated = builder.create<arith::MulIOp>(dilation, subOne); 175 Value addOne = builder.create<arith::AddIOp>(dilated, one); 176 177 Value subtract = builder.create<arith::SubIOp>(paddedAfter, addOne); 178 Value stride = reifyConstantDim(strideAttr, builder); 179 Value divide = builder.create<arith::DivUIOp>(subtract, stride); 180 return builder.create<arith::AddIOp>(divide, one); 181 } 182 183 // Creates a vector of the dynamic output dims for Conv2D and Depthwise_Conv2D 184 static SmallVector<Value> inferDynamicDimsForConv( 185 Location loc, Value input, Value weight, ShapedType resultTy, 186 ArrayRef<int64_t> padAttr, ArrayRef<int64_t> strideAttr, 187 ArrayRef<int64_t> dilationAttr, ArrayRef<int64_t> inputSizeDims, 188 ArrayRef<int64_t> kernelSizeDims, OpBuilder &rewriter) { 189 ShapedType inputTy = cast<ShapedType>(input.getType()); 190 int64_t inputRank = inputTy.getRank(); 191 192 SmallVector<Value> dynDims; 193 dynDims.resize(resultTy.getRank()); 194 195 for (uint32_t i = 0, s = inputSizeDims.size(); i < s; ++i) { 196 int64_t inputDim = inputSizeDims[i]; 197 int64_t kernelDim = kernelSizeDims[i]; 198 if (resultTy.isDynamicDim(inputDim)) { 199 auto padTop = padAttr[i * 2]; 200 auto padBottom = padAttr[i * 2 + 1]; 201 auto stride = strideAttr[i]; 202 auto dilation = dilationAttr[i]; 203 Value initDynDim = rewriter.create<tensor::DimOp>(loc, input, inputDim); 204 Value kernelDynDim = 205 rewriter.create<tensor::DimOp>(loc, weight, kernelDim); 206 // H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y) 207 dynDims[inputDim] = 208 getConvOrPoolOutputDim(loc, initDynDim, padTop, padBottom, 209 kernelDynDim, stride, dilation, rewriter); 210 } 211 } 212 213 // Get the batch/channels dimensions. 214 for (int i = 0; i < inputRank; i++) { 215 if (resultTy.isDynamicDim(i) && !dynDims[i]) 216 dynDims[i] = rewriter.create<tensor::DimOp>(loc, input, i); 217 } 218 219 SmallVector<Value> filteredDims = condenseValues(dynDims); 220 return filteredDims; 221 } 222 223 // Creates a map to collapse the last dimension of the Depthwise convolution op 224 // due to a shape mismatch 225 static void createDepthwiseConvCollapseMap( 226 int64_t outputRank, SmallVector<ReassociationExprs, 4> &reassociationMap, 227 OpBuilder &rewriter) { 228 reassociationMap.resize(outputRank); 229 for (int i = 0; i < outputRank; i++) { 230 reassociationMap[i].push_back(rewriter.getAffineDimExpr(i)); 231 } 232 reassociationMap[outputRank - 1].push_back( 233 rewriter.getAffineDimExpr(outputRank)); 234 } 235 236 namespace { 237 238 template <typename TosaConvOp, typename LinalgConvOp, typename LinalgConvQOp> 239 class ConvConverter : public OpConversionPattern<TosaConvOp> { 240 public: 241 using OpConversionPattern<TosaConvOp>::OpConversionPattern; 242 LogicalResult 243 matchAndRewrite(TosaConvOp op, typename TosaConvOp::Adaptor adaptor, 244 ConversionPatternRewriter &rewriter) const final { 245 Location loc = op->getLoc(); 246 Value input = op->getOperand(0); 247 Value weight = op->getOperand(1); 248 Value bias = op->getOperand(2); 249 250 ShapedType inputTy = cast<ShapedType>(input.getType()); 251 ShapedType weightTy = cast<ShapedType>(weight.getType()); 252 ShapedType biasTy = cast<ShapedType>(bias.getType()); 253 ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType()); 254 255 Type inputETy = inputTy.getElementType(); 256 Type resultETy = resultTy.getElementType(); 257 258 DenseI64ArrayAttr padAttr = op.getPadAttr(); 259 DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr(); 260 DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr(); 261 bool isQuantized = op.getQuantizationInfo().has_value(); 262 263 if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape()) 264 return rewriter.notifyMatchFailure( 265 op, "tosa.conv ops require static shapes for weight and bias"); 266 267 if (inputETy.isUnsignedInteger()) 268 return rewriter.notifyMatchFailure( 269 op, "tosa.conv ops does not support unsigned integer input"); 270 271 llvm::SmallVector<int64_t> inputSizeDims; 272 llvm::SmallVector<int64_t> kernelSizeDims; 273 for (int i = 1; i < resultTy.getRank() - 1; i++) { 274 inputSizeDims.push_back(i); 275 kernelSizeDims.push_back(i); 276 } 277 278 SmallVector<Value> filteredDims = inferDynamicDimsForConv( 279 loc, input, weight, resultTy, padAttr.asArrayRef(), 280 strideTosaAttr.asArrayRef(), dilationTosaAttr.asArrayRef(), 281 inputSizeDims, kernelSizeDims, rewriter); 282 283 auto weightShape = weightTy.getShape(); 284 285 // Apply padding as necessary. 286 TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy); 287 if (isQuantized) { 288 auto quantizationInfo = *op.getQuantizationInfo(); 289 int64_t iZp = quantizationInfo.getInputZp(); 290 291 int64_t intMin = 292 APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth()) 293 .getSExtValue(); 294 int64_t intMax = 295 APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth()) 296 .getSExtValue(); 297 298 if (iZp < intMin || iZp > intMax) 299 return rewriter.notifyMatchFailure( 300 op, "tosa.conv op quantization has zp outside of input range"); 301 302 zeroAttr = rewriter.getIntegerAttr(inputETy, iZp); 303 } 304 305 llvm::SmallVector<int64_t> pad; 306 pad.resize(2, 0); 307 llvm::append_range(pad, padAttr.asArrayRef()); 308 pad.resize(pad.size() + 2, 0); 309 input = applyPad(loc, input, pad, zeroAttr, rewriter); 310 311 if (4 == inputTy.getRank()) { 312 // For 2D convolutions, we need to check if the target convolution op 313 // wants a HWCF kernel layout. 314 bool wantHwcf = 315 isQuantized ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp> 316 : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>; 317 if (wantHwcf) { 318 // Transpose the kernel to match dimension ordering of the linalg 319 // convolution operation. 320 // TODO(suderman): See if this can be efficiently folded - check whether 321 // the input is used anywhere else, if not fold the constant. 322 SmallVector<int32_t> weightPerm; 323 for (int i = 1; i < resultTy.getRank(); i++) 324 weightPerm.push_back(i); 325 weightPerm.push_back(0); 326 327 SmallVector<int64_t> newWeightShape; 328 for (auto dim : weightPerm) 329 newWeightShape.push_back(weightShape[dim]); 330 auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm); 331 Value weightPermValue = 332 rewriter.create<arith::ConstantOp>(loc, weightPermAttr); 333 Type newWeightTy = 334 RankedTensorType::get(newWeightShape, weightTy.getElementType()); 335 weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight, 336 weightPermValue); 337 } 338 } 339 340 // For Conv3D transpose the kernel to match dimension ordering of the linalg 341 // convolution operation. Conv2D has a 1-1 mapping in linalg so better to 342 // map directly and then transpose later if desired. 343 if (5 == inputTy.getRank()) { 344 // TODO(suderman): See if this can be efficiently folded - check whether 345 // the input is used anywhere else, if not fold the constant. 346 SmallVector<int32_t> weightPerm; 347 for (int i = 1; i < resultTy.getRank(); i++) 348 weightPerm.push_back(i); 349 weightPerm.push_back(0); 350 351 SmallVector<int64_t> newWeightShape; 352 for (auto dim : weightPerm) 353 newWeightShape.push_back(weightShape[dim]); 354 auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm); 355 Value weightPermValue = 356 rewriter.create<arith::ConstantOp>(loc, weightPermAttr); 357 Type newWeightTy = 358 RankedTensorType::get(newWeightShape, weightTy.getElementType()); 359 weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight, 360 weightPermValue); 361 } 362 363 // Extract the attributes for convolution. 364 ArrayRef<int64_t> stride = strideTosaAttr; 365 ArrayRef<int64_t> dilation = dilationTosaAttr; 366 367 // Create the convolution op. 368 auto strideAttr = rewriter.getI64TensorAttr(stride); 369 auto dilationAttr = rewriter.getI64TensorAttr(dilation); 370 371 Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>( 372 loc, resultTy.getShape(), resultETy, filteredDims); 373 374 Value broadcastBias = 375 linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor); 376 377 if (isQuantized) { 378 auto quantizationInfo = *op.getQuantizationInfo(); 379 auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp()); 380 auto kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp()); 381 382 auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp); 383 auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp); 384 385 Value conv = 386 rewriter 387 .create<LinalgConvQOp>( 388 loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal}, 389 ValueRange{broadcastBias}, strideAttr, dilationAttr) 390 ->getResult(0); 391 392 rewriter.replaceOp(op, conv); 393 return success(); 394 } 395 396 Value conv = rewriter 397 .create<LinalgConvOp>( 398 loc, resultTy, ValueRange{input, weight}, 399 ValueRange{broadcastBias}, strideAttr, dilationAttr) 400 ->getResult(0); 401 402 rewriter.replaceOp(op, conv); 403 return success(); 404 } 405 }; 406 407 class DepthwiseConvConverter 408 : public OpConversionPattern<tosa::DepthwiseConv2DOp> { 409 public: 410 using OpConversionPattern<tosa::DepthwiseConv2DOp>::OpConversionPattern; 411 LogicalResult 412 matchAndRewrite(tosa::DepthwiseConv2DOp op, OpAdaptor adaptor, 413 ConversionPatternRewriter &rewriter) const final { 414 Location loc = op->getLoc(); 415 Value input = op->getOperand(0); 416 Value weight = op->getOperand(1); 417 Value bias = op->getOperand(2); 418 419 ShapedType inputTy = cast<ShapedType>(input.getType()); 420 ShapedType weightTy = cast<ShapedType>(weight.getType()); 421 ShapedType biasTy = cast<ShapedType>(bias.getType()); 422 ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType()); 423 int64_t resultRank = resultTy.getRank(); 424 425 Type inputETy = inputTy.getElementType(); 426 Type resultETy = resultTy.getElementType(); 427 428 auto padAttr = cast<DenseI64ArrayAttr>(op->getAttr("pad")); 429 auto strideTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("stride")); 430 auto dilationTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("dilation")); 431 432 if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape()) 433 return rewriter.notifyMatchFailure( 434 op, "tosa.depthwise_conv ops require static shapes"); 435 436 // Compute output dynamic dims 437 SmallVector<Value> filteredDims = inferDynamicDimsForConv( 438 loc, input, weight, resultTy, padAttr.asArrayRef(), 439 strideTosaAttr.asArrayRef(), dilationTosaAttr.asArrayRef(), 440 /*inputSizeDims=*/{1, 2}, 441 /*kernelSizeDims=*/{0, 1}, rewriter); 442 443 bool isQuantized = op->hasAttr("quantization_info"); 444 IntegerAttr iZp; 445 IntegerAttr kZp; 446 if (isQuantized) { 447 auto quantizationInfo = 448 cast<tosa::ConvOpQuantizationAttr>(op->getAttr("quantization_info")); 449 iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp()); 450 kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp()); 451 } 452 453 auto weightShape = weightTy.getShape(); 454 auto resultShape = resultTy.getShape(); 455 456 // Apply padding as necessary. 457 TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy); 458 if (isQuantized) { 459 auto quantizationInfo = 460 cast<tosa::ConvOpQuantizationAttr>(op->getAttr("quantization_info")); 461 int64_t iZp = quantizationInfo.getInputZp(); 462 463 int64_t intMin = 464 APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth()) 465 .getSExtValue(); 466 int64_t intMax = 467 APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth()) 468 .getSExtValue(); 469 470 if (iZp < intMin || iZp > intMax) 471 return rewriter.notifyMatchFailure( 472 op, "tosa.depthwise_conv op quantization has zp outside of input " 473 "range"); 474 475 zeroAttr = rewriter.getIntegerAttr(inputETy, iZp); 476 } 477 478 llvm::SmallVector<int64_t> pad; 479 pad.resize(2, 0); 480 llvm::append_range(pad, padAttr.asArrayRef()); 481 pad.resize(pad.size() + 2, 0); 482 483 input = applyPad(loc, input, pad, zeroAttr, rewriter); 484 485 // Extract the attributes for convolution. 486 ArrayRef<int64_t> stride = strideTosaAttr; 487 ArrayRef<int64_t> dilation = dilationTosaAttr; 488 489 // Create the convolution op. 490 auto strideAttr = rewriter.getI64TensorAttr(stride); 491 auto dilationAttr = rewriter.getI64TensorAttr(dilation); 492 ShapedType linalgConvTy = 493 RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2], 494 weightShape[2], weightShape[3]}, 495 resultETy); 496 497 auto resultZeroAttr = rewriter.getZeroAttr(resultETy); 498 Value emptyTensor = rewriter.create<tensor::EmptyOp>( 499 loc, linalgConvTy.getShape(), resultETy, filteredDims); 500 Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr); 501 Value zeroTensor = rewriter 502 .create<linalg::FillOp>(loc, ValueRange{zero}, 503 ValueRange{emptyTensor}) 504 .result(); 505 506 Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>( 507 loc, resultTy.getShape(), resultETy, filteredDims); 508 509 // Broadcast the initial value to the output tensor before convolving. 510 SmallVector<AffineMap, 4> indexingMaps; 511 indexingMaps.push_back(getBroadcastingMap(rewriter, bias, biasEmptyTensor)); 512 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank)); 513 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank)); 514 515 if (!isQuantized) { 516 Value conv = rewriter 517 .create<linalg::DepthwiseConv2DNhwcHwcmOp>( 518 loc, linalgConvTy, ValueRange{input, weight}, 519 ValueRange{zeroTensor}, strideAttr, dilationAttr) 520 .getResult(0); 521 522 SmallVector<ReassociationExprs, 4> reassociationMap; 523 createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter); 524 Value convReshape = rewriter.create<tensor::CollapseShapeOp>( 525 loc, resultTy, conv, reassociationMap); 526 527 Value result = 528 rewriter 529 .create<linalg::GenericOp>( 530 loc, resultTy, ValueRange({bias, convReshape}), 531 biasEmptyTensor, indexingMaps, 532 getNParallelLoopsAttrs(resultRank), 533 [&](OpBuilder &nestedBuilder, Location nestedLoc, 534 ValueRange args) { 535 Value added = nestedBuilder.create<arith::AddFOp>( 536 loc, args[0], args[1]); 537 nestedBuilder.create<linalg::YieldOp>(nestedLoc, added); 538 }) 539 .getResult(0); 540 rewriter.replaceOp(op, result); 541 } else { 542 auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp); 543 auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp); 544 Value conv = 545 rewriter 546 .create<linalg::DepthwiseConv2DNhwcHwcmQOp>( 547 loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal}, 548 ValueRange{zeroTensor}, strideAttr, dilationAttr) 549 .getResult(0); 550 SmallVector<ReassociationExprs, 4> reassociationMap; 551 createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter); 552 Value convReshape = rewriter.create<tensor::CollapseShapeOp>( 553 loc, resultTy, conv, reassociationMap); 554 Value result = linalgIntBroadcastExtSIAdd( 555 rewriter, loc, bias, convReshape, biasEmptyTensor, indexingMaps); 556 rewriter.replaceOp(op, result); 557 } 558 return success(); 559 } 560 }; 561 562 class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> { 563 public: 564 using OpConversionPattern<tosa::MatMulOp>::OpConversionPattern; 565 LogicalResult 566 matchAndRewrite(tosa::MatMulOp op, OpAdaptor adaptor, 567 ConversionPatternRewriter &rewriter) const final { 568 Location loc = op.getLoc(); 569 570 auto outputTy = cast<ShapedType>(op.getType()); 571 auto outputElementTy = outputTy.getElementType(); 572 573 SmallVector<Value> dynDims; 574 dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank()); 575 576 if (!outputTy.hasRank() || outputTy.isDynamicDim(0)) { 577 dynDims[0] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 0); 578 } 579 580 if (!outputTy.hasRank() || outputTy.isDynamicDim(1)) { 581 dynDims[1] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 1); 582 } 583 584 if (!outputTy.hasRank() || outputTy.isDynamicDim(2)) { 585 dynDims[2] = rewriter.create<tensor::DimOp>(loc, op->getOperand(1), 2); 586 } 587 588 SmallVector<Value> filteredDims = condenseValues(dynDims); 589 590 auto zeroAttr = rewriter.getZeroAttr(outputElementTy); 591 Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr); 592 auto emptyTensor = rewriter.create<tensor::EmptyOp>( 593 loc, outputTy.getShape(), outputTy.getElementType(), filteredDims); 594 Value zeroTensor = rewriter 595 .create<linalg::FillOp>(loc, ValueRange{zero}, 596 ValueRange{emptyTensor}) 597 .result(); 598 if (!op.getQuantizationInfo()) { 599 rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>( 600 op, TypeRange{op.getType()}, 601 ValueRange{adaptor.getA(), adaptor.getB()}, ValueRange{zeroTensor}); 602 return success(); 603 } 604 605 auto quantizationInfo = *op.getQuantizationInfo(); 606 auto aZp = rewriter.create<arith::ConstantOp>( 607 loc, rewriter.getI32IntegerAttr(quantizationInfo.getAZp())); 608 auto bZp = rewriter.create<arith::ConstantOp>( 609 loc, rewriter.getI32IntegerAttr(quantizationInfo.getBZp())); 610 rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>( 611 op, TypeRange{op.getType()}, 612 ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor); 613 614 return success(); 615 } 616 }; 617 618 class FullyConnectedConverter 619 : public OpConversionPattern<tosa::FullyConnectedOp> { 620 public: 621 using OpConversionPattern<tosa::FullyConnectedOp>::OpConversionPattern; 622 LogicalResult 623 matchAndRewrite(tosa::FullyConnectedOp op, OpAdaptor adaptor, 624 ConversionPatternRewriter &rewriter) const final { 625 Location loc = op.getLoc(); 626 auto outputTy = cast<ShapedType>(op.getType()); 627 auto input = op.getInput(); 628 auto inputTy = cast<ShapedType>(input.getType()); 629 630 auto bias = op.getBias(); 631 632 auto weight = op.getWeight(); 633 auto weightTy = cast<ShapedType>(weight.getType()); 634 auto weightShape = weightTy.getShape(); 635 636 auto outputETy = outputTy.getElementType(); 637 638 SmallVector<Value> dynDims; 639 dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank()); 640 641 if (!inputTy.hasRank() || inputTy.isDynamicDim(0)) { 642 dynDims[0] = rewriter.create<tensor::DimOp>(loc, input, 0); 643 } 644 645 if (!weightTy.hasRank() || weightTy.isDynamicDim(0)) { 646 dynDims[1] = rewriter.create<tensor::DimOp>(loc, weight, 0); 647 } 648 649 SmallVector<Value> filteredDims = condenseValues(dynDims); 650 651 SmallVector<int64_t> permutation = {1, 0}; 652 auto permutationAttr = rewriter.getI64TensorAttr(permutation); 653 Value permutationValue = 654 rewriter.create<arith::ConstantOp>(loc, permutationAttr); 655 656 SmallVector<int64_t> newWeightShape = {weightShape[1], weightShape[0]}; 657 Type newWeightTy = 658 RankedTensorType::get(newWeightShape, weightTy.getElementType()); 659 660 Value transposedWeight = rewriter.create<tosa::TransposeOp>( 661 loc, newWeightTy, weight, permutationValue); 662 663 Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>( 664 loc, outputTy.getShape(), outputETy, filteredDims); 665 666 Value broadcastBias = 667 linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor); 668 669 if (!op.getQuantizationInfo()) { 670 Value matmul = rewriter 671 .create<linalg::MatmulOp>( 672 loc, TypeRange{op.getType()}, 673 ValueRange{input, transposedWeight}, broadcastBias) 674 ->getResult(0); 675 676 rewriter.replaceOp(op, matmul); 677 return success(); 678 } 679 680 auto quantizationInfo = *op.getQuantizationInfo(); 681 auto inputZp = rewriter.create<arith::ConstantOp>( 682 loc, rewriter.getI32IntegerAttr(quantizationInfo.getInputZp())); 683 auto outputZp = rewriter.create<arith::ConstantOp>( 684 loc, rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp())); 685 Value matmul = 686 rewriter 687 .create<linalg::QuantizedMatmulOp>( 688 loc, TypeRange{op.getType()}, 689 ValueRange{input, transposedWeight, inputZp, outputZp}, 690 broadcastBias) 691 ->getResult(0); 692 693 rewriter.replaceOp(op, matmul); 694 return success(); 695 } 696 }; 697 698 class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> { 699 public: 700 using OpConversionPattern::OpConversionPattern; 701 702 // Compute the dynamic output sizes of the maxpool operation. 703 static SmallVector<Value> 704 computeDynamicOutputSizes(tosa::MaxPool2dOp op, OpAdaptor adaptor, 705 ConversionPatternRewriter &rewriter) { 706 TensorType resultTy = op.getType(); 707 Location loc = op.getLoc(); 708 709 Value input = adaptor.getInput(); 710 ArrayRef<int64_t> kernel = op.getKernel(); 711 ArrayRef<int64_t> pad = op.getPad(); 712 ArrayRef<int64_t> stride = op.getStride(); 713 714 SmallVector<Value> dynamicDims; 715 716 // Batch dimension 717 if (resultTy.isDynamicDim(0)) 718 dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0)); 719 720 // Height/width dimensions 721 for (int64_t dim : {1, 2}) { 722 if (!resultTy.isDynamicDim(dim)) 723 continue; 724 725 // Index into the attribute arrays 726 int64_t index = dim - 1; 727 728 // Input height/width 729 Value ihw = rewriter.create<tensor::DimOp>(loc, input, dim); 730 731 // Kernel height/width 732 Value khw = rewriter.create<arith::ConstantIndexOp>(loc, kernel[index]); 733 734 // Output height/width 735 Value ohw = getConvOrPoolOutputDim(loc, ihw, pad[index * 2], 736 pad[index * 2 + 1], khw, stride[index], 737 /*dilationAttr=*/1, rewriter); 738 dynamicDims.push_back(ohw); 739 } 740 741 // Channel dimension 742 if (resultTy.isDynamicDim(3)) 743 dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 3)); 744 745 return dynamicDims; 746 } 747 748 LogicalResult 749 matchAndRewrite(tosa::MaxPool2dOp op, OpAdaptor adaptor, 750 ConversionPatternRewriter &rewriter) const final { 751 Location loc = op.getLoc(); 752 Value input = adaptor.getInput(); 753 ShapedType inputTy = cast<ShapedType>(input.getType()); 754 755 bool isUnsigned = op.getType().getElementType().isUnsignedInteger(); 756 ShapedType resultTy = 757 cast<ShapedType>(getTypeConverter()->convertType(op.getType())); 758 if (!resultTy) 759 return rewriter.notifyMatchFailure(op, "failed to convert type"); 760 Type resultETy = inputTy.getElementType(); 761 762 SmallVector<Value> dynamicDims = 763 computeDynamicOutputSizes(op, adaptor, rewriter); 764 765 // Determine what the initial value needs to be for the max pool op. 766 TypedAttr initialAttr; 767 if (resultETy.isF32() || resultETy.isBF16() || resultETy.isF16()) 768 initialAttr = rewriter.getFloatAttr( 769 resultETy, APFloat::getLargest( 770 cast<FloatType>(resultETy).getFloatSemantics(), true)); 771 772 else if (isUnsigned) 773 initialAttr = rewriter.getIntegerAttr( 774 resultETy, APInt::getZero(resultETy.getIntOrFloatBitWidth())); 775 else if (isa<IntegerType>(resultETy)) 776 initialAttr = rewriter.getIntegerAttr( 777 resultETy, 778 APInt::getSignedMinValue(resultETy.getIntOrFloatBitWidth())); 779 780 if (!initialAttr) 781 return rewriter.notifyMatchFailure( 782 op, "Unsupported initial value for tosa.maxpool_2d op"); 783 784 // Apply padding as necessary. 785 llvm::SmallVector<int64_t> pad; 786 pad.resize(2, 0); 787 llvm::append_range(pad, op.getPad()); 788 pad.resize(pad.size() + 2, 0); 789 790 Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter); 791 792 Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr); 793 794 ArrayRef<int64_t> kernel = op.getKernel(); 795 ArrayRef<int64_t> stride = op.getStride(); 796 797 Attribute strideAttr = rewriter.getI64VectorAttr(stride); 798 Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1}); 799 800 // Create the linalg op that performs pooling. 801 Value emptyTensor = rewriter.create<tensor::EmptyOp>( 802 loc, resultTy.getShape(), resultTy.getElementType(), dynamicDims); 803 804 Value filledEmptyTensor = 805 rewriter.create<linalg::FillOp>(loc, initialValue, emptyTensor) 806 .result(); 807 808 Value fakeWindowDims = 809 rewriter.create<tensor::EmptyOp>(loc, kernel, resultETy); 810 811 if (isUnsigned) { 812 rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxUnsignedOp>( 813 op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims}, 814 filledEmptyTensor, strideAttr, dilationAttr); 815 } else { 816 rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>( 817 op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims}, 818 filledEmptyTensor, strideAttr, dilationAttr); 819 } 820 return success(); 821 } 822 }; 823 824 class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> { 825 public: 826 using OpRewritePattern<tosa::AvgPool2dOp>::OpRewritePattern; 827 828 LogicalResult matchAndRewrite(tosa::AvgPool2dOp op, 829 PatternRewriter &rewriter) const final { 830 Location loc = op.getLoc(); 831 Value input = op.getInput(); 832 ShapedType inputTy = cast<ShapedType>(input.getType()); 833 Type inElementTy = inputTy.getElementType(); 834 835 ShapedType resultTy = cast<ShapedType>(op.getType()); 836 Type resultETy = cast<ShapedType>(op.getType()).getElementType(); 837 838 Type accETy = op.getAccType(); 839 ShapedType accTy = resultTy.clone(accETy); 840 841 auto dynamicDimsOr = 842 checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()}); 843 if (!dynamicDimsOr.has_value()) 844 return failure(); 845 SmallVector<Value> dynamicDims = *dynamicDimsOr; 846 847 // Apply padding as necessary. 848 llvm::SmallVector<int64_t> pad; 849 pad.resize(2, 0); 850 llvm::append_range(pad, op.getPad()); 851 pad.resize(pad.size() + 2, 0); 852 TypedAttr padAttr = rewriter.getZeroAttr(inElementTy); 853 // Unsupported element type 854 if (!padAttr) 855 return failure(); 856 Value paddedInput = applyPad(loc, input, pad, padAttr, rewriter); 857 858 auto initialAttr = rewriter.getZeroAttr(accETy); 859 Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr); 860 861 ArrayRef<int64_t> kernel = op.getKernel(); 862 ArrayRef<int64_t> stride = op.getStride(); 863 864 Attribute strideAttr = rewriter.getI64VectorAttr(stride); 865 Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1}); 866 867 // Create the linalg op that performs pooling. 868 Value poolEmptyTensor = rewriter.create<tensor::EmptyOp>( 869 loc, accTy.getShape(), accETy, dynamicDims); 870 871 Value filledEmptyTensor = 872 rewriter 873 .create<linalg::FillOp>(loc, ValueRange{initialValue}, 874 ValueRange{poolEmptyTensor}) 875 .result(); 876 877 Value fakeWindowDims = 878 rewriter.create<tensor::EmptyOp>(loc, kernel, accETy); 879 880 // Sum across the pooled region. 881 Value poolingOp = rewriter 882 .create<linalg::PoolingNhwcSumOp>( 883 loc, ArrayRef<Type>{accTy}, 884 ValueRange{paddedInput, fakeWindowDims}, 885 filledEmptyTensor, strideAttr, dilationAttr) 886 .getResult(0); 887 888 // Normalize the summed value by the number of elements grouped in each 889 // pool. 890 Value iH = rewriter.create<tensor::DimOp>(loc, poolingOp, 1); 891 Value iW = rewriter.create<tensor::DimOp>(loc, poolingOp, 2); 892 893 auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 894 iH = rewriter.create<arith::SubIOp>(loc, iH, one); 895 iW = rewriter.create<arith::SubIOp>(loc, iW, one); 896 897 Value genericEmptyTensor = rewriter.create<tensor::EmptyOp>( 898 loc, resultTy.getShape(), resultETy, dynamicDims); 899 900 auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank()); 901 auto genericOp = rewriter.create<linalg::GenericOp>( 902 loc, ArrayRef<Type>({resultTy}), ValueRange{poolingOp}, 903 ValueRange{genericEmptyTensor}, 904 ArrayRef<AffineMap>({affineMap, affineMap}), 905 getNParallelLoopsAttrs(resultTy.getRank()), 906 [&](OpBuilder &b, Location loc, ValueRange args) { 907 auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 908 909 // Determines what the portion of valid input is covered by the 910 // kernel. 911 auto padFn = [&](Value valid, Value pos, int64_t pad) -> Value { 912 if (pad == 0) 913 return valid; 914 915 auto padVal = rewriter.create<arith::ConstantIndexOp>(loc, pad); 916 Value dpos = rewriter.create<arith::SubIOp>(loc, pos, padVal); 917 918 Value offset = rewriter.create<arith::MinSIOp>(loc, dpos, zero); 919 return rewriter.create<arith::AddIOp>(loc, valid, offset) 920 ->getResult(0); 921 }; 922 923 auto coverageFn = [&](int64_t i, Value isize) -> Value { 924 Value strideVal = 925 rewriter.create<arith::ConstantIndexOp>(loc, stride[i - 1]); 926 Value val = 927 rewriter.create<arith::ConstantIndexOp>(loc, kernel[i - 1]); 928 929 // Find the position relative to the input tensor's ends. 930 Value left = rewriter.create<linalg::IndexOp>(loc, i); 931 Value right = rewriter.create<arith::SubIOp>(loc, isize, left); 932 left = rewriter.create<arith::MulIOp>(loc, left, strideVal); 933 right = rewriter.create<arith::MulIOp>(loc, right, strideVal); 934 935 // Determine how much padding was included. 936 val = padFn(val, left, pad[i * 2]); 937 val = padFn(val, right, pad[i * 2 + 1]); 938 return rewriter.create<arith::MaxSIOp>(loc, one, val); 939 }; 940 941 // Compute the indices from either end. 942 Value kH3 = coverageFn(1, iH); 943 Value kW3 = coverageFn(2, iW); 944 945 // Compute the total number of elements and normalize. 946 auto count = rewriter.create<arith::IndexCastOp>( 947 loc, rewriter.getI32Type(), 948 rewriter.create<arith::MulIOp>(loc, kH3, kW3)); 949 950 // Divide by the number of summed values. For floats this is just 951 // a div however for quantized values input normalization had 952 // to be applied. 953 Value poolVal = args[0]; 954 if (isa<FloatType>(accETy)) { 955 auto countF = rewriter.create<arith::SIToFPOp>(loc, accETy, count); 956 poolVal = rewriter.create<arith::DivFOp>(loc, poolVal, countF) 957 ->getResult(0); 958 if (accETy.getIntOrFloatBitWidth() > 959 resultETy.getIntOrFloatBitWidth()) 960 poolVal = 961 rewriter.create<arith::TruncFOp>(loc, resultETy, poolVal); 962 } else { 963 964 // If we have quantization information we need to apply an offset 965 // for the input zp value. 966 if (op.getQuantizationInfo()) { 967 auto quantizationInfo = *op.getQuantizationInfo(); 968 auto inputZp = rewriter.create<arith::ConstantOp>( 969 loc, b.getIntegerAttr(accETy, quantizationInfo.getInputZp())); 970 Value offset = 971 rewriter.create<arith::MulIOp>(loc, accETy, count, inputZp); 972 poolVal = 973 rewriter.create<arith::SubIOp>(loc, accETy, poolVal, offset); 974 } 975 976 // Compute: k = 32 - count_leading_zeros(value - 1) 977 Value one32 = rewriter.create<arith::ConstantOp>( 978 loc, rewriter.getI32IntegerAttr(1)); 979 Value thirtyTwo32 = rewriter.create<arith::ConstantOp>( 980 loc, rewriter.getI32IntegerAttr(32)); 981 982 Value countSubOne = 983 rewriter.create<arith::SubIOp>(loc, count, one32); 984 Value leadingZeros = 985 rewriter.create<math::CountLeadingZerosOp>(loc, countSubOne); 986 Value k = 987 rewriter.create<arith::SubIOp>(loc, thirtyTwo32, leadingZeros); 988 989 // Compute: numerator = ((1 << 30) + 1) << k 990 Value k64 = 991 rewriter.create<arith::ExtUIOp>(loc, rewriter.getI64Type(), k); 992 Value thirtyShiftPlusOne = rewriter.create<arith::ConstantOp>( 993 loc, rewriter.getI64IntegerAttr((1 << 30) + 1)); 994 Value numerator = 995 rewriter.create<arith::ShLIOp>(loc, thirtyShiftPlusOne, k64); 996 997 // Compute: scale.multiplier = numerator / value; 998 Value count64 = rewriter.create<arith::ExtUIOp>( 999 loc, rewriter.getI64Type(), count); 1000 Value multiplier = 1001 rewriter.create<arith::DivUIOp>(loc, numerator, count64); 1002 multiplier = rewriter.create<arith::TruncIOp>( 1003 loc, rewriter.getI32Type(), multiplier); 1004 1005 // Compute: scale.shift = 30 + k 1006 Value k8 = 1007 rewriter.create<arith::TruncIOp>(loc, rewriter.getI8Type(), k); 1008 Value thirty8 = rewriter.create<arith::ConstantOp>( 1009 loc, rewriter.getI8IntegerAttr(30)); 1010 Value shift = rewriter.create<arith::AddIOp>(loc, k8, thirty8); 1011 1012 auto scaled = 1013 rewriter 1014 .create<tosa::ApplyScaleOp>(loc, rewriter.getI32Type(), 1015 poolVal, multiplier, shift, 1016 rewriter.getBoolAttr(false)) 1017 .getResult(); 1018 1019 // If we have quantization information we need to apply output 1020 // zeropoint. 1021 if (op.getQuantizationInfo()) { 1022 auto quantizationInfo = *op.getQuantizationInfo(); 1023 auto outputZp = rewriter.create<arith::ConstantOp>( 1024 loc, b.getIntegerAttr(scaled.getType(), 1025 quantizationInfo.getOutputZp())); 1026 scaled = rewriter.create<arith::AddIOp>(loc, scaled, outputZp) 1027 .getResult(); 1028 } 1029 1030 // Apply Clip. 1031 int64_t outBitwidth = resultETy.getIntOrFloatBitWidth(); 1032 1033 auto min = rewriter.create<arith::ConstantIntOp>( 1034 loc, APInt::getSignedMinValue(outBitwidth).getSExtValue(), 1035 accETy); 1036 auto max = rewriter.create<arith::ConstantIntOp>( 1037 loc, APInt::getSignedMaxValue(outBitwidth).getSExtValue(), 1038 accETy); 1039 auto clamp = clampIntHelper(loc, scaled, min, max, rewriter, 1040 /*isUnsigned=*/false); 1041 1042 poolVal = clamp; 1043 // Convert type. 1044 if (resultETy != clamp.getType()) { 1045 poolVal = 1046 rewriter.create<arith::TruncIOp>(loc, resultETy, poolVal); 1047 } 1048 } 1049 1050 rewriter.create<linalg::YieldOp>(loc, poolVal); 1051 }); 1052 1053 rewriter.replaceOp(op, genericOp.getResult(0)); 1054 return success(); 1055 } 1056 }; 1057 1058 class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> { 1059 public: 1060 using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern; 1061 1062 LogicalResult matchAndRewrite(tosa::TransposeOp op, 1063 PatternRewriter &rewriter) const final { 1064 SmallVector<int32_t> constantPerms; 1065 if (failed(op.getConstantPerms(constantPerms))) 1066 return failure(); 1067 1068 Location loc = op.getLoc(); 1069 // The verifier should have made sure we have a valid TOSA permutation 1070 // tensor. isPermutationVector doesn't actually check the TOSA perms we 1071 // expect. 1072 SmallVector<OpFoldResult> inputSizes = 1073 tensor::getMixedSizes(rewriter, loc, op.getInput1()); 1074 auto permutedSizes = 1075 applyTOSAPermutation<OpFoldResult>(inputSizes, constantPerms); 1076 1077 auto permutedInit = rewriter.create<tensor::EmptyOp>( 1078 loc, permutedSizes, op.getInput1().getType().getElementType()); 1079 rewriter.replaceOpWithNewOp<linalg::TransposeOp>( 1080 op, op.getInput1(), permutedInit, 1081 llvm::to_vector(llvm::map_range( 1082 constantPerms, [](int32_t v) -> int64_t { return v; }))); 1083 return success(); 1084 } 1085 }; 1086 } // namespace 1087 1088 void mlir::tosa::populateTosaToLinalgNamedConversionPatterns( 1089 const TypeConverter &converter, RewritePatternSet *patterns, 1090 const TosaToLinalgNamedOptions &options) { 1091 if (options.preferConv2DKernelLayoutHWCF) { 1092 patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp, 1093 linalg::Conv2DNhwcHwcfQOp>>( 1094 patterns->getContext()); 1095 } else { 1096 patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp, 1097 linalg::Conv2DNhwcFhwcQOp>>( 1098 patterns->getContext()); 1099 } 1100 patterns->add< 1101 // clang-format off 1102 ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>, 1103 DepthwiseConvConverter, 1104 MatMulConverter, 1105 AvgPool2dConverter, 1106 FullyConnectedConverter, 1107 TransposeConverter 1108 >(patterns->getContext()); 1109 1110 patterns->add< 1111 MaxPool2dConverter 1112 >(converter, patterns->getContext()); 1113 // clang-format on 1114 } 1115