xref: /llvm-project/mlir/test/Dialect/Transform/selective-targeting.mlir (revision 2c1c67674cb3beb4e091a9f446de5858631cf8ae)
1// RUN:  mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
2
3// CHECK-LABEL: func.func @matmul_tensors_1(
4func.func @matmul_tensors_1(
5  %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>,
6  %arg2: tensor<128x128xf32>)
7    -> tensor<128x128xf32> {
8  // This operation is marked for tiling only.
9  // CHECK-COUNT-3: scf.for
10  // CHECK-COUNT-3: tensor.extract_slice
11  // CHECK: linalg.matmul
12  // CHECK-SAME: -> tensor<4x4xf32>
13  %0 = linalg.matmul { test.attrA }
14                      ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
15                     outs(%arg2: tensor<128x128xf32>)
16    -> tensor<128x128xf32>
17  func.return %0 : tensor<128x128xf32>
18}
19
20func.func @matmul_tensors_2(
21  %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>,
22  %arg2: tensor<128x128xf32>)
23    -> tensor<128x128xf32> {
24  // This operation is marked f
25  // This operation is marked for tiling and vectorization.
26  // CHECK-COUNT-3: scf.for
27  // CHECK-COUNT-3: vector.transfer_read
28  // CHECK:       vector.contract
29  // CHECK-NOT:   linalg.matmul
30  // CHECK:       vector.transfer_write
31  %0 = linalg.matmul { test.attrA, test.attrC }
32                      ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
33                     outs(%arg2: tensor<128x128xf32>)
34    -> tensor<128x128xf32>
35  func.return %0 : tensor<128x128xf32>
36}
37
38func.func @matmul_tensors_3(
39  %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>,
40  %arg2: tensor<128x128xf32>)
41    -> tensor<128x128xf32> {
42  // This operation is marked for vectorization only.
43  // CHECK-NOT: scf.for
44  // CHECK-COUNT-3: vector.transfer_read
45  // CHECK: vector.contract
46  // CHECK-SAME: into vector<128x128xf32>
47  // CHECK: vector.transfer_write
48  %0 = linalg.matmul { test.attrC }
49                      ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
50                     outs(%arg2: tensor<128x128xf32>)
51    -> tensor<128x128xf32>
52  func.return %0 : tensor<128x128xf32>
53}
54
55module attributes {transform.with_named_sequence} {
56  transform.named_sequence @__transform_main(%root : !transform.any_op) {
57    transform.with_pdl_patterns %root : !transform.any_op {
58    ^bb0(%arg0: !transform.any_op):
59      // Match matmul operations inside @matmul_tensors with test.attrA set.
60      pdl.pattern @pdl_target_attrA : benefit(1) {
61        %args = operands
62        %results = types
63        %attr = attribute
64        %0 = operation "linalg.matmul"(%args : !pdl.range<value>) {"test.attrA" = %attr}-> (%results : !pdl.range<type>)
65        // TODO: we don't want this, but it is the required terminator for pdl.pattern
66        rewrite %0 with "transform.dialect"
67      }
68
69      // Match matmul operations inside @matmul_tensors with test.attrC set.
70      pdl.pattern @pdl_target_attrC : benefit(1) {
71        %args = operands
72        %results = types
73        %attr = attribute
74        %0 = operation "linalg.matmul"(%args : !pdl.range<value>) {"test.attrC" = %attr}-> (%results : !pdl.range<type>)
75        // TODO: we don't want this, but it is the required terminator for pdl.pattern
76        rewrite %0 with "transform.dialect"
77      }
78
79      transform.sequence %arg0 : !transform.any_op failures(propagate) {
80      ^bb1(%arg1: !transform.any_op):
81        %0 = pdl_match @pdl_target_attrA in %arg1 : (!transform.any_op) -> !transform.any_op
82        transform.structured.tile_using_for %0 tile_sizes [4, 4, 4] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
83        %1 = pdl_match @pdl_target_attrC in %arg1 : (!transform.any_op) -> !transform.any_op
84        %2 = get_parent_op %1 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
85        transform.structured.vectorize_children_and_apply_patterns %2 : (!transform.any_op) -> !transform.any_op
86      }
87    }
88    transform.yield
89  }
90}
91
92// -----
93
94// CHECK-LABEL: @vectorize_one
95func.func @vectorize_one(
96  %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>,
97  %arg2: tensor<128x128xf32>)
98    -> tensor<128x128xf32> {
99  // CHECK: vector.contract
100  %0 = linalg.matmul {test.attrA}
101                     ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
102                     outs(%arg2: tensor<128x128xf32>)
103    -> tensor<128x128xf32>
104  func.return %0 : tensor<128x128xf32>
105}
106
107func.func @vectorize_none(
108  %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>,
109  %arg2: tensor<128x128xf32>)
110    -> tensor<128x128xf32> {
111  // CHECK: linalg.matmul
112  %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
113                     outs(%arg2: tensor<128x128xf32>)
114    -> tensor<128x128xf32>
115  func.return %0 : tensor<128x128xf32>
116}
117
118module attributes {transform.with_named_sequence} {
119  transform.named_sequence @__transform_main(%root : !transform.any_op) {
120    transform.with_pdl_patterns %root : !transform.any_op {
121    ^bb0(%arg0: !transform.any_op):
122      pdl.pattern @pdl_target : benefit(1) {
123        %args = operands
124        %results = types
125        %attr = attribute
126        %0 = operation "linalg.matmul"(%args : !pdl.range<value>) {"test.attrA" = %attr}-> (%results : !pdl.range<type>)
127        // TODO: we don't want this, but it is the required terminator for pdl.pattern
128        rewrite %0 with "transform.dialect"
129      }
130
131      transform.sequence %arg0 : !transform.any_op failures(propagate) {
132      ^bb1(%arg1: !transform.any_op):
133        %0 = pdl_match @pdl_target in %arg1 : (!transform.any_op) -> !transform.any_op
134        %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
135        transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
136      }
137    }
138    transform.yield
139  }
140}
141
142// -----
143
144// CHECK-LABEL: @vectorize_all
145func.func @vectorize_all(
146  %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>,
147  %arg3: tensor<128x128xf32>)
148    -> tensor<128x128xf32> {
149  // CHECK: vector.contract
150  %0 = linalg.matmul {test.attrA}
151                     ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
152                     outs(%arg2: tensor<128x128xf32>)
153    -> tensor<128x128xf32>
154  // CHECK: vector.contract
155  %1 = linalg.matmul ins(%arg0, %0: tensor<128x128xf32>, tensor<128x128xf32>)
156                     outs(%arg3: tensor<128x128xf32>)
157    -> tensor<128x128xf32>
158  return %1 : tensor<128x128xf32>
159}
160
161module attributes {transform.with_named_sequence} {
162  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
163    transform.structured.vectorize_children_and_apply_patterns %arg0 : (!transform.any_op) -> !transform.any_op
164    transform.yield
165  }
166}
167