xref: /llvm-project/mlir/test/Dialect/Linalg/transform-op-insert-slice-to-copy.mlir (revision e4384149b58f7c3d19c5d38bc46038c660b77ca9)
1// RUN: mlir-opt -transform-interpreter %s --split-input-file --allow-unregistered-dialect | FileCheck %s
2
3// CHECK-LABEL: func @insert_slice_to_copy
4    // CHECK-SAME: %[[I:.*]]: tensor<2x3xf32>
5    // CHECK-SAME: %[[O:.*]]: tensor<?x?xf32>,
6    // CHECK-SAME: %[[OFF0:[0-9a-zA-Z]+]]: index,
7    // CHECK-SAME: %[[OFF1:[0-9a-zA-Z]+]]: index,
8    // CHECK-SAME: %[[SZ0:[0-9a-zA-Z]+]]: index,
9    // CHECK-SAME: %[[SZ1:[0-9a-zA-Z]+]]: index,
10    // CHECK-SAME: %[[ST0:[0-9a-zA-Z]+]]: index,
11    // CHECK-SAME: %[[ST1:[0-9a-zA-Z]+]]: index)
12func.func @insert_slice_to_copy(
13    %I : tensor<2x3xf32>, %O : tensor<?x?xf32>,
14    %off0 : index, %off1 : index,
15    %sz0 : index, %sz1 : index,
16    %st0 : index, %st1 : index) -> tensor<?x?xf32> {
17
18  //      CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[O]][%[[OFF0]], %[[OFF1]]] [2, 3] [%[[ST0]], %[[ST1]]]
19  // CHECK-SAME:   : tensor<?x?xf32> to tensor<2x3xf32>
20  //      CHECK: linalg.copy ins(%[[I]] : tensor<2x3xf32>) outs(%[[EXTRACTED_SLICE]] : tensor<2x3xf32>) -> tensor<2x3xf32>
21  //      CHECK: tensor.insert_slice %{{.*}} into %[[O]][%[[OFF0]], %[[OFF1]]] [2, 3] [%[[ST0]], %[[ST1]]]
22  // CHECK-SAME:   : tensor<2x3xf32> into tensor<?x?xf32>
23
24  %0 = tensor.insert_slice %I into %O[%off0, %off1] [2, 3] [%st0, %st1]
25    : tensor<2x3xf32> into tensor<?x?xf32>
26  return %0 : tensor<?x?xf32>
27}
28
29module attributes {transform.with_named_sequence} {
30  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
31    %0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op
32    %1 = transform.structured.insert_slice_to_copy %0 : (!transform.any_op) -> !transform.any_op
33    transform.cast %1 : !transform.any_op to !transform.op<"linalg.copy">
34    transform.yield
35  }
36}
37
38// -----
39
40// CHECK-LABEL: func @insert_slice_to_copy
41    // CHECK-SAME: %[[I:[0-9a-zA-Z]+]]: tensor<?x?xf32>
42    // CHECK-SAME: %[[O:[0-9a-zA-Z]+]]: tensor<?x?xf32>,
43    // CHECK-SAME: %[[OFF0:[0-9a-zA-Z]+]]: index,
44    // CHECK-SAME: %[[OFF1:[0-9a-zA-Z]+]]: index,
45    // CHECK-SAME: %[[SZ0:[0-9a-zA-Z]+]]: index,
46    // CHECK-SAME: %[[SZ1:[0-9a-zA-Z]+]]: index,
47    // CHECK-SAME: %[[ST0:[0-9a-zA-Z]+]]: index,
48    // CHECK-SAME: %[[ST1:[0-9a-zA-Z]+]]: index)
49func.func @insert_slice_to_copy(
50    %I : tensor<?x?xf32>, %O : tensor<?x?xf32>,
51    %off0 : index, %off1 : index,
52    %sz0 : index, %sz1 : index,
53    %st0 : index, %st1 : index) -> tensor<?x?xf32> {
54
55  //      CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[O]][%[[OFF0]], %[[OFF1]]] [%[[SZ0]], %[[SZ1]]] [1, 1]
56  // CHECK-SAME:   : tensor<?x?xf32> to tensor<?x?xf32>
57  //      CHECK: linalg.copy ins(%[[I]] : tensor<?x?xf32>) outs(%[[EXTRACTED_SLICE]] : tensor<?x?xf32>) -> tensor<?x?xf32>
58  //      CHECK: tensor.insert_slice %{{.*}} into %[[O]][%[[OFF0]], %[[OFF1]]] [%[[SZ0]], %[[SZ1]]] [1, 1]
59  // CHECK-SAME:   : tensor<?x?xf32> into tensor<?x?xf32>
60
61  %0 = tensor.insert_slice %I into %O[%off0, %off1] [%sz0, %sz1] [1, 1]
62    : tensor<?x?xf32> into tensor<?x?xf32>
63  return %0 : tensor<?x?xf32>
64}
65
66module attributes {transform.with_named_sequence} {
67  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
68    %0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op
69    %1 = transform.structured.insert_slice_to_copy %0 : (!transform.any_op) -> !transform.any_op
70    transform.cast %1 : !transform.any_op to !transform.op<"linalg.copy">
71    transform.yield
72  }
73}
74
75// -----
76// CHECK-LABEL: func @insert_slice_to_copy
77    // CHECK-SAME: %[[I:.*]]: tensor<2x3xf32>
78    // CHECK-SAME: %[[O:.*]]: tensor<?x?xf32>,
79    // CHECK-SAME: %[[OFF0:[0-9a-zA-Z]+]]: index,
80    // CHECK-SAME: %[[OFF1:[0-9a-zA-Z]+]]: index,
81    // CHECK-SAME: %[[SZ0:[0-9a-zA-Z]+]]: index,
82    // CHECK-SAME: %[[SZ1:[0-9a-zA-Z]+]]: index,
83    // CHECK-SAME: %[[ST0:[0-9a-zA-Z]+]]: index,
84    // CHECK-SAME: %[[ST1:[0-9a-zA-Z]+]]: index)
85func.func @insert_slice_to_copy(
86    %I : tensor<2x3xf32>, %O : tensor<?x?xf32>,
87    %off0 : index, %off1 : index,
88    %sz0 : index, %sz1 : index,
89    %st0 : index, %st1 : index) -> tensor<?x?xf32> {
90
91  //      CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[O]][%[[OFF0]], %[[OFF1]]] [2, 3] [%[[ST0]], %[[ST1]]]
92  // CHECK-SAME:   : tensor<?x?xf32> to tensor<2x3xf32>
93  //      CHECK: linalg.copy ins(%[[I]] : tensor<2x3xf32>) outs(%[[EXTRACTED_SLICE]] : tensor<2x3xf32>) -> tensor<2x3xf32>
94  //  CHECK-NOT: linalg.copy
95  //      CHECK: tensor.insert_slice %{{.*}} into %[[O]][%[[OFF0]], %[[OFF1]]] [2, 3] [%[[ST0]], %[[ST1]]]
96  // CHECK-SAME:   : tensor<2x3xf32> into tensor<?x?xf32>
97
98  %extracted_slice = tensor.extract_slice %O[%off0, %off1] [2, 3] [%st0, %st1]
99    : tensor<?x?xf32> to tensor<2x3xf32>
100  %0 = linalg.copy ins(%I : tensor<2x3xf32>) outs(%extracted_slice
101    : tensor<2x3xf32>) -> tensor<2x3xf32>
102  %inserted_slice = tensor.insert_slice %0 into %O[%off0, %off1] [2, 3] [%st0, %st1]
103    : tensor<2x3xf32> into tensor<?x?xf32>
104
105  return %inserted_slice : tensor<?x?xf32>
106}
107
108module attributes {transform.with_named_sequence} {
109  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
110    %0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op
111    %1 = transform.structured.insert_slice_to_copy %0 : (!transform.any_op) -> !transform.any_op
112    transform.cast %1 : !transform.any_op to !transform.op<"linalg.copy">
113    transform.yield
114  }
115}
116
117// -----
118
119// CHECK-LABEL: func @parallel_insert_slice_to_copy
120func.func @parallel_insert_slice_to_copy(%out : tensor<?x?xf32>, %sz0: index, %sz1: index) {
121  %0 = scf.forall (%arg0, %arg1) in (27, 8) shared_outs(%arg2 = %out) -> (tensor<?x?xf32>) {
122    %t = "make_me_a_tensor"() : () -> (tensor<?x?xf32> )
123
124    //      CHECK: tensor.extract_slice
125    //      CHECK: linalg.copy
126    //      CHECK: scf.forall.in_parallel
127    //      CHECK:   tensor.parallel_insert_slice
128    scf.forall.in_parallel {
129      tensor.parallel_insert_slice %t into %arg2[0, 0] [%sz0, %sz1] [1, 1]
130        : tensor<?x?xf32> into tensor<?x?xf32>
131    }
132  }
133  return
134}
135
136module attributes {transform.with_named_sequence} {
137  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
138    %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
139      : (!transform.any_op) -> !transform.any_op
140    %1 = transform.structured.insert_slice_to_copy %0
141      : (!transform.any_op) -> !transform.any_op
142    transform.cast %1 : !transform.any_op to !transform.op<"linalg.copy">
143    transform.yield
144  }
145}
146