1// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s 2 3// CHECK-LABEL: func.func @matmul_tensors_1( 4func.func @matmul_tensors_1( 5 %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, 6 %arg2: tensor<128x128xf32>) 7 -> tensor<128x128xf32> { 8 // This operation is marked for tiling only. 9 // CHECK-COUNT-3: scf.for 10 // CHECK-COUNT-3: tensor.extract_slice 11 // CHECK: linalg.matmul 12 // CHECK-SAME: -> tensor<4x4xf32> 13 %0 = linalg.matmul { test.attrA } 14 ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) 15 outs(%arg2: tensor<128x128xf32>) 16 -> tensor<128x128xf32> 17 func.return %0 : tensor<128x128xf32> 18} 19 20func.func @matmul_tensors_2( 21 %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, 22 %arg2: tensor<128x128xf32>) 23 -> tensor<128x128xf32> { 24 // This operation is marked f 25 // This operation is marked for tiling and vectorization. 26 // CHECK-COUNT-3: scf.for 27 // CHECK-COUNT-3: vector.transfer_read 28 // CHECK: vector.contract 29 // CHECK-NOT: linalg.matmul 30 // CHECK: vector.transfer_write 31 %0 = linalg.matmul { test.attrA, test.attrC } 32 ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) 33 outs(%arg2: tensor<128x128xf32>) 34 -> tensor<128x128xf32> 35 func.return %0 : tensor<128x128xf32> 36} 37 38func.func @matmul_tensors_3( 39 %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, 40 %arg2: tensor<128x128xf32>) 41 -> tensor<128x128xf32> { 42 // This operation is marked for vectorization only. 43 // CHECK-NOT: scf.for 44 // CHECK-COUNT-3: vector.transfer_read 45 // CHECK: vector.contract 46 // CHECK-SAME: into vector<128x128xf32> 47 // CHECK: vector.transfer_write 48 %0 = linalg.matmul { test.attrC } 49 ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) 50 outs(%arg2: tensor<128x128xf32>) 51 -> tensor<128x128xf32> 52 func.return %0 : tensor<128x128xf32> 53} 54 55module attributes {transform.with_named_sequence} { 56 transform.named_sequence @__transform_main(%root : !transform.any_op) { 57 transform.with_pdl_patterns %root : !transform.any_op { 58 ^bb0(%arg0: !transform.any_op): 59 // Match matmul operations inside @matmul_tensors with test.attrA set. 60 pdl.pattern @pdl_target_attrA : benefit(1) { 61 %args = operands 62 %results = types 63 %attr = attribute 64 %0 = operation "linalg.matmul"(%args : !pdl.range<value>) {"test.attrA" = %attr}-> (%results : !pdl.range<type>) 65 // TODO: we don't want this, but it is the required terminator for pdl.pattern 66 rewrite %0 with "transform.dialect" 67 } 68 69 // Match matmul operations inside @matmul_tensors with test.attrC set. 70 pdl.pattern @pdl_target_attrC : benefit(1) { 71 %args = operands 72 %results = types 73 %attr = attribute 74 %0 = operation "linalg.matmul"(%args : !pdl.range<value>) {"test.attrC" = %attr}-> (%results : !pdl.range<type>) 75 // TODO: we don't want this, but it is the required terminator for pdl.pattern 76 rewrite %0 with "transform.dialect" 77 } 78 79 transform.sequence %arg0 : !transform.any_op failures(propagate) { 80 ^bb1(%arg1: !transform.any_op): 81 %0 = pdl_match @pdl_target_attrA in %arg1 : (!transform.any_op) -> !transform.any_op 82 transform.structured.tile_using_for %0 tile_sizes [4, 4, 4] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) 83 %1 = pdl_match @pdl_target_attrC in %arg1 : (!transform.any_op) -> !transform.any_op 84 %2 = get_parent_op %1 {isolated_from_above} : (!transform.any_op) -> !transform.any_op 85 transform.structured.vectorize_children_and_apply_patterns %2 : (!transform.any_op) -> !transform.any_op 86 } 87 } 88 transform.yield 89 } 90} 91 92// ----- 93 94// CHECK-LABEL: @vectorize_one 95func.func @vectorize_one( 96 %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, 97 %arg2: tensor<128x128xf32>) 98 -> tensor<128x128xf32> { 99 // CHECK: vector.contract 100 %0 = linalg.matmul {test.attrA} 101 ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) 102 outs(%arg2: tensor<128x128xf32>) 103 -> tensor<128x128xf32> 104 func.return %0 : tensor<128x128xf32> 105} 106 107func.func @vectorize_none( 108 %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, 109 %arg2: tensor<128x128xf32>) 110 -> tensor<128x128xf32> { 111 // CHECK: linalg.matmul 112 %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) 113 outs(%arg2: tensor<128x128xf32>) 114 -> tensor<128x128xf32> 115 func.return %0 : tensor<128x128xf32> 116} 117 118module attributes {transform.with_named_sequence} { 119 transform.named_sequence @__transform_main(%root : !transform.any_op) { 120 transform.with_pdl_patterns %root : !transform.any_op { 121 ^bb0(%arg0: !transform.any_op): 122 pdl.pattern @pdl_target : benefit(1) { 123 %args = operands 124 %results = types 125 %attr = attribute 126 %0 = operation "linalg.matmul"(%args : !pdl.range<value>) {"test.attrA" = %attr}-> (%results : !pdl.range<type>) 127 // TODO: we don't want this, but it is the required terminator for pdl.pattern 128 rewrite %0 with "transform.dialect" 129 } 130 131 transform.sequence %arg0 : !transform.any_op failures(propagate) { 132 ^bb1(%arg1: !transform.any_op): 133 %0 = pdl_match @pdl_target in %arg1 : (!transform.any_op) -> !transform.any_op 134 %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op 135 transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op 136 } 137 } 138 transform.yield 139 } 140} 141 142// ----- 143 144// CHECK-LABEL: @vectorize_all 145func.func @vectorize_all( 146 %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>, 147 %arg3: tensor<128x128xf32>) 148 -> tensor<128x128xf32> { 149 // CHECK: vector.contract 150 %0 = linalg.matmul {test.attrA} 151 ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) 152 outs(%arg2: tensor<128x128xf32>) 153 -> tensor<128x128xf32> 154 // CHECK: vector.contract 155 %1 = linalg.matmul ins(%arg0, %0: tensor<128x128xf32>, tensor<128x128xf32>) 156 outs(%arg3: tensor<128x128xf32>) 157 -> tensor<128x128xf32> 158 return %1 : tensor<128x128xf32> 159} 160 161module attributes {transform.with_named_sequence} { 162 transform.named_sequence @__transform_main(%arg0: !transform.any_op) { 163 transform.structured.vectorize_children_and_apply_patterns %arg0 : (!transform.any_op) -> !transform.any_op 164 transform.yield 165 } 166} 167