xref: /llvm-project/mlir/python/mlir/dialects/scf.py (revision ad89e617c703239518187912540b8ea811dc2eda)
1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5
6from ._scf_ops_gen import *
7from ._scf_ops_gen import _Dialect
8from .arith import constant
9
10try:
11    from ..ir import *
12    from ._ods_common import (
13        get_op_result_or_value as _get_op_result_or_value,
14        get_op_results_or_values as _get_op_results_or_values,
15        _cext as _ods_cext,
16    )
17except ImportError as e:
18    raise RuntimeError("Error loading imports from extension module") from e
19
20from typing import Optional, Sequence, Union
21
22
23@_ods_cext.register_operation(_Dialect, replace=True)
24class ForOp(ForOp):
25    """Specialization for the SCF for op class."""
26
27    def __init__(
28        self,
29        lower_bound,
30        upper_bound,
31        step,
32        iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
33        *,
34        loc=None,
35        ip=None,
36    ):
37        """Creates an SCF `for` operation.
38
39        - `lower_bound` is the value to use as lower bound of the loop.
40        - `upper_bound` is the value to use as upper bound of the loop.
41        - `step` is the value to use as loop step.
42        - `iter_args` is a list of additional loop-carried arguments or an operation
43          producing them as results.
44        """
45        if iter_args is None:
46            iter_args = []
47        iter_args = _get_op_results_or_values(iter_args)
48
49        results = [arg.type for arg in iter_args]
50        super().__init__(
51            results, lower_bound, upper_bound, step, iter_args, loc=loc, ip=ip
52        )
53        self.regions[0].blocks.append(self.operands[0].type, *results)
54
55    @property
56    def body(self):
57        """Returns the body (block) of the loop."""
58        return self.regions[0].blocks[0]
59
60    @property
61    def induction_variable(self):
62        """Returns the induction variable of the loop."""
63        return self.body.arguments[0]
64
65    @property
66    def inner_iter_args(self):
67        """Returns the loop-carried arguments usable within the loop.
68
69        To obtain the loop-carried operands, use `iter_args`.
70        """
71        return self.body.arguments[1:]
72
73
74@_ods_cext.register_operation(_Dialect, replace=True)
75class IfOp(IfOp):
76    """Specialization for the SCF if op class."""
77
78    def __init__(self, cond, results_=None, *, hasElse=False, loc=None, ip=None):
79        """Creates an SCF `if` operation.
80
81        - `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed.
82        - `hasElse` determines whether the if operation has the else branch.
83        """
84        if results_ is None:
85            results_ = []
86        operands = []
87        operands.append(cond)
88        results = []
89        results.extend(results_)
90        super().__init__(results, cond, loc=loc, ip=ip)
91        self.regions[0].blocks.append(*[])
92        if hasElse:
93            self.regions[1].blocks.append(*[])
94
95    @property
96    def then_block(self):
97        """Returns the then block of the if operation."""
98        return self.regions[0].blocks[0]
99
100    @property
101    def else_block(self):
102        """Returns the else block of the if operation."""
103        return self.regions[1].blocks[0]
104
105
106def for_(
107    start,
108    stop=None,
109    step=None,
110    iter_args: Optional[Sequence[Value]] = None,
111    *,
112    loc=None,
113    ip=None,
114):
115    if step is None:
116        step = 1
117    if stop is None:
118        stop = start
119        start = 0
120    params = [start, stop, step]
121    for i, p in enumerate(params):
122        if isinstance(p, int):
123            p = constant(IndexType.get(), p)
124        elif isinstance(p, float):
125            raise ValueError(f"{p=} must be int.")
126        params[i] = p
127
128    start, stop, step = params
129
130    for_op = ForOp(start, stop, step, iter_args, loc=loc, ip=ip)
131    iv = for_op.induction_variable
132    iter_args = tuple(for_op.inner_iter_args)
133    with InsertionPoint(for_op.body):
134        if len(iter_args) > 1:
135            yield iv, iter_args, for_op.results
136        elif len(iter_args) == 1:
137            yield iv, iter_args[0], for_op.results[0]
138        else:
139            yield iv
140