1 //===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // These rewriters lower from the Tosa to the Linalg dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" 14 #include "mlir/Dialect/Arith/IR/Arith.h" 15 #include "mlir/Dialect/Arith/Utils/Utils.h" 16 #include "mlir/Dialect/Index/IR/IndexOps.h" 17 #include "mlir/Dialect/Linalg/IR/Linalg.h" 18 #include "mlir/Dialect/Math/IR/Math.h" 19 #include "mlir/Dialect/SCF/IR/SCF.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/Dialect/Tosa/IR/TosaOps.h" 22 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" 23 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" 24 #include "mlir/Dialect/Utils/StaticValueUtils.h" 25 #include "mlir/IR/ImplicitLocOpBuilder.h" 26 #include "mlir/IR/Matchers.h" 27 #include "mlir/IR/OpDefinition.h" 28 #include "mlir/IR/PatternMatch.h" 29 #include "mlir/Transforms/DialectConversion.h" 30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 31 #include "llvm/ADT/STLExtras.h" 32 #include "llvm/ADT/Sequence.h" 33 34 #include <numeric> 35 36 using namespace mlir; 37 using namespace mlir::tosa; 38 39 template <typename T> 40 static arith::ConstantOp 41 createConstFromIntAttribute(Operation *op, const std::string &attrName, 42 Type requiredAttrType, OpBuilder &rewriter) { 43 auto castedN = static_cast<T>( 44 cast<IntegerAttr>(op->getAttr(attrName)).getValue().getSExtValue()); 45 return rewriter.create<arith::ConstantOp>( 46 op->getLoc(), IntegerAttr::get(requiredAttrType, castedN)); 47 } 48 49 static Value createLinalgBodyCalculationForElementwiseOp( 50 Operation *op, ValueRange args, ArrayRef<Type> resultTypes, 51 ConversionPatternRewriter &rewriter) { 52 Location loc = op->getLoc(); 53 auto elementTy = 54 cast<ShapedType>(op->getOperand(0).getType()).getElementType(); 55 56 // tosa::AbsOp 57 if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy)) 58 return rewriter.create<math::AbsFOp>(loc, resultTypes, args); 59 60 if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) { 61 auto zero = rewriter.create<arith::ConstantOp>( 62 loc, rewriter.getZeroAttr(elementTy)); 63 auto neg = rewriter.create<arith::SubIOp>(loc, zero, args[0]); 64 return rewriter.create<arith::MaxSIOp>(loc, args[0], neg); 65 } 66 67 // tosa::AddOp 68 if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy)) 69 return rewriter.create<arith::AddFOp>(loc, resultTypes, args); 70 71 if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy)) 72 return rewriter.create<arith::AddIOp>(loc, resultTypes, args); 73 74 // tosa::SubOp 75 if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy)) 76 return rewriter.create<arith::SubFOp>(loc, resultTypes, args); 77 78 if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy)) 79 return rewriter.create<arith::SubIOp>(loc, resultTypes, args); 80 81 // tosa::IntDivOp 82 if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy)) 83 return rewriter.create<arith::DivSIOp>(loc, resultTypes, args); 84 85 // tosa::ReciprocalOp 86 if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) { 87 auto one = 88 rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1)); 89 return rewriter.create<arith::DivFOp>(loc, resultTypes, one, args[0]); 90 } 91 92 // tosa::MulOp 93 if (isa<tosa::MulOp>(op)) { 94 auto shift_val = cast<tosa::MulOp>(op).getShift(); 95 96 if (isa<FloatType>(elementTy)) { 97 return rewriter.create<arith::MulFOp>(loc, resultTypes, args[0], args[1]); 98 } 99 100 if (isa<IntegerType>(elementTy)) { 101 int32_t shift = 0; 102 ElementsAttr shift_elem; 103 if (shift_val.getImpl() && 104 matchPattern(shift_val, m_Constant(&shift_elem))) { 105 // Explicit shift is set. 106 shift = shift_elem.getValues<IntegerAttr>()[0].getInt(); 107 } 108 109 Value a = args[0]; 110 Value b = args[1]; 111 if (shift > 0) { 112 auto shiftConst = 113 rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8); 114 if (!a.getType().isInteger(32)) 115 a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a); 116 117 if (!b.getType().isInteger(32)) 118 b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b); 119 120 auto result = rewriter.create<tosa::ApplyScaleOp>( 121 loc, rewriter.getI32Type(), a, b, shiftConst, 122 rewriter.getBoolAttr(false)); 123 124 if (elementTy.isInteger(32)) 125 return result; 126 127 return rewriter.create<arith::TruncIOp>(loc, elementTy, result); 128 } 129 130 int aWidth = a.getType().getIntOrFloatBitWidth(); 131 int bWidth = b.getType().getIntOrFloatBitWidth(); 132 int cWidth = resultTypes[0].getIntOrFloatBitWidth(); 133 134 if (aWidth < cWidth) 135 a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a); 136 if (bWidth < cWidth) 137 b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b); 138 139 return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b); 140 } 141 } 142 143 // tosa::NegateOp 144 if (isa<tosa::NegateOp>(op) && isa<FloatType>(elementTy)) 145 return rewriter.create<arith::NegFOp>(loc, resultTypes, args); 146 147 if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy)) { 148 int64_t inZp = 0, outZp = 0; 149 150 if (cast<tosa::NegateOp>(op).getQuantizationInfo()) { 151 auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo(); 152 inZp = quantizationInfo.value().getInputZp(); 153 outZp = quantizationInfo.value().getOutputZp(); 154 } 155 156 int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth(); 157 if (!inZp && !outZp) { 158 auto constant = rewriter.create<arith::ConstantOp>( 159 loc, IntegerAttr::get(elementTy, 0)); 160 return rewriter.create<arith::SubIOp>(loc, resultTypes, constant, 161 args[0]); 162 } 163 164 // Compute the maximum value that can occur in the intermediate buffer. 165 int64_t zpAdd = inZp + outZp; 166 int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() + 167 std::abs(zpAdd) + 1; 168 169 // Convert that maximum value into the maximum bitwidth needed to represent 170 // it. We assume 48-bit numbers may be supported further in the pipeline. 171 int intermediateBitWidth = 64; 172 if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) { 173 intermediateBitWidth = 16; 174 } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) { 175 intermediateBitWidth = 32; 176 } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) { 177 intermediateBitWidth = 48; 178 } 179 180 Type intermediateType = rewriter.getIntegerType(intermediateBitWidth); 181 Value zpAddValue = rewriter.create<arith::ConstantOp>( 182 loc, rewriter.getIntegerAttr(intermediateType, zpAdd)); 183 184 // The negation can be applied by doing: 185 // outputValue = inZp + outZp - inputValue 186 auto ext = rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]); 187 auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext); 188 189 // Clamp to the negation range. 190 Value min = rewriter.create<arith::ConstantIntOp>( 191 loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(), 192 intermediateType); 193 Value max = rewriter.create<arith::ConstantIntOp>( 194 loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(), 195 intermediateType); 196 auto clamp = 197 clampIntHelper(loc, sub, min, max, rewriter, /*isUnsigned=*/false); 198 199 // Truncate to the final value. 200 return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp); 201 } 202 203 // tosa::BitwiseAndOp 204 if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy)) 205 return rewriter.create<arith::AndIOp>(loc, resultTypes, args); 206 207 // tosa::BitwiseOrOp 208 if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy)) 209 return rewriter.create<arith::OrIOp>(loc, resultTypes, args); 210 211 // tosa::BitwiseNotOp 212 if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) { 213 auto allOnesAttr = rewriter.getIntegerAttr( 214 elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth())); 215 auto allOnes = rewriter.create<arith::ConstantOp>(loc, allOnesAttr); 216 return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], allOnes); 217 } 218 219 // tosa::BitwiseXOrOp 220 if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy)) 221 return rewriter.create<arith::XOrIOp>(loc, resultTypes, args); 222 223 // tosa::LogicalLeftShiftOp 224 if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy)) 225 return rewriter.create<arith::ShLIOp>(loc, resultTypes, args); 226 227 // tosa::LogicalRightShiftOp 228 if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy)) 229 return rewriter.create<arith::ShRUIOp>(loc, resultTypes, args); 230 231 // tosa::ArithmeticRightShiftOp 232 if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) { 233 auto result = rewriter.create<arith::ShRSIOp>(loc, resultTypes, args); 234 auto round = cast<BoolAttr>(op->getAttr("round")).getValue(); 235 if (!round) { 236 return result; 237 } 238 239 Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1); 240 auto one = 241 rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1)); 242 auto zero = 243 rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0)); 244 auto i1one = 245 rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(i1Ty, 1)); 246 247 // Checking that input2 != 0 248 auto shiftValueGreaterThanZero = rewriter.create<arith::CmpIOp>( 249 loc, arith::CmpIPredicate::sgt, args[1], zero); 250 251 // Checking for the last bit of input1 to be 1 252 auto subtract = 253 rewriter.create<arith::SubIOp>(loc, resultTypes, args[1], one); 254 auto shifted = 255 rewriter.create<arith::ShRSIOp>(loc, resultTypes, args[0], subtract) 256 ->getResults(); 257 auto truncated = 258 rewriter.create<arith::TruncIOp>(loc, i1Ty, shifted, std::nullopt); 259 auto isInputOdd = 260 rewriter.create<arith::AndIOp>(loc, i1Ty, truncated, i1one); 261 262 auto shouldRound = rewriter.create<arith::AndIOp>( 263 loc, i1Ty, shiftValueGreaterThanZero, isInputOdd); 264 auto extended = 265 rewriter.create<arith::ExtUIOp>(loc, resultTypes, shouldRound); 266 return rewriter.create<arith::AddIOp>(loc, resultTypes, result, extended); 267 } 268 269 // tosa::ClzOp 270 if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) { 271 return rewriter.create<math::CountLeadingZerosOp>(loc, elementTy, args[0]); 272 } 273 274 // tosa::LogicalAnd 275 if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1)) 276 return rewriter.create<arith::AndIOp>(loc, resultTypes, args); 277 278 // tosa::LogicalNot 279 if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) { 280 auto one = rewriter.create<arith::ConstantOp>( 281 loc, rewriter.getIntegerAttr(elementTy, 1)); 282 return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], one); 283 } 284 285 // tosa::LogicalOr 286 if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1)) 287 return rewriter.create<arith::OrIOp>(loc, resultTypes, args); 288 289 // tosa::LogicalXor 290 if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1)) 291 return rewriter.create<arith::XOrIOp>(loc, resultTypes, args); 292 293 // tosa::PowOp 294 if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy)) 295 return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args); 296 297 // tosa::RsqrtOp 298 if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy)) 299 return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args); 300 301 // tosa::LogOp 302 if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy)) 303 return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args); 304 305 // tosa::ExpOp 306 if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy)) 307 return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args); 308 309 // tosa::SinOp 310 if (isa<tosa::SinOp>(op) && isa<FloatType>(elementTy)) 311 return rewriter.create<mlir::math::SinOp>(loc, resultTypes, args); 312 313 // tosa::CosOp 314 if (isa<tosa::CosOp>(op) && isa<FloatType>(elementTy)) 315 return rewriter.create<mlir::math::CosOp>(loc, resultTypes, args); 316 317 // tosa::TanhOp 318 if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy)) 319 return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args); 320 321 // tosa::ErfOp 322 if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy)) 323 return rewriter.create<mlir::math::ErfOp>(loc, resultTypes, args); 324 325 // tosa::GreaterOp 326 if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy)) 327 return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT, 328 args[0], args[1]); 329 330 if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger()) 331 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, 332 args[0], args[1]); 333 334 // tosa::GreaterEqualOp 335 if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy)) 336 return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE, 337 args[0], args[1]); 338 339 if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger()) 340 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, 341 args[0], args[1]); 342 343 // tosa::EqualOp 344 if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy)) 345 return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ, 346 args[0], args[1]); 347 348 if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger()) 349 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 350 args[0], args[1]); 351 352 // tosa::SelectOp 353 if (isa<tosa::SelectOp>(op)) { 354 elementTy = cast<ShapedType>(op->getOperand(1).getType()).getElementType(); 355 if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy)) 356 return rewriter.create<arith::SelectOp>(loc, args[0], args[1], args[2]); 357 } 358 359 // tosa::MaximumOp 360 if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) { 361 return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]); 362 } 363 364 if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) { 365 return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]); 366 } 367 368 // tosa::MinimumOp 369 if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) { 370 return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]); 371 } 372 373 if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) { 374 return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]); 375 } 376 377 // tosa::CeilOp 378 if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy)) 379 return rewriter.create<math::CeilOp>(loc, resultTypes, args); 380 381 // tosa::FloorOp 382 if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy)) 383 return rewriter.create<math::FloorOp>(loc, resultTypes, args); 384 385 // tosa::ClampOp 386 if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) { 387 bool losesInfo = false; 388 APFloat minApf = cast<FloatAttr>(op->getAttr("min_fp")).getValue(); 389 APFloat maxApf = cast<FloatAttr>(op->getAttr("max_fp")).getValue(); 390 minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(), 391 APFloat::rmNearestTiesToEven, &losesInfo); 392 maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(), 393 APFloat::rmNearestTiesToEven, &losesInfo); 394 auto min = rewriter.create<arith::ConstantOp>( 395 loc, elementTy, rewriter.getFloatAttr(elementTy, minApf)); 396 auto max = rewriter.create<arith::ConstantOp>( 397 loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf)); 398 return clampFloatHelper(loc, args[0], min, max, rewriter); 399 } 400 401 if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) { 402 auto intTy = cast<IntegerType>(elementTy); 403 int64_t min = 404 cast<IntegerAttr>(op->getAttr("min_int")).getValue().getSExtValue(); 405 int64_t max = 406 cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue(); 407 408 int64_t minRepresentable = std::numeric_limits<int64_t>::min(); 409 int64_t maxRepresentable = std::numeric_limits<int64_t>::max(); 410 if (intTy.isUnsignedInteger()) { 411 minRepresentable = 0; 412 if (intTy.getIntOrFloatBitWidth() <= 63) { 413 maxRepresentable = 414 (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth()) 415 .getZExtValue(); 416 } 417 } else if (intTy.getIntOrFloatBitWidth() <= 64) { 418 // Ensure that min & max fit into signed n-bit constants. 419 minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth()) 420 .getSExtValue(); 421 maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth()) 422 .getSExtValue(); 423 } 424 // Ensure that the bounds are representable as n-bit signed/unsigned 425 // integers. 426 min = std::max(min, minRepresentable); 427 max = std::max(max, minRepresentable); 428 min = std::min(min, maxRepresentable); 429 max = std::min(max, maxRepresentable); 430 431 auto minVal = rewriter.create<arith::ConstantIntOp>( 432 loc, min, intTy.getIntOrFloatBitWidth()); 433 auto maxVal = rewriter.create<arith::ConstantIntOp>( 434 loc, max, intTy.getIntOrFloatBitWidth()); 435 return clampIntHelper(loc, args[0], minVal, maxVal, rewriter, 436 intTy.isUnsignedInteger()); 437 } 438 439 // tosa::SigmoidOp 440 if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) { 441 auto one = 442 rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1)); 443 auto negate = rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]); 444 auto exp = rewriter.create<mlir::math::ExpOp>(loc, resultTypes, negate); 445 auto added = rewriter.create<arith::AddFOp>(loc, resultTypes, exp, one); 446 return rewriter.create<arith::DivFOp>(loc, resultTypes, one, added); 447 } 448 449 // tosa::CastOp 450 if (isa<tosa::CastOp>(op)) { 451 Type srcTy = elementTy; 452 Type dstTy = resultTypes.front(); 453 bool bitExtend = 454 srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth(); 455 456 if (srcTy == dstTy) 457 return args.front(); 458 459 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend) 460 return rewriter.create<arith::ExtFOp>(loc, resultTypes, args, 461 std::nullopt); 462 463 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend) 464 return rewriter.create<arith::TruncFOp>(loc, resultTypes, args, 465 std::nullopt); 466 467 // 1-bit integers need to be treated as signless. 468 if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy)) 469 return rewriter.create<arith::UIToFPOp>(loc, resultTypes, args, 470 std::nullopt); 471 472 if (srcTy.isInteger(1) && isa<IntegerType>(dstTy) && bitExtend) 473 return rewriter.create<arith::ExtUIOp>(loc, resultTypes, args, 474 std::nullopt); 475 476 // Unsigned integers need an unrealized cast so that they can be passed 477 // to UIToFP. 478 if (srcTy.isUnsignedInteger() && isa<FloatType>(dstTy)) { 479 auto unrealizedCast = 480 rewriter 481 .create<UnrealizedConversionCastOp>( 482 loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()), 483 args[0]) 484 .getResult(0); 485 return rewriter.create<arith::UIToFPOp>(loc, resultTypes[0], 486 unrealizedCast); 487 } 488 489 // All other si-to-fp conversions should be handled by SIToFP. 490 if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy)) 491 return rewriter.create<arith::SIToFPOp>(loc, resultTypes, args, 492 std::nullopt); 493 494 // Casting to boolean, floats need to only be checked as not-equal to zero. 495 if (isa<FloatType>(srcTy) && dstTy.isInteger(1)) { 496 Value zero = rewriter.create<arith::ConstantOp>( 497 loc, rewriter.getFloatAttr(srcTy, 0.0)); 498 return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, 499 args.front(), zero); 500 } 501 502 if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) { 503 auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]); 504 505 const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics(); 506 // Check whether neither int min nor int max can be represented in the 507 // input floating-point type due to too short exponent range. 508 if (static_cast<int>(dstTy.getIntOrFloatBitWidth()) - 1 > 509 APFloat::semanticsMaxExponent(fltSemantics)) { 510 // Use cmp + select to replace infinites by int min / int max. Other 511 // integral values can be represented in the integer space. 512 auto conv = rewriter.create<arith::FPToSIOp>(loc, dstTy, rounded); 513 auto posInf = rewriter.create<arith::ConstantOp>( 514 loc, rewriter.getFloatAttr(getElementTypeOrSelf(srcTy), 515 APFloat::getInf(fltSemantics))); 516 auto negInf = rewriter.create<arith::ConstantOp>( 517 loc, rewriter.getFloatAttr( 518 getElementTypeOrSelf(srcTy), 519 APFloat::getInf(fltSemantics, /*Negative=*/true))); 520 auto overflow = rewriter.create<arith::CmpFOp>( 521 loc, arith::CmpFPredicate::UEQ, rounded, posInf); 522 auto underflow = rewriter.create<arith::CmpFOp>( 523 loc, arith::CmpFPredicate::UEQ, rounded, negInf); 524 auto intMin = rewriter.create<arith::ConstantOp>( 525 loc, rewriter.getIntegerAttr( 526 getElementTypeOrSelf(dstTy), 527 APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()))); 528 auto intMax = rewriter.create<arith::ConstantOp>( 529 loc, rewriter.getIntegerAttr( 530 getElementTypeOrSelf(dstTy), 531 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()))); 532 auto maxClamped = 533 rewriter.create<arith::SelectOp>(loc, overflow, intMax, conv); 534 return rewriter.create<arith::SelectOp>(loc, underflow, intMin, 535 maxClamped); 536 } 537 538 auto intMinFP = rewriter.create<arith::ConstantOp>( 539 loc, rewriter.getFloatAttr( 540 getElementTypeOrSelf(srcTy), 541 APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()) 542 .getSExtValue())); 543 544 // Check whether the mantissa has enough bits to represent int max. 545 if (cast<FloatType>(srcTy).getFPMantissaWidth() >= 546 dstTy.getIntOrFloatBitWidth() - 1) { 547 // Int min can also be represented since it is a power of two and thus 548 // consists of a single leading bit. Therefore we can clamp the input 549 // in the floating-point domain. 550 551 auto intMaxFP = rewriter.create<arith::ConstantOp>( 552 loc, rewriter.getFloatAttr( 553 getElementTypeOrSelf(srcTy), 554 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()) 555 .getSExtValue())); 556 557 Value clamped = 558 clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter); 559 return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped); 560 } 561 562 // Due to earlier check we know exponant range is big enough to represent 563 // int min. We can therefore rely on int max + 1 being representable as 564 // well because it's just int min with a positive sign. So clamp the min 565 // value and compare against that to select the max int value if needed. 566 auto intMaxPlusOneFP = rewriter.create<arith::ConstantOp>( 567 loc, rewriter.getFloatAttr( 568 getElementTypeOrSelf(srcTy), 569 static_cast<double>( 570 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()) 571 .getSExtValue()) + 572 1.0f)); 573 574 auto intMax = rewriter.create<arith::ConstantOp>( 575 loc, rewriter.getIntegerAttr( 576 getElementTypeOrSelf(dstTy), 577 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()))); 578 auto minClampedFP = 579 rewriter.create<arith::MaximumFOp>(loc, rounded, intMinFP); 580 auto minClamped = 581 rewriter.create<arith::FPToSIOp>(loc, dstTy, minClampedFP); 582 auto overflow = rewriter.create<arith::CmpFOp>( 583 loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP); 584 return rewriter.create<arith::SelectOp>(loc, overflow, intMax, 585 minClamped); 586 } 587 588 // Casting to boolean, integers need to only be checked as not-equal to 589 // zero. 590 if (isa<IntegerType>(srcTy) && dstTy.isInteger(1)) { 591 Value zero = rewriter.create<arith::ConstantIntOp>( 592 loc, 0, srcTy.getIntOrFloatBitWidth()); 593 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, 594 args.front(), zero); 595 } 596 597 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend) 598 return rewriter.create<arith::ExtSIOp>(loc, resultTypes, args, 599 std::nullopt); 600 601 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) { 602 return rewriter.create<arith::TruncIOp>(loc, dstTy, args[0]); 603 } 604 } 605 606 (void)rewriter.notifyMatchFailure( 607 op, "unhandled op for linalg body calculation for elementwise op"); 608 return nullptr; 609 } 610 611 static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor, 612 int64_t rank) { 613 // No need to expand if we are already at the desired rank 614 auto tensorType = dyn_cast<RankedTensorType>(tensor.getType()); 615 assert(tensorType && "expected a ranked tensor type"); 616 int64_t tensorRank = tensorType.getRank(); 617 int64_t numExtraDims = rank - tensorRank; 618 assert(numExtraDims >= 0 && "cannot expand tensor to a lower rank"); 619 if (!numExtraDims) 620 return tensor; 621 622 // Compute reassociation indices 623 SmallVector<ReassociationIndices> reassociationIndices(tensorRank); 624 int64_t index = 0; 625 if (tensorRank != 0) { 626 for (index = 0; index <= numExtraDims; index++) 627 reassociationIndices[0].push_back(index); 628 for (size_t position = 1; position < reassociationIndices.size(); 629 position++) 630 reassociationIndices[position].push_back(index++); 631 } 632 633 // Compute result type 634 SmallVector<int64_t> resultShape; 635 for (index = 0; index < numExtraDims; index++) 636 resultShape.push_back(1); 637 for (auto size : tensorType.getShape()) 638 resultShape.push_back(size); 639 auto resultType = 640 RankedTensorType::get(resultShape, tensorType.getElementType()); 641 642 // Emit 'tensor.expand_shape' op 643 return rewriter.create<tensor::ExpandShapeOp>(loc, resultType, tensor, 644 reassociationIndices); 645 } 646 647 static SmallVector<Value> expandInputRanks(PatternRewriter &rewriter, 648 Location loc, ValueRange operands, 649 int64_t rank) { 650 return llvm::map_to_vector(operands, [&](Value operand) { 651 return expandRank(rewriter, loc, operand, rank); 652 }); 653 } 654 655 using IndexPool = DenseMap<int64_t, Value>; 656 657 // Emit an 'arith.constant' op for the given index if it has not been created 658 // yet, or return an existing constant. This will prevent an excessive creation 659 // of redundant constants, easing readability of emitted code for unit tests. 660 static Value createIndex(PatternRewriter &rewriter, Location loc, 661 IndexPool &indexPool, int64_t index) { 662 auto [it, inserted] = indexPool.try_emplace(index); 663 if (inserted) 664 it->second = 665 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(index)); 666 return it->second; 667 } 668 669 static Value getTensorDim(PatternRewriter &rewriter, Location loc, 670 IndexPool &indexPool, Value tensor, int64_t index) { 671 auto indexValue = createIndex(rewriter, loc, indexPool, index); 672 return rewriter.create<tensor::DimOp>(loc, tensor, indexValue).getResult(); 673 } 674 675 static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, 676 IndexPool &indexPool, Value tensor, 677 int64_t index) { 678 auto shapedType = dyn_cast<ShapedType>(tensor.getType()); 679 assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type"); 680 assert(index >= 0 && index < shapedType.getRank() && "index out of bounds"); 681 if (shapedType.isDynamicDim(index)) 682 return getTensorDim(rewriter, loc, indexPool, tensor, index); 683 return rewriter.getIndexAttr(shapedType.getDimSize(index)); 684 } 685 686 static bool operandsAndResultsRanked(Operation *operation) { 687 auto isRanked = [](Value value) { 688 return isa<RankedTensorType>(value.getType()); 689 }; 690 return llvm::all_of(operation->getOperands(), isRanked) && 691 llvm::all_of(operation->getResults(), isRanked); 692 } 693 694 // Compute the runtime dimension size for dimension 'dim' of the output by 695 // inspecting input 'operands', all of which are expected to have the same rank. 696 // This function returns a pair {targetSize, masterOperand}. 697 // 698 // The runtime size of the output dimension is returned either as a statically 699 // computed attribute or as a runtime SSA value. 700 // 701 // If the target size was inferred directly from one dominating operand, that 702 // operand is returned in 'masterOperand'. If the target size is inferred from 703 // multiple operands, 'masterOperand' is set to nullptr. 704 static std::pair<OpFoldResult, Value> 705 computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, 706 ValueRange operands, int64_t dim) { 707 // If any input operand contains a static size greater than 1 for this 708 // dimension, that is the target size. An occurrence of an additional static 709 // dimension greater than 1 with a different value is undefined behavior. 710 for (auto operand : operands) { 711 auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim); 712 if (!ShapedType::isDynamic(size) && size > 1) 713 return {rewriter.getIndexAttr(size), operand}; 714 } 715 716 // Filter operands with dynamic dimension 717 auto operandsWithDynamicDim = 718 llvm::filter_to_vector(operands, [&](Value operand) { 719 return cast<RankedTensorType>(operand.getType()).isDynamicDim(dim); 720 }); 721 722 // If no operand has a dynamic dimension, it means all sizes were 1 723 if (operandsWithDynamicDim.empty()) 724 return {rewriter.getIndexAttr(1), operands.front()}; 725 726 // Emit code that computes the runtime size for this dimension. If there is 727 // only one operand with a dynamic dimension, it is considered the master 728 // operand that determines the runtime size of the output dimension. 729 auto targetSize = 730 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim); 731 if (operandsWithDynamicDim.size() == 1) 732 return {targetSize, operandsWithDynamicDim[0]}; 733 734 // Calculate maximum size among all dynamic dimensions 735 for (size_t i = 1; i < operandsWithDynamicDim.size(); i++) { 736 auto nextSize = 737 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim); 738 targetSize = rewriter.create<arith::MaxUIOp>(loc, targetSize, nextSize); 739 } 740 return {targetSize, nullptr}; 741 } 742 743 // Compute the runtime output size for all dimensions. This function returns 744 // a pair {targetShape, masterOperands}. 745 static std::pair<SmallVector<OpFoldResult>, SmallVector<Value>> 746 computeTargetShape(PatternRewriter &rewriter, Location loc, 747 IndexPool &indexPool, ValueRange operands) { 748 assert(!operands.empty()); 749 auto rank = cast<RankedTensorType>(operands.front().getType()).getRank(); 750 SmallVector<OpFoldResult> targetShape; 751 SmallVector<Value> masterOperands; 752 for (auto dim : llvm::seq<int64_t>(0, rank)) { 753 auto [targetSize, masterOperand] = 754 computeTargetSize(rewriter, loc, indexPool, operands, dim); 755 targetShape.push_back(targetSize); 756 masterOperands.push_back(masterOperand); 757 } 758 return {targetShape, masterOperands}; 759 } 760 761 static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, 762 IndexPool &indexPool, Value operand, 763 int64_t dim, OpFoldResult targetSize, 764 Value masterOperand) { 765 // Nothing to do if this is a static dimension 766 auto rankedTensorType = cast<RankedTensorType>(operand.getType()); 767 if (!rankedTensorType.isDynamicDim(dim)) 768 return operand; 769 770 // If the target size for this dimension was directly inferred by only taking 771 // this operand into account, there is no need to broadcast. This is an 772 // optimization that will prevent redundant control flow, and constitutes the 773 // main motivation for tracking "master operands". 774 if (operand == masterOperand) 775 return operand; 776 777 // Affine maps for 'linalg.generic' op 778 auto rank = rankedTensorType.getRank(); 779 SmallVector<AffineExpr> affineExprs; 780 for (auto index : llvm::seq<int64_t>(0, rank)) { 781 auto affineExpr = index == dim ? rewriter.getAffineConstantExpr(0) 782 : rewriter.getAffineDimExpr(index); 783 affineExprs.push_back(affineExpr); 784 } 785 auto broadcastAffineMap = 786 AffineMap::get(rank, 0, affineExprs, rewriter.getContext()); 787 auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank); 788 SmallVector<AffineMap> affineMaps = {broadcastAffineMap, identityAffineMap}; 789 790 // Check if broadcast is necessary 791 auto one = createIndex(rewriter, loc, indexPool, 1); 792 auto runtimeSize = getTensorDim(rewriter, loc, indexPool, operand, dim); 793 auto broadcastNecessary = rewriter.create<arith::CmpIOp>( 794 loc, arith::CmpIPredicate::eq, runtimeSize, one); 795 796 // Emit 'then' region of 'scf.if' 797 auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) { 798 // It is not safe to cache constants across regions. 799 // New constants could potentially violate dominance requirements. 800 IndexPool localPool; 801 802 // Emit 'tensor.empty' op 803 SmallVector<OpFoldResult> outputTensorShape; 804 for (auto index : llvm::seq<int64_t>(0, rank)) { 805 auto size = index == dim ? targetSize 806 : getOrFoldTensorDim(rewriter, loc, localPool, 807 operand, index); 808 outputTensorShape.push_back(size); 809 } 810 Value outputTensor = opBuilder.create<tensor::EmptyOp>( 811 loc, outputTensorShape, rankedTensorType.getElementType()); 812 813 // Emit 'linalg.generic' op 814 auto resultTensor = 815 opBuilder 816 .create<linalg::GenericOp>( 817 loc, outputTensor.getType(), operand, outputTensor, affineMaps, 818 getNParallelLoopsAttrs(rank), 819 [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) { 820 // Emit 'linalg.yield' op 821 opBuilder.create<linalg::YieldOp>(loc, blockArgs.front()); 822 }) 823 .getResult(0); 824 825 // Cast to original operand type if necessary 826 auto castResultTensor = rewriter.createOrFold<tensor::CastOp>( 827 loc, operand.getType(), resultTensor); 828 829 // Emit 'scf.yield' op 830 opBuilder.create<scf::YieldOp>(loc, castResultTensor); 831 }; 832 833 // Emit 'else' region of 'scf.if' 834 auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) { 835 opBuilder.create<scf::YieldOp>(loc, operand); 836 }; 837 838 // Emit 'scf.if' op 839 auto ifOp = rewriter.create<scf::IfOp>(loc, broadcastNecessary, 840 emitThenRegion, emitElseRegion); 841 return ifOp.getResult(0); 842 } 843 844 static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, 845 IndexPool &indexPool, Value operand, 846 ArrayRef<OpFoldResult> targetShape, 847 ArrayRef<Value> masterOperands) { 848 int64_t rank = cast<RankedTensorType>(operand.getType()).getRank(); 849 assert((int64_t)targetShape.size() == rank); 850 assert((int64_t)masterOperands.size() == rank); 851 for (auto index : llvm::seq<int64_t>(0, rank)) 852 operand = 853 broadcastDynamicDimension(rewriter, loc, indexPool, operand, index, 854 targetShape[index], masterOperands[index]); 855 return operand; 856 } 857 858 static SmallVector<Value> 859 broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, 860 IndexPool &indexPool, ValueRange operands, 861 ArrayRef<OpFoldResult> targetShape, 862 ArrayRef<Value> masterOperands) { 863 // No need to broadcast for unary operations 864 if (operands.size() == 1) 865 return operands; 866 867 // Broadcast dynamic dimensions operand by operand 868 return llvm::map_to_vector(operands, [&](Value operand) { 869 return broadcastDynamicDimensions(rewriter, loc, indexPool, operand, 870 targetShape, masterOperands); 871 }); 872 } 873 874 static LogicalResult 875 emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, 876 Operation *operation, ValueRange operands, 877 ArrayRef<OpFoldResult> targetShape, 878 const TypeConverter &converter) { 879 // Generate output tensor 880 auto resultType = cast_or_null<RankedTensorType>( 881 converter.convertType(operation->getResultTypes().front())); 882 if (!resultType) { 883 return rewriter.notifyMatchFailure(operation, "failed to convert type"); 884 } 885 Value outputTensor = rewriter.create<tensor::EmptyOp>( 886 loc, targetShape, resultType.getElementType()); 887 888 // Create affine maps. Input affine maps broadcast static dimensions of size 889 // 1. The output affine map is an identity map. 890 // 891 auto rank = resultType.getRank(); 892 auto affineMaps = llvm::map_to_vector(operands, [&](Value operand) { 893 auto shape = cast<ShapedType>(operand.getType()).getShape(); 894 SmallVector<AffineExpr> affineExprs; 895 for (auto it : llvm::enumerate(shape)) { 896 // Prefer producting identity maps whenever possible (i.e. no broadcasting 897 // needed) because some transforms (like reshape folding) 898 // do not support affine constant exprs. 899 bool requiresBroadcast = 900 (it.value() == 1 && resultType.getDimSize(it.index()) != 1); 901 auto affineExpr = requiresBroadcast 902 ? rewriter.getAffineConstantExpr(0) 903 : rewriter.getAffineDimExpr(it.index()); 904 affineExprs.push_back(affineExpr); 905 } 906 return AffineMap::get(rank, 0, affineExprs, rewriter.getContext()); 907 }); 908 affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank)); 909 910 // Emit 'linalg.generic' op 911 bool encounteredError = false; 912 auto linalgOp = rewriter.create<linalg::GenericOp>( 913 loc, outputTensor.getType(), operands, outputTensor, affineMaps, 914 getNParallelLoopsAttrs(rank), 915 [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) { 916 Value opResult = createLinalgBodyCalculationForElementwiseOp( 917 operation, blockArgs.take_front(operation->getNumOperands()), 918 {resultType.getElementType()}, rewriter); 919 if (!opResult) { 920 encounteredError = true; 921 return; 922 } 923 opBuilder.create<linalg::YieldOp>(loc, opResult); 924 }); 925 if (encounteredError) 926 return rewriter.notifyMatchFailure( 927 operation, "unable to create linalg.generic body for elementwise op"); 928 929 // Cast 'linalg.generic' result into original result type if needed 930 auto castResult = rewriter.createOrFold<tensor::CastOp>( 931 loc, resultType, linalgOp->getResult(0)); 932 rewriter.replaceOp(operation, castResult); 933 return success(); 934 } 935 936 static LogicalResult 937 elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, 938 ConversionPatternRewriter &rewriter, 939 const TypeConverter &converter) { 940 941 // Collect op properties 942 assert(operation->getNumResults() == 1 && "elementwise op expects 1 result"); 943 assert(operation->getNumOperands() >= 1 && 944 "elementwise op expects at least 1 operand"); 945 if (!operandsAndResultsRanked(operation)) 946 return rewriter.notifyMatchFailure(operation, 947 "Unranked tensors not supported"); 948 949 // Lower operation 950 IndexPool indexPool; 951 auto loc = operation->getLoc(); 952 auto rank = 953 cast<RankedTensorType>(operation->getResultTypes().front()).getRank(); 954 // For the mul op we need to avoid expanding the rank of the optional shift 955 // input. 956 auto operandsToExpand = 957 isa<tosa::MulOp>(operation) ? operands.take_front(2) : operands; 958 959 auto expandedOperands = 960 expandInputRanks(rewriter, loc, operandsToExpand, rank); 961 auto [targetShape, masterOperands] = 962 computeTargetShape(rewriter, loc, indexPool, expandedOperands); 963 auto broadcastOperands = broadcastDynamicDimensions( 964 rewriter, loc, indexPool, expandedOperands, targetShape, masterOperands); 965 return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands, 966 targetShape, converter); 967 } 968 969 // Returns the constant initial value for a given reduction operation. The 970 // attribute type varies depending on the element type required. 971 static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, 972 PatternRewriter &rewriter) { 973 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) 974 return rewriter.getFloatAttr(elementTy, 0.0); 975 976 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) 977 return rewriter.getIntegerAttr(elementTy, 0); 978 979 if (isa<tosa::ReduceProdOp>(op) && isa<FloatType>(elementTy)) 980 return rewriter.getFloatAttr(elementTy, 1.0); 981 982 if (isa<tosa::ReduceProdOp>(op) && isa<IntegerType>(elementTy)) 983 return rewriter.getIntegerAttr(elementTy, 1); 984 985 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) 986 return rewriter.getFloatAttr( 987 elementTy, APFloat::getLargest( 988 cast<FloatType>(elementTy).getFloatSemantics(), false)); 989 990 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) 991 return rewriter.getIntegerAttr( 992 elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth())); 993 994 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) 995 return rewriter.getFloatAttr( 996 elementTy, APFloat::getLargest( 997 cast<FloatType>(elementTy).getFloatSemantics(), true)); 998 999 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) 1000 return rewriter.getIntegerAttr( 1001 elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())); 1002 1003 if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1)) 1004 return rewriter.getIntegerAttr(elementTy, APInt::getAllOnes(1)); 1005 1006 if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1)) 1007 return rewriter.getIntegerAttr(elementTy, APInt::getZero(1)); 1008 1009 if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy)) 1010 return rewriter.getFloatAttr( 1011 elementTy, APFloat::getLargest( 1012 cast<FloatType>(elementTy).getFloatSemantics(), true)); 1013 1014 if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy)) 1015 return rewriter.getIntegerAttr( 1016 elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())); 1017 1018 return {}; 1019 } 1020 1021 // Creates the body calculation for a reduction. The operations vary depending 1022 // on the input type. 1023 static Value createLinalgBodyCalculationForReduceOp(Operation *op, 1024 ValueRange args, 1025 Type elementTy, 1026 PatternRewriter &rewriter) { 1027 Location loc = op->getLoc(); 1028 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) { 1029 return rewriter.create<arith::AddFOp>(loc, args); 1030 } 1031 1032 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) { 1033 return rewriter.create<arith::AddIOp>(loc, args); 1034 } 1035 1036 if (isa<tosa::ReduceProdOp>(op) && isa<FloatType>(elementTy)) { 1037 return rewriter.create<arith::MulFOp>(loc, args); 1038 } 1039 1040 if (isa<tosa::ReduceProdOp>(op) && isa<IntegerType>(elementTy)) { 1041 return rewriter.create<arith::MulIOp>(loc, args); 1042 } 1043 1044 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) { 1045 return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]); 1046 } 1047 1048 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) { 1049 return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]); 1050 } 1051 1052 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) { 1053 return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]); 1054 } 1055 1056 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) { 1057 return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]); 1058 } 1059 1060 if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1)) 1061 return rewriter.create<arith::AndIOp>(loc, args); 1062 1063 if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1)) 1064 return rewriter.create<arith::OrIOp>(loc, args); 1065 1066 return {}; 1067 } 1068 1069 // Performs the match and rewrite for reduction operations. This includes 1070 // declaring a correctly sized initial value, and the linalg.generic operation 1071 // that reduces across the specified axis. 1072 static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis, 1073 PatternRewriter &rewriter) { 1074 auto loc = op->getLoc(); 1075 auto inputTy = cast<ShapedType>(op->getOperand(0).getType()); 1076 auto resultTy = cast<ShapedType>(op->getResult(0).getType()); 1077 auto elementTy = resultTy.getElementType(); 1078 Value input = op->getOperand(0); 1079 1080 SmallVector<int64_t> reduceShape; 1081 SmallVector<Value> dynDims; 1082 for (unsigned i = 0; i < inputTy.getRank(); i++) { 1083 if (axis != i) { 1084 reduceShape.push_back(inputTy.getDimSize(i)); 1085 if (inputTy.isDynamicDim(i)) 1086 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i)); 1087 } 1088 } 1089 1090 // First fill the output buffer with the init value. 1091 auto emptyTensor = 1092 rewriter 1093 .create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(), 1094 dynDims) 1095 .getResult(); 1096 1097 auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter); 1098 if (!fillValueAttr) 1099 return rewriter.notifyMatchFailure( 1100 op, "No initial value found for reduction operation"); 1101 1102 auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr); 1103 auto filledTensor = rewriter 1104 .create<linalg::FillOp>(loc, ValueRange{fillValue}, 1105 ValueRange{emptyTensor}) 1106 .result(); 1107 1108 bool didEncounterError = false; 1109 auto linalgOp = rewriter.create<linalg::ReduceOp>( 1110 loc, input, filledTensor, axis, 1111 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { 1112 auto result = createLinalgBodyCalculationForReduceOp( 1113 op, blockArgs, elementTy, rewriter); 1114 if (result) 1115 didEncounterError = true; 1116 1117 nestedBuilder.create<linalg::YieldOp>(loc, result); 1118 }); 1119 1120 if (!didEncounterError) 1121 return rewriter.notifyMatchFailure( 1122 op, "unable to create linalg.generic body for reduce op"); 1123 1124 SmallVector<ReassociationExprs, 4> reassociationMap; 1125 uint64_t expandInputRank = 1126 cast<ShapedType>(linalgOp.getResults()[0].getType()).getRank(); 1127 reassociationMap.resize(expandInputRank); 1128 1129 for (uint64_t i = 0; i < expandInputRank; i++) { 1130 int32_t dimToPush = i > axis ? i + 1 : i; 1131 reassociationMap[i].push_back(rewriter.getAffineDimExpr(dimToPush)); 1132 } 1133 1134 if (expandInputRank != 0) { 1135 int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1; 1136 reassociationMap[expandedDim].push_back( 1137 rewriter.getAffineDimExpr(expandedDim + 1)); 1138 } 1139 1140 // Lower directly to `tensor::ExpandShapeOp` instead of `tosa::ReshapeOp`, 1141 // since here we know which dimension to expand, and `tosa::ReshapeOp` would 1142 // not have access to such information. This matters when handling dynamically 1143 // sized tensors. 1144 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>( 1145 op, resultTy, linalgOp.getResults()[0], reassociationMap); 1146 return success(); 1147 } 1148 1149 namespace { 1150 1151 template <typename SrcOp> 1152 class PointwiseConverter : public OpConversionPattern<SrcOp> { 1153 public: 1154 using OpConversionPattern<SrcOp>::OpConversionPattern; 1155 using typename OpConversionPattern<SrcOp>::OpAdaptor; 1156 1157 LogicalResult 1158 matchAndRewrite(SrcOp op, OpAdaptor operands, 1159 ConversionPatternRewriter &rewriter) const final { 1160 return elementwiseMatchAndRewriteHelper( 1161 op, operands.getOperands(), rewriter, *this->getTypeConverter()); 1162 } 1163 }; 1164 1165 class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> { 1166 public: 1167 using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern; 1168 1169 LogicalResult matchAndRewrite(tosa::RescaleOp op, 1170 PatternRewriter &rewriter) const final { 1171 auto loc = op.getLoc(); 1172 auto input = op.getInput(); 1173 auto inputTy = cast<ShapedType>(op.getInput().getType()); 1174 auto outputTy = cast<ShapedType>(op.getOutput().getType()); 1175 unsigned rank = inputTy.getRank(); 1176 1177 // This is an illegal configuration. terminate and log an error 1178 if (op.getDoubleRound() && !op.getScale32()) 1179 return rewriter.notifyMatchFailure( 1180 op, "tosa.rescale requires scale32 for double_round to be true"); 1181 1182 if (!isa<IntegerType>(inputTy.getElementType())) 1183 return rewriter.notifyMatchFailure(op, "only support integer type"); 1184 1185 SmallVector<Value> dynDims; 1186 for (int i = 0; i < outputTy.getRank(); i++) { 1187 if (outputTy.isDynamicDim(i)) { 1188 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i)); 1189 } 1190 } 1191 1192 // The shift and multiplier values. 1193 SmallVector<int32_t> multiplierValues(op.getMultiplier()); 1194 SmallVector<int8_t> shiftValues(op.getShift()); 1195 1196 // If we shift by more than the bitwidth, this just sets to 0. 1197 for (int i = 0, s = multiplierValues.size(); i < s; i++) { 1198 if (shiftValues[i] > 63) { 1199 shiftValues[i] = 0; 1200 multiplierValues[i] = 0; 1201 } 1202 } 1203 1204 // Double round only occurs if shift is greater than 31, check that this 1205 // is ever true. 1206 bool doubleRound = 1207 op.getDoubleRound() && 1208 llvm::any_of(shiftValues, [](int32_t v) { return v > 31; }); 1209 1210 SmallVector<AffineMap> indexingMaps = { 1211 rewriter.getMultiDimIdentityMap(rank)}; 1212 SmallVector<Value, 4> genericInputs = {input}; 1213 1214 // If we are rescaling per-channel then we need to store the multiplier 1215 // values in a buffer. 1216 Value multiplierConstant; 1217 int64_t multiplierArg = 0; 1218 if (multiplierValues.size() == 1) { 1219 multiplierConstant = rewriter.create<arith::ConstantOp>( 1220 loc, rewriter.getI32IntegerAttr(multiplierValues.front())); 1221 } else { 1222 SmallVector<AffineExpr, 2> multiplierExprs{ 1223 rewriter.getAffineDimExpr(rank - 1)}; 1224 auto multiplierType = 1225 RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())}, 1226 rewriter.getI32Type()); 1227 genericInputs.push_back(rewriter.create<arith::ConstantOp>( 1228 loc, DenseIntElementsAttr::get(multiplierType, multiplierValues))); 1229 1230 indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, 1231 /*symbolCount=*/0, multiplierExprs, 1232 rewriter.getContext())); 1233 1234 multiplierArg = indexingMaps.size() - 1; 1235 } 1236 1237 // If we are rescaling per-channel then we need to store the shift 1238 // values in a buffer. 1239 Value shiftConstant; 1240 int64_t shiftArg = 0; 1241 if (shiftValues.size() == 1) { 1242 shiftConstant = rewriter.create<arith::ConstantOp>( 1243 loc, rewriter.getI8IntegerAttr(shiftValues.front())); 1244 } else { 1245 SmallVector<AffineExpr, 2> shiftExprs = { 1246 rewriter.getAffineDimExpr(rank - 1)}; 1247 auto shiftType = 1248 RankedTensorType::get({static_cast<int64_t>(shiftValues.size())}, 1249 rewriter.getIntegerType(8)); 1250 genericInputs.push_back(rewriter.create<arith::ConstantOp>( 1251 loc, DenseIntElementsAttr::get(shiftType, shiftValues))); 1252 indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, 1253 /*symbolCount=*/0, shiftExprs, 1254 rewriter.getContext())); 1255 shiftArg = indexingMaps.size() - 1; 1256 } 1257 1258 // Indexing maps for output values. 1259 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank)); 1260 1261 // Construct the indexing maps needed for linalg.generic ops. 1262 Value emptyTensor = rewriter.create<tensor::EmptyOp>( 1263 loc, outputTy.getShape(), outputTy.getElementType(), 1264 ArrayRef<Value>({dynDims})); 1265 1266 auto linalgOp = rewriter.create<linalg::GenericOp>( 1267 loc, outputTy, genericInputs, ValueRange{emptyTensor}, indexingMaps, 1268 getNParallelLoopsAttrs(rank), 1269 [&](OpBuilder &nestedBuilder, Location nestedLoc, 1270 ValueRange blockArgs) { 1271 Value value = blockArgs[0]; 1272 Type valueTy = value.getType(); 1273 1274 // For now we do all of our math in 64-bit. This is not optimal but 1275 // should be correct for now, consider computing correct bit depth 1276 // later. 1277 int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32; 1278 1279 auto inputZp = createConstFromIntAttribute<int32_t>( 1280 op, "input_zp", nestedBuilder.getIntegerType(inBitwidth), 1281 nestedBuilder); 1282 auto outputZp = createConstFromIntAttribute<int32_t>( 1283 op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder); 1284 1285 Value multiplier = multiplierConstant ? multiplierConstant 1286 : blockArgs[multiplierArg]; 1287 Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg]; 1288 1289 if (valueTy.getIntOrFloatBitWidth() < 32) { 1290 if (op.getInputUnsigned()) { 1291 value = nestedBuilder.create<arith::ExtUIOp>( 1292 nestedLoc, nestedBuilder.getI32Type(), value); 1293 } else { 1294 value = nestedBuilder.create<arith::ExtSIOp>( 1295 nestedLoc, nestedBuilder.getI32Type(), value); 1296 } 1297 } 1298 1299 value = 1300 nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp); 1301 1302 value = nestedBuilder.create<tosa::ApplyScaleOp>( 1303 loc, nestedBuilder.getI32Type(), value, multiplier, shift, 1304 nestedBuilder.getBoolAttr(doubleRound)); 1305 1306 // Move to the new zero-point. 1307 value = 1308 nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp); 1309 1310 // Saturate to the output size. 1311 IntegerType outIntType = 1312 cast<IntegerType>(blockArgs.back().getType()); 1313 unsigned outBitWidth = outIntType.getWidth(); 1314 1315 int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue(); 1316 int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue(); 1317 1318 // Unsigned integers have a difference output value. 1319 if (op.getOutputUnsigned()) { 1320 intMin = 0; 1321 intMax = APInt::getMaxValue(outBitWidth).getZExtValue(); 1322 } 1323 1324 auto intMinVal = nestedBuilder.create<arith::ConstantOp>( 1325 loc, nestedBuilder.getI32IntegerAttr(intMin)); 1326 auto intMaxVal = nestedBuilder.create<arith::ConstantOp>( 1327 loc, nestedBuilder.getI32IntegerAttr(intMax)); 1328 1329 value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal, 1330 nestedBuilder, /*isUnsigned=*/false); 1331 1332 if (outIntType.getWidth() < 32) { 1333 value = nestedBuilder.create<arith::TruncIOp>( 1334 nestedLoc, rewriter.getIntegerType(outIntType.getWidth()), 1335 value); 1336 } 1337 1338 nestedBuilder.create<linalg::YieldOp>(loc, value); 1339 }); 1340 1341 rewriter.replaceOp(op, linalgOp->getResults()); 1342 return success(); 1343 } 1344 }; 1345 1346 // Handle the resize case where the input is a 1x1 image. This case 1347 // can entirely avoiding having extract operations which target much 1348 // more difficult to optimize away. 1349 class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> { 1350 public: 1351 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern; 1352 1353 LogicalResult matchAndRewrite(tosa::ResizeOp op, 1354 PatternRewriter &rewriter) const final { 1355 Location loc = op.getLoc(); 1356 ImplicitLocOpBuilder builder(loc, rewriter); 1357 auto input = op.getInput(); 1358 auto inputTy = cast<RankedTensorType>(input.getType()); 1359 auto resultTy = cast<RankedTensorType>(op.getType()); 1360 const bool isBilinear = op.getMode() == "BILINEAR"; 1361 1362 auto inputH = inputTy.getDimSize(1); 1363 auto inputW = inputTy.getDimSize(2); 1364 auto outputH = resultTy.getDimSize(1); 1365 auto outputW = resultTy.getDimSize(2); 1366 1367 if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1) 1368 return rewriter.notifyMatchFailure( 1369 op, "tosa.resize is not a pure 1x1->1x1 image operation"); 1370 1371 // TODO(suderman): These string values should be declared the TOSA dialect. 1372 if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR") 1373 return rewriter.notifyMatchFailure( 1374 op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR"); 1375 1376 if (inputTy == resultTy) { 1377 rewriter.replaceOp(op, input); 1378 return success(); 1379 } 1380 1381 ArrayRef<int64_t> scale = op.getScale(); 1382 1383 // Collapse the unit width and height away. 1384 SmallVector<ReassociationExprs, 4> reassociationMap(2); 1385 reassociationMap[0].push_back(builder.getAffineDimExpr(0)); 1386 reassociationMap[1].push_back(builder.getAffineDimExpr(1)); 1387 reassociationMap[1].push_back(builder.getAffineDimExpr(2)); 1388 reassociationMap[1].push_back(builder.getAffineDimExpr(3)); 1389 1390 auto collapseTy = 1391 RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)}, 1392 inputTy.getElementType()); 1393 Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, input, 1394 reassociationMap); 1395 1396 // Get any dynamic shapes that appear in the input format. 1397 llvm::SmallVector<Value> outputDynSize; 1398 if (inputTy.isDynamicDim(0)) 1399 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0)); 1400 if (inputTy.isDynamicDim(3)) 1401 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3)); 1402 1403 // Generate the elementwise operation for casting scaling the input value. 1404 auto genericTy = collapseTy.clone(resultTy.getElementType()); 1405 Value empty = builder.create<tensor::EmptyOp>( 1406 genericTy.getShape(), resultTy.getElementType(), outputDynSize); 1407 auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank()); 1408 SmallVector<utils::IteratorType> iterators(genericTy.getRank(), 1409 utils::IteratorType::parallel); 1410 1411 auto generic = builder.create<linalg::GenericOp>( 1412 genericTy, ValueRange{collapse}, ValueRange{empty}, 1413 ArrayRef<AffineMap>{genericMap, genericMap}, iterators, 1414 [=](OpBuilder &b, Location loc, ValueRange args) { 1415 Value value = args[0]; 1416 // This is the quantized case. 1417 if (inputTy.getElementType() != resultTy.getElementType()) { 1418 value = 1419 b.create<arith::ExtSIOp>(loc, resultTy.getElementType(), value); 1420 1421 if (isBilinear && scale[0] != 0) { 1422 Value scaleY = b.create<arith::ConstantOp>( 1423 loc, b.getI32IntegerAttr(scale[0])); 1424 value = b.create<arith::MulIOp>(loc, value, scaleY); 1425 } 1426 1427 if (isBilinear && scale[2] != 0) { 1428 Value scaleX = b.create<arith::ConstantOp>( 1429 loc, b.getI32IntegerAttr(scale[2])); 1430 value = b.create<arith::MulIOp>(loc, value, scaleX); 1431 } 1432 } 1433 1434 b.create<linalg::YieldOp>(loc, value); 1435 }); 1436 1437 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>( 1438 op, resultTy, generic.getResults()[0], reassociationMap); 1439 return success(); 1440 } 1441 }; 1442 1443 // TOSA resize with width or height of 1 may be broadcasted to a wider 1444 // dimension. This is done by materializing a new tosa.resize without 1445 // the broadcasting behavior, and an explicit broadcast afterwards. 1446 class MaterializeResizeBroadcast : public OpRewritePattern<tosa::ResizeOp> { 1447 public: 1448 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern; 1449 1450 LogicalResult matchAndRewrite(tosa::ResizeOp op, 1451 PatternRewriter &rewriter) const final { 1452 Location loc = op.getLoc(); 1453 ImplicitLocOpBuilder builder(loc, rewriter); 1454 auto input = op.getInput(); 1455 auto inputTy = dyn_cast<RankedTensorType>(input.getType()); 1456 auto resultTy = dyn_cast<RankedTensorType>(op.getType()); 1457 1458 if (!inputTy || !resultTy) 1459 return rewriter.notifyMatchFailure(op, 1460 "requires ranked input/output types"); 1461 1462 auto batch = inputTy.getDimSize(0); 1463 auto channels = inputTy.getDimSize(3); 1464 auto inputH = inputTy.getDimSize(1); 1465 auto inputW = inputTy.getDimSize(2); 1466 auto outputH = resultTy.getDimSize(1); 1467 auto outputW = resultTy.getDimSize(2); 1468 1469 if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1)) 1470 return rewriter.notifyMatchFailure( 1471 op, "tosa.resize has no broadcasting behavior"); 1472 1473 // For any dimension that is broadcastable we generate a width of 1 1474 // on the output. 1475 llvm::SmallVector<int64_t> resizeShape; 1476 resizeShape.push_back(batch); 1477 resizeShape.push_back(inputH == 1 ? 1 : outputH); 1478 resizeShape.push_back(inputW == 1 ? 1 : outputW); 1479 resizeShape.push_back(channels); 1480 1481 auto resizeTy = resultTy.clone(resizeShape); 1482 auto resize = 1483 builder.create<tosa::ResizeOp>(resizeTy, input, op->getAttrs()); 1484 1485 // Collapse an unit result dims. 1486 SmallVector<ReassociationExprs, 4> reassociationMap(2); 1487 reassociationMap[0].push_back(builder.getAffineDimExpr(0)); 1488 reassociationMap.back().push_back(builder.getAffineDimExpr(1)); 1489 if (inputH != 1) 1490 reassociationMap.push_back({}); 1491 reassociationMap.back().push_back(builder.getAffineDimExpr(2)); 1492 if (inputW != 1) 1493 reassociationMap.push_back({}); 1494 reassociationMap.back().push_back(builder.getAffineDimExpr(3)); 1495 1496 llvm::SmallVector<int64_t> collapseShape = {batch}; 1497 if (inputH != 1) 1498 collapseShape.push_back(outputH); 1499 if (inputW != 1) 1500 collapseShape.push_back(outputW); 1501 collapseShape.push_back(channels); 1502 1503 auto collapseTy = resultTy.clone(collapseShape); 1504 Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, resize, 1505 reassociationMap); 1506 1507 // Broadcast the collapsed shape to the output result. 1508 llvm::SmallVector<Value> outputDynSize; 1509 if (inputTy.isDynamicDim(0)) 1510 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0)); 1511 if (inputTy.isDynamicDim(3)) 1512 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3)); 1513 1514 SmallVector<utils::IteratorType> iterators(resultTy.getRank(), 1515 utils::IteratorType::parallel); 1516 Value empty = builder.create<tensor::EmptyOp>( 1517 resultTy.getShape(), resultTy.getElementType(), outputDynSize); 1518 1519 SmallVector<AffineExpr, 4> inputExprs{rewriter.getAffineDimExpr(0)}; 1520 if (inputH != 1) 1521 inputExprs.push_back(rewriter.getAffineDimExpr(1)); 1522 if (inputW != 1) 1523 inputExprs.push_back(rewriter.getAffineDimExpr(2)); 1524 inputExprs.push_back(rewriter.getAffineDimExpr(3)); 1525 1526 auto inputMap = AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, 1527 inputExprs, rewriter.getContext()); 1528 1529 auto outputMap = rewriter.getMultiDimIdentityMap(resultTy.getRank()); 1530 rewriter.replaceOpWithNewOp<linalg::GenericOp>( 1531 op, resultTy, ValueRange{collapse}, ValueRange{empty}, 1532 ArrayRef<AffineMap>{inputMap, outputMap}, iterators, 1533 [=](OpBuilder &b, Location loc, ValueRange args) { 1534 Value value = args[0]; 1535 b.create<linalg::YieldOp>(loc, value); 1536 }); 1537 1538 return success(); 1539 } 1540 }; 1541 1542 class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> { 1543 public: 1544 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern; 1545 1546 LogicalResult matchAndRewrite(tosa::ResizeOp op, 1547 PatternRewriter &rewriter) const final { 1548 Location loc = op.getLoc(); 1549 ImplicitLocOpBuilder b(loc, rewriter); 1550 auto input = op.getInput(); 1551 auto inputTy = cast<ShapedType>(input.getType()); 1552 auto resultTy = cast<ShapedType>(op.getType()); 1553 auto resultETy = resultTy.getElementType(); 1554 1555 bool floatingPointMode = resultETy.isF16() || resultETy.isF32(); 1556 auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type(); 1557 1558 auto imageH = inputTy.getShape()[1]; 1559 auto imageW = inputTy.getShape()[2]; 1560 1561 auto dynamicDimsOr = 1562 checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()}); 1563 if (!dynamicDimsOr.has_value()) 1564 return rewriter.notifyMatchFailure( 1565 op, "unable to get dynamic dimensions of tosa.resize"); 1566 1567 if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR") 1568 return rewriter.notifyMatchFailure( 1569 op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR"); 1570 1571 SmallVector<AffineMap, 2> affineMaps = { 1572 rewriter.getMultiDimIdentityMap(resultTy.getRank())}; 1573 auto emptyTensor = b.create<tensor::EmptyOp>(resultTy.getShape(), resultETy, 1574 *dynamicDimsOr); 1575 auto genericOp = b.create<linalg::GenericOp>( 1576 resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps, 1577 getNParallelLoopsAttrs(resultTy.getRank())); 1578 Value resize = genericOp.getResult(0); 1579 1580 { 1581 OpBuilder::InsertionGuard regionGuard(b); 1582 b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(), 1583 TypeRange({resultETy}), loc); 1584 Value batch = b.create<linalg::IndexOp>(0); 1585 Value y = b.create<linalg::IndexOp>(1); 1586 Value x = b.create<linalg::IndexOp>(2); 1587 Value channel = b.create<linalg::IndexOp>(3); 1588 1589 Value zeroI32 = 1590 b.create<arith::ConstantOp>(b.getZeroAttr(b.getI32Type())); 1591 Value zeroFp = b.create<arith::ConstantOp>(b.getZeroAttr(floatTy)); 1592 Value hMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageH - 1)); 1593 Value wMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageW - 1)); 1594 1595 Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y); 1596 Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x); 1597 1598 ArrayRef<int64_t> offset = op.getOffset(); 1599 ArrayRef<int64_t> border = op.getBorder(); 1600 ArrayRef<int64_t> scale = op.getScale(); 1601 1602 Value yScaleN, yScaleD, xScaleN, xScaleD; 1603 yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0])); 1604 yScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[1])); 1605 xScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[2])); 1606 xScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[3])); 1607 1608 Value yOffset, xOffset, yBorder, xBorder; 1609 yOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[0])); 1610 xOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[1])); 1611 yBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[0])); 1612 xBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[1])); 1613 1614 // Compute the ix and dx values for both the X and Y dimensions. 1615 auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in, 1616 Value scaleN, Value scaleD, Value offset, 1617 int size, ImplicitLocOpBuilder &b) { 1618 if (size == 1) { 1619 index = zeroI32; 1620 delta = zeroFp; 1621 return; 1622 } 1623 // x = x * scale_d + offset; 1624 // ix = floor(x / scale_n) 1625 Value val = b.create<arith::MulIOp>(in, scaleD); 1626 val = b.create<arith::AddIOp>(val, offset); 1627 index = b.create<arith::FloorDivSIOp>(val, scaleN); 1628 1629 // rx = x % scale_n 1630 // dx = rx / scale_n 1631 Value r = b.create<arith::RemSIOp>(val, scaleN); 1632 Value rFp = b.create<arith::SIToFPOp>(floatTy, r); 1633 Value scaleNfp = b.create<arith::UIToFPOp>(floatTy, scaleN); 1634 delta = b.create<arith::DivFOp>(rFp, scaleNfp); 1635 }; 1636 1637 // Compute the ix and dx values for the X and Y dimensions - int case. 1638 auto getIndexAndDeltaInt = [&](Value &index, Value &delta, Value in, 1639 Value scaleN, Value scaleD, Value offset, 1640 int size, ImplicitLocOpBuilder &b) { 1641 if (size == 1) { 1642 index = zeroI32; 1643 delta = zeroI32; 1644 return; 1645 } 1646 // x = x * scale_d + offset; 1647 // ix = floor(x / scale_n) 1648 // dx = x - ix * scale_n; 1649 Value val = b.create<arith::MulIOp>(in, scaleD); 1650 val = b.create<arith::AddIOp>(val, offset); 1651 index = b.create<arith::DivSIOp>(val, scaleN); 1652 delta = b.create<arith::MulIOp>(index, scaleN); 1653 delta = b.create<arith::SubIOp>(val, delta); 1654 }; 1655 1656 Value ix, iy, dx, dy; 1657 if (floatingPointMode) { 1658 getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b); 1659 getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b); 1660 } else { 1661 getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b); 1662 getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b); 1663 } 1664 1665 if (op.getMode() == "NEAREST_NEIGHBOR") { 1666 auto one = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1)); 1667 1668 auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale, 1669 Value max, int size, 1670 ImplicitLocOpBuilder &b) -> Value { 1671 if (size == 1) { 1672 return b.create<arith::ConstantIndexOp>(0); 1673 } 1674 1675 Value pred; 1676 if (floatingPointMode) { 1677 auto h = b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 0.5f)); 1678 pred = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, dval, h); 1679 } else { 1680 Value dvalDouble = b.create<arith::ShLIOp>(dval, one); 1681 pred = b.create<arith::CmpIOp>(arith::CmpIPredicate::sge, 1682 dvalDouble, scale); 1683 } 1684 1685 auto offset = b.create<arith::SelectOp>(pred, one, zeroI32); 1686 val = b.create<arith::AddIOp>(val, offset); 1687 val = clampIntHelper(loc, val, zeroI32, max, b, /*isUnsigned=*/false); 1688 return b.create<arith::IndexCastOp>(b.getIndexType(), val); 1689 }; 1690 1691 iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b); 1692 ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b); 1693 1694 Value result = b.create<tensor::ExtractOp>( 1695 input, ValueRange{batch, iy, ix, channel}); 1696 1697 b.create<linalg::YieldOp>(result); 1698 } else { 1699 // The mode here must be BILINEAR. 1700 assert(op.getMode() == "BILINEAR"); 1701 1702 auto oneVal = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1)); 1703 1704 auto getClampedIdxs = [&](Value &val0, Value &val1, int size, Value in, 1705 Value max, ImplicitLocOpBuilder &b) { 1706 val0 = in; 1707 val1 = b.create<arith::AddIOp>(val0, oneVal); 1708 val0 = 1709 clampIntHelper(loc, val0, zeroI32, max, b, /*isUnsigned=*/false); 1710 val1 = 1711 clampIntHelper(loc, val1, zeroI32, max, b, /*isUnsigned=*/false); 1712 val0 = b.create<arith::IndexCastOp>(b.getIndexType(), val0); 1713 val1 = b.create<arith::IndexCastOp>(b.getIndexType(), val1); 1714 }; 1715 1716 // Linalg equivalent to the section below: 1717 // int16_t iy0 = apply_max(iy, 0); 1718 // int16_t iy1 = apply_min(iy + 1, IH - 1); 1719 // int16_t ix0 = apply_max(ix, 0); 1720 // int16_t ix1 = apply_min(ix + 1, IW - 1); 1721 Value x0, x1, y0, y1; 1722 getClampedIdxs(y0, y1, imageH, iy, hMax, b); 1723 getClampedIdxs(x0, x1, imageW, ix, wMax, b); 1724 1725 Value y0x0 = b.create<tensor::ExtractOp>( 1726 input, ValueRange{batch, y0, x0, channel}); 1727 Value y0x1 = b.create<tensor::ExtractOp>( 1728 input, ValueRange{batch, y0, x1, channel}); 1729 Value y1x0 = b.create<tensor::ExtractOp>( 1730 input, ValueRange{batch, y1, x0, channel}); 1731 Value y1x1 = b.create<tensor::ExtractOp>( 1732 input, ValueRange{batch, y1, x1, channel}); 1733 1734 if (floatingPointMode) { 1735 auto oneVal = 1736 b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 1.0f)); 1737 auto interpolate = [&](Value val0, Value val1, Value delta, 1738 int inputSize, 1739 ImplicitLocOpBuilder &b) -> Value { 1740 if (inputSize == 1) 1741 return val0; 1742 Value oneMinusDelta = b.create<arith::SubFOp>(oneVal, delta); 1743 Value mul0 = b.create<arith::MulFOp>(val0, oneMinusDelta); 1744 Value mul1 = b.create<arith::MulFOp>(val1, delta); 1745 return b.create<arith::AddFOp>(mul0, mul1); 1746 }; 1747 1748 // Linalg equivalent to the section below: 1749 // topAcc = v00 * (unit_x - dx); 1750 // topAcc += v01 * dx; 1751 Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b); 1752 1753 // Linalg equivalent to the section below: 1754 // bottomAcc = v10 * (unit_x - dx); 1755 // bottomAcc += v11 * dx; 1756 Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b); 1757 1758 // Linalg equivalent to the section below: 1759 // result = topAcc * (unit_y - dy) + bottomAcc * dy 1760 Value result = interpolate(topAcc, bottomAcc, dy, imageH, b); 1761 b.create<linalg::YieldOp>(result); 1762 } else { 1763 // Perform in quantized space. 1764 y0x0 = b.create<arith::ExtSIOp>(resultETy, y0x0); 1765 y0x1 = b.create<arith::ExtSIOp>(resultETy, y0x1); 1766 y1x0 = b.create<arith::ExtSIOp>(resultETy, y1x0); 1767 y1x1 = b.create<arith::ExtSIOp>(resultETy, y1x1); 1768 1769 const int64_t deltaBitwidth = dx.getType().getIntOrFloatBitWidth(); 1770 if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) { 1771 dx = b.create<arith::ExtSIOp>(resultETy, dx); 1772 dy = b.create<arith::ExtSIOp>(resultETy, dy); 1773 } 1774 1775 Value yScaleNExt = yScaleN; 1776 Value xScaleNExt = xScaleN; 1777 1778 const int64_t scaleBitwidth = 1779 xScaleN.getType().getIntOrFloatBitWidth(); 1780 if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) { 1781 yScaleNExt = b.create<arith::ExtSIOp>(resultETy, yScaleN); 1782 xScaleNExt = b.create<arith::ExtSIOp>(resultETy, xScaleN); 1783 } 1784 1785 auto interpolate = [](Value val0, Value val1, Value weight1, 1786 Value scale, int inputSize, 1787 ImplicitLocOpBuilder &b) -> Value { 1788 if (inputSize == 1) 1789 return b.create<arith::MulIOp>(val0, scale); 1790 Value weight0 = b.create<arith::SubIOp>(scale, weight1); 1791 Value mul0 = b.create<arith::MulIOp>(val0, weight0); 1792 Value mul1 = b.create<arith::MulIOp>(val1, weight1); 1793 return b.create<arith::AddIOp>(mul0, mul1); 1794 }; 1795 1796 Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b); 1797 Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b); 1798 Value result = 1799 interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b); 1800 b.create<linalg::YieldOp>(result); 1801 } 1802 } 1803 } 1804 1805 rewriter.replaceOp(op, resize); 1806 return success(); 1807 } 1808 }; 1809 1810 // At the codegen level any identity operations should be removed. Any cases 1811 // where identity is load-bearing (e.g. cross device computation) should be 1812 // handled before lowering to codegen. 1813 template <typename SrcOp> 1814 class IdentityNConverter : public OpRewritePattern<SrcOp> { 1815 public: 1816 using OpRewritePattern<SrcOp>::OpRewritePattern; 1817 1818 LogicalResult matchAndRewrite(SrcOp op, 1819 PatternRewriter &rewriter) const final { 1820 rewriter.replaceOp(op, op.getOperation()->getOperands()); 1821 return success(); 1822 } 1823 }; 1824 1825 template <typename SrcOp> 1826 class ReduceConverter : public OpRewritePattern<SrcOp> { 1827 public: 1828 using OpRewritePattern<SrcOp>::OpRewritePattern; 1829 1830 LogicalResult matchAndRewrite(SrcOp reduceOp, 1831 PatternRewriter &rewriter) const final { 1832 return reduceMatchAndRewriteHelper(reduceOp, reduceOp.getAxis(), rewriter); 1833 } 1834 }; 1835 1836 class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> { 1837 public: 1838 using OpRewritePattern<tosa::ReverseOp>::OpRewritePattern; 1839 1840 LogicalResult matchAndRewrite(tosa::ReverseOp op, 1841 PatternRewriter &rewriter) const final { 1842 auto loc = op.getLoc(); 1843 Value input = op.getInput1(); 1844 auto inputTy = cast<ShapedType>(input.getType()); 1845 auto resultTy = cast<ShapedType>(op.getType()); 1846 auto axis = op.getAxis(); 1847 1848 SmallVector<Value> dynDims; 1849 for (int i = 0; i < inputTy.getRank(); i++) { 1850 if (inputTy.isDynamicDim(i)) { 1851 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i)); 1852 } 1853 } 1854 1855 Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis); 1856 1857 // First fill the output buffer with the init value. 1858 auto emptyTensor = rewriter 1859 .create<tensor::EmptyOp>(loc, inputTy.getShape(), 1860 inputTy.getElementType(), 1861 ArrayRef<Value>({dynDims})) 1862 .getResult(); 1863 SmallVector<AffineMap, 2> affineMaps = { 1864 rewriter.getMultiDimIdentityMap(resultTy.getRank())}; 1865 1866 rewriter.replaceOpWithNewOp<linalg::GenericOp>( 1867 op, resultTy, ArrayRef<Value>({}), ValueRange{emptyTensor}, affineMaps, 1868 getNParallelLoopsAttrs(resultTy.getRank()), 1869 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { 1870 llvm::SmallVector<Value> indices; 1871 for (unsigned int i = 0; i < inputTy.getRank(); i++) { 1872 Value index = 1873 rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult(); 1874 if (i == axis) { 1875 auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1); 1876 auto sizeMinusOne = 1877 rewriter.create<arith::SubIOp>(nestedLoc, axisDimSize, one); 1878 index = rewriter.create<arith::SubIOp>(nestedLoc, sizeMinusOne, 1879 index); 1880 } 1881 1882 indices.push_back(index); 1883 } 1884 1885 auto extract = nestedBuilder.create<tensor::ExtractOp>( 1886 nestedLoc, input, indices); 1887 nestedBuilder.create<linalg::YieldOp>(op.getLoc(), 1888 extract.getResult()); 1889 }); 1890 return success(); 1891 } 1892 }; 1893 1894 // This converter translate a tile operation to a reshape, broadcast, reshape. 1895 // The first reshape minimally expands each tiled dimension to include a 1896 // proceding size-1 dim. This dim is then broadcasted to the appropriate 1897 // multiple. 1898 struct TileConverter : public OpConversionPattern<tosa::TileOp> { 1899 using OpConversionPattern<tosa::TileOp>::OpConversionPattern; 1900 1901 LogicalResult 1902 matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor, 1903 ConversionPatternRewriter &rewriter) const override { 1904 auto loc = op.getLoc(); 1905 auto input = op.getInput1(); 1906 auto inputTy = cast<ShapedType>(input.getType()); 1907 auto inputShape = inputTy.getShape(); 1908 auto resultTy = cast<ShapedType>(op.getType()); 1909 auto elementTy = inputTy.getElementType(); 1910 int64_t rank = inputTy.getRank(); 1911 1912 SmallVector<int64_t> multiples; 1913 if (failed(op.getConstantMultiples(multiples))) 1914 return failure(); 1915 1916 // Broadcast the newly added dimensions to their appropriate multiple. 1917 SmallVector<int64_t, 2> genericShape; 1918 for (int i = 0; i < rank; i++) { 1919 int64_t dim = multiples[i]; 1920 genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim); 1921 genericShape.push_back(inputShape[i]); 1922 } 1923 1924 SmallVector<Value> dynDims; 1925 for (int i = 0; i < inputTy.getRank(); i++) { 1926 if (inputTy.isDynamicDim(i) || multiples[i] == -1) { 1927 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i)); 1928 } 1929 } 1930 1931 auto emptyTensor = rewriter.create<tensor::EmptyOp>( 1932 op.getLoc(), genericShape, elementTy, dynDims); 1933 1934 // We needs to map the input shape to the non-broadcasted dimensions. 1935 SmallVector<AffineExpr, 4> dimExprs; 1936 dimExprs.reserve(rank); 1937 for (unsigned i = 0; i < rank; ++i) 1938 dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1)); 1939 1940 auto readAffineMap = 1941 AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, dimExprs, 1942 rewriter.getContext()); 1943 1944 SmallVector<AffineMap, 2> affineMaps = { 1945 readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())}; 1946 1947 auto genericOp = rewriter.create<linalg::GenericOp>( 1948 loc, RankedTensorType::get(genericShape, elementTy), input, 1949 ValueRange{emptyTensor}, affineMaps, 1950 getNParallelLoopsAttrs(genericShape.size()), 1951 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { 1952 nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin()); 1953 }); 1954 1955 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>( 1956 op, resultTy, genericOp.getResult(0), 1957 rewriter.getDenseI64ArrayAttr(resultTy.getShape())); 1958 return success(); 1959 } 1960 }; 1961 1962 // Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic 1963 // op, producing two output buffers. 1964 // 1965 // The first output buffer contains the index of the found maximum value. It is 1966 // initialized to 0 and is resulting integer type. 1967 // 1968 // The second output buffer contains the maximum value found. It is initialized 1969 // to the minimum representable value of the input element type. After being 1970 // populated by indexed_generic, this buffer is disgarded as only the index is 1971 // requested. 1972 // 1973 // The indexed_generic op updates both the maximum value and index if the 1974 // current value exceeds the running max. 1975 class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> { 1976 public: 1977 using OpRewritePattern<tosa::ArgMaxOp>::OpRewritePattern; 1978 1979 LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp, 1980 PatternRewriter &rewriter) const final { 1981 auto loc = argmaxOp.getLoc(); 1982 Value input = argmaxOp.getInput(); 1983 auto inputTy = cast<ShapedType>(input.getType()); 1984 auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType()); 1985 auto inElementTy = inputTy.getElementType(); 1986 auto outElementTy = resultTy.getElementType(); 1987 int axis = argmaxOp.getAxis(); 1988 auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy); 1989 1990 if (!isa<IntegerType>(outElementTy)) 1991 return rewriter.notifyMatchFailure( 1992 argmaxOp, 1993 "tosa.arg_max to linalg.* requires integer-like result type"); 1994 1995 SmallVector<Value> dynDims; 1996 for (int i = 0; i < inputTy.getRank(); i++) { 1997 if (inputTy.isDynamicDim(i) && i != axis) { 1998 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i)); 1999 } 2000 } 2001 2002 // First fill the output buffer for the index. 2003 auto emptyTensorIdx = rewriter 2004 .create<tensor::EmptyOp>(loc, resultTy.getShape(), 2005 outElementTy, dynDims) 2006 .getResult(); 2007 auto fillValueIdx = rewriter.create<arith::ConstantOp>( 2008 loc, rewriter.getIntegerAttr(outElementTy, 0)); 2009 auto filledTensorIdx = 2010 rewriter 2011 .create<linalg::FillOp>(loc, ValueRange{fillValueIdx}, 2012 ValueRange{emptyTensorIdx}) 2013 .result(); 2014 2015 // Second fill the output buffer for the running max. 2016 auto emptyTensorMax = rewriter 2017 .create<tensor::EmptyOp>(loc, resultTy.getShape(), 2018 inElementTy, dynDims) 2019 .getResult(); 2020 auto fillValueMaxAttr = 2021 createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter); 2022 2023 if (!fillValueMaxAttr) 2024 return rewriter.notifyMatchFailure( 2025 argmaxOp, "unsupported tosa.argmax element type"); 2026 2027 auto fillValueMax = 2028 rewriter.create<arith::ConstantOp>(loc, fillValueMaxAttr); 2029 auto filledTensorMax = 2030 rewriter 2031 .create<linalg::FillOp>(loc, ValueRange{fillValueMax}, 2032 ValueRange{emptyTensorMax}) 2033 .result(); 2034 2035 // We need to reduce along the arg-max axis, with parallel operations along 2036 // the rest. 2037 SmallVector<utils::IteratorType, 4> iteratorTypes; 2038 iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel); 2039 iteratorTypes[axis] = utils::IteratorType::reduction; 2040 2041 SmallVector<AffineExpr, 2> srcExprs; 2042 SmallVector<AffineExpr, 2> dstExprs; 2043 for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) { 2044 srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext())); 2045 if (axis != i) 2046 dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext())); 2047 } 2048 2049 bool didEncounterError = false; 2050 auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs}, 2051 rewriter.getContext()); 2052 auto linalgOp = rewriter.create<linalg::GenericOp>( 2053 loc, ArrayRef<Type>({resultTy, resultMaxTy}), input, 2054 ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes, 2055 [&](OpBuilder &nestedBuilder, Location nestedLoc, 2056 ValueRange blockArgs) { 2057 auto newValue = blockArgs[0]; 2058 auto oldIndex = blockArgs[1]; 2059 auto oldValue = blockArgs[2]; 2060 2061 Value newIndex = rewriter.create<arith::IndexCastOp>( 2062 nestedLoc, oldIndex.getType(), 2063 rewriter.create<linalg::IndexOp>(loc, axis)); 2064 2065 Value predicate; 2066 if (isa<FloatType>(inElementTy)) { 2067 predicate = rewriter.create<arith::CmpFOp>( 2068 nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue); 2069 } else if (isa<IntegerType>(inElementTy)) { 2070 predicate = rewriter.create<arith::CmpIOp>( 2071 nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue); 2072 } else { 2073 didEncounterError = true; 2074 return; 2075 } 2076 2077 auto resultMax = rewriter.create<arith::SelectOp>( 2078 nestedLoc, predicate, newValue, oldValue); 2079 auto resultIndex = rewriter.create<arith::SelectOp>( 2080 nestedLoc, predicate, newIndex, oldIndex); 2081 nestedBuilder.create<linalg::YieldOp>( 2082 nestedLoc, ValueRange({resultIndex, resultMax})); 2083 }); 2084 2085 if (didEncounterError) 2086 return rewriter.notifyMatchFailure( 2087 argmaxOp, "unsupported tosa.argmax element type"); 2088 2089 rewriter.replaceOp(argmaxOp, linalgOp.getResult(0)); 2090 return success(); 2091 } 2092 }; 2093 2094 class GatherConverter : public OpConversionPattern<tosa::GatherOp> { 2095 public: 2096 using OpConversionPattern<tosa::GatherOp>::OpConversionPattern; 2097 LogicalResult 2098 matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor, 2099 ConversionPatternRewriter &rewriter) const final { 2100 auto input = adaptor.getOperands()[0]; 2101 auto indices = adaptor.getOperands()[1]; 2102 2103 auto valuesTy = 2104 dyn_cast_or_null<RankedTensorType>(op.getValues().getType()); 2105 auto resultTy = cast<ShapedType>(op.getType()); 2106 2107 if (!valuesTy) 2108 return rewriter.notifyMatchFailure(op, "unranked tensors not supported"); 2109 2110 auto dynamicDims = inferDynamicDimsForGather( 2111 rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices()); 2112 2113 auto resultElementTy = resultTy.getElementType(); 2114 2115 auto loc = op.getLoc(); 2116 auto emptyTensor = 2117 rewriter 2118 .create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy, 2119 dynamicDims) 2120 .getResult(); 2121 2122 SmallVector<AffineMap, 2> affineMaps = { 2123 AffineMap::get( 2124 /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0, 2125 {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)}, 2126 rewriter.getContext()), 2127 rewriter.getMultiDimIdentityMap(resultTy.getRank())}; 2128 2129 auto genericOp = rewriter.create<linalg::GenericOp>( 2130 loc, ArrayRef<Type>({resultTy}), ValueRange{indices}, 2131 ValueRange{emptyTensor}, affineMaps, 2132 getNParallelLoopsAttrs(resultTy.getRank()), 2133 [&](OpBuilder &b, Location loc, ValueRange args) { 2134 auto indexValue = args[0]; 2135 auto index0 = rewriter.create<linalg::IndexOp>(loc, 0); 2136 Value index1 = rewriter.create<arith::IndexCastOp>( 2137 loc, rewriter.getIndexType(), indexValue); 2138 auto index2 = rewriter.create<linalg::IndexOp>(loc, 2); 2139 Value extract = rewriter.create<tensor::ExtractOp>( 2140 loc, input, ValueRange{index0, index1, index2}); 2141 rewriter.create<linalg::YieldOp>(loc, extract); 2142 }); 2143 rewriter.replaceOp(op, genericOp.getResult(0)); 2144 return success(); 2145 } 2146 2147 static llvm::SmallVector<Value> inferDynamicDimsForGather(OpBuilder &builder, 2148 Location loc, 2149 Value values, 2150 Value indices) { 2151 llvm::SmallVector<Value> results; 2152 2153 auto addDynamicDimension = [&](Value source, int64_t dim) { 2154 auto sz = tensor::getMixedSize(builder, loc, source, dim); 2155 if (auto dimValue = llvm::dyn_cast_if_present<Value>(sz)) 2156 results.push_back(dimValue); 2157 }; 2158 2159 addDynamicDimension(values, 0); 2160 addDynamicDimension(indices, 1); 2161 addDynamicDimension(values, 2); 2162 return results; 2163 } 2164 }; 2165 2166 // Lowerings the TableOp to a series of gathers and numerica operations. This 2167 // includes interpolation between the high/low values. For the I8 varient, this 2168 // simplifies to a single gather operation. 2169 class TableConverter : public OpRewritePattern<tosa::TableOp> { 2170 public: 2171 using OpRewritePattern<tosa::TableOp>::OpRewritePattern; 2172 2173 LogicalResult matchAndRewrite(tosa::TableOp op, 2174 PatternRewriter &rewriter) const final { 2175 auto loc = op.getLoc(); 2176 Value input = op.getInput1(); 2177 Value table = op.getTable(); 2178 auto inputTy = cast<ShapedType>(input.getType()); 2179 auto tableTy = cast<ShapedType>(table.getType()); 2180 auto resultTy = cast<ShapedType>(op.getType()); 2181 2182 auto inputElementTy = inputTy.getElementType(); 2183 auto tableElementTy = tableTy.getElementType(); 2184 auto resultElementTy = resultTy.getElementType(); 2185 2186 SmallVector<Value> dynDims; 2187 for (int i = 0; i < resultTy.getRank(); ++i) { 2188 if (inputTy.isDynamicDim(i)) { 2189 dynDims.push_back( 2190 rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i)); 2191 } 2192 } 2193 2194 auto emptyTensor = rewriter 2195 .create<tensor::EmptyOp>(loc, resultTy.getShape(), 2196 resultElementTy, dynDims) 2197 .getResult(); 2198 2199 SmallVector<AffineMap, 2> affineMaps = { 2200 rewriter.getMultiDimIdentityMap(resultTy.getRank()), 2201 rewriter.getMultiDimIdentityMap(resultTy.getRank())}; 2202 2203 auto genericOp = rewriter.create<linalg::GenericOp>( 2204 loc, resultTy, ValueRange({input}), ValueRange{emptyTensor}, affineMaps, 2205 getNParallelLoopsAttrs(resultTy.getRank())); 2206 rewriter.replaceOp(op, genericOp.getResult(0)); 2207 2208 { 2209 OpBuilder::InsertionGuard regionGuard(rewriter); 2210 Block *block = rewriter.createBlock( 2211 &genericOp.getRegion(), genericOp.getRegion().end(), 2212 TypeRange({inputElementTy, resultElementTy}), {loc, loc}); 2213 2214 auto inputValue = block->getArgument(0); 2215 rewriter.setInsertionPointToStart(block); 2216 if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) && 2217 resultElementTy.isInteger(8)) { 2218 Value index = rewriter.create<arith::IndexCastOp>( 2219 loc, rewriter.getIndexType(), inputValue); 2220 Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 128); 2221 index = rewriter.create<arith::AddIOp>(loc, rewriter.getIndexType(), 2222 index, offset); 2223 Value extract = 2224 rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index}); 2225 rewriter.create<linalg::YieldOp>(loc, extract); 2226 return success(); 2227 } 2228 2229 if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) && 2230 resultElementTy.isInteger(32)) { 2231 Value extend = rewriter.create<arith::ExtSIOp>( 2232 loc, rewriter.getI32Type(), inputValue); 2233 2234 auto offset = rewriter.create<arith::ConstantOp>( 2235 loc, rewriter.getI32IntegerAttr(32768)); 2236 auto seven = rewriter.create<arith::ConstantOp>( 2237 loc, rewriter.getI32IntegerAttr(7)); 2238 auto one = rewriter.create<arith::ConstantOp>( 2239 loc, rewriter.getI32IntegerAttr(1)); 2240 auto b1111111 = rewriter.create<arith::ConstantOp>( 2241 loc, rewriter.getI32IntegerAttr(127)); 2242 2243 // Compute the index and fractional part from the input value: 2244 // value = value + 32768 2245 // index = value >> 7; 2246 // fraction = 0x01111111 & value 2247 auto extendAdd = rewriter.create<arith::AddIOp>(loc, extend, offset); 2248 Value index = rewriter.create<arith::ShRUIOp>(loc, extendAdd, seven); 2249 Value fraction = 2250 rewriter.create<arith::AndIOp>(loc, extendAdd, b1111111); 2251 2252 // Extract the base and next values from the table. 2253 // base = (int32_t) table[index]; 2254 // next = (int32_t) table[index + 1]; 2255 Value indexPlusOne = rewriter.create<arith::AddIOp>(loc, index, one); 2256 2257 index = rewriter.create<arith::IndexCastOp>( 2258 loc, rewriter.getIndexType(), index); 2259 indexPlusOne = rewriter.create<arith::IndexCastOp>( 2260 loc, rewriter.getIndexType(), indexPlusOne); 2261 2262 Value base = 2263 rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index}); 2264 Value next = rewriter.create<tensor::ExtractOp>( 2265 loc, table, ValueRange{indexPlusOne}); 2266 2267 base = 2268 rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), base); 2269 next = 2270 rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), next); 2271 2272 // Use the fractional part to interpolate between the input values: 2273 // result = (base << 7) + (next - base) * fraction 2274 Value baseScaled = rewriter.create<arith::ShLIOp>(loc, base, seven); 2275 Value diff = rewriter.create<arith::SubIOp>(loc, next, base); 2276 Value diffScaled = rewriter.create<arith::MulIOp>(loc, diff, fraction); 2277 Value result = 2278 rewriter.create<arith::AddIOp>(loc, baseScaled, diffScaled); 2279 2280 rewriter.create<linalg::YieldOp>(loc, result); 2281 2282 return success(); 2283 } 2284 } 2285 2286 return rewriter.notifyMatchFailure( 2287 op, "unable to create body for tosa.table op"); 2288 } 2289 }; 2290 2291 struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> { 2292 using OpRewritePattern<RFFT2dOp>::OpRewritePattern; 2293 2294 static bool isRankedTensor(Type type) { return isa<RankedTensorType>(type); } 2295 2296 static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc, 2297 OpFoldResult ofr) { 2298 auto one = builder.create<arith::ConstantIndexOp>(loc, 1); 2299 auto two = builder.create<arith::ConstantIndexOp>(loc, 2); 2300 2301 auto value = getValueOrCreateConstantIndexOp(builder, loc, ofr); 2302 auto divBy2 = builder.createOrFold<arith::DivUIOp>(loc, value, two); 2303 auto plusOne = builder.createOrFold<arith::AddIOp>(loc, divBy2, one); 2304 return getAsOpFoldResult(plusOne); 2305 } 2306 2307 static RankedTensorType 2308 computeOutputShape(OpBuilder &builder, Location loc, Value input, 2309 llvm::SmallVectorImpl<Value> &dynamicSizes) { 2310 // Get [N, H, W] 2311 auto dims = tensor::getMixedSizes(builder, loc, input); 2312 2313 // Set W = (W / 2) + 1 to account for the half-sized W dimension of the 2314 // output tensors. 2315 dims[2] = halfPlusOne(builder, loc, dims[2]); 2316 2317 llvm::SmallVector<int64_t, 3> staticSizes; 2318 dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes); 2319 2320 auto elementType = cast<RankedTensorType>(input.getType()).getElementType(); 2321 return RankedTensorType::get(staticSizes, elementType); 2322 } 2323 2324 static Value createZeroTensor(PatternRewriter &rewriter, Location loc, 2325 RankedTensorType type, 2326 llvm::ArrayRef<Value> dynamicSizes) { 2327 auto emptyTensor = 2328 rewriter.create<tensor::EmptyOp>(loc, type, dynamicSizes); 2329 auto fillValueAttr = rewriter.getZeroAttr(type.getElementType()); 2330 auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr); 2331 auto filledTensor = rewriter 2332 .create<linalg::FillOp>(loc, ValueRange{fillValue}, 2333 ValueRange{emptyTensor}) 2334 .result(); 2335 return filledTensor; 2336 } 2337 2338 static Value castIndexToFloat(OpBuilder &builder, Location loc, 2339 FloatType type, Value value) { 2340 auto integerVal = builder.create<arith::IndexCastUIOp>( 2341 loc, 2342 type.getIntOrFloatBitWidth() > 32 ? builder.getI64Type() 2343 : builder.getI32Type(), 2344 value); 2345 2346 return builder.create<arith::UIToFPOp>(loc, type, integerVal); 2347 } 2348 2349 static Value createLinalgIndex(OpBuilder &builder, Location loc, 2350 FloatType type, int64_t index) { 2351 auto indexVal = builder.create<linalg::IndexOp>(loc, index); 2352 return castIndexToFloat(builder, loc, type, indexVal); 2353 } 2354 2355 template <typename... Args> 2356 static llvm::SmallVector<AffineExpr, 4> affineDimsExpr(OpBuilder &builder, 2357 Args... args) { 2358 return {builder.getAffineDimExpr(args)...}; 2359 } 2360 2361 LogicalResult matchAndRewrite(RFFT2dOp rfft2d, 2362 PatternRewriter &rewriter) const override { 2363 if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) || 2364 !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) { 2365 return rewriter.notifyMatchFailure(rfft2d, 2366 "only supports ranked tensors"); 2367 } 2368 2369 auto loc = rfft2d.getLoc(); 2370 auto input = rfft2d.getInput(); 2371 auto elementType = 2372 dyn_cast<FloatType>(cast<ShapedType>(input.getType()).getElementType()); 2373 if (!elementType) 2374 return rewriter.notifyMatchFailure(rfft2d, 2375 "only supports float element types"); 2376 2377 // Compute the output type and set of dynamic sizes 2378 llvm::SmallVector<Value> dynamicSizes; 2379 auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes); 2380 2381 // Iterator types for the linalg.generic implementation 2382 llvm::SmallVector<utils::IteratorType, 5> iteratorTypes = { 2383 utils::IteratorType::parallel, utils::IteratorType::parallel, 2384 utils::IteratorType::parallel, utils::IteratorType::reduction, 2385 utils::IteratorType::reduction}; 2386 2387 // Inputs/outputs to the linalg.generic implementation 2388 llvm::SmallVector<Value> genericOpInputs = {input}; 2389 llvm::SmallVector<Value> genericOpOutputs = { 2390 createZeroTensor(rewriter, loc, outputType, dynamicSizes), 2391 createZeroTensor(rewriter, loc, outputType, dynamicSizes)}; 2392 2393 // Indexing maps for input and output tensors 2394 auto indexingMaps = AffineMap::inferFromExprList( 2395 llvm::ArrayRef{affineDimsExpr(rewriter, 0, 3, 4), 2396 affineDimsExpr(rewriter, 0, 1, 2), 2397 affineDimsExpr(rewriter, 0, 1, 2)}, 2398 rewriter.getContext()); 2399 2400 // Width and height dimensions of the original input. 2401 auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input, 1); 2402 auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input, 2); 2403 2404 // Constants and dimension sizes 2405 auto twoPiAttr = rewriter.getFloatAttr(elementType, 6.283185307179586); 2406 auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr); 2407 auto constH = castIndexToFloat(rewriter, loc, elementType, dimH); 2408 auto constW = castIndexToFloat(rewriter, loc, elementType, dimW); 2409 2410 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) { 2411 Value valReal = args[0]; 2412 Value sumReal = args[1]; 2413 Value sumImag = args[2]; 2414 2415 // Indices for angle computation 2416 Value oy = builder.create<linalg::IndexOp>(loc, 1); 2417 Value ox = builder.create<linalg::IndexOp>(loc, 2); 2418 Value iy = builder.create<linalg::IndexOp>(loc, 3); 2419 Value ix = builder.create<linalg::IndexOp>(loc, 4); 2420 2421 // Calculating angle without integer parts of components as sin/cos are 2422 // periodic: angle = 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * ox) % W ) 2423 // / W); 2424 auto iyXoy = builder.create<index::MulOp>(loc, iy, oy); 2425 auto ixXox = builder.create<index::MulOp>(loc, ix, ox); 2426 2427 auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH); 2428 auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW); 2429 2430 auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem); 2431 auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem); 2432 2433 auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH); 2434 auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW); 2435 auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent); 2436 auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY); 2437 2438 // realComponent = valReal * cos(angle) 2439 // imagComponent = valReal * sin(angle) 2440 auto cosAngle = builder.create<math::CosOp>(loc, angle); 2441 auto sinAngle = builder.create<math::SinOp>(loc, angle); 2442 auto realComponent = 2443 builder.create<arith::MulFOp>(loc, valReal, cosAngle); 2444 auto imagComponent = 2445 builder.create<arith::MulFOp>(loc, valReal, sinAngle); 2446 2447 // outReal = sumReal + realComponent 2448 // outImag = sumImag - imagComponent 2449 auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent); 2450 auto outImag = builder.create<arith::SubFOp>(loc, sumImag, imagComponent); 2451 2452 builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag}); 2453 }; 2454 2455 rewriter.replaceOpWithNewOp<linalg::GenericOp>( 2456 rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs, 2457 indexingMaps, iteratorTypes, buildBody); 2458 2459 return success(); 2460 } 2461 }; 2462 2463 struct FFT2dConverter final : OpRewritePattern<FFT2dOp> { 2464 using OpRewritePattern::OpRewritePattern; 2465 2466 LogicalResult matchAndRewrite(FFT2dOp fft2d, 2467 PatternRewriter &rewriter) const override { 2468 if (!llvm::all_of(fft2d->getOperandTypes(), 2469 RFFT2dConverter::isRankedTensor) || 2470 !llvm::all_of(fft2d->getResultTypes(), 2471 RFFT2dConverter::isRankedTensor)) { 2472 return rewriter.notifyMatchFailure(fft2d, "only supports ranked tensors"); 2473 } 2474 2475 Location loc = fft2d.getLoc(); 2476 Value input_real = fft2d.getInputReal(); 2477 Value input_imag = fft2d.getInputImag(); 2478 BoolAttr inverse = fft2d.getInverseAttr(); 2479 2480 auto real_el_ty = cast<FloatType>( 2481 cast<ShapedType>(input_real.getType()).getElementType()); 2482 [[maybe_unused]] auto imag_el_ty = cast<FloatType>( 2483 cast<ShapedType>(input_imag.getType()).getElementType()); 2484 2485 assert(real_el_ty == imag_el_ty); 2486 2487 // Compute the output type and set of dynamic sizes 2488 SmallVector<Value> dynamicSizes; 2489 2490 // Get [N, H, W] 2491 auto dims = tensor::getMixedSizes(rewriter, loc, input_real); 2492 2493 SmallVector<int64_t, 3> staticSizes; 2494 dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes); 2495 2496 auto outputType = RankedTensorType::get(staticSizes, real_el_ty); 2497 2498 // Iterator types for the linalg.generic implementation 2499 SmallVector<utils::IteratorType, 5> iteratorTypes = { 2500 utils::IteratorType::parallel, utils::IteratorType::parallel, 2501 utils::IteratorType::parallel, utils::IteratorType::reduction, 2502 utils::IteratorType::reduction}; 2503 2504 // Inputs/outputs to the linalg.generic implementation 2505 SmallVector<Value> genericOpInputs = {input_real, input_imag}; 2506 SmallVector<Value> genericOpOutputs = { 2507 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType, 2508 dynamicSizes), 2509 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType, 2510 dynamicSizes)}; 2511 2512 // Indexing maps for input and output tensors 2513 auto indexingMaps = AffineMap::inferFromExprList( 2514 ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4), 2515 RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4), 2516 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2), 2517 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)}, 2518 rewriter.getContext()); 2519 2520 // Width and height dimensions of the original input. 2521 auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 1); 2522 auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 2); 2523 2524 // Constants and dimension sizes 2525 auto twoPiAttr = rewriter.getFloatAttr(real_el_ty, 6.283185307179586); 2526 auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr); 2527 Value constH = 2528 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH); 2529 Value constW = 2530 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW); 2531 2532 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) { 2533 Value valReal = args[0]; 2534 Value valImag = args[1]; 2535 Value sumReal = args[2]; 2536 Value sumImag = args[3]; 2537 2538 // Indices for angle computation 2539 Value oy = builder.create<linalg::IndexOp>(loc, 1); 2540 Value ox = builder.create<linalg::IndexOp>(loc, 2); 2541 Value iy = builder.create<linalg::IndexOp>(loc, 3); 2542 Value ix = builder.create<linalg::IndexOp>(loc, 4); 2543 2544 // float_t angle = sign_val * 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * 2545 // ox) % W ) / W); 2546 auto iyXoy = builder.create<index::MulOp>(loc, iy, oy); 2547 auto ixXox = builder.create<index::MulOp>(loc, ix, ox); 2548 2549 auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH); 2550 auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW); 2551 2552 auto iyRemFloat = 2553 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem); 2554 auto ixRemFloat = 2555 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem); 2556 2557 auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH); 2558 auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW); 2559 2560 auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent); 2561 auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY); 2562 2563 if (inverse.getValue()) { 2564 angle = builder.create<arith::MulFOp>( 2565 loc, angle, 2566 rewriter.create<arith::ConstantOp>( 2567 loc, rewriter.getFloatAttr(real_el_ty, -1.0))); 2568 } 2569 2570 // realComponent = val_real * cos(a) + val_imag * sin(a); 2571 // imagComponent = -val_real * sin(a) + val_imag * cos(a); 2572 auto cosAngle = builder.create<math::CosOp>(loc, angle); 2573 auto sinAngle = builder.create<math::SinOp>(loc, angle); 2574 2575 auto rcos = builder.create<arith::MulFOp>(loc, valReal, cosAngle); 2576 auto rsin = builder.create<arith::MulFOp>(loc, valImag, sinAngle); 2577 auto realComponent = builder.create<arith::AddFOp>(loc, rcos, rsin); 2578 2579 auto icos = builder.create<arith::MulFOp>(loc, valImag, cosAngle); 2580 auto isin = builder.create<arith::MulFOp>(loc, valReal, sinAngle); 2581 2582 auto imagComponent = builder.create<arith::SubFOp>(loc, icos, isin); 2583 2584 // outReal = sumReal + realComponent 2585 // outImag = sumImag - imagComponent 2586 auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent); 2587 auto outImag = builder.create<arith::AddFOp>(loc, sumImag, imagComponent); 2588 2589 builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag}); 2590 }; 2591 2592 rewriter.replaceOpWithNewOp<linalg::GenericOp>( 2593 fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs, 2594 indexingMaps, iteratorTypes, buildBody); 2595 2596 return success(); 2597 } 2598 }; 2599 2600 } // namespace 2601 2602 void mlir::tosa::populateTosaToLinalgConversionPatterns( 2603 const TypeConverter &converter, RewritePatternSet *patterns) { 2604 2605 // We have multiple resize coverters to handle degenerate cases. 2606 patterns->add<GenericResizeConverter>(patterns->getContext(), 2607 /*benefit=*/100); 2608 patterns->add<ResizeUnaryConverter>(patterns->getContext(), 2609 /*benefit=*/200); 2610 patterns->add<MaterializeResizeBroadcast>(patterns->getContext(), 2611 /*benefit=*/300); 2612 2613 patterns->add< 2614 // clang-format off 2615 PointwiseConverter<tosa::AddOp>, 2616 PointwiseConverter<tosa::SubOp>, 2617 PointwiseConverter<tosa::MulOp>, 2618 PointwiseConverter<tosa::IntDivOp>, 2619 PointwiseConverter<tosa::NegateOp>, 2620 PointwiseConverter<tosa::PowOp>, 2621 PointwiseConverter<tosa::ReciprocalOp>, 2622 PointwiseConverter<tosa::RsqrtOp>, 2623 PointwiseConverter<tosa::LogOp>, 2624 PointwiseConverter<tosa::ExpOp>, 2625 PointwiseConverter<tosa::AbsOp>, 2626 PointwiseConverter<tosa::SinOp>, 2627 PointwiseConverter<tosa::CosOp>, 2628 PointwiseConverter<tosa::TanhOp>, 2629 PointwiseConverter<tosa::ErfOp>, 2630 PointwiseConverter<tosa::BitwiseAndOp>, 2631 PointwiseConverter<tosa::BitwiseOrOp>, 2632 PointwiseConverter<tosa::BitwiseNotOp>, 2633 PointwiseConverter<tosa::BitwiseXorOp>, 2634 PointwiseConverter<tosa::LogicalAndOp>, 2635 PointwiseConverter<tosa::LogicalNotOp>, 2636 PointwiseConverter<tosa::LogicalOrOp>, 2637 PointwiseConverter<tosa::LogicalXorOp>, 2638 PointwiseConverter<tosa::CastOp>, 2639 PointwiseConverter<tosa::LogicalLeftShiftOp>, 2640 PointwiseConverter<tosa::LogicalRightShiftOp>, 2641 PointwiseConverter<tosa::ArithmeticRightShiftOp>, 2642 PointwiseConverter<tosa::ClzOp>, 2643 PointwiseConverter<tosa::SelectOp>, 2644 PointwiseConverter<tosa::GreaterOp>, 2645 PointwiseConverter<tosa::GreaterEqualOp>, 2646 PointwiseConverter<tosa::EqualOp>, 2647 PointwiseConverter<tosa::MaximumOp>, 2648 PointwiseConverter<tosa::MinimumOp>, 2649 PointwiseConverter<tosa::CeilOp>, 2650 PointwiseConverter<tosa::FloorOp>, 2651 PointwiseConverter<tosa::ClampOp>, 2652 PointwiseConverter<tosa::SigmoidOp> 2653 >(converter, patterns->getContext()); 2654 2655 patterns->add< 2656 IdentityNConverter<tosa::IdentityOp>, 2657 ReduceConverter<tosa::ReduceAllOp>, 2658 ReduceConverter<tosa::ReduceAnyOp>, 2659 ReduceConverter<tosa::ReduceMinOp>, 2660 ReduceConverter<tosa::ReduceMaxOp>, 2661 ReduceConverter<tosa::ReduceSumOp>, 2662 ReduceConverter<tosa::ReduceProdOp>, 2663 ArgMaxConverter, 2664 GatherConverter, 2665 RescaleConverter, 2666 ReverseConverter, 2667 RFFT2dConverter, 2668 FFT2dConverter, 2669 TableConverter, 2670 TileConverter>(patterns->getContext()); 2671 // clang-format on 2672 } 2673