xref: /llvm-project/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir (revision 13bd41096286305ee603428f6adf161f52981827)
1// RUN: mlir-opt -pass-pipeline="builtin.module(func.func(convert-shape-constraints))" <%s | FileCheck %s
2
3// There's not very much useful to check here other than pasting the output.
4// CHECK-LABEL:   func @cstr_broadcastable(
5// CHECK-SAME:                             %[[LHS:.*]]: tensor<?xindex>,
6// CHECK-SAME:                             %[[RHS:.*]]: tensor<?xindex>) -> !shape.witness {
7// CHECK:           %[[RET:.*]] = shape.const_witness true
8// CHECK:           %[[BROADCAST_IS_VALID:.*]] = shape.is_broadcastable %[[LHS]], %[[RHS]]
9// CHECK:           cf.assert %[[BROADCAST_IS_VALID]], "required broadcastable shapes"
10// CHECK:           return %[[RET]] : !shape.witness
11// CHECK:         }
12func.func @cstr_broadcastable(%arg0: tensor<?xindex>, %arg1: tensor<?xindex>) -> !shape.witness {
13  %witness = shape.cstr_broadcastable %arg0, %arg1 : tensor<?xindex>, tensor<?xindex>
14  return %witness : !shape.witness
15}
16
17// CHECK-LABEL:   func @cstr_eq(
18// CHECK-SAME:                             %[[LHS:.*]]: tensor<?xindex>,
19// CHECK-SAME:                             %[[RHS:.*]]: tensor<?xindex>) -> !shape.witness {
20// CHECK:           %[[RET:.*]] = shape.const_witness true
21// CHECK:           %[[EQUAL_IS_VALID:.*]] = shape.shape_eq %[[LHS]], %[[RHS]]
22// CHECK:           cf.assert %[[EQUAL_IS_VALID]], "required equal shapes"
23// CHECK:           return %[[RET]] : !shape.witness
24// CHECK:         }
25func.func @cstr_eq(%arg0: tensor<?xindex>, %arg1: tensor<?xindex>) -> !shape.witness {
26  %witness = shape.cstr_eq %arg0, %arg1 : tensor<?xindex>, tensor<?xindex>
27  return %witness : !shape.witness
28}
29
30// CHECK-LABEL: func @cstr_require
31func.func @cstr_require(%arg0: i1) -> !shape.witness {
32  // CHECK: %[[RET:.*]] = shape.const_witness true
33  // CHECK: cf.assert %arg0, "msg"
34  // CHECK: return %[[RET]]
35  %witness = shape.cstr_require %arg0, "msg"
36  return %witness : !shape.witness
37}
38