Lines Matching defs:tensor
37 using namespace mlir::tensor;
56 OpFoldResult tensor::getMixedSize(OpBuilder &builder, Location loc, Value value,
61 return builder.createOrFold<tensor::DimOp>(loc, value, dim);
66 SmallVector<OpFoldResult> tensor::getMixedSizes(OpBuilder &builder,
75 FailureOr<Value> tensor::getOrCreateDestination(OpBuilder &b, Location loc,
78 assert(tensorType && "expected tensor type");
86 // Otherwise, create a new destination tensor with the same shape.
104 // Create empty tensor.
106 b.create<tensor::EmptyOp>(loc, mixedSizes, tensorType.getElementType());
110 LogicalResult tensor::getOrCreateDestinations(OpBuilder &b, Location loc,
124 bool tensor::isSameTypeWithoutEncoding(Type tp1, Type tp2) {
134 /// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
135 /// rank-extending tensor.insert_slice op.
176 /// Given a ranked tensor type and a range of values that defines its dynamic
230 /// Replaces chains of two tensor.bitcast operations by a single tensor.bitcast
264 /// Returns true if `target` is a ranked tensor type that preserves static
265 /// information available in the `source` ranked tensor type.
266 bool mlir::tensor::preservesStaticInformation(Type source, Type target) {
296 /// Determines whether tensor::CastOp casts to a more dynamic version of the
297 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
299 /// consume the results of tensor.cast operations. Such foldable tensor.cast
305 /// 2. the tensor type has more static information than the result
309 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
310 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
316 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
318 bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
328 /// Determines whether the tensor::CastOp casts to a more static version of the
329 /// source tensor. This is useful to fold into a producing op and implement
330 /// canonicaliation patterns with the `tensor.cast` op as the root, but producer
337 /// %1 = producer ... : tensor<?x?xf32>
338 /// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
344 /// %2 = producer ... : tensor<8x16xf32>
348 bool mlir::tensor::canFoldIntoProducerOp(CastOp castOp) {
355 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
357 LogicalResult mlir::tensor::foldTensorCast(Operation *op) {
360 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
361 if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
418 /// Replaces chains of two tensor.cast operations by a single tensor.cast
457 /// Fold tensor.cast into tesor.extract_slice producer.
460 /// %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] :
461 /// tensor<128x512xf32> to tensor<?x512xf32>
462 /// %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
466 /// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] :
467 /// tensor<128x512xf32> to tensor<16x512xf32>
477 // Cannot fold cast to unranked tensor.
580 return emitOpError("concatenation dim must be less than the tensor rank");
634 tensor::getMixedSizes(builder, input.getLoc(), input);
647 Value replacement = builder.create<tensor::EmptyOp>(
656 auto insertSlice = builder.create<tensor::InsertSliceOp>(
661 replacement = builder.create<tensor::CastOp>(loc, getType(), replacement);
692 builder.create<tensor::DimOp>(init.getLoc(), init, i).getResult();
700 builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
704 builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
810 // Fold dim to the operand of tensor.generate.
811 if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
830 if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
847 /// Fold dim of a cast into the dim of the source of the tensor cast.
883 /// Fold dim of a tensor reshape operation to a extract into the reshape's shape
991 /// Change the type of the result of a `tensor.empty` by making the result
997 /// %0 = tensor.empty(%arg0, %c5) : tensor<?x?xf32>
1001 /// %0 = tensor.empty(%arg0) : tensor<?x5xf32>
1017 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
1025 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1045 /// %0 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
1046 /// %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32>
1052 /// %0 = tensor.empty(%d1) : tensor<4x?xf32>
1079 // Case 1: The empty tensor dim is static. Check that the tensor cast
1085 // than the empty tensor result shape (enforced by
1088 producer, "mismatch in static value of shape of empty tensor "
1095 // Case 2 : The tensor cast shape is static, but empty tensor result
1102 // Case 3 : The tensor cast shape is dynamic and empty tensor result
1103 // shape is dynamic. Use the dynamic value from the empty tensor op.
1107 // TODO: Do not drop tensor encoding.
1122 /// Try to remove a tensor operation if it would only reshape a constant.
1144 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
1145 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
1149 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
1150 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
1151 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1153 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1155 auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
1160 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1182 if (Attribute tensor = adaptor.getTensor()) {
1185 if (auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1189 if (isa<DenseResourceElementsAttr>(tensor))
1193 // Collect the constant indices into the tensor.
1222 if (Attribute tensor = adaptor.getTensor()) {
1223 auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1265 // %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
1266 // %extract = tensor.extract %cast[%index] : tensor<1xindex>
1270 // %extract = tensor.extract %tensor[%index] : tensor<1xindex>
1275 // Consider expanding this to a template and handle all tensor cast
1278 : public OpRewritePattern<tensor::ExtractOp> {
1279 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1281 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1290 auto newExtract = rewriter.create<tensor::ExtractOp>(
1317 /// - sourceType is the type of the source tensor gathered from
1323 /// The leading dimensions of the index tensor give the result tensor its
1325 /// The trailing dimensions of the result tensor are obtained from the source
1326 /// tensor by setting the dimensions specified in gather_dims to `1` (if
1460 // Ensure that the tensor type has as many dynamic dimensions as are
1483 "body must be terminated with a `yield` operation of the tensor "
1508 /// Canonicalizes tensor.generate operations with a constant
1531 rewriter.replaceOpWithNewOp<tensor::CastOp>(generateOp,
1539 /// %tensor = tensor.generate %x {
1543 /// } : tensor<?xindex>
1544 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
1547 /// tensor.generate operation has no side-effects.
1548 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
1549 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1551 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1574 // TODO: Move extract pattern to tensor::ExtractOp.
1616 return emitOpError("element types of source and destination tensor "
1628 return emitOpError("source and destination tensor should have the "
1633 "reshape to statically-ranked tensor type");
1636 "length of shape operand differs from the result's tensor rank");
1647 // If the producer of operand 'source' is another 'tensor.reshape' op, use the
1648 // producer's input instead as the original tensor to reshape. This could
1662 // reshape has no effect, even if the tensor is dynamically shaped.
1666 if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
1678 if (auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1901 auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
1905 rewriter.replaceOpWithNewOp<tensor::SplatOp>(
1940 auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
1941 if (!tensor::canFoldIntoConsumerOp(castOp))
1957 rewriter.replaceOpWithNewOp<tensor::CastOp>(
2048 /// Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by
2050 /// `tensor.expand_shape` more static and creates a consumer cast that can be
2075 // corresponding expanded dimensions. `tensor.expand_shape` requires at
2145 tensor::DimOp, RankedTensorType>,
2348 assert(sourceTensorType && "not a ranked tensor type");
2376 /// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
2382 /// %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
2383 /// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
2384 /// tensor<3x4xf32>
2388 /// %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
2389 /// tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
2454 /// Fold arith.constant and tensor.extract_slice into arith.constant. The
2544 void mlir::tensor::populateFoldConstantExtractSlicePatterns(
2568 replacement = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
2631 Value mlir::tensor::createCanonicalRankReducingExtractSliceOp(
2632 OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
2633 auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.getType());
2636 SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, tensor);
2638 return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2721 /// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
2722 /// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
2728 /// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
2748 /// %0 = tensor.extract_slice %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2749 /// %1 = tensor.insert_slice %0 into %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2781 reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest());
2819 toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
2830 /// destination tensor is a tensor_cast that removes static type information,
2834 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
2835 /// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
2841 /// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
2844 /// Note: When folding a cast on the destination tensor, the result of the
2861 auto castOp = v.getDefiningOp<tensor::CastOp>();
2881 // The tensor.cast source could have additional static information not seen
2891 // that are not static in the tensor.cast source (i.e., when the cast op
2917 replacement = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
2934 /// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
2935 /// : tensor<?x?xf32> into ...
2941 /// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
2942 /// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
2943 /// : tensor<64x64xf32> into ...
2974 !tensor::CastOp::areCastCompatible(srcType, newSrcType))
2988 Value cast = rewriter.create<tensor::CastOp>(
3010 Value mlir::tensor::createCanonicalRankReducingInsertSliceOp(OpBuilder &b,
3012 Value tensor,
3019 return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
3187 b.create<tensor::YieldOp>(result.location, constantPadValue);
3203 // Folds tensor.pad when padding is static zeros and the attribute
3214 rewriter.replaceOpWithNewOp<tensor::CastOp>(
3227 auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3228 if (!tensor::canFoldIntoConsumerOp(castOp))
3249 rewriter.replaceOpWithNewOp<tensor::CastOp>(
3266 dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3269 if (!tensor::preservesStaticInformation(padTensorOp.getResult().getType(),
3287 /// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
3290 /// 1) the tensor::ExtractSliceOps are not rank-reducing,
3291 /// 2) the tensor::ExtractSliceOps have only unit-strides,
3292 /// 3) the tensor::PadOps perform only high-padding,
3293 /// 4) the tensor::PadOps have the same constant padding value,
3294 /// 5) the tensor::PadOps do not have common padding dimensions,
3295 /// 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
3297 /// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for
3304 /// %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
3305 /// : tensor<64x64xf32> to tensor<?x64xf32>
3306 /// %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
3307 /// } : tensor<?x64xf32> to tensor<8x64xf32>
3308 /// %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
3309 /// : tensor<8x64xf32> to tensor<8x?xf32>
3310 /// %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
3311 /// } : tensor<8x?xf32> to tensor<8x4xf32>
3317 /// %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
3318 /// : tensor<64x64xf32> to tensor<?x?xf32>
3319 /// %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
3320 /// } : tensor<?x?xf32> to tensor<8x4xf32>
3344 // 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
3350 // 3) Fail if the tensor::PadOps have non-zero low padding.
3356 // 4) Fail if the tensor::PadOps padding values do not match.
3368 // 5) Fail if a dimension is padded by both tensor::PadOps.
3376 // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
3377 // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3379 // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3399 // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size
3400 // of the outer tensor::ExtractSliceOp for the dimensions padded by the
3401 // outer tensor::PadOp and fail if the size of the inner
3402 // tensor::ExtractSliceOp does not match the size of the padded dimension.
3403 // Otherwise, take the size of the inner tensor::ExtractSliceOp.
3420 // Combine the high paddings of the two tensor::PadOps.
3429 // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs
3538 rewriter.replaceOpWithNewOp<tensor::CastOp>(padTensorOp, oldResultType,
3545 /// Folds a chain of `tensor.pad` ops with the same constant padding value.
3550 /// %1 = tensor.pad %0 low[0, 1] high[0, 2] {
3551 /// tensor.yield %val
3552 /// } : tensor<1x2xf32> to tensor<2x5xf32>
3553 /// %res = tensor.pad %1 low[0, 2] high[3, 0] {
3554 /// tensor.yield %val
3555 /// } : tensor<1x5xf32> to tensor<5x7xf32>
3561 /// %res = tensor.pad %0 low[0, 3] high[3, 2] {
3562 /// tensor.yield %val
3563 /// } : tensor<1x2xf32> to tensor<5x7xf32>
3565 struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
3566 using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
3568 LogicalResult matchAndRewrite(tensor::PadOp padOp,
3574 auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3577 padOp, "producer is not a foldable tensor.pad op");
3580 // Fail if the tensor::PadOps padding values do not match.
3594 // Combine the low/high paddings of the two tensor::PadOps.
3611 auto newPadOp = rewriter.create<tensor::PadOp>(
3614 getPrunedAttributeList(padOp, tensor::PadOp::getAttributeNames()));
3764 // tensor dims than the dest dims. If this is not the case, the unique
3864 tensor::getMixedSizes(builder, op.getLoc(), op.getDest());
4314 return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4460 // Insert tensor.cast ops if static shape inference is available..
4468 rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
4476 rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
4487 rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
4499 static_assert(std::is_same<PackOrUnpackOp, tensor::PackOp>::value ||
4500 std::is_same<PackOrUnpackOp, tensor::UnPackOp>::value,
4648 return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4701 if (PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>()) {
4721 // Insert tensor.cast ops if static shape inference is available..
4728 source = rewriter.create<tensor::CastOp>(loc, newSrcType,
4735 rewriter.create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
4737 Value newOp = rewriter.create<tensor::UnPackOp>(
4740 rewriter.replaceOpWithNewOp<tensor::CastOp>(
4765 // 1. InsertSliceOp has its own logic about folding tensor.cast ops.
4772 // If no operand comes from a tensor::CastOp and can be folded then fail.
4777 auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
4792 auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
4840 /// Folds a tensor.cast op into a consuming tensor::PackOp op if the
4841 /// `tensor.cast` has source that is more static than the consuming op.
4845 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
4846 /// %2 = tensor.pack %1 ... : tensor<?x?xf32> ...
4852 /// %2 = tensor.pack %0 ... : tensor<8x16xf32> ...
4882 ? rewriter.create<tensor::CastOp>(
4892 /// Folds a tensor.cast op into a consuming tensor::UnPackOp op if the
4893 /// `tensor.cast` has source that is more static than the consuming op.
4897 /// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
4898 /// %2 = tensor.unpack %1 ... : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
4904 /// %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
4935 ? rewriter.create<tensor::CastOp>(
4945 /// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
4946 /// the `tensor.cast` has source that is more static than the consuming op.
4950 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
4951 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
4957 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
4969 // Reject tensor::PackOp - there's dedicated pattern for that instead.
4971 isa<tensor::PackOp, tensor::UnPackOp>(*op))
4985 replacements.push_back(rewriter.create<tensor::CastOp>(