Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extending TypeApplier to remove type annotations + Improvements #26

Open
wants to merge 46 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
39660d5
Removing type annotatations that do not type check by mypy [WIP] - Pa…
mir-am Jun 30, 2021
54ba46b
Removing type annotations that do not type check by mypy [WIP] - Part 2
mir-am Jul 1, 2021
dfe612e
Removing type annotations that do not type check by mypy [WIP] - Part 3
mir-am Jul 1, 2021
d2a41f2
Fixed AttributeError when removing annotations for unintialized vars
mir-am Jul 1, 2021
a170308
Removing type annotations that do not type check by mypy [WIP] - Part 4
mir-am Jul 1, 2021
f7b286d
Removing type annotations that do not type check by mypy [WIP] - Part 4
mir-am Jul 1, 2021
715430a
Removing type annotations that do not type check by mypy [WIP] - Part 5
mir-am Jul 1, 2021
8f46016
Removing type annotations that do not type check by mypy [WIP] - Part 6
mir-am Jul 2, 2021
00011f0
Merge branch 'extend-type-applier' of https://github.com/saltudelft/l…
mir-am Jul 5, 2021
1b48708
Merge remote-tracking branch 'origin/ln-vars' into extend-type-applier
mir-am Jul 5, 2021
c638973
Improvements to the TypeAnnotationsRemoval pipeline
mir-am Jul 12, 2021
7aa76bd
Improve type annotation removal code
mir-am Jul 12, 2021
593f96a
Merge branch 'master' into extend-type-applier
mir-am Jul 12, 2021
54f3e50
Merge remote-tracking branch 'origin/improve-imp' into extend-type-ap…
mir-am Jul 12, 2021
1c51fb9
Merge remote-tracking branch 'origin/optional-seq2seq' into extend-ty…
mir-am Jul 14, 2021
dc2412a
Merge branch 'master' into extend-type-applier
mir-am Jul 15, 2021
9d58e2b
Merge remote-tracking branch 'origin/optional-lemmatize' into extend-…
mir-am Jul 16, 2021
dde5757
Merge remote-tracking branch 'origin/optional-lemmatize' into extend-…
mir-am Jul 16, 2021
e726688
Merge remote-tracking branch 'origin/optional-lemmatize' into extend-…
mir-am Jul 18, 2021
edc8615
Merge branch 'master' into extend-type-applier
mir-am Jul 19, 2021
17f8bcc
Improve TypeApplier by matching functions by line and column no. & ma…
mir-am Jul 19, 2021
221c14c
Fix unit tests for TypeApplier when matching functions based on line …
mir-am Jul 19, 2021
57eacf6
Remove superfluous assignment line from TypeApplier
mir-am Jul 19, 2021
f0aa00e
A workaround for a very rare case where the class' QN doesn't match w…
mir-am Jul 21, 2021
19428fd
When applying types, first match functions' QN & signature first, if …
mir-am Jul 21, 2021
65130df
Count total no. of added types in TypeApplier and its pipeline
mir-am Jul 22, 2021
6166771
Improvements to the TypeAnnotatingProjects pipeline
mir-am Jul 22, 2021
95d3c7e
Fix test failure for types removal
mir-am Jul 23, 2021
361986e
Merge remote-tracking branch 'origin/annotation-counter' into extend-…
mir-am Jul 25, 2021
7927d45
Improvements to TypeApplier: (1) Better matching of function, classes…
mir-am Jul 26, 2021
d41fea3
Improvments to the pipeline of TypeApplier: (1) Dry run (2) Assertion…
mir-am Jul 26, 2021
1c23f76
Merge branch 'annotation-counter' into extend-type-applier
mir-am Jul 26, 2021
3111e84
Fix test failure for the TypeAppier
mir-am Jul 26, 2021
8efb7b1
Improvements to the TypeRemoval pipeline : (1) Dry run (2) better mul…
mir-am Jul 28, 2021
1fc8156
Run mypy with the file's abs. path, which may improve TC in some cases
mir-am Jul 28, 2021
09e5362
Merge branch 'extend-type-applier' of https://github.com/saltudelft/l…
mir-am Jul 28, 2021
a89338c
Fixing re-importing names when applying types
mir-am Aug 2, 2021
d374848
(1) Exclude source files in the ignored list for the main pipeline, (…
mir-am Aug 2, 2021
3ca89f4
Merge branch 'extend-type-applier' of https://github.com/saltudelft/l…
mir-am Aug 2, 2021
729c76d
Putting large projects at the front of the jobs' queue to reduce over…
mir-am Aug 6, 2021
3553944
ignore type errors of imported modules and missing imports when type-…
mir-am Aug 6, 2021
61b4b0c
Improvements to TypeApplier: (1) write ignored files to a separate fi…
mir-am Aug 6, 2021
e3ddcc2
In the main pipeline, sort projects based on total size of their files
mir-am Aug 6, 2021
5456c24
Merge remote-tracking branch 'origin/pipeline-jobs' into extend-type-…
mir-am Aug 6, 2021
715aea8
Improvements to TypeRemover: (1) Copying input dataset to another des…
mir-am Aug 10, 2021
060c4be
Add a utility method to copy files while making required dirs
mir-am Aug 10, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions libsa4py/__main__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,27 @@
from argparse import ArgumentParser
from multiprocessing import cpu_count
from libsa4py.utils import find_repos_list
from libsa4py.cst_pipeline import Pipeline, TypeAnnotatingProjects
from libsa4py.cst_pipeline import Pipeline, TypeAnnotatingProjects, TypeAnnotationsRemoval
from libsa4py.merge import merge_projects


def process_projects(args):
input_repos = find_repos_list(args.p) if args.l is None else find_repos_list(args.p)[:args.l]
p = Pipeline(args.p, args.o, not args.no_nlp, args.use_cache, args.use_pyre, args.use_tc, args.d, args.s)
p = Pipeline(args.p, args.o, not args.no_nlp, args.use_cache, args.use_pyre, args.use_tc, args.d,
args.s, args.i)
p.run(input_repos, args.j)


def apply_types_projects(args):
tap = TypeAnnotatingProjects(args.p, args.o)
tap = TypeAnnotatingProjects(args.p, args.o, args.dry_run)
tap.run(args.j)


def remove_err_type_annotations(args):
tar = TypeAnnotationsRemoval(args.i, args.o, args.p, args.l, args.dry_run)
tar.run(args.j)


def main():

arg_parser = ArgumentParser(description="Light-weight static analysis to extract Python's code representations")
Expand All @@ -26,6 +32,7 @@ def main():
process_parser.add_argument("--o", required=True, type=str, help="Path to store JSON-based processed projects")
process_parser.add_argument("--d", "--deduplicate", required=False, type=str, help="Path to duplicate files")
process_parser.add_argument("--s", "--split", required=False, type=str, help="Path to the dataset split files")
process_parser.add_argument("--i", "--ignore", required=False, type=str, help="Path to the ignored files")
process_parser.add_argument("--j", default=cpu_count(), type=int, help="Number of workers for processing projects")
process_parser.add_argument("--l", required=False, type=int, help="Number of projects to process")
process_parser.add_argument("--c", "--cache", dest='use_cache', action='store_true', help="Whether to ignore processed projects")
Expand All @@ -51,8 +58,23 @@ def main():
apply_parser.add_argument("--p", required=True, type=str, help="Path to Python projects")
apply_parser.add_argument("--o", required=True, type=str, help="Path to store JSON-based processed projects")
apply_parser.add_argument("--j", default=cpu_count(), type=int, help="Number of workers for processing projects")
apply_parser.add_argument("--d", dest='dry_run', action='store_true',
help="Dry run does not apply types to the dataset's files")

apply_parser.set_defaults(dry_run=False)
apply_parser.set_defaults(func=apply_types_projects)

remove_parser = sub_parsers.add_parser('remove')
remove_parser.add_argument("--i", required=True, type=str, help="Path to input dataset")
remove_parser.add_argument("--o", required=True, type=str, help="Path to output dataset")
remove_parser.add_argument("--p", required=True, type=str, help="Path to JSON-formatted processed projects")
remove_parser.add_argument("--j", default=cpu_count(), type=int, help="Number of workers for processing files")
remove_parser.add_argument("--l", required=False, type=int, help="Number of projects to process")
remove_parser.add_argument("--d", dest='dry_run', action='store_true',
help="Dry run does not remove types from the dataset's files")
remove_parser.set_defaults(dry_run=False)
remove_parser.set_defaults(func=remove_err_type_annotations)

args = arg_parser.parse_args()
args.func(args)

Expand Down
442 changes: 417 additions & 25 deletions libsa4py/cst_pipeline.py

Large diffs are not rendered by default.

116 changes: 98 additions & 18 deletions libsa4py/cst_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,13 +853,15 @@ def leave_SubscriptElement(self, original_node, updated_node):
return updated_node


# TODO: Write two separate CSTTransformers for applying and removing type annotations
class TypeApplier(cst.CSTTransformer):
"""
It applies (inferred) type annotations to a source code file.
Specifically, it applies the type of arguments, return types, and variables' type.
"""

METADATA_DEPENDENCIES = (cst.metadata.ScopeProvider, cst.metadata.QualifiedNameProvider)
METADATA_DEPENDENCIES = (cst.metadata.ScopeProvider, cst.metadata.QualifiedNameProvider,
cst.metadata.PositionProvider)

def __init__(self, f_processeed_dict: dict, apply_nlp: bool=True):
self.f_processed_dict = f_processeed_dict
Expand All @@ -871,6 +873,10 @@ def __init__(self, f_processeed_dict: dict, apply_nlp: bool=True):
self.lambda_d = 0

self.all_applied_types = set()
self.no_applied_types = 0
self.no_failed_applied_types = 0

self.imported_names: List[str] = []

if apply_nlp:
self.nlp_p = NLPreprocessor().process_identifier
Expand All @@ -883,9 +889,15 @@ def __get_fn(self, f_node: cst.FunctionDef) -> dict:
else:
fns = self.f_processed_dict['funcs']

qn = self.__get_qualified_name(f_node.name)
fn_params = set(self.__get_fn_params(f_node.params))
fn_lc = self.__get_line_column_no(f_node)
for fn in fns:
if fn['q_name'] == self.__get_qualified_name(f_node.name) and \
set(list(fn['params'].keys())) == set(self.__get_fn_params(f_node.params)):
if (fn['fn_lc'][0][0], fn['fn_lc'][1][0]) == fn_lc:
return fn

for fn in fns:
if fn['q_name'] == qn and set(list(fn['params'].keys())) == fn_params:
return fn

def __get_fn_param_type(self, param_name: str):
Expand All @@ -895,11 +907,20 @@ def __get_fn_param_type(self, param_name: str):
fn_param_type = self.__name2annotation(fn_param_type_resolved)
if fn_param_type is not None:
self.all_applied_types.add((fn_param_type_resolved, fn_param_type))
self.no_applied_types += 1
return fn_param_type
else:
self.no_failed_applied_types += 1

def __get_cls(self, cls: cst.ClassDef) -> dict:
cls_lc = self.__get_line_column_no(cls)
cls_qn = self.__get_qualified_name(cls.name)
for c in self.f_processed_dict['classes']:
if c['q_name'] == self.__get_qualified_name(cls.name):
if (c['cls_lc'][0][0], c['cls_lc'][1][0]) == cls_lc:
return c

for c in self.f_processed_dict['classes']:
if c['q_name'] == cls_qn:
return c

def __get_fn_vars(self, var_name: str) -> dict:
Expand All @@ -923,24 +944,27 @@ def __get_cls_vars(self, var_name: str) -> dict:
def __get_mod_vars(self):
return self.f_processed_dict['variables']

def __get_var_type_assign_t(self, var_name: str):
def __get_var_type_assign_t(self, var_name: str, var_node):
t: str = None
var_line_no = self.__get_line_column_no(var_node)
if len(self.cls_visited) != 0:
if len(self.fn_visited) != 0:
# A class method's variable
if self.fn_visited[-1][1][var_name] == self.last_visited_assign_t_count:
if self.fn_visited[-1][0]['fn_var_ln'][var_name][0][0] == var_line_no[0]:
t = self.__get_fn_vars(self.nlp_p(var_name))
else:
# A class variable
if self.cls_visited[-1][1][var_name] == self.last_visited_assign_t_count:
if self.cls_visited[-1][0]["cls_var_ln"][var_name][0][0] == var_line_no[0]:
t = self.__get_cls_vars(self.nlp_p(var_name))
elif len(self.fn_visited) != 0:
# A module function's variable
if self.fn_visited[-1][1][var_name] == self.last_visited_assign_t_count:
#if self.fn_visited[-1][1][var_name] == self.last_visited_assign_t_count:
if self.fn_visited[-1][0]['fn_var_ln'][var_name][0][0] == var_line_no[0]:
t = self.__get_fn_vars(self.nlp_p(var_name))
else:
# A module's variables
t = self.__get_mod_vars()[self.nlp_p(var_name)]
if self.f_processed_dict['mod_var_ln'][var_name][0][0] == var_line_no[0]:
t = self.__get_mod_vars()[self.nlp_p(var_name)]
return t

def __get_var_type_an_assign(self, var_name: str):
Expand All @@ -962,9 +986,17 @@ def __get_var_type_an_assign(self, var_name: str):
def __get_var_names_counter(self, node, scope):
vars_name = match.extractall(node, match.OneOf(match.AssignTarget(target=match.SaveMatchedNode(
match.Name(value=match.DoNotCare()), "name")), match.AnnAssign(target=match.SaveMatchedNode(
match.Name(value=match.DoNotCare()), "name"))))
match.Name(value=match.DoNotCare()), "name"))
))
attr_name = match.extractall(node, match.OneOf(match.AssignTarget(
target=match.SaveMatchedNode(match.Attribute(value=match.Name(value=match.DoNotCare()), attr=
match.Name(value=match.DoNotCare())), "attr")),
match.AnnAssign(target=match.SaveMatchedNode(match.Attribute(value=match.Name(value=match.DoNotCare()), attr=
match.Name(value=match.DoNotCare())), "attr"))))
return Counter([n['name'].value for n in vars_name if isinstance(self.get_metadata(cst.metadata.ScopeProvider,
n['name']), scope)])
n['name']), scope)] +
[n['attr'].attr.value for n in attr_name if isinstance(self.get_metadata(cst.metadata.ScopeProvider,
n['attr']), scope)])

