1# RUN: %PYTHON -m mlir.dialects.linalg.opdsl.dump_oplib --file %s | FileCheck %s 2 3from mlir.dialects.linalg.opdsl.lang import * 4 5 6# CHECK: --- 7# CHECK-LABEL: matmul 8# CHECK: assignments: 9# CHECK: - 10# CHECK: arg: C 11# CHECK: value: 12# CHECK: scalar_fn: 13# CHECK: kind: binary 14# CHECK: fn_name: add 15# CHECK: operands: 16# CHECK: scalar_fn: 17# CHECK: kind: binary 18# CHECK: attr_name: mul 19# CHECK: operands: 20# CHECK: scalar_fn: 21# CHECK: kind: type 22# CHECK: attr_name: cast 23# CHECK: type_var: U 24# CHECK: operands: 25# CHECK: scalar_arg: A 26# CHECK: scalar_fn: 27# CHECK: kind: type 28# CHECK: attr_name: cast 29# CHECK: type_var: U 30# CHECK: operands: 31# CHECK: scalar_arg: B 32@linalg_structured_op 33def matmul( 34 A=TensorDef(T, S.M, S.K), 35 B=TensorDef(T, S.K, S.N), 36 C=TensorDef(U, S.M, S.N, output=True), 37 mul=BinaryFnAttrDef(default=BinaryFn.mul), 38 cast=TypeFnAttrDef(default=TypeFn.cast_signed), 39): 40 C[D.m, D.n] += mul(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n])) 41 42 43# CHECK: --- 44# CHECK-LABEL: constants 45# CHECK: assignments: 46# CHECK: - 47# CHECK: arg: O 48# CHECK: scalar_fn: 49# CHECK: kind: binary 50# CHECK: fn_name: sub 51# CHECK: operands: 52# CHECK: scalar_fn: 53# CHECK: kind: binary 54# CHECK: fn_name: add 55# CHECK: operands: 56# CHECK: scalar_fn: 57# CHECK: kind: unary 58# CHECK: fn_name: exp 59# CHECK: operands: 60# CHECK: scalar_fn: 61# CHECK: kind: type 62# CHECK: type_var: T 63# CHECK: operands: 64# CHECK: scalar_const: '3.1415926535897931 : f64' 65# CHECK: scalar_fn: 66# CHECK: kind: type 67# CHECK: fn_name: cast_signed 68# CHECK: type_var: T 69# CHECK: operands: 70# CHECK: scalar_const: '42 : i64' 71# CHECK: scalar_fn: 72# CHECK: kind: type 73# CHECK: fn_name: cast_signed 74# CHECK: type_var: T 75# CHECK: operands: 76# CHECK: scalar_fn: 77# CHECK: kind: unary 78# CHECK: attr_name: exp 79# CHECK: operands: 80# CHECK: scalar_const: '1.{{[0]*}}e+03 : f64' 81@linalg_structured_op 82def constants( 83 O=TensorDef(T, S.M, S.K, output=True), exp=UnaryFnAttrDef(default=UnaryFn.exp) 84): 85 pi = TypeFn.cast_signed(T, const(3.1415926535897931)) 86 cst42 = TypeFn.cast_signed(T, const(42)) 87 cst1000 = TypeFn.cast_signed(T, exp(const(1e3))) 88 O[D.m, D.n] = UnaryFn.exp(pi) + cst42 - cst1000 89 90 91# CHECK: --- 92# CHECK-LABEL: indices 93# CHECK: assignments: 94# CHECK: - 95# CHECK: arg: O 96# CHECK: scalar_fn: 97# CHECK: kind: binary 98# CHECK: fn_name: add 99# CHECK: operands: 100# CHECK: scalar_index: 1 101# CHECK: scalar_index: 0 102@linalg_structured_op 103def indices(O=TensorDef(T, S.M, S.K, output=True)): 104 O[D.m, D.n] = index(D.n) + index(D.m) 105 106 107# CHECK: --- 108# CHECK-LABEL: fill 109# CHECK: assignments: 110# CHECK: - 111# CHECK: arg: O 112# CHECK: scalar_arg: value 113@linalg_structured_op 114def fill(value=ScalarDef(T), O=TensorDef(T, S.M, S.K, output=True)): 115 O[D.m, D.n] = value 116