1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 5 6from ._scf_ops_gen import * 7from ._scf_ops_gen import _Dialect 8from .arith import constant 9 10try: 11 from ..ir import * 12 from ._ods_common import ( 13 get_op_result_or_value as _get_op_result_or_value, 14 get_op_results_or_values as _get_op_results_or_values, 15 _cext as _ods_cext, 16 ) 17except ImportError as e: 18 raise RuntimeError("Error loading imports from extension module") from e 19 20from typing import Optional, Sequence, Union 21 22 23@_ods_cext.register_operation(_Dialect, replace=True) 24class ForOp(ForOp): 25 """Specialization for the SCF for op class.""" 26 27 def __init__( 28 self, 29 lower_bound, 30 upper_bound, 31 step, 32 iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None, 33 *, 34 loc=None, 35 ip=None, 36 ): 37 """Creates an SCF `for` operation. 38 39 - `lower_bound` is the value to use as lower bound of the loop. 40 - `upper_bound` is the value to use as upper bound of the loop. 41 - `step` is the value to use as loop step. 42 - `iter_args` is a list of additional loop-carried arguments or an operation 43 producing them as results. 44 """ 45 if iter_args is None: 46 iter_args = [] 47 iter_args = _get_op_results_or_values(iter_args) 48 49 results = [arg.type for arg in iter_args] 50 super().__init__( 51 results, lower_bound, upper_bound, step, iter_args, loc=loc, ip=ip 52 ) 53 self.regions[0].blocks.append(self.operands[0].type, *results) 54 55 @property 56 def body(self): 57 """Returns the body (block) of the loop.""" 58 return self.regions[0].blocks[0] 59 60 @property 61 def induction_variable(self): 62 """Returns the induction variable of the loop.""" 63 return self.body.arguments[0] 64 65 @property 66 def inner_iter_args(self): 67 """Returns the loop-carried arguments usable within the loop. 68 69 To obtain the loop-carried operands, use `iter_args`. 70 """ 71 return self.body.arguments[1:] 72 73 74@_ods_cext.register_operation(_Dialect, replace=True) 75class IfOp(IfOp): 76 """Specialization for the SCF if op class.""" 77 78 def __init__(self, cond, results_=None, *, hasElse=False, loc=None, ip=None): 79 """Creates an SCF `if` operation. 80 81 - `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed. 82 - `hasElse` determines whether the if operation has the else branch. 83 """ 84 if results_ is None: 85 results_ = [] 86 operands = [] 87 operands.append(cond) 88 results = [] 89 results.extend(results_) 90 super().__init__(results, cond, loc=loc, ip=ip) 91 self.regions[0].blocks.append(*[]) 92 if hasElse: 93 self.regions[1].blocks.append(*[]) 94 95 @property 96 def then_block(self): 97 """Returns the then block of the if operation.""" 98 return self.regions[0].blocks[0] 99 100 @property 101 def else_block(self): 102 """Returns the else block of the if operation.""" 103 return self.regions[1].blocks[0] 104 105 106def for_( 107 start, 108 stop=None, 109 step=None, 110 iter_args: Optional[Sequence[Value]] = None, 111 *, 112 loc=None, 113 ip=None, 114): 115 if step is None: 116 step = 1 117 if stop is None: 118 stop = start 119 start = 0 120 params = [start, stop, step] 121 for i, p in enumerate(params): 122 if isinstance(p, int): 123 p = constant(IndexType.get(), p) 124 elif isinstance(p, float): 125 raise ValueError(f"{p=} must be int.") 126 params[i] = p 127 128 start, stop, step = params 129 130 for_op = ForOp(start, stop, step, iter_args, loc=loc, ip=ip) 131 iv = for_op.induction_variable 132 iter_args = tuple(for_op.inner_iter_args) 133 with InsertionPoint(for_op.body): 134 if len(iter_args) > 1: 135 yield iv, iter_args, for_op.results 136 elif len(iter_args) == 1: 137 yield iv, iter_args[0], for_op.results[0] 138 else: 139 yield iv 140