def visit_ClassDef(self, node: cst.ClassDef):
self.cls_visited.append((self.__get_cls(node),
Expand All @@ -986,7 +1018,12 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu
fn_ret_type = self.__name2annotation(fn_ret_type_resolved)
if fn_ret_type is not None:
self.all_applied_types.add((fn_ret_type_resolved, fn_ret_type))
self.no_applied_types += 1
return updated_node.with_changes(returns=fn_ret_type)
else:
self.no_failed_applied_types += 1
else:
return updated_node.with_changes(returns=None)

return updated_node

Expand All @@ -999,34 +1036,62 @@ def leave_Lambda(self, original_node: cst.Lambda, updated_node: cst.Lambda):

def leave_SimpleStatementLine(self, original_node: cst.SimpleStatementLine,
updated_node: cst.SimpleStatementLine):

# Untyped variables
t = None
if match.matches(original_node, match.SimpleStatementLine(body=[match.Assign(targets=[match.AssignTarget(
target=match.DoNotCare())])])):
if match.matches(original_node, match.SimpleStatementLine(body=[match.Assign(targets=[match.AssignTarget(
target=match.Name(value=match.DoNotCare()))])])):
t = self.__get_var_type_assign_t(original_node.body[0].targets[0].target.value)
t = self.__get_var_type_assign_t(original_node.body[0].targets[0].target.value,
original_node.body[0].targets[0].target)
elif match.matches(original_node, match.SimpleStatementLine(body=[match.Assign(targets=[match.AssignTarget(
target=match.Attribute(value=match.Name(value=match.DoNotCare()), attr=match.Name(value=match.DoNotCare())))])])):
t = self.__get_var_type_assign_t(original_node.body[0].targets[0].target.attr.value,
original_node.body[0].targets[0].target)

if t is not None:
t_annot_node_resolved = self.resolve_type_alias(t)
t_annot_node = self.__name2annotation(t_annot_node_resolved)
if t_annot_node is not None:
self.all_applied_types.add((t_annot_node_resolved, t_annot_node))
self.no_applied_types += 1
return updated_node.with_changes(body=[cst.AnnAssign(
target=original_node.body[0].targets[0].target,
value=original_node.body[0].value,
annotation=t_annot_node,
equal=cst.AssignEqual(whitespace_after=original_node.body[0].targets[0].whitespace_after_equal,
whitespace_before=original_node.body[0].targets[0].whitespace_before_equal))]
)
elif match.matches(original_node, match.SimpleStatementLine(body=[match.AnnAssign(target=match.Name(value=match.DoNotCare()))])):
t = self.__get_var_type_an_assign(original_node.body[0].target.value)
if t is not None:
else:
self.no_failed_applied_types += 1
# Typed variables
elif match.matches(original_node, match.SimpleStatementLine(body=[match.AnnAssign(target=match.DoNotCare(),
value=match.MatchIfTrue(lambda v: v is not None))])):
if match.matches(original_node, match.SimpleStatementLine(body=[match.AnnAssign(target=match.Name(value=match.DoNotCare()))])):
t = self.__get_var_type_an_assign(original_node.body[0].target.value)
elif match.matches(original_node, match.SimpleStatementLine(body=[match.AnnAssign(target=match.Attribute(value=match.Name(value=match.DoNotCare()),
attr=match.Name(value=match.DoNotCare())))])):
t = self.__get_var_type_an_assign(original_node.body[0].target.attr.value)
if t:
t_annot_node_resolved = self.resolve_type_alias(t)
t_annot_node = self.__name2annotation(t_annot_node_resolved)
if t_annot_node is not None:
self.all_applied_types.add((t_annot_node_resolved, t_annot_node))
self.no_applied_types += 1
return updated_node.with_changes(body=[cst.AnnAssign(
target=original_node.body[0].target,
value=original_node.body[0].value,
annotation=t_annot_node,
equal=original_node.body[0].equal)])
else:
self.no_failed_applied_types += 1
else:
return updated_node.with_changes(body=[cst.Assign(targets=[cst.AssignTarget(target=original_node.body[0].target,
whitespace_before_equal=original_node.body[0].equal.whitespace_before,
whitespace_after_equal=original_node.body[0].equal.whitespace_after)],
value=original_node.body[0].value)])


return original_node

Expand All @@ -1035,6 +1100,8 @@ def leave_Param(self, original_node: cst.Param, updated_node: cst.Param):
fn_param_type = self.__get_fn_param_type(original_node.name.value)
if fn_param_type is not None:
return updated_node.with_changes(annotation=fn_param_type)
else:
return updated_node.with_changes(annotation=None)

return original_node

Expand All @@ -1051,20 +1118,29 @@ def visit_AssignTarget(self, node: cst.AssignTarget):
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module):
return updated_node.with_changes(body=self.__get_required_imports() + list(updated_node.body))

def visit_ImportAlias(self, node: cst.ImportAlias):
self.imported_names.extend([n.value for n in match.findall(node.name, match.Name(value=match.DoNotCare()))])

# TODO: Check the imported modules before adding new ones
def __get_required_imports(self):
def find_required_modules(all_types):
def find_required_modules(all_types, imported_names):
req_mod = set()
for _, a_node in all_types:
m = match.findall(a_node.annotation, match.Attribute(value=match.DoNotCare(), attr=match.DoNotCare()))
if len(m) != 0:
for i in m:
req_mod.add([n.value for n in match.findall(i, match.Name(value=match.DoNotCare()))][0])
mod_imp = [n.value for n in match.findall(i, match.Name(value=match.DoNotCare()))][0]
if mod_imp not in imported_names:
req_mod.add(mod_imp)
# if n.value not in imported_names
print(req_mod)
return req_mod

req_imports = []
all_req_mods = find_required_modules(self.all_applied_types)
self.imported_names = set(self.imported_names)
all_req_mods = find_required_modules(self.all_applied_types, self.imported_names)
all_type_names = set(chain.from_iterable(map(lambda t: regex.findall(r"\w+", t[0]), self.all_applied_types)))
all_type_names = all_type_names - self.imported_names

typing_imports = PY_TYPING_MOD & all_type_names
collection_imports = PY_COLLECTION_MOD & all_type_names
Expand Down Expand Up @@ -1098,6 +1174,10 @@ def __get_qualified_name(self, node) -> Optional[str]:
q_name = list(self.get_metadata(cst.metadata.QualifiedNameProvider, node))
return q_name[0].name if len(q_name) != 0 else None

def __get_line_column_no(self, node) -> Tuple[int, int]:
lc = self.get_metadata(cst.metadata.PositionProvider, node)
return lc.start.line, lc.end.line

def resolve_type_alias(self, t: str):
type_aliases = {'^{}$|^Dict$|(?<=.*)Dict\[\](?<=.*)|(?<=.*)Dict\[Any, *?Any\](?=.*)|^Dict\[unknown, *Any\]$': 'dict',
'^Set$|(?<=.*)Set\[\](?<=.*)|^Set\[Any\]$': 'set',
Expand Down
2 changes: 1 addition & 1 deletion libsa4py/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,6 @@ def merge_projects(args):
"""
Saves merged projects into a single JSON file and a Dataframe
"""
merged_jsons = merge_jsons_to_dict(list_files(join(args.o, 'processed_projects'), ".json"), args.l)
merged_jsons = merge_jsons_to_dict(list_files(join(args.o, 'processed_projects'), ".json")[0], args.l)
save_json(join(args.o, 'merged_%s_projects.json' % (str(args.l) if args.l is not None else 'all')), merged_jsons)
create_dataframe_fns(args.o, merged_jsons)
Loading