1# RUN: %PYTHON -m mlir.dialects.linalg.opdsl.dump_oplib --file %s | FileCheck %s 2 3from mlir.dialects.linalg.opdsl.lang import * 4 5 6# CHECK: --- 7# CHECK-LABEL: matmul 8# CHECK: args: 9# CHECK: name: A 10# CHECK: kind: input_tensor 11# CHECK: type_var: T 12# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)> 13# CHECK: name: B 14# CHECK: kind: input_tensor 15# CHECK: type_var: T 16# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)> 17# CHECK: name: C 18# CHECK: kind: output_tensor 19# CHECK: type_var: U 20# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)> 21# CHECK: name: bfn 22# CHECK: kind: binary_fn_attr 23# CHECK: default_fn: mul 24# CHECK: name: ufn 25# CHECK: kind: unary_fn_attr 26# CHECK: default_fn: exp 27# CHECK: name: cast 28# CHECK: kind: type_fn_attr 29# CHECK: default_fn: cast_signed 30@linalg_structured_op 31def matmul( 32 A=TensorDef(T, S.M, S.K), 33 B=TensorDef(T, S.K, S.N), 34 C=TensorDef(U, S.M, S.N, output=True), 35 bfn=BinaryFnAttrDef(default=BinaryFn.mul), 36 ufn=UnaryFnAttrDef(default=UnaryFn.exp), 37 cast=TypeFnAttrDef(default=TypeFn.cast_signed), 38): 39 C[D.m, D.n] += bfn(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n])) 40 41 42# CHECK: --- 43# CHECK-LABEL: fill 44# CHECK: args: 45# CHECK: name: value 46# CHECK: kind: scalar 47# CHECK-NOT: shape_map: 48# CHECK: type_var: T 49@linalg_structured_op 50def fill(value=ScalarDef(T), O=TensorDef(T, S.M, S.K, output=True)): 51 O[D.m, D.n] = value 52 53 54# CHECK: --- 55# CHECK-LABEL: strided_copy 56# CHECK: args: 57# CHECK: name: I 58# CHECK: kind: input_tensor 59# CHECK: type_var: T 60# CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1)> 61# CHECK: name: O 62# CHECK: kind: output_tensor 63# CHECK: type_var: T 64# CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s2, s3)> 65# CHECK: name: strides 66# CHECK: kind: index_attr 67# CHECK: index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4, s5)> 68# CHECK: default_indices: 69# CHECK: - 1 70# CHECK: - 2 71@linalg_structured_op 72def strided_copy( 73 I=TensorDef(T, S.IH, S.IW), 74 O=TensorDef(T, S.OH, S.OW, output=True), 75 strides=IndexAttrDef(S.SH, S.SW, default=[1, 2]), 76): 77 O[D.oh, D.ow] = I[D.oh * S.SH, D.ow * S.SW] 78