xref: /llvm-project/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp (revision 956c0707d9098499a2682297b71f46b0a562eed9)
1 //===- TosaToTensor.cpp - Lowering Tosa to Tensor Dialect -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Tensor dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/Arith/Utils/Utils.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/Dialect/Tensor/Utils/Utils.h"
18 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 
22 #include <numeric>
23 
24 using namespace mlir;
25 using namespace tosa;
26 
27 namespace {
28 
29 // Infer the type to which the input of a 'tosa.reshape' op must be cast when
30 // lowered.
31 TensorType inferReshapeInputType(TypedValue<TensorType> input,
32                                  ArrayRef<int64_t> newShape) {
33   // No need to cast input for non-empty target shape
34   if (!newShape.empty())
35     return input.getType();
36 
37   // The input type must be cast into a tensor with the same rank and all static
38   // dimensions set to 1. This prevents the generation of a tensor.collapse_shape
39   // op that converts a dynamically shaped tensor into a 0D tensor. While such
40   // construct is not incorrect on its own, bufferization cannot properly handle
41   // it at the moment, so we avoid it.
42   SmallVector<int64_t> shape(input.getType().getRank(), 1);
43   return input.getType().clone(shape);
44 }
45 
46 // Infer the result type of 'tensor.expand_shape' in the collapse-expand
47 // pair emitted for a 'tosa.reshape' op.
48 TensorType inferReshapeExpandedType(TensorType inputType,
49                                     ArrayRef<int64_t> newShape) {
50   // Special case for 0D output tensor. Note: Watch out when using Type::clone()
51   // with just '{}', as it will invoke the incorrect overload.
52   if (newShape.empty())
53     return inputType.clone(ArrayRef<int64_t>{});
54 
55   // Check if the input is static, and if so, get its total size
56   bool inputIsStatic = inputType.hasStaticShape();
57   int64_t totalSize = inputIsStatic ? inputType.getNumElements() : -1;
58 
59   // Compute result shape
60   auto resultShape = llvm::map_to_vector(newShape, [&](int64_t size) -> int64_t {
61     // If this is not a placeholder, do not change it
62     if (size >= 0)
63       return size;
64 
65     // If we do not know the total size of the tensor, keep this dimension
66     // dynamic in the result shape.
67     if (!inputIsStatic)
68       return ShapedType::kDynamic;
69 
70     // Calculate the product of all elements in 'newShape' except for the -1
71     // placeholder, which we discard by negating the result.
72     int64_t totalSizeNoPlaceholder = -std::accumulate(
73         newShape.begin(), newShape.end(), 1, std::multiplies<int64_t>());
74 
75     // If there is a 0 component in 'newShape', resolve the placeholder as 0.
76     if (totalSizeNoPlaceholder == 0)
77       return 0;
78 
79     // Resolve the placeholder as the quotient between the total tensor size and
80     // the product of all other sizes.
81     return totalSize / totalSizeNoPlaceholder;
82   });
83 
84   bool resultIsStatic = !ShapedType::isDynamicShape(resultShape);
85 
86   // A syntactic restriction in 'tensor.expand_shape' forbids a dynamically
87   // shaped input from being reshaped into a statically shaped result. We may
88   // simply turn the first result dimension dynamic to address this.
89   if (!inputIsStatic && resultIsStatic)
90     resultShape[0] = ShapedType::kDynamic;
91 
92   // The 'tensor.expand_shape' op also forbids a statically shaped input from
93   // being reshaped into a dynamically shaped result, but the placeholder
94   // inference algorithm above guarantees that this will never be the case.
95   assert(!inputIsStatic || resultIsStatic);
96 
97   // Create result type
98   return inputType.clone(resultShape);
99 }
100 
101 // Infer the result type of 'tensor.collapse_shape' in the collapse-expand
102 // pair emitted for a 'tosa.reshape' op.
103 TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) {
104   auto lhsShape = lhsType.getShape();
105   auto rhsShape = rhsType.getShape();
106 
107   if (lhsShape.empty() || rhsShape.empty())
108     return lhsType.clone(ArrayRef<int64_t>{});
109 
110   if (ShapedType::isDynamicShape(lhsShape) || ShapedType::isDynamicShape(rhsShape))
111     return lhsType.clone({ShapedType::kDynamic});
112 
113   SmallVector<int64_t> intermediateShape;
114   unsigned currLhsDim = 0, currRhsDim = 0;
115   while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
116     int64_t rhsSize = rhsShape[currRhsDim];
117     int64_t lhsSize = lhsShape[currLhsDim];
118     while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
119            currRhsDim < rhsShape.size()) {
120       if (lhsSize < rhsSize) {
121         currLhsDim++;
122         if (currLhsDim < lhsShape.size()) {
123           lhsSize *= lhsShape[currLhsDim];
124         }
125       } else {
126         currRhsDim++;
127         if (currRhsDim < rhsShape.size()) {
128           rhsSize *= rhsShape[currRhsDim];
129         }
130       }
131     }
132     if (lhsSize == rhsSize) {
133       intermediateShape.push_back(lhsSize);
134     }
135     currRhsDim++;
136     currLhsDim++;
137   }
138 
139   // Static shapes are guaranteed to be compatible by the op verifier, so all
140   // leftover dimensions should be 1.
141   for (; currLhsDim < lhsShape.size(); currLhsDim++) {
142     assert(lhsShape[currLhsDim] == 1);
143   }
144   for (; currRhsDim < rhsShape.size(); currRhsDim++) {
145     assert(rhsShape[currRhsDim] == 1);
146   }
147 
148   return lhsType.clone(intermediateShape);
149 }
150 
151 SmallVector<ReassociationExprs>
152 createReassociationMapForCollapse(OpBuilder &builder, Type srcType, Type dstType) {
153   auto srcShape = cast<TensorType>(srcType).getShape();
154   auto dstShape = cast<TensorType>(dstType).getShape();
155 
156   if (srcShape.empty() || dstShape.empty())
157     return {};
158 
159   if (ShapedType::isDynamicShape(srcShape) || ShapedType::isDynamicShape(dstShape)) {
160     assert(dstShape.size() == 1);
161     SmallVector<AffineExpr, 2> exprs;
162     for (auto i : llvm::seq<int64_t>(srcShape.size()))
163       exprs.push_back(builder.getAffineDimExpr(i));
164     return {exprs};
165   }
166 
167   SmallVector<ReassociationExprs> reassociationMap(dstShape.size());
168   unsigned currSrcDim = 0, currDstDim = 0;
169   while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
170     int64_t dstSize = dstShape[currDstDim];
171     int64_t srcSize = srcShape[currSrcDim];
172     while (srcSize < dstSize && currSrcDim < srcShape.size()) {
173       reassociationMap[currDstDim].push_back(
174           builder.getAffineDimExpr(currSrcDim++));
175       srcSize *= srcShape[currSrcDim];
176     }
177     if (srcSize == dstSize) {
178       reassociationMap[currDstDim].push_back(
179           builder.getAffineDimExpr(currSrcDim++));
180       // If the next dim in collapsedShape is not 1, treat subsequent dims in
181       // expandedShape which are 1 to be collapsed.
182       if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
183         while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
184           reassociationMap[currDstDim].push_back(
185               builder.getAffineDimExpr(currSrcDim++));
186         }
187       }
188     }
189     currDstDim++;
190   }
191 
192   // If the source and target shapes are compatible, both iterators must have
193   // reached the end. This condition is guaranteed by the op verifier for
194   // static shapes.
195   assert(currSrcDim == srcShape.size() && currDstDim == dstShape.size());
196   return reassociationMap;
197 }
198 
199 // Create a tensor.collapse_shape op that reshapes the input into the given
200 // result type.
201 Value createCollapse(OpBuilder &builder, Location loc, TensorType resultType,
202                      Value input) {
203   auto reassociationMap =
204       createReassociationMapForCollapse(builder, input.getType(), resultType);
205   return builder.createOrFold<tensor::CollapseShapeOp>(loc, resultType, input,
206                                                        reassociationMap);
207 }
208 
209 // Create a tensor.expand_shape op that reshapes the input into the given result
210 // type.
211 Value createExpand(OpBuilder &builder, Location loc, TensorType resultType,
212                    Value input) {
213   auto reassociationMap =
214       createReassociationMapForCollapse(builder, resultType, input.getType());
215   return builder.createOrFold<tensor::ExpandShapeOp>(loc, resultType, input,
216                                                      reassociationMap);
217 }
218 
219 class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
220 public:
221   using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
222 
223   LogicalResult
224   matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
225                   ConversionPatternRewriter &rewriter) const final {
226     auto loc = reshape.getLoc();
227     auto resultType = cast_if_present<ShapedType>(
228         getTypeConverter()->convertType(reshape.getType()));
229     if (!resultType) {
230       return rewriter.notifyMatchFailure(reshape.getLoc(),
231                                          "could not convert result type");
232     }
233     auto input = dyn_cast<TypedValue<TensorType>>(adaptor.getInput1());
234     if (!input) {
235       return rewriter.notifyMatchFailure(reshape.getLoc(),
236                                          "expected input type to be tensor");
237     }
238     auto newShape = reshape.getNewShape();
239 
240     // Infer all intermediate types
241     auto inputType = inferReshapeInputType(input, newShape);
242     auto expandedType = inferReshapeExpandedType(inputType, newShape);
243     auto collapsedType = inferReshapeCollapsedType(inputType, expandedType);
244 
245     // Cast input if needed
246     auto castInput = rewriter.createOrFold<tensor::CastOp>(loc, inputType, input);
247 
248     // Emit collaspe-expand pair
249     auto collapsed = createCollapse(rewriter, loc, collapsedType, castInput);
250     auto expanded = createExpand(rewriter, loc, expandedType, collapsed);
251 
252     // Cast to final result type if needed
253     auto result = rewriter.createOrFold<tensor::CastOp>(loc, resultType, expanded);
254     rewriter.replaceOp(reshape, result);
255     return success();
256   }
257 };
258 
259 class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
260 public:
261   using OpConversionPattern<tosa::SliceOp>::OpConversionPattern;
262 
263   LogicalResult
264   matchAndRewrite(tosa::SliceOp sliceOp, OpAdaptor adaptor,
265                   ConversionPatternRewriter &rewriter) const final {
266     Location loc = sliceOp.getLoc();
267     Value input = adaptor.getInput1();
268     ShapedType resultType = cast<ShapedType>(sliceOp.getType());
269     if (llvm::isa<UnrankedTensorType>(resultType))
270       return failure();
271 
272     ElementsAttr startElems;
273     ElementsAttr sizeElems;
274 
275     if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
276       return rewriter.notifyMatchFailure(
277           sliceOp, "start of slice must be a static ranked shape");
278 
279     if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
280       return rewriter.notifyMatchFailure(
281           sliceOp, "size of slice must be a static ranked shape");
282 
283     llvm::SmallVector<int64_t> sliceStarts =
284         llvm::to_vector(startElems.getValues<int64_t>());
285     llvm::SmallVector<int64_t> sliceSizes =
286         llvm::to_vector(sizeElems.getValues<int64_t>());
287 
288     SmallVector<int64_t> strides, sizes;
289     strides.resize(cast<ShapedType>(sliceOp.getType()).getRank(), 1);
290 
291     SmallVector<Value> dynSizes;
292     for (const auto &i : llvm::enumerate(sliceSizes)) {
293       int64_t size = i.value();
294       size_t index = i.index();
295       sizes.push_back(size == -1 ? ShapedType::kDynamic : size);
296       if (!ShapedType::isDynamic(sizes.back()))
297         continue;
298 
299       auto dim = rewriter.create<tensor::DimOp>(loc, input, index);
300       auto offset = rewriter.create<arith::ConstantOp>(
301           loc, rewriter.getIndexAttr(sliceStarts[index]));
302       dynSizes.push_back(rewriter.create<arith::SubIOp>(loc, dim, offset));
303     }
304 
305     auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
306         sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes,
307         ValueRange({}), rewriter.getDenseI64ArrayAttr(sliceStarts),
308         rewriter.getDenseI64ArrayAttr(sizes),
309         rewriter.getDenseI64ArrayAttr(strides));
310 
311     rewriter.replaceOp(sliceOp, newSliceOp.getResult());
312 
313     // Remove const_shape ops when it no longer has use point.
314     Operation *startConstShape = sliceOp.getStart().getDefiningOp();
315     if (startConstShape->getResult(0).hasOneUse())
316       rewriter.eraseOp(startConstShape);
317 
318     Operation *sizeConstShape = sliceOp.getSize().getDefiningOp();
319     if (sizeConstShape->getResult(0).hasOneUse())
320       rewriter.eraseOp(sizeConstShape);
321 
322     return success();
323   }
324 };
325 
326 class PadConverter : public OpConversionPattern<tosa::PadOp> {
327 public:
328   using OpConversionPattern::OpConversionPattern;
329 
330   LogicalResult
331   matchAndRewrite(tosa::PadOp padOp, OpAdaptor adaptor,
332                   ConversionPatternRewriter &rewriter) const final {
333     auto loc = padOp.getLoc();
334     auto input = padOp.getInput1();
335 
336     ElementsAttr paddingElems;
337     if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
338       return rewriter.notifyMatchFailure(
339           padOp, "padding must be a static shape value");
340     }
341     llvm::SmallVector<int64_t> paddingVals;
342     for (auto idx : paddingElems.getValues<IntegerAttr>()) {
343       paddingVals.push_back(static_cast<int64_t>(idx.getInt()));
344     }
345 
346     ShapedType inputTy = cast<ShapedType>(input.getType());
347     Type elementTy = inputTy.getElementType();
348     int64_t rank = inputTy.getRank();
349 
350     // Setup the default constantAttr.
351 
352     Value padConstant;
353 
354     if (padOp.getPadConst()) {
355       padConstant = rewriter.createOrFold<tensor::ExtractOp>(
356           loc, padOp.getPadConst(), ValueRange({}));
357     } else {
358       TypedAttr constantAttr;
359       if (isa<FloatType>(elementTy)) {
360         constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
361       } else if (isa<IntegerType>(elementTy) && !padOp.getQuantizationInfo()) {
362         constantAttr = rewriter.getIntegerAttr(elementTy, 0);
363       } else if (isa<IntegerType>(elementTy) && padOp.getQuantizationInfo()) {
364         int64_t value = padOp.getQuantizationInfo()->getInputZp();
365         constantAttr = rewriter.getIntegerAttr(elementTy, value);
366       }
367       if (constantAttr)
368         padConstant = rewriter.create<arith::ConstantOp>(loc, constantAttr);
369     }
370 
371     if (!padConstant) {
372       return rewriter.notifyMatchFailure(
373           padOp, "tosa.pad was unable to determine the pad constant value.");
374     }
375 
376     SmallVector<OpFoldResult, 3> lowValues;
377     SmallVector<OpFoldResult, 3> highValues;
378 
379     lowValues.reserve(rank);
380     highValues.reserve(rank);
381 
382     for (int i = 0; i < rank; i++) {
383       Value lowVal = rewriter.create<arith::ConstantOp>(
384           loc, rewriter.getIndexAttr(paddingVals[2 * i]));
385       Value highVal = rewriter.create<arith::ConstantOp>(
386           loc, rewriter.getIndexAttr(paddingVals[2 * i + 1]));
387       lowValues.push_back(lowVal);
388       highValues.push_back(highVal);
389     }
390 
391     auto newPadOp = rewriter.create<tensor::PadOp>(
392         loc, padOp.getType(), input, lowValues, highValues, padConstant);
393 
394     rewriter.replaceOp(padOp, newPadOp.getResult());
395     return success();
396   }
397 };
398 
399 struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
400   using OpConversionPattern<tosa::ConcatOp>::OpConversionPattern;
401 
402   LogicalResult
403   matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor,
404                   ConversionPatternRewriter &rewriter) const override {
405     auto resultType = dyn_cast<RankedTensorType>(op.getType());
406 
407     Location loc = op.getLoc();
408     int axis = op.getAxis();
409     Value axisValue =
410         rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(axis));
411     int64_t rank = resultType.getRank();
412 
413     SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
414     SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
415     SmallVector<OpFoldResult> sizes =
416         tensor::getMixedSizes(rewriter, op.getLoc(), adaptor.getOperands()[0]);
417 
418     // Pre-compute the offsets along the axis dimension.
419     // The axisOffsets will be of size rank + 1, where the last value
420     // will hold the total size of the tensor along the 'axis' dimension.
421     SmallVector<OpFoldResult> axisOffsets;
422     axisOffsets.push_back(rewriter.getIndexAttr(0));
423     axisOffsets.push_back(sizes[axis]);
424 
425     for (auto arg : adaptor.getOperands().drop_front()) {
426       auto size = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
427       auto currentOffset =
428           getValueOrCreateConstantIndexOp(rewriter, loc, axisOffsets.back());
429       auto total =
430           rewriter.createOrFold<arith::AddIOp>(loc, currentOffset, size);
431       axisOffsets.push_back(getAsOpFoldResult(total));
432     }
433     sizes[axis] = axisOffsets.back();
434 
435     // Compute the dynamic sizes of the tensor.empty operation.
436     // This is based off of the specified result type of the tosa.concat
437     // operation, since we don't want to change the result type of the operation
438     // during the conversion.
439     SmallVector<Value> dynDims;
440     for (int64_t i = 0; i < rank; ++i) {
441       if (resultType.isDynamicDim(i)) {
442         dynDims.push_back(
443             getValueOrCreateConstantIndexOp(rewriter, loc, sizes[i]));
444       }
445     }
446 
447     Value result = rewriter.create<tensor::EmptyOp>(
448         loc, resultType.getShape(), resultType.getElementType(), dynDims);
449 
450     for (auto [arg, offset] : llvm::zip(adaptor.getOperands(), axisOffsets)) {
451       auto sizes = tensor::getMixedSizes(rewriter, op.getLoc(), arg);
452       offsets[axis] = offset;
453       result = rewriter.createOrFold<tensor::InsertSliceOp>(
454           loc, arg, result, offsets, sizes, strides);
455     }
456     rewriter.replaceOp(op, result);
457     return success();
458   }
459 };
460 
461 } // namespace
462 
463 void mlir::tosa::populateTosaToTensorConversionPatterns(
464     const TypeConverter &converter, RewritePatternSet *patterns) {
465   patterns
466       ->add<ConcatConverter, PadConverter, ReshapeConverter, SliceConverter>(
467           converter, patterns->getContext());
468 }
469