Lines Matching defs:tensor

1 //===- SparseTensorRewriting.cpp - Sparse tensor rewriting rules ----------===//
47 // Helper to detect a sparse tensor type operand.
55 // Helper method to find zero/uninitialized tensor materialization.
65 // Check for empty tensor materialization.
66 if (auto empty = val.getDefiningOp<tensor::EmptyOp>())
124 /// the tensor (for dynamic sizes).
126 Location loc, ShapedType stp, Value tensor) {
130 dim = builder.create<tensor::DimOp>(loc, tensor, d.index());
207 // The actual sparse tensor rewriting rules.
212 /// TODO: move it to tensor dialect instead.
214 /// Fold `tensor.concat` and `tensor.extract_slice`
216 /// %concat = tensor.concat dim(2) %t0, %t1
217 /// : (tensor<1x64x1xf32>, tensor<1x64x1xf32>) -> tensor<1x64x2xf32>
218 /// %extracted0 = tensor.extract_slice %concat[0, 0, 0][1, 64, 1][1, 1, 1]
219 /// : tensor<1x64x2xf32> to tensor<1x64x1xf32>
220 /// %extracted1 = tensor.extract_slice %concat[0, 0, 1][1, 64, 1][1, 1, 1]
221 /// : tensor<1x64x2xf32> to tensor<1x64x1xf32>
227 : public OpRewritePattern<tensor::ExtractSliceOp> {
228 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
230 LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp,
232 auto concatOp = extractOp.getSource().getDefiningOp<tensor::ConcatOp>();
252 rewriter.createOrFold<tensor::DimOp>(loc, input, dim));
272 tensor::getMixedSizes(rewriter, loc, input);
335 // Yielding zero on newly materialized sparse tensor can be
363 /// into the reduction loop. However, for sparse sampling tensor S, such
435 // TODO: deal with non alloc tensor here one day
458 // Fuse a tensor cast into producing operation. Note that a tensor.cast
462 // TODO: audit the pure tensor dialect rewriting rules
463 struct FuseTensorCast : public OpRewritePattern<tensor::CastOp> {
465 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
467 LogicalResult matchAndRewrite(tensor::CastOp op,
477 if (tensor::isSameTypeWithoutEncoding(srcType, dstType)) {
479 if (def->hasOneUse() && isa<tensor::ExtractSliceOp>(def)) {
488 // Repair tensor casts with at least one sparse operand into the
499 /// Rewrites a sequence of operations for sparse tensor selections in to
512 /// TODO: We require that the tensor used for extracting conditions to be dense
513 /// to sparsify the code. To support a sparse condition tensor, we need a
593 // are directly loaded the input tensor. We can probably admit more cases
608 // If the condition value is load directly from a dense tensor or
677 rewriter.create<tensor::ExtractOp>(loc, init->get(), ValueRange());
723 auto tensor = op.getTensor();
724 auto stt = getSparseTensorType(tensor);
726 auto nse = rewriter.create<NumberOfEntriesOp>(loc, tensor);
732 printSizes(rewriter, loc, tensor, stt.getDimRank(), /*isDim=*/true);
734 printSizes(rewriter, loc, tensor, stt.getLvlRank(), /*isDim=*/false);
736 // all typical sparse tensor components for printing.
737 foreachFieldAndTypeInSparseTensor(stt, [&rewriter, &loc, &tensor,
751 auto pos = rewriter.create<ToPositionsOp>(loc, tensor, l);
766 crd = rewriter.create<ToCoordinatesBufferOp>(loc, tensor);
768 crd = rewriter.create<ToCoordinatesOp>(loc, tensor, l);
775 auto val = rewriter.create<ToValuesOp>(loc, tensor);
852 static void printSizes(PatternRewriter &rewriter, Location loc, Value tensor,
861 val = rewriter.create<tensor::DimOp>(loc, tensor, idx);
863 val = rewriter.create<LvlOp>(loc, tensor, idx);
876 struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
878 using OpRewritePattern<tensor::ReshapeOp>::OpRewritePattern;
880 LogicalResult matchAndRewrite(tensor::ReshapeOp op,
912 // and then expand it to the match the rank of the destination tensor.
958 builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs);
1043 builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs);
1091 if constexpr (std::is_same<ReshapeOp, tensor::ExpandShapeOp>::value) {
1123 val = builder.create<tensor::InsertOp>(loc, v, val, crds);
1139 struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
1141 LogicalResult matchAndRewrite(tensor::DimOp op,
1202 // %tmp = memref.alloc : dense tensor
1213 // Builds a for op for each input tensor to append new values into the
1214 // output tensor.
1236 // Exits the ifOp, update the sparse tensor SSA value.
1274 // Trivial tensor conversion and simple element type conversion is handled
1319 // Exits the ifOp, update the sparse tensor SSA value.
1413 // Loads the value from sparse tensor using position-index;
1414 // loads the value from dense tensor using coords.
1488 // Release the temporary ordered COO tensor.
1521 // Create a sparse tensor writer and output meta data.
1538 // For each element in the source tensor, output the element.
1580 patterns.add<ConcatenateRewriter, ReshapeRewriter<tensor::ExpandShapeOp>,
1581 ReshapeRewriter<tensor::CollapseShapeOp>,
1582 Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
1583 Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,