1 //===- TosaToTensor.cpp - Lowering Tosa to Tensor Dialect -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // These rewriters lower from the Tosa to the Tensor dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/TosaToTensor/TosaToTensor.h" 14 #include "mlir/Dialect/Arith/IR/Arith.h" 15 #include "mlir/Dialect/Arith/Utils/Utils.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/Dialect/Tensor/Utils/Utils.h" 18 #include "mlir/Dialect/Tosa/IR/TosaOps.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 22 #include <numeric> 23 24 using namespace mlir; 25 using namespace tosa; 26 27 namespace { 28 29 // Infer the type to which the input of a 'tosa.reshape' op must be cast when 30 // lowered. 31 TensorType inferReshapeInputType(TypedValue<TensorType> input, 32 ArrayRef<int64_t> newShape) { 33 // No need to cast input for non-empty target shape 34 if (!newShape.empty()) 35 return input.getType(); 36 37 // The input type must be cast into a tensor with the same rank and all static 38 // dimensions set to 1. This prevents the generation of a tensor.collapse_shape 39 // op that converts a dynamically shaped tensor into a 0D tensor. While such 40 // construct is not incorrect on its own, bufferization cannot properly handle 41 // it at the moment, so we avoid it. 42 SmallVector<int64_t> shape(input.getType().getRank(), 1); 43 return input.getType().clone(shape); 44 } 45 46 // Infer the result type of 'tensor.expand_shape' in the collapse-expand 47 // pair emitted for a 'tosa.reshape' op. 48 TensorType inferReshapeExpandedType(TensorType inputType, 49 ArrayRef<int64_t> newShape) { 50 // Special case for 0D output tensor. Note: Watch out when using Type::clone() 51 // with just '{}', as it will invoke the incorrect overload. 52 if (newShape.empty()) 53 return inputType.clone(ArrayRef<int64_t>{}); 54 55 // Check if the input is static, and if so, get its total size 56 bool inputIsStatic = inputType.hasStaticShape(); 57 int64_t totalSize = inputIsStatic ? inputType.getNumElements() : -1; 58 59 // Compute result shape 60 auto resultShape = llvm::map_to_vector(newShape, [&](int64_t size) -> int64_t { 61 // If this is not a placeholder, do not change it 62 if (size >= 0) 63 return size; 64 65 // If we do not know the total size of the tensor, keep this dimension 66 // dynamic in the result shape. 67 if (!inputIsStatic) 68 return ShapedType::kDynamic; 69 70 // Calculate the product of all elements in 'newShape' except for the -1 71 // placeholder, which we discard by negating the result. 72 int64_t totalSizeNoPlaceholder = -std::accumulate( 73 newShape.begin(), newShape.end(), 1, std::multiplies<int64_t>()); 74 75 // If there is a 0 component in 'newShape', resolve the placeholder as 0. 76 if (totalSizeNoPlaceholder == 0) 77 return 0; 78 79 // Resolve the placeholder as the quotient between the total tensor size and 80 // the product of all other sizes. 81 return totalSize / totalSizeNoPlaceholder; 82 }); 83 84 bool resultIsStatic = !ShapedType::isDynamicShape(resultShape); 85 86 // A syntactic restriction in 'tensor.expand_shape' forbids a dynamically 87 // shaped input from being reshaped into a statically shaped result. We may 88 // simply turn the first result dimension dynamic to address this. 89 if (!inputIsStatic && resultIsStatic) 90 resultShape[0] = ShapedType::kDynamic; 91 92 // The 'tensor.expand_shape' op also forbids a statically shaped input from 93 // being reshaped into a dynamically shaped result, but the placeholder 94 // inference algorithm above guarantees that this will never be the case. 95 assert(!inputIsStatic || resultIsStatic); 96 97 // Create result type 98 return inputType.clone(resultShape); 99 } 100 101 // Infer the result type of 'tensor.collapse_shape' in the collapse-expand 102 // pair emitted for a 'tosa.reshape' op. 103 TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) { 104 auto lhsShape = lhsType.getShape(); 105 auto rhsShape = rhsType.getShape(); 106 107 if (lhsShape.empty() || rhsShape.empty()) 108 return lhsType.clone(ArrayRef<int64_t>{}); 109 110 if (ShapedType::isDynamicShape(lhsShape) || ShapedType::isDynamicShape(rhsShape)) 111 return lhsType.clone({ShapedType::kDynamic}); 112 113 SmallVector<int64_t> intermediateShape; 114 unsigned currLhsDim = 0, currRhsDim = 0; 115 while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) { 116 int64_t rhsSize = rhsShape[currRhsDim]; 117 int64_t lhsSize = lhsShape[currLhsDim]; 118 while (lhsSize != rhsSize && currLhsDim < lhsShape.size() && 119 currRhsDim < rhsShape.size()) { 120 if (lhsSize < rhsSize) { 121 currLhsDim++; 122 if (currLhsDim < lhsShape.size()) { 123 lhsSize *= lhsShape[currLhsDim]; 124 } 125 } else { 126 currRhsDim++; 127 if (currRhsDim < rhsShape.size()) { 128 rhsSize *= rhsShape[currRhsDim]; 129 } 130 } 131 } 132 if (lhsSize == rhsSize) { 133 intermediateShape.push_back(lhsSize); 134 } 135 currRhsDim++; 136 currLhsDim++; 137 } 138 139 // Static shapes are guaranteed to be compatible by the op verifier, so all 140 // leftover dimensions should be 1. 141 for (; currLhsDim < lhsShape.size(); currLhsDim++) { 142 assert(lhsShape[currLhsDim] == 1); 143 } 144 for (; currRhsDim < rhsShape.size(); currRhsDim++) { 145 assert(rhsShape[currRhsDim] == 1); 146 } 147 148 return lhsType.clone(intermediateShape); 149 } 150 151 SmallVector<ReassociationExprs> 152 createReassociationMapForCollapse(OpBuilder &builder, Type srcType, Type dstType) { 153 auto srcShape = cast<TensorType>(srcType).getShape(); 154 auto dstShape = cast<TensorType>(dstType).getShape(); 155 156 if (srcShape.empty() || dstShape.empty()) 157 return {}; 158 159 if (ShapedType::isDynamicShape(srcShape) || ShapedType::isDynamicShape(dstShape)) { 160 assert(dstShape.size() == 1); 161 SmallVector<AffineExpr, 2> exprs; 162 for (auto i : llvm::seq<int64_t>(srcShape.size())) 163 exprs.push_back(builder.getAffineDimExpr(i)); 164 return {exprs}; 165 } 166 167 SmallVector<ReassociationExprs> reassociationMap(dstShape.size()); 168 unsigned currSrcDim = 0, currDstDim = 0; 169 while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) { 170 int64_t dstSize = dstShape[currDstDim]; 171 int64_t srcSize = srcShape[currSrcDim]; 172 while (srcSize < dstSize && currSrcDim < srcShape.size()) { 173 reassociationMap[currDstDim].push_back( 174 builder.getAffineDimExpr(currSrcDim++)); 175 srcSize *= srcShape[currSrcDim]; 176 } 177 if (srcSize == dstSize) { 178 reassociationMap[currDstDim].push_back( 179 builder.getAffineDimExpr(currSrcDim++)); 180 // If the next dim in collapsedShape is not 1, treat subsequent dims in 181 // expandedShape which are 1 to be collapsed. 182 if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) { 183 while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) { 184 reassociationMap[currDstDim].push_back( 185 builder.getAffineDimExpr(currSrcDim++)); 186 } 187 } 188 } 189 currDstDim++; 190 } 191 192 // If the source and target shapes are compatible, both iterators must have 193 // reached the end. This condition is guaranteed by the op verifier for 194 // static shapes. 195 assert(currSrcDim == srcShape.size() && currDstDim == dstShape.size()); 196 return reassociationMap; 197 } 198 199 // Create a tensor.collapse_shape op that reshapes the input into the given 200 // result type. 201 Value createCollapse(OpBuilder &builder, Location loc, TensorType resultType, 202 Value input) { 203 auto reassociationMap = 204 createReassociationMapForCollapse(builder, input.getType(), resultType); 205 return builder.createOrFold<tensor::CollapseShapeOp>(loc, resultType, input, 206 reassociationMap); 207 } 208 209 // Create a tensor.expand_shape op that reshapes the input into the given result 210 // type. 211 Value createExpand(OpBuilder &builder, Location loc, TensorType resultType, 212 Value input) { 213 auto reassociationMap = 214 createReassociationMapForCollapse(builder, resultType, input.getType()); 215 return builder.createOrFold<tensor::ExpandShapeOp>(loc, resultType, input, 216 reassociationMap); 217 } 218 219 class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> { 220 public: 221 using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern; 222 223 LogicalResult 224 matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, 225 ConversionPatternRewriter &rewriter) const final { 226 auto loc = reshape.getLoc(); 227 auto resultType = cast_if_present<ShapedType>( 228 getTypeConverter()->convertType(reshape.getType())); 229 if (!resultType) { 230 return rewriter.notifyMatchFailure(reshape.getLoc(), 231 "could not convert result type"); 232 } 233 auto input = dyn_cast<TypedValue<TensorType>>(adaptor.getInput1()); 234 if (!input) { 235 return rewriter.notifyMatchFailure(reshape.getLoc(), 236 "expected input type to be tensor"); 237 } 238 auto newShape = reshape.getNewShape(); 239 240 // Infer all intermediate types 241 auto inputType = inferReshapeInputType(input, newShape); 242 auto expandedType = inferReshapeExpandedType(inputType, newShape); 243 auto collapsedType = inferReshapeCollapsedType(inputType, expandedType); 244 245 // Cast input if needed 246 auto castInput = rewriter.createOrFold<tensor::CastOp>(loc, inputType, input); 247 248 // Emit collaspe-expand pair 249 auto collapsed = createCollapse(rewriter, loc, collapsedType, castInput); 250 auto expanded = createExpand(rewriter, loc, expandedType, collapsed); 251 252 // Cast to final result type if needed 253 auto result = rewriter.createOrFold<tensor::CastOp>(loc, resultType, expanded); 254 rewriter.replaceOp(reshape, result); 255 return success(); 256 } 257 }; 258 259 class SliceConverter : public OpConversionPattern<tosa::SliceOp> { 260 public: 261 using OpConversionPattern<tosa::SliceOp>::OpConversionPattern; 262 263 LogicalResult 264 matchAndRewrite(tosa::SliceOp sliceOp, OpAdaptor adaptor, 265 ConversionPatternRewriter &rewriter) const final { 266 Location loc = sliceOp.getLoc(); 267 Value input = adaptor.getInput1(); 268 ShapedType resultType = cast<ShapedType>(sliceOp.getType()); 269 if (llvm::isa<UnrankedTensorType>(resultType)) 270 return failure(); 271 272 ElementsAttr startElems; 273 ElementsAttr sizeElems; 274 275 if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems))) 276 return rewriter.notifyMatchFailure( 277 sliceOp, "start of slice must be a static ranked shape"); 278 279 if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems))) 280 return rewriter.notifyMatchFailure( 281 sliceOp, "size of slice must be a static ranked shape"); 282 283 llvm::SmallVector<int64_t> sliceStarts = 284 llvm::to_vector(startElems.getValues<int64_t>()); 285 llvm::SmallVector<int64_t> sliceSizes = 286 llvm::to_vector(sizeElems.getValues<int64_t>()); 287 288 SmallVector<int64_t> strides, sizes; 289 strides.resize(cast<ShapedType>(sliceOp.getType()).getRank(), 1); 290 291 SmallVector<Value> dynSizes; 292 for (const auto &i : llvm::enumerate(sliceSizes)) { 293 int64_t size = i.value(); 294 size_t index = i.index(); 295 sizes.push_back(size == -1 ? ShapedType::kDynamic : size); 296 if (!ShapedType::isDynamic(sizes.back())) 297 continue; 298 299 auto dim = rewriter.create<tensor::DimOp>(loc, input, index); 300 auto offset = rewriter.create<arith::ConstantOp>( 301 loc, rewriter.getIndexAttr(sliceStarts[index])); 302 dynSizes.push_back(rewriter.create<arith::SubIOp>(loc, dim, offset)); 303 } 304 305 auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>( 306 sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes, 307 ValueRange({}), rewriter.getDenseI64ArrayAttr(sliceStarts), 308 rewriter.getDenseI64ArrayAttr(sizes), 309 rewriter.getDenseI64ArrayAttr(strides)); 310 311 rewriter.replaceOp(sliceOp, newSliceOp.getResult()); 312 313 // Remove const_shape ops when it no longer has use point. 314 Operation *startConstShape = sliceOp.getStart().getDefiningOp(); 315 if (startConstShape->getResult(0).hasOneUse()) 316 rewriter.eraseOp(startConstShape); 317 318 Operation *sizeConstShape = sliceOp.getSize().getDefiningOp(); 319 if (sizeConstShape->getResult(0).hasOneUse()) 320 rewriter.eraseOp(sizeConstShape); 321 322 return success(); 323 } 324 }; 325 326 class PadConverter : public OpConversionPattern<tosa::PadOp> { 327 public: 328 using OpConversionPattern::OpConversionPattern; 329 330 LogicalResult 331 matchAndRewrite(tosa::PadOp padOp, OpAdaptor adaptor, 332 ConversionPatternRewriter &rewriter) const final { 333 auto loc = padOp.getLoc(); 334 auto input = padOp.getInput1(); 335 336 ElementsAttr paddingElems; 337 if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) { 338 return rewriter.notifyMatchFailure( 339 padOp, "padding must be a static shape value"); 340 } 341 llvm::SmallVector<int64_t> paddingVals; 342 for (auto idx : paddingElems.getValues<IntegerAttr>()) { 343 paddingVals.push_back(static_cast<int64_t>(idx.getInt())); 344 } 345 346 ShapedType inputTy = cast<ShapedType>(input.getType()); 347 Type elementTy = inputTy.getElementType(); 348 int64_t rank = inputTy.getRank(); 349 350 // Setup the default constantAttr. 351 352 Value padConstant; 353 354 if (padOp.getPadConst()) { 355 padConstant = rewriter.createOrFold<tensor::ExtractOp>( 356 loc, padOp.getPadConst(), ValueRange({})); 357 } else { 358 TypedAttr constantAttr; 359 if (isa<FloatType>(elementTy)) { 360 constantAttr = rewriter.getFloatAttr(elementTy, 0.0); 361 } else if (isa<IntegerType>(elementTy) && !padOp.getQuantizationInfo()) { 362 constantAttr = rewriter.getIntegerAttr(elementTy, 0); 363 } else if (isa<IntegerType>(elementTy) && padOp.getQuantizationInfo()) { 364 int64_t value = padOp.getQuantizationInfo()->getInputZp(); 365 constantAttr = rewriter.getIntegerAttr(elementTy, value); 366 } 367 if (constantAttr) 368 padConstant = rewriter.create<arith::ConstantOp>(loc, constantAttr); 369 } 370 371 if (!padConstant) { 372 return rewriter.notifyMatchFailure( 373 padOp, "tosa.pad was unable to determine the pad constant value."); 374 } 375 376 SmallVector<OpFoldResult, 3> lowValues; 377 SmallVector<OpFoldResult, 3> highValues; 378 379 lowValues.reserve(rank); 380 highValues.reserve(rank); 381 382 for (int i = 0; i < rank; i++) { 383 Value lowVal = rewriter.create<arith::ConstantOp>( 384 loc, rewriter.getIndexAttr(paddingVals[2 * i])); 385 Value highVal = rewriter.create<arith::ConstantOp>( 386 loc, rewriter.getIndexAttr(paddingVals[2 * i + 1])); 387 lowValues.push_back(lowVal); 388 highValues.push_back(highVal); 389 } 390 391 auto newPadOp = rewriter.create<tensor::PadOp>( 392 loc, padOp.getType(), input, lowValues, highValues, padConstant); 393 394 rewriter.replaceOp(padOp, newPadOp.getResult()); 395 return success(); 396 } 397 }; 398 399 struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> { 400 using OpConversionPattern<tosa::ConcatOp>::OpConversionPattern; 401 402 LogicalResult 403 matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor, 404 ConversionPatternRewriter &rewriter) const override { 405 auto resultType = dyn_cast<RankedTensorType>(op.getType()); 406 407 Location loc = op.getLoc(); 408 int axis = op.getAxis(); 409 Value axisValue = 410 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(axis)); 411 int64_t rank = resultType.getRank(); 412 413 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); 414 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0)); 415 SmallVector<OpFoldResult> sizes = 416 tensor::getMixedSizes(rewriter, op.getLoc(), adaptor.getOperands()[0]); 417 418 // Pre-compute the offsets along the axis dimension. 419 // The axisOffsets will be of size rank + 1, where the last value 420 // will hold the total size of the tensor along the 'axis' dimension. 421 SmallVector<OpFoldResult> axisOffsets; 422 axisOffsets.push_back(rewriter.getIndexAttr(0)); 423 axisOffsets.push_back(sizes[axis]); 424 425 for (auto arg : adaptor.getOperands().drop_front()) { 426 auto size = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue); 427 auto currentOffset = 428 getValueOrCreateConstantIndexOp(rewriter, loc, axisOffsets.back()); 429 auto total = 430 rewriter.createOrFold<arith::AddIOp>(loc, currentOffset, size); 431 axisOffsets.push_back(getAsOpFoldResult(total)); 432 } 433 sizes[axis] = axisOffsets.back(); 434 435 // Compute the dynamic sizes of the tensor.empty operation. 436 // This is based off of the specified result type of the tosa.concat 437 // operation, since we don't want to change the result type of the operation 438 // during the conversion. 439 SmallVector<Value> dynDims; 440 for (int64_t i = 0; i < rank; ++i) { 441 if (resultType.isDynamicDim(i)) { 442 dynDims.push_back( 443 getValueOrCreateConstantIndexOp(rewriter, loc, sizes[i])); 444 } 445 } 446 447 Value result = rewriter.create<tensor::EmptyOp>( 448 loc, resultType.getShape(), resultType.getElementType(), dynDims); 449 450 for (auto [arg, offset] : llvm::zip(adaptor.getOperands(), axisOffsets)) { 451 auto sizes = tensor::getMixedSizes(rewriter, op.getLoc(), arg); 452 offsets[axis] = offset; 453 result = rewriter.createOrFold<tensor::InsertSliceOp>( 454 loc, arg, result, offsets, sizes, strides); 455 } 456 rewriter.replaceOp(op, result); 457 return success(); 458 } 459 }; 460 461 } // namespace 462 463 void mlir::tosa::populateTosaToTensorConversionPatterns( 464 const TypeConverter &converter, RewritePatternSet *patterns) { 465 patterns 466 ->add<ConcatConverter, PadConverter, ReshapeConverter, SliceConverter>( 467 converter, patterns->getContext()); 468 } 469