xref: /llvm-project/mlir/test/Dialect/SCF/transform-op-take-assumed-branch.mlir (revision e4384149b58f7c3d19c5d38bc46038c660b77ca9)
1// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics --allow-unregistered-dialect | FileCheck %s
2
3func.func @if_no_else(%cond: i1, %a: index, %b: memref<?xf32>, %c: i8) {
4  scf.if %cond {
5    "some_op"(%cond, %b) : (i1, memref<?xf32>) -> ()
6    scf.yield
7  }
8  return
9}
10
11module attributes {transform.with_named_sequence} {
12  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
13    %if = transform.structured.match ops{["scf.if"]} in %arg1
14      : (!transform.any_op) -> !transform.any_op
15    // expected-error @+1 {{requires an scf.if op with a single-block `else` region}}
16    transform.scf.take_assumed_branch %if take_else_branch
17      : (!transform.any_op) -> ()
18      transform.yield
19  }
20}
21
22// -----
23
24// CHECK-LABEL: if_no_else
25func.func @if_no_else(%cond: i1, %a: index, %b: memref<?xf32>, %c: i8) {
26  scf.if %cond {
27    "some_op"(%cond, %b) : (i1, memref<?xf32>) -> ()
28    scf.yield
29  }
30  return
31}
32
33module attributes {transform.with_named_sequence} {
34  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
35    %if = transform.structured.match ops{["scf.if"]} in %arg1
36      : (!transform.any_op) -> !transform.any_op
37    %some_op = transform.structured.match ops{["some_op"]} in %arg1
38      : (!transform.any_op) -> !transform.any_op
39
40    transform.scf.take_assumed_branch %if : (!transform.any_op) -> ()
41
42    // Handle to formerly nested `some_op` is still valid after the transform.
43    transform.print %some_op: !transform.any_op
44    transform.yield
45  }
46}
47
48// -----
49
50// CHECK-LABEL: tile_tensor_pad
51func.func @tile_tensor_pad(
52  %arg0 : tensor<?x?xf32>, %cst : f32, %low: index, %high: index)
53    -> tensor<20x40xf32>
54{
55  //     CHECK: scf.forall
56  // CHECK-NOT:   scf.if
57  // CHECK-NOT:     tensor.generate
58  // CHECK-NOT:   else
59  //     CHECK:     tensor.pad {{.*}} nofold
60  %0 = tensor.pad %arg0 nofold low[%low, %low] high[%high, %high] {
61        ^bb0(%arg9: index, %arg10: index):
62          tensor.yield %cst : f32
63  } : tensor<?x?xf32> to tensor<20x40xf32>
64  return %0 : tensor<20x40xf32>
65}
66
67module attributes {transform.with_named_sequence} {
68  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
69    %0 = transform.structured.match ops{["tensor.pad"]} in %arg1
70      : (!transform.any_op) -> !transform.any_op
71    transform.structured.tile_using_forall %0 tile_sizes[1, 1]
72      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
73
74    %if = transform.structured.match ops{["scf.if"]} in %arg1
75      : (!transform.any_op) -> !transform.any_op
76    transform.scf.take_assumed_branch %if take_else_branch
77      : (!transform.any_op) -> ()
78      transform.yield
79  }
80}
81