1 /* 2 * Copyright 2011 INRIA Saclay 3 * Copyright 2012 Ecole Normale Superieure 4 * Copyright 2020 Cerebras Systems 5 * 6 * Use of this software is governed by the MIT license 7 * 8 * Written by Sven Verdoolaege, INRIA Saclay - Ile-de-France, 9 * Parc Club Orsay Universite, ZAC des vignes, 4 rue Jacques Monod, 10 * 91893 Orsay, France 11 * and Ecole Normale Superieure, 45 rue d'Ulm, 75230 Paris, France 12 * and Cerebras Systems, 175 S San Antonio Rd, Los Altos, CA, USA 13 */ 14 15 #include <isl_pw_macro.h> 16 17 /* Given a function "cmp" that returns the set of elements where 18 * "el1" is "better" than "el2", return this set. 19 */ 20 static __isl_give isl_set *FN(PW,better)(__isl_keep EL *el1, __isl_keep EL *el2, 21 __isl_give isl_set *(*cmp)(__isl_take EL *el1, __isl_take EL *el2)) 22 { 23 return cmp(FN(EL,copy)(el1), FN(EL,copy)(el2)); 24 } 25 26 /* Return a list containing the domains of the pieces of "pw". 27 */ 28 static __isl_give isl_set_list *FN(PW,extract_domains)(__isl_keep PW *pw) 29 { 30 int i; 31 isl_ctx *ctx; 32 isl_set_list *list; 33 34 if (!pw) 35 return NULL; 36 ctx = FN(PW,get_ctx)(pw); 37 list = isl_set_list_alloc(ctx, pw->n); 38 for (i = 0; i < pw->n; ++i) 39 list = isl_set_list_add(list, isl_set_copy(pw->p[i].set)); 40 41 return list; 42 } 43 44 /* Given sets B ("set"), C ("better") and A' ("out"), return 45 * 46 * (B \cap C) \cup ((B \setminus C) \setminus A') 47 */ 48 static __isl_give isl_set *FN(PW,better_or_out)(__isl_take isl_set *set, 49 __isl_take isl_set *better, __isl_take isl_set *out) 50 { 51 isl_set *set_better, *set_out; 52 53 set_better = isl_set_intersect(isl_set_copy(set), isl_set_copy(better)); 54 set_out = isl_set_subtract(isl_set_subtract(set, better), out); 55 56 return isl_set_union(set_better, set_out); 57 } 58 59 /* Given sets A ("set"), C ("better") and B' ("out"), return 60 * 61 * (A \setminus C) \cup ((A \cap C) \setminus B') 62 */ 63 static __isl_give isl_set *FN(PW,worse_or_out)(__isl_take isl_set *set, 64 __isl_take isl_set *better, __isl_take isl_set *out) 65 { 66 isl_set *set_worse, *set_out; 67 68 set_worse = isl_set_subtract(isl_set_copy(set), isl_set_copy(better)); 69 set_out = isl_set_subtract(isl_set_intersect(set, better), out); 70 71 return isl_set_union(set_worse, set_out); 72 } 73 74 /* Internal data structure used by isl_pw_*_union_opt_cmp 75 * that keeps track of a piecewise expression with updated cells. 76 * "pw" holds the original piecewise expression. 77 * "list" holds the updated cells. 78 */ 79 S(PW,union_opt_cmp_data) { 80 PW *pw; 81 isl_set_list *cell; 82 }; 83 84 /* Free all memory allocated for "data". 85 */ 86 static void FN(PW,union_opt_cmp_data_clear)(S(PW,union_opt_cmp_data) *data) 87 { 88 isl_set_list_free(data->cell); 89 FN(PW,free)(data->pw); 90 } 91 92 /* Given (potentially) updated cells "i" of data_i->pw and "j" of data_j->pw and 93 * a set "better" where the piece from data_j->pw is better 94 * than the piece from data_i->pw, 95 * (further) update the specified cells such that only the better elements 96 * remain on the (non-empty) intersection. 97 * 98 * Let C be the set "better". 99 * Let A be the cell data_i->cell[i] and B the cell data_j->cell[j]. 100 * 101 * The elements in C need to be removed from A, except for those parts 102 * that lie outside of B. That is, 103 * 104 * A <- (A \setminus C) \cup ((A \cap C) \setminus B') 105 * 106 * Conversely, the elements in B need to be restricted to C, except 107 * for those parts that lie outside of A. That is 108 * 109 * B <- (B \cap C) \cup ((B \setminus C) \setminus A') 110 * 111 * Since all pairs of pieces are considered, the domains are updated 112 * several times. A and B refer to these updated domains 113 * (kept track of in data_i->cell[i] and data_j->cell[j]), while A' and B' refer 114 * to the original domains of the pieces. It is safe to use these 115 * original domains because the difference between, say, A' and A is 116 * the domains of pw2-pieces that have been removed before and 117 * those domains are disjoint from B. A' is used instead of A 118 * because the continued updating of A may result in this domain 119 * getting broken up into more disjuncts. 120 */ 121 static isl_stat FN(PW,union_opt_cmp_split)(S(PW,union_opt_cmp_data) *data_i, 122 int i, S(PW,union_opt_cmp_data) *data_j, int j, 123 __isl_take isl_set *better) 124 { 125 isl_set *set_i, *set_j; 126 127 set_i = isl_set_list_get_set(data_i->cell, i); 128 set_j = FN(PW,get_domain_at)(data_j->pw, j); 129 set_i = FN(PW,worse_or_out)(set_i, isl_set_copy(better), set_j); 130 data_i->cell = isl_set_list_set_set(data_i->cell, i, set_i); 131 set_i = FN(PW,get_domain_at)(data_i->pw, i); 132 set_j = isl_set_list_get_set(data_j->cell, j); 133 set_j = FN(PW,better_or_out)(set_j, better, set_i); 134 data_j->cell = isl_set_list_set_set(data_j->cell, j, set_j); 135 136 return isl_stat_ok; 137 } 138 139 /* Given (potentially) updated cells "i" of data_i->pw and "j" of data_j->pw and 140 * a function "cmp" that returns the set of elements where 141 * "el1" is "better" than "el2", 142 * (further) update the specified cells such that only the "better" elements 143 * remain on the (non-empty) intersection. 144 */ 145 static isl_stat FN(PW,union_opt_cmp_pair)(S(PW,union_opt_cmp_data) *data_i, 146 int i, S(PW,union_opt_cmp_data) *data_j, int j, 147 __isl_give isl_set *(*cmp)(__isl_take EL *el1, __isl_take EL *el2)) 148 { 149 isl_set *better; 150 EL *el_i, *el_j; 151 152 el_i = FN(PW,peek_base_at)(data_i->pw, i); 153 el_j = FN(PW,peek_base_at)(data_j->pw, j); 154 better = FN(PW,better)(el_j, el_i, cmp); 155 return FN(PW,union_opt_cmp_split)(data_i, i, data_j, j, better); 156 } 157 158 /* Given (potentially) updated cells "i" of data_i->pw and "j" of data_j->pw and 159 * a function "cmp" that returns the set of elements where 160 * "el1" is "better" than "el2", 161 * (further) update the specified cells such that only the "better" elements 162 * remain on the (non-empty) intersection. 163 * 164 * The base computation is performed by isl_pw_*_union_opt_cmp_pair, 165 * which splits the cells according to the set of elements 166 * where the piece from data_j->pw is better than the piece from data_i->pw. 167 * 168 * In some cases, there may be a subset of the intersection 169 * where both pieces have the same value and can therefore 170 * both be considered to be "better" than the other. 171 * This can result in unnecessary splitting on this subset. 172 * Avoid some of these cases by checking whether 173 * data_i->pw is always better than data_j->pw on the intersection. 174 * In particular, do this for the special case where this intersection 175 * is equal to the cell "j" and data_i->pw is better on its entire cell. 176 * 177 * Similarly, if data_i->pw is never better than data_j->pw, 178 * then no splitting will occur and there is no need to check 179 * where data_j->pw is better than data_i->pw. 180 */ 181 static isl_stat FN(PW,union_opt_cmp_two)(S(PW,union_opt_cmp_data) *data_i, 182 int i, S(PW,union_opt_cmp_data) *data_j, int j, 183 __isl_give isl_set *(*cmp)(__isl_take EL *el1, __isl_take EL *el2)) 184 { 185 isl_bool is_subset, is_empty; 186 isl_set *better, *set_i, *set_j; 187 EL *el_i, *el_j; 188 189 set_i = FN(PW,peek_domain_at)(data_i->pw, i); 190 set_j = FN(PW,peek_domain_at)(data_j->pw, j); 191 is_subset = isl_set_is_subset(set_j, set_i); 192 if (is_subset < 0) 193 return isl_stat_error; 194 if (!is_subset) 195 return FN(PW,union_opt_cmp_pair)(data_i, i, data_j, j, cmp); 196 197 el_i = FN(PW,peek_base_at)(data_i->pw, i); 198 el_j = FN(PW,peek_base_at)(data_j->pw, j); 199 better = FN(PW,better)(el_i, el_j, cmp); 200 is_empty = isl_set_is_empty(better); 201 if (is_empty >= 0 && is_empty) 202 return FN(PW,union_opt_cmp_split)(data_j, j, data_i, i, better); 203 is_subset = isl_set_is_subset(set_i, better); 204 if (is_subset >= 0 && is_subset) 205 return FN(PW,union_opt_cmp_split)(data_j, j, data_i, i, better); 206 isl_set_free(better); 207 if (is_empty < 0 || is_subset < 0) 208 return isl_stat_error; 209 210 return FN(PW,union_opt_cmp_pair)(data_i, i, data_j, j, cmp); 211 } 212 213 /* Given two piecewise expressions data1->pw and data2->pw, replace 214 * their domains 215 * by the sets in data1->cell and data2->cell and combine the results into 216 * a single piecewise expression. 217 * The pieces of data1->pw and data2->pw are assumed to have been sorted 218 * according to the function value expressions. 219 * The pieces of the result are also sorted in this way. 220 * 221 * Run through the pieces of data1->pw and data2->pw in order until they 222 * have both been exhausted, picking the piece from data1->pw or data2->pw 223 * depending on which should come first, together with the corresponding 224 * domain from data1->cell or data2->cell. In cases where the next pieces 225 * in both data1->pw and data2->pw have the same function value expression, 226 * construct only a single piece in the result with as domain 227 * the union of the domains in data1->cell and data2->cell. 228 */ 229 static __isl_give PW *FN(PW,merge)(S(PW,union_opt_cmp_data) *data1, 230 S(PW,union_opt_cmp_data) *data2) 231 { 232 int i, j; 233 PW *res; 234 PW *pw1 = data1->pw; 235 PW *pw2 = data2->pw; 236 isl_set_list *list1 = data1->cell; 237 isl_set_list *list2 = data2->cell; 238 239 if (!pw1 || !pw2) 240 return NULL; 241 242 res = FN(PW,alloc_size)(isl_space_copy(pw1->dim), pw1->n + pw2->n); 243 244 i = 0; j = 0; 245 while (i < pw1->n || j < pw2->n) { 246 int cmp; 247 isl_set *set; 248 EL *el; 249 250 if (i < pw1->n && j < pw2->n) 251 cmp = FN(EL,plain_cmp)(pw1->p[i].FIELD, 252 pw2->p[j].FIELD); 253 else 254 cmp = i < pw1->n ? -1 : 1; 255 256 if (cmp < 0) { 257 set = isl_set_list_get_set(list1, i); 258 el = FN(EL,copy)(pw1->p[i].FIELD); 259 ++i; 260 } else if (cmp > 0) { 261 set = isl_set_list_get_set(list2, j); 262 el = FN(EL,copy)(pw2->p[j].FIELD); 263 ++j; 264 } else { 265 set = isl_set_union(isl_set_list_get_set(list1, i), 266 isl_set_list_get_set(list2, j)); 267 el = FN(EL,copy)(pw1->p[i].FIELD); 268 ++i; 269 ++j; 270 } 271 res = FN(PW,add_piece)(res, set, el); 272 } 273 274 return res; 275 } 276 277 /* Given a function "cmp" that returns the set of elements where 278 * "el1" is "better" than "el2", return a piecewise 279 * expression defined on the union of the definition domains 280 * of "pw1" and "pw2" that maps to the "best" of "pw1" and 281 * "pw2" on each cell. If only one of the two input functions 282 * is defined on a given cell, then it is considered the best. 283 * 284 * Run through all pairs of pieces in "pw1" and "pw2". 285 * If the domains of these pieces intersect, then the intersection 286 * needs to be distributed over the two pieces based on "cmp". 287 * 288 * After the updated domains have been computed, the result is constructed 289 * from "pw1", "pw2", data[0].cell and data[1].cell. If there are any pieces 290 * in "pw1" and "pw2" with the same function value expression, then 291 * they are combined into a single piece in the result. 292 * In order to be able to do this efficiently, the pieces of "pw1" and 293 * "pw2" are first sorted according to their function value expressions. 294 */ 295 static __isl_give PW *FN(PW,union_opt_cmp)( 296 __isl_take PW *pw1, __isl_take PW *pw2, 297 __isl_give isl_set *(*cmp)(__isl_take EL *el1, __isl_take EL *el2)) 298 { 299 S(PW,union_opt_cmp_data) data[2] = { { pw1, NULL }, { pw2, NULL } }; 300 int i, j; 301 isl_size n1, n2; 302 PW *res = NULL; 303 isl_ctx *ctx; 304 305 if (!pw1 || !pw2) 306 goto error; 307 308 ctx = isl_space_get_ctx(pw1->dim); 309 if (!isl_space_is_equal(pw1->dim, pw2->dim)) 310 isl_die(ctx, isl_error_invalid, 311 "arguments should live in the same space", goto error); 312 313 if (FN(PW,is_empty)(pw1)) { 314 FN(PW,free)(pw1); 315 return pw2; 316 } 317 318 if (FN(PW,is_empty)(pw2)) { 319 FN(PW,free)(pw2); 320 return pw1; 321 } 322 323 for (i = 0; i < 2; ++i) { 324 data[i].pw = FN(PW,sort_unique)(data[i].pw); 325 data[i].cell = FN(PW,extract_domains)(data[i].pw); 326 } 327 328 n1 = FN(PW,n_piece)(data[0].pw); 329 n2 = FN(PW,n_piece)(data[1].pw); 330 if (n1 < 0 || n2 < 0) 331 goto error; 332 for (i = 0; i < n1; ++i) { 333 for (j = 0; j < n2; ++j) { 334 isl_bool disjoint; 335 isl_set *set_i, *set_j; 336 337 set_i = FN(PW,peek_domain_at)(data[0].pw, i); 338 set_j = FN(PW,peek_domain_at)(data[1].pw, j); 339 disjoint = isl_set_is_disjoint(set_i, set_j); 340 if (disjoint < 0) 341 goto error; 342 if (disjoint) 343 continue; 344 if (FN(PW,union_opt_cmp_two)(&data[0], i, 345 &data[1], j, cmp) < 0) 346 goto error; 347 } 348 } 349 350 res = FN(PW,merge)(&data[0], &data[1]); 351 for (i = 0; i < 2; ++i) 352 FN(PW,union_opt_cmp_data_clear)(&data[i]); 353 354 return res; 355 error: 356 for (i = 0; i < 2; ++i) 357 FN(PW,union_opt_cmp_data_clear)(&data[i]); 358 return FN(PW,free)(res); 359 } 360