xref: /llvm-project/mlir/test/python/dialects/linalg/opdsl/arguments.py (revision f9008e6366c2496b1ca1785b891d5578174ad63e)
19a2769dbSTobias Gysi# RUN: %PYTHON -m mlir.dialects.linalg.opdsl.dump_oplib --file %s | FileCheck %s
29a2769dbSTobias Gysi
39a2769dbSTobias Gysifrom mlir.dialects.linalg.opdsl.lang import *
49a2769dbSTobias Gysi
59a2769dbSTobias Gysi
69a2769dbSTobias Gysi# CHECK: ---
79a2769dbSTobias Gysi# CHECK-LABEL: matmul
89a2769dbSTobias Gysi# CHECK: args:
99a2769dbSTobias Gysi# CHECK:     name: A
1051fdd802Sgysit# CHECK:     kind: input_tensor
11662f9bffSTobias Gysi# CHECK:     type_var: T
124361bd9bSTobias Gysi# CHECK:     shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
139a2769dbSTobias Gysi# CHECK:     name: B
1451fdd802Sgysit# CHECK:     kind: input_tensor
15662f9bffSTobias Gysi# CHECK:     type_var: T
164361bd9bSTobias Gysi# CHECK:     shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
179a2769dbSTobias Gysi# CHECK:     name: C
1851fdd802Sgysit# CHECK:     kind: output_tensor
19662f9bffSTobias Gysi# CHECK:     type_var: U
204361bd9bSTobias Gysi# CHECK:     shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
2124357fecSgysit# CHECK:     name: bfn
2224357fecSgysit# CHECK:     kind: binary_fn_attr
2324357fecSgysit# CHECK:     default_fn: mul
2424357fecSgysit# CHECK:     name: ufn
2524357fecSgysit# CHECK:     kind: unary_fn_attr
2624357fecSgysit# CHECK:     default_fn: exp
2751fdd802Sgysit# CHECK:     name: cast
2851fdd802Sgysit# CHECK:     kind: type_fn_attr
29e9085d0dSgysit# CHECK:     default_fn: cast_signed
309a2769dbSTobias Gysi@linalg_structured_op
319a2769dbSTobias Gysidef matmul(
329a2769dbSTobias Gysi    A=TensorDef(T, S.M, S.K),
339a2769dbSTobias Gysi    B=TensorDef(T, S.K, S.N),
3451fdd802Sgysit    C=TensorDef(U, S.M, S.N, output=True),
3524357fecSgysit    bfn=BinaryFnAttrDef(default=BinaryFn.mul),
3624357fecSgysit    ufn=UnaryFnAttrDef(default=UnaryFn.exp),
37*f9008e63STobias Hieta    cast=TypeFnAttrDef(default=TypeFn.cast_signed),
38*f9008e63STobias Hieta):
3924357fecSgysit    C[D.m, D.n] += bfn(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n]))
409a2769dbSTobias Gysi
419a2769dbSTobias Gysi
429a2769dbSTobias Gysi# CHECK: ---
439a2769dbSTobias Gysi# CHECK-LABEL: fill
44662f9bffSTobias Gysi# CHECK: args:
459a2769dbSTobias Gysi# CHECK:     name: value
4651fdd802Sgysit# CHECK:     kind: scalar
4731f888eaSTobias Gysi# CHECK-NOT: shape_map:
489a2769dbSTobias Gysi# CHECK:     type_var: T
499a2769dbSTobias Gysi@linalg_structured_op
50662f9bffSTobias Gysidef fill(value=ScalarDef(T), O=TensorDef(T, S.M, S.K, output=True)):
519a2769dbSTobias Gysi    O[D.m, D.n] = value
5231f888eaSTobias Gysi
5331f888eaSTobias Gysi
5431f888eaSTobias Gysi# CHECK: ---
5531f888eaSTobias Gysi# CHECK-LABEL: strided_copy
5631f888eaSTobias Gysi# CHECK: args:
5731f888eaSTobias Gysi# CHECK:     name: I
5851fdd802Sgysit# CHECK:     kind: input_tensor
5931f888eaSTobias Gysi# CHECK:     type_var: T
604361bd9bSTobias Gysi# CHECK:     shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1)>
6131f888eaSTobias Gysi# CHECK:     name: O
6251fdd802Sgysit# CHECK:     kind: output_tensor
6331f888eaSTobias Gysi# CHECK:     type_var: T
644361bd9bSTobias Gysi# CHECK:     shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s2, s3)>
6531f888eaSTobias Gysi# CHECK:     name: strides
6651fdd802Sgysit# CHECK:     kind: index_attr
67d50571abSgysit# CHECK:     index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4, s5)>
6851fdd802Sgysit# CHECK:     default_indices:
69d50571abSgysit# CHECK:     - 1
70d50571abSgysit# CHECK:     - 2
7131f888eaSTobias Gysi@linalg_structured_op
7231f888eaSTobias Gysidef strided_copy(
7325bb6164STobias Gysi    I=TensorDef(T, S.IH, S.IW),
7431f888eaSTobias Gysi    O=TensorDef(T, S.OH, S.OW, output=True),
75*f9008e63STobias Hieta    strides=IndexAttrDef(S.SH, S.SW, default=[1, 2]),
76*f9008e63STobias Hieta):
774361bd9bSTobias Gysi    O[D.oh, D.ow] = I[D.oh * S.SH, D.ow * S.SW]
78