xref: /llvm-project/mlir/test/python/dialects/linalg/opdsl/assignments.py (revision f9008e6366c2496b1ca1785b891d5578174ad63e)
1# RUN: %PYTHON -m mlir.dialects.linalg.opdsl.dump_oplib --file %s | FileCheck %s
2
3from mlir.dialects.linalg.opdsl.lang import *
4
5
6# CHECK: ---
7# CHECK-LABEL: matmul
8# CHECK: assignments:
9# CHECK:  -
10# CHECK:    arg: C
11# CHECK:    value:
12# CHECK:      scalar_fn:
13# CHECK:        kind: binary
14# CHECK:        fn_name: add
15# CHECK:        operands:
16# CHECK:          scalar_fn:
17# CHECK:            kind: binary
18# CHECK:            attr_name: mul
19# CHECK:            operands:
20# CHECK:              scalar_fn:
21# CHECK:                kind: type
22# CHECK:                attr_name: cast
23# CHECK:                type_var: U
24# CHECK:                operands:
25# CHECK:                  scalar_arg: A
26# CHECK:              scalar_fn:
27# CHECK:                kind: type
28# CHECK:                attr_name: cast
29# CHECK:                type_var: U
30# CHECK:                operands:
31# CHECK:                  scalar_arg: B
32@linalg_structured_op
33def matmul(
34    A=TensorDef(T, S.M, S.K),
35    B=TensorDef(T, S.K, S.N),
36    C=TensorDef(U, S.M, S.N, output=True),
37    mul=BinaryFnAttrDef(default=BinaryFn.mul),
38    cast=TypeFnAttrDef(default=TypeFn.cast_signed),
39):
40    C[D.m, D.n] += mul(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n]))
41
42
43# CHECK: ---
44# CHECK-LABEL: constants
45# CHECK: assignments:
46# CHECK:  -
47# CHECK:    arg: O
48# CHECK:      scalar_fn:
49# CHECK:        kind: binary
50# CHECK:        fn_name: sub
51# CHECK:        operands:
52# CHECK:          scalar_fn:
53# CHECK:            kind: binary
54# CHECK:            fn_name: add
55# CHECK:            operands:
56# CHECK:              scalar_fn:
57# CHECK:                kind: unary
58# CHECK:                fn_name: exp
59# CHECK:                operands:
60# CHECK:                  scalar_fn:
61# CHECK:                    kind: type
62# CHECK:                    type_var: T
63# CHECK:                    operands:
64# CHECK:                      scalar_const: '3.1415926535897931 : f64'
65# CHECK:              scalar_fn:
66# CHECK:                kind: type
67# CHECK:                fn_name: cast_signed
68# CHECK:                type_var: T
69# CHECK:                operands:
70# CHECK:                  scalar_const: '42 : i64'
71# CHECK:          scalar_fn:
72# CHECK:            kind: type
73# CHECK:            fn_name: cast_signed
74# CHECK:            type_var: T
75# CHECK:            operands:
76# CHECK:              scalar_fn:
77# CHECK:                kind: unary
78# CHECK:                attr_name: exp
79# CHECK:                operands:
80# CHECK:                  scalar_const: '1.{{[0]*}}e+03 : f64'
81@linalg_structured_op
82def constants(
83    O=TensorDef(T, S.M, S.K, output=True), exp=UnaryFnAttrDef(default=UnaryFn.exp)
84):
85    pi = TypeFn.cast_signed(T, const(3.1415926535897931))
86    cst42 = TypeFn.cast_signed(T, const(42))
87    cst1000 = TypeFn.cast_signed(T, exp(const(1e3)))
88    O[D.m, D.n] = UnaryFn.exp(pi) + cst42 - cst1000
89
90
91# CHECK: ---
92# CHECK-LABEL: indices
93# CHECK: assignments:
94# CHECK:  -
95# CHECK:    arg: O
96# CHECK:      scalar_fn:
97# CHECK:        kind: binary
98# CHECK:        fn_name: add
99# CHECK:        operands:
100# CHECK:          scalar_index: 1
101# CHECK:          scalar_index: 0
102@linalg_structured_op
103def indices(O=TensorDef(T, S.M, S.K, output=True)):
104    O[D.m, D.n] = index(D.n) + index(D.m)
105
106
107# CHECK: ---
108# CHECK-LABEL: fill
109# CHECK: assignments:
110# CHECK:  -
111# CHECK:    arg: O
112# CHECK:      scalar_arg: value
113@linalg_structured_op
114def fill(value=ScalarDef(T), O=TensorDef(T, S.M, S.K, output=True)):
115    O[D.m, D.n] = value
116