xref: /llvm-project/mlir/include/mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.td (revision 5a9bdd85ee4d8527e2cedf44f3ce26ff414f9b6a)
1//===- LoopExtensionOps.td - Transform dialect operations --*- tablegen -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_TRANSFORM_LOOPEXTENSION_LOOPEXTENSIONOPS
10#define MLIR_DIALECT_TRANSFORM_LOOPEXTENSION_LOOPEXTENSIONOPS
11
12include "mlir/Dialect/Transform/IR/TransformDialect.td"
13include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
14include "mlir/Interfaces/SideEffectInterfaces.td"
15
16def HoistLoopInvariantSubsetsOp
17    : TransformDialectOp<"loop.hoist_loop_invariant_subsets",
18        [TransformOpInterface, TransformEachOpTrait,
19         DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
20         ReportTrackingListenerFailuresOpTrait]> {
21  let summary = "Hoist loop invariant subset ops";
22  let description = [{
23    This transform hoists loop-invariant subset ops out of the targeted
24    loop-like op. It looks for matching subset extraction/insertion op pairs and
25    hoists them. The loop body operates on a newly introduced region iter_arg.
26
27    Subset ops are hoisted only from the targeted op. If subset ops should be
28    hoisted from an entire loop nest, this transformation must be applied to
29    each loop-like op of the loop nest, starting with the innermost loop and
30    ending with the outermost loop.
31
32    Example:
33    ```
34    %r = scf.for ... iter_args(%t = %a) -> (tensor<?xf32>) {
35      %0 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
36      %1 = "test.foo"(%0) : (tensor<5xf32>) -> (tensor<5xf32>)
37      %2 = tensor.insert_slice %1 into %t[0][5][1]
38          : tensor<5xf32> into tensor<?xf32>
39      scf.yield %2 : tensor<?xf32>
40    }
41    ```
42    Is transformed to:
43    ```
44    %0 = tensor.extract_slice %a[0][5][1] : tensor<?xf32> to tensor<5xf32>
45    %new_loop:2 = scf.for ... iter_args(%t = %a, %h = %0) -> (tensor<?xf32>) {
46      %1 = "test.foo"(%h) : (tensor<5xf32>) -> (tensor<5xf32>)
47      scf.yield %t, %2 : tensor<?xf32>, tensor<5xf32>
48    }
49    %r = tensor.insert_slice %new_loop#1 into %new_loop#0
50        : tensor<5xf32> into tensor<?xf32>
51    ```
52
53    Subset ops are hoisted only if there are no conflicting subset ops. E.g.,
54    if there were a second overlapping extraction in the above example, no ops
55    could be hoisted safely.
56
57    This transform reads the target handle and modifies the payload. This
58    transform does not invalidate any handles, but loop-like ops are replaced
59    with new loop-like ops when a subset op is hoisted. The transform rewriter
60    updates all handles accordingly.
61  }];
62
63  let arguments = (ins TransformHandleTypeInterface:$target);
64  let results = (outs);
65  let assemblyFormat = "$target attr-dict `:` type($target)";
66
67  let extraClassDeclaration = [{
68    ::mlir::DiagnosedSilenceableFailure applyToOne(
69      ::mlir::transform::TransformRewriter &rewriter,
70      ::mlir::LoopLikeOpInterface loopLikeOp,
71      ::mlir::transform::ApplyToEachResultList &results,
72      ::mlir::transform::TransformState &state);
73  }];
74}
75
76#endif // MLIR_DIALECT_TRANSFORM_LOOPEXTENSION_LOOPEXTENSIONOPS
77