Newer
Older
# SPDX-License-Identifier: GPL-2.0-or-later
"""
Example:
./source/tools/utils/code_clean.py /src/cmake_debug --match ".*/editmesh_.*" --fix=use_const_vars
Note: currently this is limited to paths in "source/" and "intern/",
we could change this if it's needed.
"""
import re
import subprocess
import sys
import os
import string
from typing import (
Any,
Dict,
Generator,
List,
Optional,
Sequence,
Tuple,
Type,
)
# List of (source_file, all_arguments)
ProcessedCommands = List[Tuple[str, str]]
USE_MULTIPROCESS = True
VERBOSE = False
# Print the output of the compiler (_very_ noisy, only useful for troubleshooting compiler issues).
VERBOSE_COMPILER = False
#
# - Causes code not to compile.
# - Compiles but changes the resulting behavior.
# - Succeeds.
VERBOSE_EDIT_ACTION = False
BASE_DIR = os.path.abspath(os.path.dirname(__file__))
SOURCE_DIR = os.path.normpath(os.path.join(BASE_DIR, "..", "..", ".."))
# -----------------------------------------------------------------------------
# Generic Constants
IDENTIFIER_CHARS = set(string.ascii_letters + "_" + string.digits)
# -----------------------------------------------------------------------------
# General Utilities
# Note that we could use a hash, however there is no advantage, compare it's contents.
with open(filename, 'rb') as fh:
return fh.read()
def line_from_span(text: str, start: int, end: int) -> str:
while start > 0 and text[start - 1] != '\n':
start -= 1
while end < len(text) and text[end] != '\n':
end += 1
return text[start:end]
def files_recursive_with_ext(path: str, ext: Tuple[str, ...]) -> Generator[str, None, None]:
for dirpath, dirnames, filenames in os.walk(path):
# skip '.git' and other dot-files.
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
for filename in filenames:
if filename.endswith(ext):
yield os.path.join(dirpath, filename)
def text_matching_bracket_forward(
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
data: str,
pos_beg: int,
pos_limit: int,
beg_bracket: str,
end_bracket: str,
) -> int:
"""
Return the matching bracket or -1.
.. note:: This is not sophisticated, brackets in strings will confuse the function.
"""
level = 1
# The next bracket.
pos = pos_beg + 1
# Clamp the limit.
limit = min(pos_beg + pos_limit, len(data))
while pos < limit:
c = data[pos]
if c == beg_bracket:
level += 1
elif c == end_bracket:
level -= 1
if level == 0:
return pos
pos += 1
return -1
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def text_matching_bracket_backward(
data: str,
pos_end: int,
pos_limit: int,
beg_bracket: str,
end_bracket: str,
) -> int:
"""
Return the matching bracket or -1.
.. note:: This is not sophisticated, brackets in strings will confuse the function.
"""
level = 1
# The next bracket.
pos = pos_end - 1
# Clamp the limit.
limit = max(0, pos_limit)
while pos >= limit:
c = data[pos]
if c == end_bracket:
level += 1
elif c == beg_bracket:
level -= 1
if level == 0:
return pos
pos -= 1
return -1
# -----------------------------------------------------------------------------
# Execution Wrappers
else:
out = subprocess.DEVNULL
import shlex
p = subprocess.Popen(shlex.split(args), stdout=out, stderr=out)
p.wait()
return p.returncode
# -----------------------------------------------------------------------------
# Build System Access
def cmake_cache_var(cmake_dir: str, var: str) -> Optional[str]:
with open(os.path.join(cmake_dir, "CMakeCache.txt"), encoding='utf-8') as cache_file:
lines = [
l_strip for l in cache_file
if (l_strip := l.strip())
if not l_strip.startswith(("//", "#"))
]
for l in lines:
if l.split(":")[0] == var:
return l.split("=", 1)[-1]
return None
RE_CFILE_SEARCH = re.compile(r"\s\-c\s([\S]+)")
def process_commands(cmake_dir: str, data: Sequence[str]) -> Optional[ProcessedCommands]:
compiler_c = cmake_cache_var(cmake_dir, "CMAKE_C_COMPILER")
compiler_cxx = cmake_cache_var(cmake_dir, "CMAKE_CXX_COMPILER")
if compiler_c is None:
sys.stderr.write("Can't find C compiler in %r" % cmake_dir)
return None
if compiler_cxx is None:
sys.stderr.write("Can't find C++ compiler in %r" % cmake_dir)
return None
file_args = []
for l in data:
if (
(compiler_c in l) or
(compiler_cxx in l)
):
# Extract:
# -c SOME_FILE
c_file_search = re.search(RE_CFILE_SEARCH, l)
if c_file_search is not None:
c_file = c_file_search.group(1)
file_args.append((c_file, l))
else:
# could print, NO C FILE FOUND?
pass
file_args.sort()
return file_args
def find_build_args_ninja(build_dir: str) -> Optional[ProcessedCommands]:
cmake_dir = build_dir
make_exe = "ninja"
process = subprocess.Popen(
[make_exe, "-t", "commands"],
stdout=subprocess.PIPE,
cwd=build_dir,
)
while process.poll():
time.sleep(1)
assert process.stdout is not None
out = process.stdout.read()
process.stdout.close()
# print("done!", len(out), "bytes")
data = out.decode("utf-8", errors="ignore").split("\n")
return process_commands(cmake_dir, data)
def find_build_args_make(build_dir: str) -> Optional[ProcessedCommands]:
make_exe = "make"
process = subprocess.Popen(
[make_exe, "--always-make", "--dry-run", "--keep-going", "VERBOSE=1"],
stdout=subprocess.PIPE,
cwd=build_dir,
)
while process.poll():
time.sleep(1)
assert process.stdout is not None
out = process.stdout.read()
process.stdout.close()
# print("done!", len(out), "bytes")
data = out.decode("utf-8", errors="ignore").split("\n")
return process_commands(build_dir, data)
# -----------------------------------------------------------------------------
# Create Edit Lists
# Create an edit list from a file, in the format:
#
# [((start_index, end_index), text_to_replace), ...]
#
# Note that edits should not overlap, in the _very_ rare case overlapping edits are needed,
# this could be run multiple times on the same code-base.
#
# Although this seems like it's not a common use-case.
from collections import namedtuple
Edit = namedtuple(
"Edit", (
# Keep first, for sorting.
"span",
"content",
"content_fail",
# Optional.
"extra_build_args",
),
defaults=(
# `extra_build_args`.
None,
)
)
del namedtuple
class EditGenerator:
__slots__ = ()
def __new__(cls, *args: Tuple[Any], **kwargs: Dict[str, Any]) -> Any:
raise RuntimeError("%s should not be instantiated" % cls)
@staticmethod
def edit_list_from_file(_source: str, _data: str, _shared_edit_data: Any) -> List[Edit]:
raise RuntimeError("This function must be overridden by it's subclass!")
return []
@staticmethod
def setup() -> Any:
return None
@staticmethod
def teardown(_shared_edit_data: Any) -> None:
pass
class edit_generators:
# fake module.
class sizeof_fixed_array(EditGenerator):
"""
Use fixed size array syntax with `sizeof`:
Replace:
sizeof(float) * 4 * 4
With:
sizeof(float[4][4])
"""
@staticmethod
def edit_list_from_file(_source: str, data: str, _shared_edit_data: Any) -> List[Edit]:
edits = []
for match in re.finditer(r"sizeof\(([a-zA-Z_]+)\) \* (\d+) \* (\d+)", data):
edits.append(Edit(
span=match.span(),
content='sizeof(%s[%s][%s])' % (match.group(1), match.group(2), match.group(3)),
content_fail='__ALWAYS_FAIL__',
))
for match in re.finditer(r"sizeof\(([a-zA-Z_]+)\) \* (\d+)", data):
edits.append(Edit(
span=match.span(),
content='sizeof(%s[%s])' % (match.group(1), match.group(2)),
content_fail='__ALWAYS_FAIL__',
))
for match in re.finditer(r"\b(\d+) \* sizeof\(([a-zA-Z_]+)\)", data):
edits.append(Edit(
span=match.span(),
content='sizeof(%s[%s])' % (match.group(2), match.group(1)),
content_fail='__ALWAYS_FAIL__',
))
return edits
class use_const(EditGenerator):
"""
Use const variables:
Replace:
float abc[3] = {0, 1, 2};
With:
const float abc[3] = {0, 1, 2};
Replace:
float abc[3]
With:
const float abc[3]
As well as casts.
Replace:
(float *)
With:
(const float *)
Replace:
(float (*))
With:
(const float (*))
"""
@staticmethod
def edit_list_from_file(_source: str, data: str, _shared_edit_data: Any) -> List[Edit]:
edits = []
# `float abc[3] = {0, 1, 2};` -> `const float abc[3] = {0, 1, 2};`
for match in re.finditer(r"(\(|, | )([a-zA-Z_0-9]+ [a-zA-Z_0-9]+\[)\b([^\n]+ = )", data):
edits.append(Edit(
span=match.span(),
content='%s const %s%s' % (match.group(1), match.group(2), match.group(3)),
content_fail='__ALWAYS_FAIL__',
))
# `float abc[3]` -> `const float abc[3]`
for match in re.finditer(r"(\(|, )([a-zA-Z_0-9]+ [a-zA-Z_0-9]+\[)", data):
edits.append(Edit(
span=match.span(),
content='%s const %s' % (match.group(1), match.group(2)),
content_fail='__ALWAYS_FAIL__',
))
# `(float *)` -> `(const float *)`
# `(float (*))` -> `(const float (*))`
# `(float (*)[4])` -> `(const float (*)[4])`
for match in re.finditer(
r"(\()"
r"([a-zA-Z_0-9]+\s*)"
r"(\*+\)|\(\*+\))"
r"(|\[[a-zA-Z_0-9]+\])",
data,
):
edits.append(Edit(
span=match.span(),
content='%sconst %s%s%s' % (match.group(1), match.group(2), match.group(3), match.group(4)),
content_fail='__ALWAYS_FAIL__',
))
return edits
class use_zero_before_float_suffix(EditGenerator):
"""
Use zero before the float suffix.
Replace:
1.f
With:
1.0f
Replace:
1.0F
With:
1.0f
"""
@staticmethod
def edit_list_from_file(_source: str, data: str, _shared_edit_data: Any) -> List[Edit]:
edits = []
# `1.f` -> `1.0f`
for match in re.finditer(r"\b(\d+)\.([fF])\b", data):
edits.append(Edit(
span=match.span(),
content='%s.0%s' % (match.group(1), match.group(2)),
content_fail='__ALWAYS_FAIL__',
))
# `1.0F` -> `1.0f`
for match in re.finditer(r"\b(\d+\.\d+)F\b", data):
edits.append(Edit(
span=match.span(),
content='%sf' % (match.group(1),),
content_fail='__ALWAYS_FAIL__',
))
return edits
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
class use_brief_types(EditGenerator):
"""
Use zero before the float suffix.
Replace:
unsigned int
With:
uint
"""
@staticmethod
def edit_list_from_file(_source: str, data: str, _shared_edit_data: Any) -> List[Edit]:
edits = []
# `unsigned char` -> `uchar`.
for match in re.finditer(r"(unsigned)\s+([a-z]+)", data):
edits.append(Edit(
span=match.span(),
content='u%s' % match.group(2),
content_fail='__ALWAYS_FAIL__',
))
# There may be some remaining uses of `unsigned` without any integer type afterwards.
# `unsigned` -> `uint`.
for match in re.finditer(r"\bunsigned\b", data):
edits.append(Edit(
span=match.span(),
content='uint',
content_fail='__ALWAYS_FAIL__',
))
return edits
class use_elem_macro(EditGenerator):
"""
Use the `ELEM` macro for more abbreviated expressions.
Replace:
(a == b || a == c)
(a != b && a != c)
With:
(ELEM(a, b, c))
(!ELEM(a, b, c))
"""
@staticmethod
def edit_list_from_file(_source: str, data: str, _shared_edit_data: Any) -> List[Edit]:
edits = []
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
for use_brackets in (True, False):
test_equal = (
r'([^\|\(\)]+)' # group 1 (no (|))
r'\s+==\s+'
r'([^\|\(\)]+)' # group 2 (no (|))
)
test_not_equal = (
r'([^\|\(\)]+)' # group 1 (no (|))
r'\s+!=\s+'
r'([^\|\(\)]+)' # group 2 (no (|))
)
if use_brackets:
test_equal = r'\(' + test_equal + r'\)'
test_not_equal = r'\(' + test_not_equal + r'\)'
for is_equal in (True, False):
for n in reversed(range(2, 64)):
if is_equal:
re_str = r'\(' + r'\s+\|\|\s+'.join([test_equal] * n) + r'\)'
else:
re_str = r'\(' + r'\s+\&\&\s+'.join([test_not_equal] * n) + r'\)'
for match in re.finditer(re_str, data):
var = match.group(1)
var_rest = []
groups = match.groups()
groups_paired = [(groups[i * 2], groups[i * 2 + 1]) for i in range(len(groups) // 2)]
found = True
for a, b in groups_paired:
# Unlikely but possible the checks are swapped.
if b == var and a != var:
a, b = b, a
if a != var:
found = False
break
var_rest.append(b)
if found:
edits.append(Edit(
span=match.span(),
content='(%sELEM(%s, %s))' % (
('' if is_equal else '!'),
var,
', '.join(var_rest),
),
# Use same expression otherwise this can change values
# inside assert when it shouldn't.
content_fail='(%s__ALWAYS_FAIL__(%s, %s))' % (
('' if is_equal else '!'),
var,
', '.join(var_rest),
),
))
return edits
class use_str_elem_macro(EditGenerator):
"""
Use `STR_ELEM` macro:
Replace:
(STREQ(a, b) || STREQ(a, c))
With:
(STR_ELEM(a, b, c))
"""
@staticmethod
def edit_list_from_file(_source: str, data: str, _shared_edit_data: Any) -> List[Edit]:
edits = []
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
for use_brackets in (True, False):
test_equal = (
r'STREQ'
r'\('
r'([^\|\(\),]+)' # group 1 (no (|,))
r',\s+'
r'([^\|\(\),]+)' # group 2 (no (|,))
r'\)'
)
test_not_equal = (
'!' # Only difference.
r'STREQ'
r'\('
r'([^\|\(\),]+)' # group 1 (no (|,))
r',\s+'
r'([^\|\(\),]+)' # group 2 (no (|,))
r'\)'
)
if use_brackets:
test_equal = r'\(' + test_equal + r'\)'
test_not_equal = r'\(' + test_not_equal + r'\)'
for is_equal in (True, False):
for n in reversed(range(2, 64)):
if is_equal:
re_str = r'\(' + r'\s+\|\|\s+'.join([test_equal] * n) + r'\)'
else:
re_str = r'\(' + r'\s+\&\&\s+'.join([test_not_equal] * n) + r'\)'
for match in re.finditer(re_str, data):
if _source == '/src/blender/source/blender/editors/mesh/editmesh_extrude_spin.c':
print(match.groups())
var = match.group(1)
var_rest = []
groups = match.groups()
groups_paired = [(groups[i * 2], groups[i * 2 + 1]) for i in range(len(groups) // 2)]
found = True
for a, b in groups_paired:
# Unlikely but possible the checks are swapped.
if b == var and a != var:
a, b = b, a
if a != var:
found = False
break
var_rest.append(b)
if found:
edits.append(Edit(
span=match.span(),
content='(%sSTR_ELEM(%s, %s))' % (
('' if is_equal else '!'),
var,
', '.join(var_rest),
),
# Use same expression otherwise this can change values
# inside assert when it shouldn't.
content_fail='(%s__ALWAYS_FAIL__(%s, %s))' % (
('' if is_equal else '!'),
var,
', '.join(var_rest),
),
))
return edits
class use_const_vars(EditGenerator):
"""
Use `const` where possible:
Replace:
float abc[3] = {0, 1, 2};
With:
const float abc[3] = {0, 1, 2};
"""
@staticmethod
def edit_list_from_file(_source: str, data: str, _shared_edit_data: Any) -> List[Edit]:
edits = []
# for match in re.finditer(r"( [a-zA-Z0-9_]+ [a-zA-Z0-9_]+ = [A-Z][A-Z_0-9_]*;)", data):
# edits.append(Edit(
# span=match.span(),
# content='const %s' % (match.group(1).lstrip()),
# content_fail='__ALWAYS_FAIL__',
# ))
for match in re.finditer(r"( [a-zA-Z0-9_]+ [a-zA-Z0-9_]+ = .*;)", data):
edits.append(Edit(
span=match.span(),
content='const %s' % (match.group(1).lstrip()),
content_fail='__ALWAYS_FAIL__',
))
return edits
class remove_return_parens(EditGenerator):
"""
Remove redundant parenthesis around return arguments:
Replace:
return (value);
With:
return value;
"""
@staticmethod
def edit_list_from_file(_source: str, data: str, _shared_edit_data: Any) -> List[Edit]:
edits = []
# Remove `return (NULL);`
for match in re.finditer(r"return \(([a-zA-Z_0-9]+)\);", data):
edits.append(Edit(
span=match.span(),
content='return %s;' % (match.group(1)),
content_fail='return __ALWAYS_FAIL__;',
))
return edits
class use_streq_macro(EditGenerator):
"""
Use `STREQ` macro:
Replace:
strcmp(a, b) == 0
With:
STREQ(a, b)
Replace:
strcmp(a, b) != 0
With:
!STREQ(a, b)
"""
@staticmethod
def edit_list_from_file(_source: str, data: str, _shared_edit_data: Any) -> List[Edit]:
edits = []
# `strcmp(a, b) == 0` -> `STREQ(a, b)`
for match in re.finditer(r"\bstrcmp\((.*)\) == 0", data):
edits.append(Edit(
span=match.span(),
content='STREQ(%s)' % (match.group(1)),
content_fail='__ALWAYS_FAIL__',
))
for match in re.finditer(r"!strcmp\((.*)\)", data):
edits.append(Edit(
span=match.span(),
content='STREQ(%s)' % (match.group(1)),
content_fail='__ALWAYS_FAIL__',
))
# `strcmp(a, b) != 0` -> `!STREQ(a, b)`
for match in re.finditer(r"\bstrcmp\((.*)\) != 0", data):
edits.append(Edit(
span=match.span(),
content='!STREQ(%s)' % (match.group(1)),
content_fail='__ALWAYS_FAIL__',
))
for match in re.finditer(r"\bstrcmp\((.*)\)", data):
edits.append(Edit(
span=match.span(),
content='!STREQ(%s)' % (match.group(1)),
content_fail='__ALWAYS_FAIL__',
))
return edits
class use_array_size_macro(EditGenerator):
"""
Use macro for an error checked array size:
Replace:
sizeof(foo) / sizeof(*foo)
With:
ARRAY_SIZE(foo)
"""
@staticmethod
def edit_list_from_file(_source: str, data: str, _shared_edit_data: Any) -> List[Edit]:
edits = []
# Note that this replacement is only valid in some cases,
# so only apply with validation that binary output matches.
for match in re.finditer(r"\bsizeof\((.*)\) / sizeof\([^\)]+\)", data):
edits.append(Edit(
span=match.span(),
content='ARRAY_SIZE(%s)' % match.group(1),
content_fail='__ALWAYS_FAIL__',
))
return edits
class parenthesis_cleanup(EditGenerator):
"""
Use macro for an error checked array size:
Replace:
((a + b))
With:
(a + b)
Replace:
(func(a + b))
With:
func(a + b)
Note that the `CFLAGS` should be set so missing parentheses that contain assignments - error instead of warn:
With GCC: `-Werror=parentheses`
"""
@staticmethod
def edit_list_from_file(_source: str, data: str, _shared_edit_data: Any) -> List[Edit]:
edits = []
# Give up after searching for a bracket this many characters and finding none.
bracket_seek_limit = 4000
# Don't match double brackets because this will not match multiple overlapping matches
# Where 3 brackets should be checked as two separate pairs.
for match in re.finditer(r"(\()", data):
outer_beg = match.span()[0]
inner_beg = outer_beg + 1
if data[inner_beg] != "(":
continue
inner_end = text_matching_bracket_forward(data, inner_beg, inner_beg + bracket_seek_limit, "(", ")")
if inner_end == -1:
continue
outer_beg = inner_beg - 1
outer_end = text_matching_bracket_forward(data, outer_beg, inner_end + 1, "(", ")")
if outer_end != inner_end + 1:
continue
text = data[inner_beg:inner_end + 1]
edits.append(Edit(
span=(outer_beg, outer_end + 1),
content=text,
content_fail='(__ALWAYS_FAIL__)',
))
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
# Handle `(func(a + b))` -> `func(a + b)`
for match in re.finditer(r"(\))", data):
inner_end = match.span()[0]
outer_end = inner_end + 1
if data[outer_end] != ")":
continue
inner_beg = text_matching_bracket_backward(data, inner_end, inner_end - bracket_seek_limit, "(", ")")
if inner_beg == -1:
continue
outer_beg = text_matching_bracket_backward(data, outer_end, outer_end - bracket_seek_limit, "(", ")")
if outer_beg == -1:
continue
# The text between the first two opening brackets:
# `(function_name(a + b))` -> `function_name`.
text = data[outer_beg + 1:inner_beg]
# Handled in the first loop looking for forward brackets.
if text == "":
continue
# Don't convert `prefix(func(a + b))` -> `prefixfunc(a + b)`
if data[outer_beg - 1] in IDENTIFIER_CHARS:
continue
# Don't convert `static_cast<float>(foo(bar))` -> `static_cast<float>foo(bar)`
# While this will always fail to compile it slows down tests.
if data[outer_beg - 1] == ">":
continue
# Exact rule here is arbitrary, in general though spaces mean there are operations
# that can use the brackets.
if " " in text:
continue
# Search back an arbitrary number of chars 8 should be enough
# but manual formatting can add additional white-space, so increase
# the size to account for that.
prefix = data[max(outer_beg - 20, 0):outer_beg].strip()
if prefix:
# Avoid `if (SOME_MACRO(..)) {..}` -> `if SOME_MACRO(..) {..}`
# While correct it relies on parenthesis within the macro which isn't ideal.
if prefix.split()[-1] in {"if", "while", "switch"}:
continue
# Avoid `*(--foo)` -> `*--foo`.
# While correct it reads badly.
if data[outer_beg - 1] == "*":
continue
text_no_parens = data[outer_beg + 1: outer_end]
edits.append(Edit(
span=(outer_beg, outer_end + 1),
content=text_no_parens,
content_fail='__ALWAYS_FAIL__',
))
class header_clean(EditGenerator):
"""
Clean headers, ensuring that the headers removed are not used directly or indirectly.
Note that the `CFLAGS` should be set so missing prototypes error instead of warn:
With GCC: `-Werror=missing-prototypes`
"""
@staticmethod
def _header_guard_from_filename(f: str) -> str:
return '__%s__' % os.path.basename(f).replace('.', '_').upper()
@classmethod
# For each file replace `pragma once` with old-style header guard.
# This is needed so we can remove the header with the knowledge the source file didn't use it indirectly.
files: List[Tuple[str, str, str, str]] = []
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
shared_edit_data = {
'files': files,
}
for f in files_recursive_with_ext(
os.path.join(SOURCE_DIR, 'source'),
('.h', '.hh', '.inl', '.hpp', '.hxx'),
):
with open(f, 'r', encoding='utf-8') as fh:
data = fh.read()
for match in re.finditer(r'^[ \t]*#\s*(pragma\s+once)\b', data, flags=re.MULTILINE):
header_guard = cls._header_guard_from_filename(f)
start, end = match.span()
src = data[start:end]
dst = (
'#ifndef %s\n#define %s' % (header_guard, header_guard)
)
dst_footer = '\n#endif /* %s */\n' % header_guard
files.append((f, src, dst, dst_footer))
data = data[:start] + dst + data[end:] + dst_footer
with open(f, 'w', encoding='utf-8') as fh:
fh.write(data)
break
return shared_edit_data
@staticmethod
def teardown(shared_edit_data: Any) -> None:
files = shared_edit_data['files']
for f, src, dst, dst_footer in files:
with open(f, 'r', encoding='utf-8') as fh:
data = fh.read()
data = data.replace(
dst, src,
).replace(
dst_footer, '',
)
with open(f, 'w', encoding='utf-8') as fh:
fh.write(data)
@classmethod
def edit_list_from_file(cls, _source: str, data: str, _shared_edit_data: Any) -> List[Edit]:
edits = []
# Remove include.
for match in re.finditer(r"^(([ \t]*#\s*include\s+\")([^\"]+)(\"[^\n]*\n))", data, flags=re.MULTILINE):
header_name = match.group(3)
header_guard = cls._header_guard_from_filename(header_name)
edits.append(Edit(
span=match.span(),
content_fail='%s__ALWAYS_FAIL__%s' % (match.group(2), match.group(4)),
extra_build_args=('-D' + header_guard),
))
return edits
def test_edit(
source: str,
output: str,
output_bytes: Optional[bytes],
build_args: str,
data: str,
data_test: str,
keep_edits: bool = True,
expect_failure: bool = False,
) -> bool:
"""
Return true if `data_test` has the same object output as `data`.
"""
if os.path.exists(output):
os.remove(output)
with open(source, 'w', encoding='utf-8') as fh:
fh.write(data_test)
ret = run(build_args, quiet=expect_failure)
if ret == 0:
output_bytes_test = file_as_bytes(output)
if (output_bytes is None) or (file_as_bytes(output) == output_bytes):
if not keep_edits:
with open(source, 'w', encoding='utf-8') as fh:
fh.write(data)
return True
else:
if VERBOSE_EDIT_ACTION:
print("Changed code, skip...", hex(hash(output_bytes)), hex(hash(output_bytes_test)))
if VERBOSE_EDIT_ACTION:
print("Failed to compile, skip...")
with open(source, 'w', encoding='utf-8') as fh:
fh.write(data)
return False
# -----------------------------------------------------------------------------
# List Fix Functions
for name in dir(edit_generators):
value = getattr(edit_generators, name)
if type(value) is type and issubclass(value, EditGenerator):
fixes.append(name)
def edit_class_from_id(name: str) -> Type[EditGenerator]:
result = getattr(edit_generators, name)
assert issubclass(result, EditGenerator)
# MYPY 0.812 doesn't recognize the assert above.
return result # type: ignore
# -----------------------------------------------------------------------------