xref: /netbsd-src/external/gpl3/gcc/dist/libphobos/src/std/regex/internal/kickstart.d (revision b1e838363e3c6fc78a55519254d99869742dd33c)
1 /*
2     Kickstart is a coarse-grained "filter" engine that finds likely matches
3     to be verified by full-blown matcher.
4 */
5 module std.regex.internal.kickstart;
6 
7 package(std.regex):
8 
9 import std.range.primitives, std.utf;
10 import std.regex.internal.ir;
11 
12 //utility for shiftOr, returns a minimum number of bytes to test in a Char
effectiveSize(Char)13 uint effectiveSize(Char)()
14 {
15     static if (is(Char == char))
16         return 1;
17     else static if (is(Char == wchar))
18         return 2;
19     else static if (is(Char == dchar))
20         return 3;
21     else
22         static assert(0);
23 }
24 
25 /*
26     Kickstart engine using ShiftOr algorithm,
27     a bit parallel technique for inexact string searching.
28 */
ShiftOr(Char)29 struct ShiftOr(Char)
30 {
31 private:
32     uint[] table;
33     uint fChar;
34     uint n_length;
35     enum charSize =  effectiveSize!Char();
36     //maximum number of chars in CodepointSet to process
37     enum uint charsetThreshold = 32_000;
38     static struct ShiftThread
39     {
40         uint[] tab;
41         uint mask;
42         uint idx;
43         uint pc, counter, hops;
44         this(uint newPc, uint newCounter, uint[] table)
45         {
46             pc = newPc;
47             counter = newCounter;
48             mask = 1;
49             idx = 0;
50             hops = 0;
51             tab = table;
52         }
53 
54         void setMask(uint idx, uint mask)
55         {
56             tab[idx] |= mask;
57         }
58 
59         void setInvMask(uint idx, uint mask)
60         {
61             tab[idx] &= ~mask;
62         }
63 
64         void set(alias setBits = setInvMask)(dchar ch)
65         {
66             static if (charSize == 3)
67             {
68                 uint val = ch, tmask = mask;
69                 setBits(val&0xFF, tmask);
70                 tmask <<= 1;
71                 val >>= 8;
72                 setBits(val&0xFF, tmask);
73                 tmask <<= 1;
74                 val >>= 8;
75                 assert(val <= 0x10);
76                 setBits(val, tmask);
77                 tmask <<= 1;
78             }
79             else
80             {
81                 Char[dchar.sizeof/Char.sizeof] buf;
82                 uint tmask = mask;
83                 size_t total = encode(buf, ch);
84                 for (size_t i = 0; i < total; i++, tmask<<=1)
85                 {
86                     static if (charSize == 1)
87                         setBits(buf[i], tmask);
88                     else static if (charSize == 2)
89                     {
90                         setBits(buf[i]&0xFF, tmask);
91                         tmask <<= 1;
92                         setBits(buf[i]>>8, tmask);
93                     }
94                 }
95             }
96         }
97         void add(dchar ch){ return set!setInvMask(ch); }
98         void advance(uint s)
99         {
100             mask <<= s;
101             idx += s;
102         }
103         @property bool full(){    return !mask; }
104     }
105 
106     static ShiftThread fork(ShiftThread t, uint newPc, uint newCounter)
107     {
108         ShiftThread nt = t;
109         nt.pc = newPc;
110         nt.counter = newCounter;
111         return nt;
112     }
113 
114     @trusted static ShiftThread fetch(ref ShiftThread[] worklist)
115     {
116         auto t = worklist[$-1];
117         worklist.length -= 1;
118         if (!__ctfe)
119             cast(void) worklist.assumeSafeAppend();
120         return t;
121     }
122 
123     static uint charLen(uint ch)
124     {
125         assert(ch <= 0x10FFFF);
126         return codeLength!Char(cast(dchar) ch)*charSize;
127     }
128 
129 public:
130     @trusted this(ref Regex!Char re, uint[] memory)
131     {
132         static import std.algorithm.comparison;
133         import std.algorithm.searching : countUntil;
134         import std.conv : text;
135         import std.range : assumeSorted;
136         assert(memory.length == 256);
137         fChar = uint.max;
138         // FNV-1a flavored hash (uses 32bits at a time)
139         ulong hash(uint[] tab)
140         {
141             ulong h = 0xcbf29ce484222325;
142             foreach (v; tab)
143             {
144                 h ^= v;
145                 h *= 0x100000001b3;
146             }
147             return h;
148         }
149     L_FindChar:
150         for (size_t i = 0;;)
151         {
152             switch (re.ir[i].code)
153             {
154                 case IR.Char:
155                     fChar = re.ir[i].data;
156                     static if (charSize != 3)
157                     {
158                         Char[dchar.sizeof/Char.sizeof] buf;
159                         encode(buf, fChar);
160                         fChar = buf[0];
161                     }
162                     fChar = fChar & 0xFF;
163                     break L_FindChar;
164                 case IR.GroupStart, IR.GroupEnd:
165                     i += IRL!(IR.GroupStart);
166                     break;
167                 case IR.Bof, IR.Bol, IR.Wordboundary, IR.Notwordboundary:
168                     i += IRL!(IR.Bol);
169                     break;
170                 default:
171                     break L_FindChar;
172             }
173         }
174         table = memory;
175         table[] =  uint.max;
176         alias MergeTab = bool[ulong];
177         // use reasonably complex hash to identify equivalent tables
178         auto merge = new MergeTab[re.hotspotTableSize];
179         ShiftThread[] trs;
180         ShiftThread t = ShiftThread(0, 0, table);
181         //locate first fixed char if any
182         n_length = 32;
183         for (;;)
184         {
185         L_Eval_Thread:
186             for (;;)
187             {
188                 switch (re.ir[t.pc].code)
189                 {
190                 case IR.Char:
191                     uint s = charLen(re.ir[t.pc].data);
192                     if (t.idx+s > n_length)
193                         goto L_StopThread;
194                     t.add(re.ir[t.pc].data);
195                     t.advance(s);
196                     t.pc += IRL!(IR.Char);
197                     break;
198                 case IR.OrChar://assumes IRL!(OrChar) == 1
199                     uint len = re.ir[t.pc].sequence;
200                     uint end = t.pc + len;
201                     uint[Bytecode.maxSequence] s;
202                     uint numS;
203                     for (uint i = 0; i < len; i++)
204                     {
205                         auto x = charLen(re.ir[t.pc+i].data);
206                         if (countUntil(s[0 .. numS], x) < 0)
207                            s[numS++] = x;
208                     }
209                     for (uint i = t.pc; i < end; i++)
210                     {
211                         t.add(re.ir[i].data);
212                     }
213                     for (uint i = 0; i < numS; i++)
214                     {
215                         auto tx = fork(t, t.pc + len, t.counter);
216                         if (tx.idx + s[i] <= n_length)
217                         {
218                             tx.advance(s[i]);
219                             trs ~= tx;
220                         }
221                     }
222                     if (!trs.empty)
223                         t = fetch(trs);
224                     else
225                         goto L_StopThread;
226                     break;
227                 case IR.CodepointSet:
228                 case IR.Trie:
229                     auto set = re.charsets[re.ir[t.pc].data];
230                     uint[4] s;
231                     uint numS;
232                     static if (charSize == 3)
233                     {
234                         s[0] = charSize;
235                         numS = 1;
236                     }
237                     else
238                     {
239 
240                         static if (charSize == 1)
241                             static immutable codeBounds = [0x0, 0x7F, 0x80, 0x7FF, 0x800, 0xFFFF, 0x10000, 0x10FFFF];
242                         else //== 2
243                             static immutable codeBounds = [0x0, 0xFFFF, 0x10000, 0x10FFFF];
244                         uint[] arr = new uint[set.byInterval.length * 2];
245                         size_t ofs = 0;
246                         foreach (ival; set.byInterval)
247                         {
248                             arr[ofs++] = ival.a;
249                             arr[ofs++] = ival.b;
250                         }
251                         auto srange = assumeSorted!"a <= b"(arr);
252                         for (uint i = 0; i < codeBounds.length/2; i++)
253                         {
254                             auto start = srange.lowerBound(codeBounds[2*i]).length;
255                             auto end = srange.lowerBound(codeBounds[2*i+1]).length;
256                             if (end > start || (end == start && (end & 1)))
257                                s[numS++] = (i+1)*charSize;
258                         }
259                     }
260                     if (numS == 0 || t.idx + s[numS-1] > n_length)
261                         goto L_StopThread;
262                     auto  chars = set.length;
263                     if (chars > charsetThreshold)
264                         goto L_StopThread;
265                     foreach (ch; set.byCodepoint)
266                     {
267                         //avoid surrogate pairs
268                         if (0xD800 <= ch && ch <= 0xDFFF)
269                             continue;
270                         t.add(ch);
271                     }
272                     for (uint i = 0; i < numS; i++)
273                     {
274                         auto tx =  fork(t, t.pc + IRL!(IR.CodepointSet), t.counter);
275                         tx.advance(s[i]);
276                         trs ~= tx;
277                     }
278                     if (!trs.empty)
279                         t = fetch(trs);
280                     else
281                         goto L_StopThread;
282                     break;
283                 case IR.Any:
284                     goto L_StopThread;
285 
286                 case IR.GotoEndOr:
287                     t.pc += IRL!(IR.GotoEndOr)+re.ir[t.pc].data;
288                     assert(re.ir[t.pc].code == IR.OrEnd);
289                     goto case;
290                 case IR.OrEnd:
291                     auto slot = re.ir[t.pc+1].raw+t.counter;
292                     auto val = hash(t.tab);
293                     if (val in merge[slot])
294                         goto L_StopThread; // merge equivalent
295                     merge[slot][val] = true;
296                     t.pc += IRL!(IR.OrEnd);
297                     break;
298                 case IR.OrStart:
299                     t.pc += IRL!(IR.OrStart);
300                     goto case;
301                 case IR.Option:
302                     uint next = t.pc + re.ir[t.pc].data + IRL!(IR.Option);
303                     //queue next Option
304                     if (re.ir[next].code == IR.Option)
305                     {
306                         trs ~= fork(t, next, t.counter);
307                     }
308                     t.pc += IRL!(IR.Option);
309                     break;
310                 case IR.RepeatStart:case IR.RepeatQStart:
311                     t.pc += IRL!(IR.RepeatStart)+re.ir[t.pc].data;
312                     goto case IR.RepeatEnd;
313                 case IR.RepeatEnd:
314                 case IR.RepeatQEnd:
315                     auto slot = re.ir[t.pc+1].raw+t.counter;
316                     auto val = hash(t.tab);
317                     if (val in merge[slot])
318                         goto L_StopThread; // merge equivalent
319                     merge[slot][val] = true;
320                     uint len = re.ir[t.pc].data;
321                     uint step = re.ir[t.pc+2].raw;
322                     uint min = re.ir[t.pc+3].raw;
323                     if (t.counter < min)
324                     {
325                         t.counter += step;
326                         t.pc -= len;
327                         break;
328                     }
329                     uint max = re.ir[t.pc+4].raw;
330                     if (t.counter < max)
331                     {
332                         trs ~= fork(t, t.pc - len, t.counter + step);
333                         t.counter = t.counter%step;
334                         t.pc += IRL!(IR.RepeatEnd);
335                     }
336                     else
337                     {
338                         t.counter = t.counter%step;
339                         t.pc += IRL!(IR.RepeatEnd);
340                     }
341                     break;
342                 case IR.InfiniteStart, IR.InfiniteQStart:
343                     t.pc += re.ir[t.pc].data + IRL!(IR.InfiniteStart);
344                     goto case IR.InfiniteEnd; //both Q and non-Q
345                 case IR.InfiniteEnd:
346                 case IR.InfiniteQEnd:
347                     auto slot = re.ir[t.pc+1].raw+t.counter;
348                     auto val = hash(t.tab);
349                     if (val in merge[slot])
350                         goto L_StopThread; // merge equivalent
351                     merge[slot][val] = true;
352                     uint len = re.ir[t.pc].data;
353                     uint pc1, pc2; //branches to take in priority order
354                     if (++t.hops == 32)
355                         goto L_StopThread;
356                     pc1 = t.pc + IRL!(IR.InfiniteEnd);
357                     pc2 = t.pc - len;
358                     trs ~= fork(t, pc2, t.counter);
359                     t.pc = pc1;
360                     break;
361                 case IR.GroupStart, IR.GroupEnd:
362                     t.pc += IRL!(IR.GroupStart);
363                     break;
364                 case IR.Bof, IR.Bol, IR.Wordboundary, IR.Notwordboundary:
365                     t.pc += IRL!(IR.Bol);
366                     break;
367                 case IR.LookaheadStart, IR.NeglookaheadStart, IR.LookbehindStart, IR.NeglookbehindStart:
368                     t.pc += IRL!(IR.LookaheadStart) + IRL!(IR.LookaheadEnd) + re.ir[t.pc].data;
369                     break;
370                 default:
371                 L_StopThread:
372                     assert(re.ir[t.pc].code >= 0x80, text(re.ir[t.pc].code));
373                     debug (fred_search) writeln("ShiftOr stumbled on ",re.ir[t.pc].mnemonic);
374                     n_length = std.algorithm.comparison.min(t.idx, n_length);
375                     break L_Eval_Thread;
376                 }
377             }
378             if (trs.empty)
379                 break;
380             t = fetch(trs);
381         }
382         debug(std_regex_search)
383         {
384             writeln("Min length: ", n_length);
385         }
386     }
387 
388     @property bool empty() const {  return n_length == 0; }
389 
390     @property uint length() const{ return n_length/charSize; }
391 
392     // lookup compatible bit pattern in haystack, return starting index
393     // has a useful trait: if supplied with valid UTF indexes,
394     // returns only valid UTF indexes
395     // (that given the haystack in question is valid UTF string)
396     @trusted size_t search(const(Char)[] haystack, size_t idx) const
397     {//@BUG: apparently assumes little endian machines
398         import core.stdc.string : memchr;
399         import std.conv : text;
400         assert(!empty);
401         auto p = cast(const(ubyte)*)(haystack.ptr+idx);
402         uint state = uint.max;
403         uint limit = 1u<<(n_length - 1u);
404         debug(std_regex_search) writefln("Limit: %32b",limit);
405         if (fChar != uint.max)
406         {
407             const(ubyte)* end = cast(ubyte*)(haystack.ptr + haystack.length);
408             const orginalAlign = cast(size_t) p & (Char.sizeof-1);
409             while (p != end)
410             {
411                 if (!~state)
412                 {//speed up seeking first matching place
413                     for (;;)
414                     {
415                         assert(p <= end, text(p," vs ", end));
416                         p = cast(ubyte*) memchr(p, fChar, end - p);
417                         if (!p)
418                             return haystack.length;
419                         if ((cast(size_t) p & (Char.sizeof-1)) == orginalAlign)
420                             break;
421                         if (++p == end)
422                             return haystack.length;
423                     }
424                     state = ~1u;
425                     assert((cast(size_t) p & (Char.sizeof-1)) == orginalAlign);
426                     static if (charSize == 3)
427                     {
428                         state = (state << 1) | table[p[1]];
429                         state = (state << 1) | table[p[2]];
430                         p += 4;
431                     }
432                     else
433                         p++;
434                     //first char is tested, see if that's all
435                     if (!(state & limit))
436                         return (p-cast(ubyte*) haystack.ptr)/Char.sizeof
437                             -length;
438                 }
439                 else
440                 {//have some bits/states for possible matches,
441                  //use the usual shift-or cycle
442                     static if (charSize == 3)
443                     {
444                         state = (state << 1) | table[p[0]];
445                         state = (state << 1) | table[p[1]];
446                         state = (state << 1) | table[p[2]];
447                         p += 4;
448                     }
449                     else
450                     {
451                         state = (state << 1) | table[p[0]];
452                         p++;
453                     }
454                     if (!(state & limit))
455                         return (p-cast(ubyte*) haystack.ptr)/Char.sizeof
456                             -length;
457                 }
458                 debug(std_regex_search) writefln("State: %32b", state);
459             }
460         }
461         else
462         {
463             //normal path, partially unrolled for char/wchar
464             static if (charSize == 3)
465             {
466                 const(ubyte)* end = cast(ubyte*)(haystack.ptr + haystack.length);
467                 while (p != end)
468                 {
469                     state = (state << 1) | table[p[0]];
470                     state = (state << 1) | table[p[1]];
471                     state = (state << 1) | table[p[2]];
472                     p += 4;
473                     if (!(state & limit))//division rounds down for dchar
474                         return (p-cast(ubyte*) haystack.ptr)/Char.sizeof
475                         -length;
476                 }
477             }
478             else
479             {
480                 auto len = cast(ubyte*)(haystack.ptr + haystack.length) - p;
481                 size_t i  = 0;
482                 if (len & 1)
483                 {
484                     state = (state << 1) | table[p[i++]];
485                     if (!(state & limit))
486                         return idx+i/Char.sizeof-length;
487                 }
488                 while (i < len)
489                 {
490                     state = (state << 1) | table[p[i++]];
491                     if (!(state & limit))
492                         return idx+i/Char.sizeof
493                             -length;
494                     state = (state << 1) | table[p[i++]];
495                     if (!(state & limit))
496                         return idx+i/Char.sizeof
497                             -length;
498                     debug(std_regex_search) writefln("State: %32b", state);
499                 }
500             }
501         }
502         return haystack.length;
503     }
504 
505     @system debug static void dump(uint[] table)
506     {//@@@BUG@@@ writef(ln) is @system
507         import std.stdio : writefln;
508         for (size_t i = 0; i < table.length; i += 4)
509         {
510             writefln("%32b %32b %32b %32b",table[i], table[i+1], table[i+2], table[i+3]);
511         }
512     }
513 }
514 
515 @system unittest
516 {
517     import std.conv, std.regex;
test_fixed(alias Kick)518     @trusted void test_fixed(alias Kick)()
519     {
520         static foreach (i, v; AliasSeq!(char, wchar, dchar))
521         {{
522             alias Char = v;
523             alias String = immutable(v)[];
524             auto r = regex(to!String(`abc$`));
525             auto kick = Kick!Char(r, new uint[256]);
526             assert(kick.length == 3, text(Kick.stringof," ",v.stringof, " == ", kick.length));
527             auto r2 = regex(to!String(`(abc){2}a+`));
528             kick = Kick!Char(r2, new uint[256]);
529             assert(kick.length == 7, text(Kick.stringof,v.stringof," == ", kick.length));
530             auto r3 = regex(to!String(`\b(a{2}b{3}){2,4}`));
531             kick = Kick!Char(r3, new uint[256]);
532             assert(kick.length == 10, text(Kick.stringof,v.stringof," == ", kick.length));
533             auto r4 = regex(to!String(`\ba{2}c\bxyz`));
534             kick = Kick!Char(r4, new uint[256]);
535             assert(kick.length == 6, text(Kick.stringof,v.stringof, " == ", kick.length));
536             auto r5 = regex(to!String(`\ba{2}c\b`));
537             kick = Kick!Char(r5, new uint[256]);
538             size_t x = kick.search("aabaacaa", 0);
539             assert(x == 3, text(Kick.stringof,v.stringof," == ", kick.length));
540             x = kick.search("aabaacaa", x+1);
541             assert(x == 8, text(Kick.stringof,v.stringof," == ", kick.length));
542         }}
543     }
test_flex(alias Kick)544     @trusted void test_flex(alias Kick)()
545     {
546         static foreach (i, v; AliasSeq!(char, wchar, dchar))
547         {{
548             alias Char = v;
549             alias String = immutable(v)[];
550             auto r = regex(to!String(`abc[a-z]`));
551             auto kick = Kick!Char(r, new uint[256]);
552             auto x = kick.search(to!String("abbabca"), 0);
553             assert(x == 3, text("real x is ", x, " ",v.stringof));
554 
555             auto r2 = regex(to!String(`(ax|bd|cdy)`));
556             String s2 = to!String("abdcdyabax");
557             kick = Kick!Char(r2, new uint[256]);
558             x = kick.search(s2, 0);
559             assert(x == 1, text("real x is ", x));
560             x = kick.search(s2, x+1);
561             assert(x == 3, text("real x is ", x));
562             x = kick.search(s2, x+1);
563             assert(x == 8, text("real x is ", x));
564             auto rdot = regex(to!String(`...`));
565             kick = Kick!Char(rdot, new uint[256]);
566             assert(kick.length == 0);
567             auto rN = regex(to!String(`a(b+|c+)x`));
568             kick = Kick!Char(rN, new uint[256]);
569             assert(kick.length == 3, to!string(kick.length));
570             assert(kick.search("ababx",0) == 2);
571             assert(kick.search("abaacba",0) == 3);//expected inexact
572 
573         }}
574     }
575     test_fixed!(ShiftOr)();
576     test_flex!(ShiftOr)();
577 }
578 
579 alias Kickstart = ShiftOr;
580