xref: /llvm-project/flang/test/Evaluate/test_folding.py (revision f98ee40f4b5d7474fc67e82824bf6abbaedb7b1c)
1#!/usr/bin/env python3
2
3"""This script verifies expression folding.
4It compiles a source file with '-fdebug-dump-symbols'
5and looks for parameter declarations to check
6they have been folded as expected.
7To check folding of an expression EXPR,
8the fortran program passed to this script
9must contain the following:
10
11  logical, parameter :: test_x = <compare EXPR to expected value>
12
13This script will test that all parameter
14with a name starting with "test_"
15have been folded to .true.
16For instance, acos folding can be tested with:
17
18  real(4), parameter :: res_acos = acos(0.5_4)
19  real(4), parameter :: exp_acos = 1.047
20  logical, parameter :: test_acos = abs(res_acos - exp_acos).LE.(0.001_4)
21
22There are two kinds of failure:
23  - test_x is folded to .false..
24    This means the expression was folded
25    but the value is not as expected.
26  - test_x is not folded (it is neither .true. nor .false.).
27    This means the compiler could not fold the expression.
28
29Parameters:
30    sys.argv[1]: a source file with contains the input and expected output
31    sys.argv[2]: the Flang frontend driver
32    sys.argv[3:]: Optional arguments to the Flang frontend driver"""
33
34import os
35import sys
36import tempfile
37import re
38import subprocess
39
40from difflib import unified_diff
41from pathlib import Path
42
43
44def check_args(args):
45    """Verifies that the number is arguments passed is correct."""
46    if len(args) < 3:
47        print(f"Usage: {args[0]} <fortran-source> <flang-command>")
48        sys.exit(1)
49
50
51def set_source(source):
52    """Sets the path to the source files."""
53    if not Path(source).is_file():
54        print(f"File not found: {src}")
55        sys.exit(1)
56    return Path(source)
57
58
59def set_executable(exe):
60    """Sets the path to the Flang frontend driver."""
61    if not Path(exe).is_file():
62        print(f"Flang was not found: {exe}")
63        sys.exit(1)
64    return str(Path(exe))
65
66
67check_args(sys.argv)
68cwd = os.getcwd()
69srcdir = set_source(sys.argv[1]).resolve()
70with open(srcdir, "r", encoding="utf-8") as f:
71    src = f.readlines()
72src1 = ""
73src2 = ""
74src3 = ""
75src4 = ""
76messages = ""
77actual_warnings = ""
78expected_warnings = ""
79warning_diffs = ""
80
81flang_fc1 = set_executable(sys.argv[2])
82flang_fc1_args = sys.argv[3:]
83flang_fc1_options = ""
84LIBPGMATH = os.getenv("LIBPGMATH")
85if LIBPGMATH:
86    flang_fc1_options = ["-fdebug-dump-symbols", "-DTEST_LIBPGMATH"]
87    print("Assuming libpgmath support")
88else:
89    flang_fc1_options = ["-fdebug-dump-symbols"]
90    print("Not assuming libpgmath support")
91
92cmd = [flang_fc1, *flang_fc1_args, *flang_fc1_options, str(srcdir)]
93with tempfile.TemporaryDirectory() as tmpdir:
94    proc = subprocess.run(
95        cmd,
96        stdout=subprocess.PIPE,
97        stderr=subprocess.PIPE,
98        check=True,
99        universal_newlines=True,
100        cwd=tmpdir,
101    )
102    src1 = proc.stdout
103    messages = proc.stderr
104
105for line in src1.split("\n"):
106    m = re.search(r"(\w*)(?=, PARAMETER).*init:(.*)", line)
107    if m:
108        src2 += f"{m.group(1)} {m.group(2)}\n"
109
110for line in src2.split("\n"):
111    m = re.match(r"test_*", line)
112    if m:
113        src3 += f"{m.string}\n"
114
115for passed_results, line in enumerate(src3.split("\n")):
116    m = re.search(r"\.false\._.$", line)
117    if m:
118        src4 += f"{line}\n"
119
120for line in messages.split("\n"):
121    m = re.search(r"[^:]*:(\d*):\d*: (.*)", line)
122    if m:
123        actual_warnings += f"{m.group(1)}: {m.group(2)}\n"
124
125passed_warnings = 0
126warnings = []
127for i, line in enumerate(src, 1):
128    m = re.search(r"(?:!WARN:)(.*)", line)
129    if m:
130        warnings.append(m.group(1))
131        continue
132    if warnings:
133        for x in warnings:
134            passed_warnings += 1
135            expected_warnings += f"{i}:{x}\n"
136        warnings = []
137
138for line in unified_diff(
139    actual_warnings.split("\n"), expected_warnings.split("\n"), n=0
140):
141    line = re.sub(r"(^\-)(\d+:)", r"\nactual at \g<2>", line)
142    line = re.sub(r"(^\+)(\d+:)", r"\nexpect at \g<2>", line)
143    warning_diffs += line
144
145if src4 or warning_diffs:
146    print("Folding test failed:")
147    # Prints failed tests, including parameters with the same
148    # suffix so that more information can be obtained by declaring
149    # expected_x and result_x
150    if src4:
151        for line in src4.split("\n"):
152            m = re.match(r"test_(\w+)", line)
153            if m:
154                for line in src2.split("\n"):
155                    if m.group(1) in line:
156                        print(line)
157    if warning_diffs:
158        print(warning_diffs)
159    print()
160    print("FAIL")
161    sys.exit(1)
162else:
163    print()
164    print(f"All {passed_results+passed_warnings} tests passed")
165    print("PASS")
166