1126e7eafSRob Suderman //===- TosaToTensor.cpp - Lowering Tosa to Tensor Dialect -------------===// 2126e7eafSRob Suderman // 3126e7eafSRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4126e7eafSRob Suderman // See https://llvm.org/LICENSE.txt for license information. 5126e7eafSRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6126e7eafSRob Suderman // 7126e7eafSRob Suderman //===----------------------------------------------------------------------===// 8126e7eafSRob Suderman // 9126e7eafSRob Suderman // These rewriters lower from the Tosa to the Tensor dialect. 10126e7eafSRob Suderman // 11126e7eafSRob Suderman //===----------------------------------------------------------------------===// 12126e7eafSRob Suderman 13126e7eafSRob Suderman #include "mlir/Conversion/TosaToTensor/TosaToTensor.h" 14abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 1586c4972fSSpenser Bauman #include "mlir/Dialect/Arith/Utils/Utils.h" 16126e7eafSRob Suderman #include "mlir/Dialect/Tensor/IR/Tensor.h" 1786c4972fSSpenser Bauman #include "mlir/Dialect/Tensor/Utils/Utils.h" 18126e7eafSRob Suderman #include "mlir/Dialect/Tosa/IR/TosaOps.h" 19126e7eafSRob Suderman #include "mlir/IR/PatternMatch.h" 20723979efSKrzysztof Drewniak #include "mlir/Transforms/DialectConversion.h" 21126e7eafSRob Suderman 2226d896f3SRafael Ubal #include <numeric> 2326d896f3SRafael Ubal 24126e7eafSRob Suderman using namespace mlir; 25126e7eafSRob Suderman using namespace tosa; 26126e7eafSRob Suderman 2726d896f3SRafael Ubal namespace { 2826d896f3SRafael Ubal 2926d896f3SRafael Ubal // Infer the type to which the input of a 'tosa.reshape' op must be cast when 3026d896f3SRafael Ubal // lowered. 3126d896f3SRafael Ubal TensorType inferReshapeInputType(TypedValue<TensorType> input, 3226d896f3SRafael Ubal ArrayRef<int64_t> newShape) { 3326d896f3SRafael Ubal // No need to cast input for non-empty target shape 3426d896f3SRafael Ubal if (!newShape.empty()) 3526d896f3SRafael Ubal return input.getType(); 3626d896f3SRafael Ubal 3726d896f3SRafael Ubal // The input type must be cast into a tensor with the same rank and all static 3826d896f3SRafael Ubal // dimensions set to 1. This prevents the generation of a tensor.collapse_shape 3926d896f3SRafael Ubal // op that converts a dynamically shaped tensor into a 0D tensor. While such 4026d896f3SRafael Ubal // construct is not incorrect on its own, bufferization cannot properly handle 4126d896f3SRafael Ubal // it at the moment, so we avoid it. 4226d896f3SRafael Ubal SmallVector<int64_t> shape(input.getType().getRank(), 1); 4326d896f3SRafael Ubal return input.getType().clone(shape); 44723979efSKrzysztof Drewniak } 45723979efSKrzysztof Drewniak 4626d896f3SRafael Ubal // Infer the result type of 'tensor.expand_shape' in the collapse-expand 4726d896f3SRafael Ubal // pair emitted for a 'tosa.reshape' op. 4826d896f3SRafael Ubal TensorType inferReshapeExpandedType(TensorType inputType, 4926d896f3SRafael Ubal ArrayRef<int64_t> newShape) { 5026d896f3SRafael Ubal // Special case for 0D output tensor. Note: Watch out when using Type::clone() 5126d896f3SRafael Ubal // with just '{}', as it will invoke the incorrect overload. 5226d896f3SRafael Ubal if (newShape.empty()) 5326d896f3SRafael Ubal return inputType.clone(ArrayRef<int64_t>{}); 5426d896f3SRafael Ubal 5526d896f3SRafael Ubal // Check if the input is static, and if so, get its total size 5626d896f3SRafael Ubal bool inputIsStatic = inputType.hasStaticShape(); 5726d896f3SRafael Ubal int64_t totalSize = inputIsStatic ? inputType.getNumElements() : -1; 5826d896f3SRafael Ubal 5926d896f3SRafael Ubal // Compute result shape 6026d896f3SRafael Ubal auto resultShape = llvm::map_to_vector(newShape, [&](int64_t size) -> int64_t { 6126d896f3SRafael Ubal // If this is not a placeholder, do not change it 6226d896f3SRafael Ubal if (size >= 0) 6326d896f3SRafael Ubal return size; 6426d896f3SRafael Ubal 6526d896f3SRafael Ubal // If we do not know the total size of the tensor, keep this dimension 6626d896f3SRafael Ubal // dynamic in the result shape. 679d66dcafSSpenser Bauman if (!inputIsStatic) 6826d896f3SRafael Ubal return ShapedType::kDynamic; 69723979efSKrzysztof Drewniak 7026d896f3SRafael Ubal // Calculate the product of all elements in 'newShape' except for the -1 7126d896f3SRafael Ubal // placeholder, which we discard by negating the result. 7226d896f3SRafael Ubal int64_t totalSizeNoPlaceholder = -std::accumulate( 731eaef445SKazu Hirata newShape.begin(), newShape.end(), 1, std::multiplies<int64_t>()); 7426d896f3SRafael Ubal 7526d896f3SRafael Ubal // If there is a 0 component in 'newShape', resolve the placeholder as 0. 7626d896f3SRafael Ubal if (totalSizeNoPlaceholder == 0) 7726d896f3SRafael Ubal return 0; 7826d896f3SRafael Ubal 7926d896f3SRafael Ubal // Resolve the placeholder as the quotient between the total tensor size and 8026d896f3SRafael Ubal // the product of all other sizes. 8126d896f3SRafael Ubal return totalSize / totalSizeNoPlaceholder; 8226d896f3SRafael Ubal }); 8326d896f3SRafael Ubal 849d66dcafSSpenser Bauman bool resultIsStatic = !ShapedType::isDynamicShape(resultShape); 859d66dcafSSpenser Bauman 8626d896f3SRafael Ubal // A syntactic restriction in 'tensor.expand_shape' forbids a dynamically 8726d896f3SRafael Ubal // shaped input from being reshaped into a statically shaped result. We may 8826d896f3SRafael Ubal // simply turn the first result dimension dynamic to address this. 8926d896f3SRafael Ubal if (!inputIsStatic && resultIsStatic) 9026d896f3SRafael Ubal resultShape[0] = ShapedType::kDynamic; 9126d896f3SRafael Ubal 9226d896f3SRafael Ubal // The 'tensor.expand_shape' op also forbids a statically shaped input from 9326d896f3SRafael Ubal // being reshaped into a dynamically shaped result, but the placeholder 9426d896f3SRafael Ubal // inference algorithm above guarantees that this will never be the case. 9526d896f3SRafael Ubal assert(!inputIsStatic || resultIsStatic); 9626d896f3SRafael Ubal 9726d896f3SRafael Ubal // Create result type 9826d896f3SRafael Ubal return inputType.clone(resultShape); 9926d896f3SRafael Ubal } 10026d896f3SRafael Ubal 10126d896f3SRafael Ubal // Infer the result type of 'tensor.collapse_shape' in the collapse-expand 10226d896f3SRafael Ubal // pair emitted for a 'tosa.reshape' op. 10326d896f3SRafael Ubal TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) { 10426d896f3SRafael Ubal auto lhsShape = lhsType.getShape(); 10526d896f3SRafael Ubal auto rhsShape = rhsType.getShape(); 10626d896f3SRafael Ubal 10726d896f3SRafael Ubal if (lhsShape.empty() || rhsShape.empty()) 10826d896f3SRafael Ubal return lhsType.clone(ArrayRef<int64_t>{}); 10926d896f3SRafael Ubal 11026d896f3SRafael Ubal if (ShapedType::isDynamicShape(lhsShape) || ShapedType::isDynamicShape(rhsShape)) 11126d896f3SRafael Ubal return lhsType.clone({ShapedType::kDynamic}); 11226d896f3SRafael Ubal 11326d896f3SRafael Ubal SmallVector<int64_t> intermediateShape; 114723979efSKrzysztof Drewniak unsigned currLhsDim = 0, currRhsDim = 0; 115723979efSKrzysztof Drewniak while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) { 116723979efSKrzysztof Drewniak int64_t rhsSize = rhsShape[currRhsDim]; 117723979efSKrzysztof Drewniak int64_t lhsSize = lhsShape[currLhsDim]; 118723979efSKrzysztof Drewniak while (lhsSize != rhsSize && currLhsDim < lhsShape.size() && 119723979efSKrzysztof Drewniak currRhsDim < rhsShape.size()) { 120723979efSKrzysztof Drewniak if (lhsSize < rhsSize) { 121723979efSKrzysztof Drewniak currLhsDim++; 122723979efSKrzysztof Drewniak if (currLhsDim < lhsShape.size()) { 123723979efSKrzysztof Drewniak lhsSize *= lhsShape[currLhsDim]; 124723979efSKrzysztof Drewniak } 125723979efSKrzysztof Drewniak } else { 126723979efSKrzysztof Drewniak currRhsDim++; 127723979efSKrzysztof Drewniak if (currRhsDim < rhsShape.size()) { 128723979efSKrzysztof Drewniak rhsSize *= rhsShape[currRhsDim]; 129723979efSKrzysztof Drewniak } 130723979efSKrzysztof Drewniak } 131723979efSKrzysztof Drewniak } 132723979efSKrzysztof Drewniak if (lhsSize == rhsSize) { 133723979efSKrzysztof Drewniak intermediateShape.push_back(lhsSize); 134723979efSKrzysztof Drewniak } 135723979efSKrzysztof Drewniak currRhsDim++; 136723979efSKrzysztof Drewniak currLhsDim++; 137723979efSKrzysztof Drewniak } 138723979efSKrzysztof Drewniak 13926d896f3SRafael Ubal // Static shapes are guaranteed to be compatible by the op verifier, so all 14026d896f3SRafael Ubal // leftover dimensions should be 1. 14126d896f3SRafael Ubal for (; currLhsDim < lhsShape.size(); currLhsDim++) { 14226d896f3SRafael Ubal assert(lhsShape[currLhsDim] == 1); 143723979efSKrzysztof Drewniak } 14426d896f3SRafael Ubal for (; currRhsDim < rhsShape.size(); currRhsDim++) { 14526d896f3SRafael Ubal assert(rhsShape[currRhsDim] == 1); 146723979efSKrzysztof Drewniak } 147723979efSKrzysztof Drewniak 14826d896f3SRafael Ubal return lhsType.clone(intermediateShape); 149723979efSKrzysztof Drewniak } 150723979efSKrzysztof Drewniak 15126d896f3SRafael Ubal SmallVector<ReassociationExprs> 15226d896f3SRafael Ubal createReassociationMapForCollapse(OpBuilder &builder, Type srcType, Type dstType) { 15326d896f3SRafael Ubal auto srcShape = cast<TensorType>(srcType).getShape(); 15426d896f3SRafael Ubal auto dstShape = cast<TensorType>(dstType).getShape(); 155723979efSKrzysztof Drewniak 15626d896f3SRafael Ubal if (srcShape.empty() || dstShape.empty()) 15726d896f3SRafael Ubal return {}; 158723979efSKrzysztof Drewniak 15926d896f3SRafael Ubal if (ShapedType::isDynamicShape(srcShape) || ShapedType::isDynamicShape(dstShape)) { 16026d896f3SRafael Ubal assert(dstShape.size() == 1); 161723979efSKrzysztof Drewniak SmallVector<AffineExpr, 2> exprs; 16226d896f3SRafael Ubal for (auto i : llvm::seq<int64_t>(srcShape.size())) 16326d896f3SRafael Ubal exprs.push_back(builder.getAffineDimExpr(i)); 16426d896f3SRafael Ubal return {exprs}; 165723979efSKrzysztof Drewniak } 166723979efSKrzysztof Drewniak 16726d896f3SRafael Ubal SmallVector<ReassociationExprs> reassociationMap(dstShape.size()); 168723979efSKrzysztof Drewniak unsigned currSrcDim = 0, currDstDim = 0; 169723979efSKrzysztof Drewniak while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) { 170723979efSKrzysztof Drewniak int64_t dstSize = dstShape[currDstDim]; 171723979efSKrzysztof Drewniak int64_t srcSize = srcShape[currSrcDim]; 172723979efSKrzysztof Drewniak while (srcSize < dstSize && currSrcDim < srcShape.size()) { 173723979efSKrzysztof Drewniak reassociationMap[currDstDim].push_back( 17426d896f3SRafael Ubal builder.getAffineDimExpr(currSrcDim++)); 175723979efSKrzysztof Drewniak srcSize *= srcShape[currSrcDim]; 176723979efSKrzysztof Drewniak } 177723979efSKrzysztof Drewniak if (srcSize == dstSize) { 178723979efSKrzysztof Drewniak reassociationMap[currDstDim].push_back( 17926d896f3SRafael Ubal builder.getAffineDimExpr(currSrcDim++)); 180723979efSKrzysztof Drewniak // If the next dim in collapsedShape is not 1, treat subsequent dims in 181723979efSKrzysztof Drewniak // expandedShape which are 1 to be collapsed. 182723979efSKrzysztof Drewniak if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) { 183723979efSKrzysztof Drewniak while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) { 184723979efSKrzysztof Drewniak reassociationMap[currDstDim].push_back( 18526d896f3SRafael Ubal builder.getAffineDimExpr(currSrcDim++)); 186723979efSKrzysztof Drewniak } 187723979efSKrzysztof Drewniak } 188723979efSKrzysztof Drewniak } 189723979efSKrzysztof Drewniak currDstDim++; 190723979efSKrzysztof Drewniak } 191723979efSKrzysztof Drewniak 19226d896f3SRafael Ubal // If the source and target shapes are compatible, both iterators must have 19326d896f3SRafael Ubal // reached the end. This condition is guaranteed by the op verifier for 19426d896f3SRafael Ubal // static shapes. 19526d896f3SRafael Ubal assert(currSrcDim == srcShape.size() && currDstDim == dstShape.size()); 19626d896f3SRafael Ubal return reassociationMap; 197723979efSKrzysztof Drewniak } 198723979efSKrzysztof Drewniak 19926d896f3SRafael Ubal // Create a tensor.collapse_shape op that reshapes the input into the given 20026d896f3SRafael Ubal // result type. 20126d896f3SRafael Ubal Value createCollapse(OpBuilder &builder, Location loc, TensorType resultType, 20226d896f3SRafael Ubal Value input) { 20326d896f3SRafael Ubal auto reassociationMap = 20426d896f3SRafael Ubal createReassociationMapForCollapse(builder, input.getType(), resultType); 20526d896f3SRafael Ubal return builder.createOrFold<tensor::CollapseShapeOp>(loc, resultType, input, 2060ebb0503SMatthias Gehre reassociationMap); 207723979efSKrzysztof Drewniak } 208723979efSKrzysztof Drewniak 20926d896f3SRafael Ubal // Create a tensor.expand_shape op that reshapes the input into the given result 21026d896f3SRafael Ubal // type. 21126d896f3SRafael Ubal Value createExpand(OpBuilder &builder, Location loc, TensorType resultType, 21226d896f3SRafael Ubal Value input) { 21326d896f3SRafael Ubal auto reassociationMap = 21426d896f3SRafael Ubal createReassociationMapForCollapse(builder, resultType, input.getType()); 21526d896f3SRafael Ubal return builder.createOrFold<tensor::ExpandShapeOp>(loc, resultType, input, 2160ebb0503SMatthias Gehre reassociationMap); 217723979efSKrzysztof Drewniak } 218723979efSKrzysztof Drewniak 21926d896f3SRafael Ubal class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> { 220723979efSKrzysztof Drewniak public: 221723979efSKrzysztof Drewniak using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern; 222723979efSKrzysztof Drewniak 223723979efSKrzysztof Drewniak LogicalResult 224723979efSKrzysztof Drewniak matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, 225723979efSKrzysztof Drewniak ConversionPatternRewriter &rewriter) const final { 22626d896f3SRafael Ubal auto loc = reshape.getLoc(); 227af22e274SMatthias Gehre auto resultType = cast_if_present<ShapedType>( 228af22e274SMatthias Gehre getTypeConverter()->convertType(reshape.getType())); 229af22e274SMatthias Gehre if (!resultType) { 230af22e274SMatthias Gehre return rewriter.notifyMatchFailure(reshape.getLoc(), 231af22e274SMatthias Gehre "could not convert result type"); 232af22e274SMatthias Gehre } 233af22e274SMatthias Gehre auto input = dyn_cast<TypedValue<TensorType>>(adaptor.getInput1()); 234af22e274SMatthias Gehre if (!input) { 235af22e274SMatthias Gehre return rewriter.notifyMatchFailure(reshape.getLoc(), 236af22e274SMatthias Gehre "expected input type to be tensor"); 237af22e274SMatthias Gehre } 23826d896f3SRafael Ubal auto newShape = reshape.getNewShape(); 239723979efSKrzysztof Drewniak 24026d896f3SRafael Ubal // Infer all intermediate types 24126d896f3SRafael Ubal auto inputType = inferReshapeInputType(input, newShape); 24226d896f3SRafael Ubal auto expandedType = inferReshapeExpandedType(inputType, newShape); 24326d896f3SRafael Ubal auto collapsedType = inferReshapeCollapsedType(inputType, expandedType); 244723979efSKrzysztof Drewniak 24526d896f3SRafael Ubal // Cast input if needed 24626d896f3SRafael Ubal auto castInput = rewriter.createOrFold<tensor::CastOp>(loc, inputType, input); 2470ebb0503SMatthias Gehre 24826d896f3SRafael Ubal // Emit collaspe-expand pair 24926d896f3SRafael Ubal auto collapsed = createCollapse(rewriter, loc, collapsedType, castInput); 25026d896f3SRafael Ubal auto expanded = createExpand(rewriter, loc, expandedType, collapsed); 2510ebb0503SMatthias Gehre 25226d896f3SRafael Ubal // Cast to final result type if needed 25326d896f3SRafael Ubal auto result = rewriter.createOrFold<tensor::CastOp>(loc, resultType, expanded); 25426d896f3SRafael Ubal rewriter.replaceOp(reshape, result); 255723979efSKrzysztof Drewniak return success(); 256723979efSKrzysztof Drewniak } 257723979efSKrzysztof Drewniak }; 258723979efSKrzysztof Drewniak 259723979efSKrzysztof Drewniak class SliceConverter : public OpConversionPattern<tosa::SliceOp> { 260723979efSKrzysztof Drewniak public: 261723979efSKrzysztof Drewniak using OpConversionPattern<tosa::SliceOp>::OpConversionPattern; 262723979efSKrzysztof Drewniak 263723979efSKrzysztof Drewniak LogicalResult 264723979efSKrzysztof Drewniak matchAndRewrite(tosa::SliceOp sliceOp, OpAdaptor adaptor, 265723979efSKrzysztof Drewniak ConversionPatternRewriter &rewriter) const final { 266640973f2SRob Suderman Location loc = sliceOp.getLoc(); 267c6876b4eSJerry-Ge Value input = adaptor.getInput1(); 2689ab732f6SLiqinWeng ShapedType resultType = cast<ShapedType>(sliceOp.getType()); 269d37056c6SLiqinWeng if (llvm::isa<UnrankedTensorType>(resultType)) 2709ab732f6SLiqinWeng return failure(); 271*956c0707SJerry-Ge 272*956c0707SJerry-Ge ElementsAttr startElems; 273*956c0707SJerry-Ge ElementsAttr sizeElems; 274*956c0707SJerry-Ge 275*956c0707SJerry-Ge if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems))) 276*956c0707SJerry-Ge return rewriter.notifyMatchFailure( 277*956c0707SJerry-Ge sliceOp, "start of slice must be a static ranked shape"); 278*956c0707SJerry-Ge 279*956c0707SJerry-Ge if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems))) 280*956c0707SJerry-Ge return rewriter.notifyMatchFailure( 281*956c0707SJerry-Ge sliceOp, "size of slice must be a static ranked shape"); 282*956c0707SJerry-Ge 283*956c0707SJerry-Ge llvm::SmallVector<int64_t> sliceStarts = 284*956c0707SJerry-Ge llvm::to_vector(startElems.getValues<int64_t>()); 285*956c0707SJerry-Ge llvm::SmallVector<int64_t> sliceSizes = 286*956c0707SJerry-Ge llvm::to_vector(sizeElems.getValues<int64_t>()); 287*956c0707SJerry-Ge 2889e1a3441SAlexander Shaposhnikov SmallVector<int64_t> strides, sizes; 2895550c821STres Popp strides.resize(cast<ShapedType>(sliceOp.getType()).getRank(), 1); 290126e7eafSRob Suderman 291640973f2SRob Suderman SmallVector<Value> dynSizes; 292*956c0707SJerry-Ge for (const auto &i : llvm::enumerate(sliceSizes)) { 2939e1a3441SAlexander Shaposhnikov int64_t size = i.value(); 294640973f2SRob Suderman size_t index = i.index(); 295399638f9SAliia Khasanova sizes.push_back(size == -1 ? ShapedType::kDynamic : size); 296fb4cedccSAliia Khasanova if (!ShapedType::isDynamic(sizes.back())) 297640973f2SRob Suderman continue; 298640973f2SRob Suderman 299640973f2SRob Suderman auto dim = rewriter.create<tensor::DimOp>(loc, input, index); 300640973f2SRob Suderman auto offset = rewriter.create<arith::ConstantOp>( 301*956c0707SJerry-Ge loc, rewriter.getIndexAttr(sliceStarts[index])); 302640973f2SRob Suderman dynSizes.push_back(rewriter.create<arith::SubIOp>(loc, dim, offset)); 303640973f2SRob Suderman } 304640973f2SRob Suderman 305640973f2SRob Suderman auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>( 306640973f2SRob Suderman sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes, 307*956c0707SJerry-Ge ValueRange({}), rewriter.getDenseI64ArrayAttr(sliceStarts), 308a9733b8aSLorenzo Chelini rewriter.getDenseI64ArrayAttr(sizes), 309a9733b8aSLorenzo Chelini rewriter.getDenseI64ArrayAttr(strides)); 310640973f2SRob Suderman 311640973f2SRob Suderman rewriter.replaceOp(sliceOp, newSliceOp.getResult()); 312*956c0707SJerry-Ge 313*956c0707SJerry-Ge // Remove const_shape ops when it no longer has use point. 314*956c0707SJerry-Ge Operation *startConstShape = sliceOp.getStart().getDefiningOp(); 315*956c0707SJerry-Ge if (startConstShape->getResult(0).hasOneUse()) 316*956c0707SJerry-Ge rewriter.eraseOp(startConstShape); 317*956c0707SJerry-Ge 318*956c0707SJerry-Ge Operation *sizeConstShape = sliceOp.getSize().getDefiningOp(); 319*956c0707SJerry-Ge if (sizeConstShape->getResult(0).hasOneUse()) 320*956c0707SJerry-Ge rewriter.eraseOp(sizeConstShape); 321*956c0707SJerry-Ge 322126e7eafSRob Suderman return success(); 323126e7eafSRob Suderman } 324126e7eafSRob Suderman }; 325126e7eafSRob Suderman 326af22e274SMatthias Gehre class PadConverter : public OpConversionPattern<tosa::PadOp> { 3272a196254SRamkumar Ramachandra public: 328af22e274SMatthias Gehre using OpConversionPattern::OpConversionPattern; 3292a196254SRamkumar Ramachandra 330af22e274SMatthias Gehre LogicalResult 331af22e274SMatthias Gehre matchAndRewrite(tosa::PadOp padOp, OpAdaptor adaptor, 332af22e274SMatthias Gehre ConversionPatternRewriter &rewriter) const final { 3332a196254SRamkumar Ramachandra auto loc = padOp.getLoc(); 3342a196254SRamkumar Ramachandra auto input = padOp.getInput1(); 3357e622b61SJerry-Ge 3367e622b61SJerry-Ge ElementsAttr paddingElems; 3377e622b61SJerry-Ge if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) { 3387e622b61SJerry-Ge return rewriter.notifyMatchFailure( 3397e622b61SJerry-Ge padOp, "padding must be a static shape value"); 3407e622b61SJerry-Ge } 3417e622b61SJerry-Ge llvm::SmallVector<int64_t> paddingVals; 3427e622b61SJerry-Ge for (auto idx : paddingElems.getValues<IntegerAttr>()) { 3437e622b61SJerry-Ge paddingVals.push_back(static_cast<int64_t>(idx.getInt())); 3447e622b61SJerry-Ge } 3452a196254SRamkumar Ramachandra 3465550c821STres Popp ShapedType inputTy = cast<ShapedType>(input.getType()); 3472a196254SRamkumar Ramachandra Type elementTy = inputTy.getElementType(); 3482a196254SRamkumar Ramachandra int64_t rank = inputTy.getRank(); 3492a196254SRamkumar Ramachandra 3502a196254SRamkumar Ramachandra // Setup the default constantAttr. 3512a196254SRamkumar Ramachandra 3522a196254SRamkumar Ramachandra Value padConstant; 3532a196254SRamkumar Ramachandra 3542a196254SRamkumar Ramachandra if (padOp.getPadConst()) { 3552a196254SRamkumar Ramachandra padConstant = rewriter.createOrFold<tensor::ExtractOp>( 3562a196254SRamkumar Ramachandra loc, padOp.getPadConst(), ValueRange({})); 3572a196254SRamkumar Ramachandra } else { 3586089d612SRahul Kayaith TypedAttr constantAttr; 3595550c821STres Popp if (isa<FloatType>(elementTy)) { 3602a196254SRamkumar Ramachandra constantAttr = rewriter.getFloatAttr(elementTy, 0.0); 3615550c821STres Popp } else if (isa<IntegerType>(elementTy) && !padOp.getQuantizationInfo()) { 3622a196254SRamkumar Ramachandra constantAttr = rewriter.getIntegerAttr(elementTy, 0); 3635550c821STres Popp } else if (isa<IntegerType>(elementTy) && padOp.getQuantizationInfo()) { 3642a196254SRamkumar Ramachandra int64_t value = padOp.getQuantizationInfo()->getInputZp(); 3652a196254SRamkumar Ramachandra constantAttr = rewriter.getIntegerAttr(elementTy, value); 3662a196254SRamkumar Ramachandra } 3672a196254SRamkumar Ramachandra if (constantAttr) 3682a196254SRamkumar Ramachandra padConstant = rewriter.create<arith::ConstantOp>(loc, constantAttr); 3692a196254SRamkumar Ramachandra } 3702a196254SRamkumar Ramachandra 3712a196254SRamkumar Ramachandra if (!padConstant) { 3722a196254SRamkumar Ramachandra return rewriter.notifyMatchFailure( 3732a196254SRamkumar Ramachandra padOp, "tosa.pad was unable to determine the pad constant value."); 3742a196254SRamkumar Ramachandra } 3752a196254SRamkumar Ramachandra 3762a196254SRamkumar Ramachandra SmallVector<OpFoldResult, 3> lowValues; 3772a196254SRamkumar Ramachandra SmallVector<OpFoldResult, 3> highValues; 3782a196254SRamkumar Ramachandra 3792a196254SRamkumar Ramachandra lowValues.reserve(rank); 3802a196254SRamkumar Ramachandra highValues.reserve(rank); 3812a196254SRamkumar Ramachandra 3822a196254SRamkumar Ramachandra for (int i = 0; i < rank; i++) { 3837e622b61SJerry-Ge Value lowVal = rewriter.create<arith::ConstantOp>( 3847e622b61SJerry-Ge loc, rewriter.getIndexAttr(paddingVals[2 * i])); 3857e622b61SJerry-Ge Value highVal = rewriter.create<arith::ConstantOp>( 3867e622b61SJerry-Ge loc, rewriter.getIndexAttr(paddingVals[2 * i + 1])); 3872a196254SRamkumar Ramachandra lowValues.push_back(lowVal); 3882a196254SRamkumar Ramachandra highValues.push_back(highVal); 3892a196254SRamkumar Ramachandra } 3902a196254SRamkumar Ramachandra 3912a196254SRamkumar Ramachandra auto newPadOp = rewriter.create<tensor::PadOp>( 3922a196254SRamkumar Ramachandra loc, padOp.getType(), input, lowValues, highValues, padConstant); 3932a196254SRamkumar Ramachandra 3942a196254SRamkumar Ramachandra rewriter.replaceOp(padOp, newPadOp.getResult()); 3952a196254SRamkumar Ramachandra return success(); 3962a196254SRamkumar Ramachandra } 3972a196254SRamkumar Ramachandra }; 3982a196254SRamkumar Ramachandra 399e377520aSMaya Amrami struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> { 400e377520aSMaya Amrami using OpConversionPattern<tosa::ConcatOp>::OpConversionPattern; 401e377520aSMaya Amrami 402e377520aSMaya Amrami LogicalResult 403e377520aSMaya Amrami matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor, 404e377520aSMaya Amrami ConversionPatternRewriter &rewriter) const override { 4055550c821STres Popp auto resultType = dyn_cast<RankedTensorType>(op.getType()); 406e377520aSMaya Amrami 407e377520aSMaya Amrami Location loc = op.getLoc(); 408e377520aSMaya Amrami int axis = op.getAxis(); 40965066c02SHugo Trachino Value axisValue = 41065066c02SHugo Trachino rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(axis)); 41186c4972fSSpenser Bauman int64_t rank = resultType.getRank(); 412e377520aSMaya Amrami 41386c4972fSSpenser Bauman SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); 41486c4972fSSpenser Bauman SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0)); 4156596b0ddSMatthias Springer SmallVector<OpFoldResult> sizes = 4166596b0ddSMatthias Springer tensor::getMixedSizes(rewriter, op.getLoc(), adaptor.getOperands()[0]); 417e377520aSMaya Amrami 41886c4972fSSpenser Bauman // Pre-compute the offsets along the axis dimension. 41986c4972fSSpenser Bauman // The axisOffsets will be of size rank + 1, where the last value 42086c4972fSSpenser Bauman // will hold the total size of the tensor along the 'axis' dimension. 42186c4972fSSpenser Bauman SmallVector<OpFoldResult> axisOffsets; 42286c4972fSSpenser Bauman axisOffsets.push_back(rewriter.getIndexAttr(0)); 42386c4972fSSpenser Bauman axisOffsets.push_back(sizes[axis]); 42486c4972fSSpenser Bauman 425e377520aSMaya Amrami for (auto arg : adaptor.getOperands().drop_front()) { 426e377520aSMaya Amrami auto size = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue); 42786c4972fSSpenser Bauman auto currentOffset = 42886c4972fSSpenser Bauman getValueOrCreateConstantIndexOp(rewriter, loc, axisOffsets.back()); 42986c4972fSSpenser Bauman auto total = 43086c4972fSSpenser Bauman rewriter.createOrFold<arith::AddIOp>(loc, currentOffset, size); 43186c4972fSSpenser Bauman axisOffsets.push_back(getAsOpFoldResult(total)); 432e377520aSMaya Amrami } 43386c4972fSSpenser Bauman sizes[axis] = axisOffsets.back(); 434e377520aSMaya Amrami 43586c4972fSSpenser Bauman // Compute the dynamic sizes of the tensor.empty operation. 43686c4972fSSpenser Bauman // This is based off of the specified result type of the tosa.concat 43786c4972fSSpenser Bauman // operation, since we don't want to change the result type of the operation 43886c4972fSSpenser Bauman // during the conversion. 43986c4972fSSpenser Bauman SmallVector<Value> dynDims; 44086c4972fSSpenser Bauman for (int64_t i = 0; i < rank; ++i) { 44186c4972fSSpenser Bauman if (resultType.isDynamicDim(i)) { 44286c4972fSSpenser Bauman dynDims.push_back( 44386c4972fSSpenser Bauman getValueOrCreateConstantIndexOp(rewriter, loc, sizes[i])); 44486c4972fSSpenser Bauman } 44586c4972fSSpenser Bauman } 44686c4972fSSpenser Bauman 44786c4972fSSpenser Bauman Value result = rewriter.create<tensor::EmptyOp>( 448e377520aSMaya Amrami loc, resultType.getShape(), resultType.getElementType(), dynDims); 449e377520aSMaya Amrami 45086c4972fSSpenser Bauman for (auto [arg, offset] : llvm::zip(adaptor.getOperands(), axisOffsets)) { 4516596b0ddSMatthias Springer auto sizes = tensor::getMixedSizes(rewriter, op.getLoc(), arg); 45286c4972fSSpenser Bauman offsets[axis] = offset; 453e377520aSMaya Amrami result = rewriter.createOrFold<tensor::InsertSliceOp>( 45486c4972fSSpenser Bauman loc, arg, result, offsets, sizes, strides); 455e377520aSMaya Amrami } 456e377520aSMaya Amrami rewriter.replaceOp(op, result); 457e377520aSMaya Amrami return success(); 458e377520aSMaya Amrami } 459e377520aSMaya Amrami }; 460e377520aSMaya Amrami 461126e7eafSRob Suderman } // namespace 462126e7eafSRob Suderman 463126e7eafSRob Suderman void mlir::tosa::populateTosaToTensorConversionPatterns( 464206fad0eSMatthias Springer const TypeConverter &converter, RewritePatternSet *patterns) { 465af22e274SMatthias Gehre patterns 466af22e274SMatthias Gehre ->add<ConcatConverter, PadConverter, ReshapeConverter, SliceConverter>( 467af22e274SMatthias Gehre converter, patterns->getContext()); 468126e7eafSRob Suderman } 469