diff --git a/libsa4py/__main__.py b/libsa4py/__main__.py index fc891cf..4bcd1e7 100644 --- a/libsa4py/__main__.py +++ b/libsa4py/__main__.py @@ -1,21 +1,27 @@ from argparse import ArgumentParser from multiprocessing import cpu_count from libsa4py.utils import find_repos_list -from libsa4py.cst_pipeline import Pipeline, TypeAnnotatingProjects +from libsa4py.cst_pipeline import Pipeline, TypeAnnotatingProjects, TypeAnnotationsRemoval from libsa4py.merge import merge_projects def process_projects(args): input_repos = find_repos_list(args.p) if args.l is None else find_repos_list(args.p)[:args.l] - p = Pipeline(args.p, args.o, not args.no_nlp, args.use_cache, args.use_pyre, args.use_tc, args.d, args.s) + p = Pipeline(args.p, args.o, not args.no_nlp, args.use_cache, args.use_pyre, args.use_tc, args.d, + args.s, args.i) p.run(input_repos, args.j) def apply_types_projects(args): - tap = TypeAnnotatingProjects(args.p, args.o) + tap = TypeAnnotatingProjects(args.p, args.o, args.dry_run) tap.run(args.j) +def remove_err_type_annotations(args): + tar = TypeAnnotationsRemoval(args.i, args.o, args.p, args.l, args.dry_run) + tar.run(args.j) + + def main(): arg_parser = ArgumentParser(description="Light-weight static analysis to extract Python's code representations") @@ -26,6 +32,7 @@ def main(): process_parser.add_argument("--o", required=True, type=str, help="Path to store JSON-based processed projects") process_parser.add_argument("--d", "--deduplicate", required=False, type=str, help="Path to duplicate files") process_parser.add_argument("--s", "--split", required=False, type=str, help="Path to the dataset split files") + process_parser.add_argument("--i", "--ignore", required=False, type=str, help="Path to the ignored files") process_parser.add_argument("--j", default=cpu_count(), type=int, help="Number of workers for processing projects") process_parser.add_argument("--l", required=False, type=int, help="Number of projects to process") process_parser.add_argument("--c", "--cache", dest='use_cache', action='store_true', help="Whether to ignore processed projects") @@ -51,8 +58,23 @@ def main(): apply_parser.add_argument("--p", required=True, type=str, help="Path to Python projects") apply_parser.add_argument("--o", required=True, type=str, help="Path to store JSON-based processed projects") apply_parser.add_argument("--j", default=cpu_count(), type=int, help="Number of workers for processing projects") + apply_parser.add_argument("--d", dest='dry_run', action='store_true', + help="Dry run does not apply types to the dataset's files") + + apply_parser.set_defaults(dry_run=False) apply_parser.set_defaults(func=apply_types_projects) + remove_parser = sub_parsers.add_parser('remove') + remove_parser.add_argument("--i", required=True, type=str, help="Path to input dataset") + remove_parser.add_argument("--o", required=True, type=str, help="Path to output dataset") + remove_parser.add_argument("--p", required=True, type=str, help="Path to JSON-formatted processed projects") + remove_parser.add_argument("--j", default=cpu_count(), type=int, help="Number of workers for processing files") + remove_parser.add_argument("--l", required=False, type=int, help="Number of projects to process") + remove_parser.add_argument("--d", dest='dry_run', action='store_true', + help="Dry run does not remove types from the dataset's files") + remove_parser.set_defaults(dry_run=False) + remove_parser.set_defaults(func=remove_err_type_annotations) + args = arg_parser.parse_args() args.func(args) diff --git a/libsa4py/cst_pipeline.py b/libsa4py/cst_pipeline.py index 4cbe4c3..fd8d979 100644 --- a/libsa4py/cst_pipeline.py +++ b/libsa4py/cst_pipeline.py @@ -1,20 +1,26 @@ +from libsa4py.cst_visitor import TypeAnnotationCounter import os import traceback import random import csv import time +import queue -from typing import List, Dict +from typing import List, Dict, Tuple from os.path import join +from tempfile import NamedTemporaryFile from pathlib import Path from datetime import timedelta from joblib import delayed +from multiprocessing import Manager, Process, Queue, managers +from multiprocessing.queues import Queue from dpu_utils.utils.dataloading import load_jsonl_gz from libsa4py.cst_extractor import Extractor -from libsa4py.cst_transformers import TypeApplier +from libsa4py.cst_transformers import TypeAnnotationRemover, TypeApplier from libsa4py.exceptions import ParseError, NullProjectException from libsa4py.nl_preprocessing import NLPreprocessor -from libsa4py.utils import read_file, list_files, ParallelExecutor, mk_dir_not_exist, save_json, load_json, write_file +from libsa4py.utils import read_file, list_files, ParallelExecutor, mk_dir_not_exist, save_json, load_json, write_file, \ + create_tmp_file, write_to_tmp_file, delete_tmp_file, mk_dir_cp_file from libsa4py.pyre import pyre_server_init, pyre_query_types, pyre_server_shutdown, pyre_kill_all_servers, \ clean_pyre_config from libsa4py.type_check import MypyManager, type_check_single_file @@ -32,7 +38,7 @@ class Pipeline: def __init__(self, projects_path, output_dir, nlp_transf: bool = True, use_cache: bool = True, use_pyre: bool = False, use_tc: bool = False, - dups_files_path=None, split_files_path=None): + dups_files_path=None, split_files_path=None, ignored_files_path=None): self.projects_path = projects_path self.output_dir = output_dir self.processed_projects = None @@ -54,10 +60,16 @@ def __init__(self, projects_path, output_dir, nlp_transf: bool = True, else: self.is_file_duplicate = lambda x: False + if ignored_files_path is not None: + self.ignored_files = set(read_file(ignored_files_path).splitlines()) + else: + self.ignored_files = {} + if self.use_tc: self.tc = MypyManager('mypy', MAX_TC_TIME) - self.split_dataset_files = {f:s for s, f in csv.reader(open(split_files_path, 'r'))} if split_files_path is not None else {} + self.split_dataset_files = {f: s for s, f in + csv.reader(open(split_files_path, 'r'))} if split_files_path is not None else {} # TODO: Fix the logger issue not outputing the logs into the file. # logging.basicConfig(filename=join(self.err_log_dir, "pipeline_errors.log"), level=logging.DEBUG, @@ -81,17 +93,17 @@ def __setup_pipeline_logger(self, log_dir: str): logger_ch = logging.StreamHandler() logger_ch.setLevel(logging.DEBUG) - + logger_fh = logging.FileHandler(filename=log_dir) logger_fh.setLevel(logging.DEBUG) - + logger_formatter = logging.Formatter(fmt='%(asctime)s - %(name)s - %(message)s') logger_ch.setFormatter(logger_formatter) logger_fh.setFormatter(logger_formatter) logger.addHandler(logger_ch) logger.addHandler(logger_fh) - + return logger def get_project_filename(self, project) -> str: @@ -123,21 +135,24 @@ def fn_nlp_transf(fn_d: dict, nlp_prep: NLPreprocessor): fn_d['docstring']['long_descr'] = nlp_prep.process_sentence(fn_d['docstring']['long_descr']) return fn_d - extracted_module['variables'] = {self.nlp_prep.process_identifier(v): t for v, t in extracted_module['variables'].items()} + extracted_module['variables'] = {self.nlp_prep.process_identifier(v): t for v, t in + extracted_module['variables'].items()} extracted_module['mod_var_occur'] = {v: [self.nlp_prep.process_sentence(j) for i in o for j in i] for v, - o in extracted_module['mod_var_occur'].items()} + o in + extracted_module['mod_var_occur'].items()} for c in extracted_module['classes']: c['variables'] = {self.nlp_prep.process_identifier(v): t for v, t in c['variables'].items()} c['cls_var_occur'] = {v: [self.nlp_prep.process_sentence(j) for i in o for j in i] for v, - o in c['cls_var_occur'].items()} + o in + c['cls_var_occur'].items()} c['funcs'] = [fn_nlp_transf(f, self.nlp_prep) for f in c['funcs']] extracted_module['funcs'] = [fn_nlp_transf(f, self.nlp_prep) for f in extracted_module['funcs']] return extracted_module - def process_project(self, i, project): + def process_project(self, i, project, project_files: List[str]): project_id = f'{project["author"]}/{project["repo"]}' project_analyzed_files: dict = {project_id: {"src_files": {}, "type_annot_cove": 0.0}} @@ -148,14 +163,16 @@ def process_project(self, i, project): print(f'Extracting for {project_id}...') extracted_avl_types = None - project_files = list_files(join(self.projects_path, project["author"], project["repo"])) print(f"{project_id} has {len(project_files)} files before deduplication") project_files = [f for f in project_files if not self.is_file_duplicate(f)] print(f"{project_id} has {len(project_files)} files after deduplication") + project_files = [f for f in project_files if str(Path(f).relative_to(Path(self.projects_path).parent)) not in self.ignored_files] + print(f"{project_id} has {len(project_files)} files after ignoring files") project_files = [(f, str(Path(f).relative_to(Path(self.projects_path).parent))) for f in project_files] project_files = [(f, f_r, self.split_dataset_files[f_r] if f_r in self.split_dataset_files else None) for f, - f_r in project_files] + f_r + in project_files] if len(project_files) != 0: if self.use_pyre: @@ -193,10 +210,10 @@ def process_project(self, i, project): # fail the entire project processing. # TODO: A better workaround would be to have a specialized exception thrown # by the extractor, so that this exception is specialized. - #print(f"Could not process file {filename}") + # print(f"Could not process file {filename}") traceback.print_exc() self.logger.error("project: %s |file: %s |Exception: %s" % (project_id, filename, err)) - #logging.error("project: %s |file: %s |Exception: %s" % (project_id, filename, err)) + # logging.error("project: %s |file: %s |Exception: %s" % (project_id, filename, err)) print(f'Saving available type hints for {project_id}...') if self.avl_types_dir is not None: @@ -233,12 +250,15 @@ def process_project(self, i, project): def run(self, repos_list: List[Dict], jobs, start=0): print(f"Number of projects to be processed: {len(repos_list)}") - repos_list = [p for p in repos_list if not (os.path.exists(self.get_project_filename(p)) and self.use_cache)] + repos_list = [(p, *list_files(join(self.projects_path, p["author"], p["repo"]))) \ + for p in repos_list if not (os.path.exists(self.get_project_filename(p)) and self.use_cache)] + # Sorts projects based on total size of their files + repos_list.sort(key=lambda x: x[2], reverse=True) print(f"Number of projects to be processed after considering cache: {len(repos_list)}") start_t = time.time() ParallelExecutor(n_jobs=jobs)(total=len(repos_list))( - delayed(self.process_project)(i, project) for i, project in enumerate(repos_list, start=start)) + delayed(self.process_project)(i, p, p_files) for i, (p, p_files, p_size) in enumerate(repos_list, start=start)) print("Finished processing %d projects in %s " % (len(repos_list), str(timedelta(seconds=time.time()-start_t)))) if self.use_pyre: @@ -251,32 +271,404 @@ class TypeAnnotatingProjects: It applies the inferred type annotations to the input dataset """ - def __init__(self, projects_path: str, output_path: str, apply_nlp: bool = True): + def __init__(self, projects_path: str, output_path: str, dry_run: bool = False, + apply_nlp: bool = True): self.projects_path = projects_path self.output_path = output_path + self.dry_run = dry_run self.apply_nlp = apply_nlp def process_project(self, proj_json_path: str): proj_json = load_json(proj_json_path) + total_added_types = 0 + total_no_types = 0 for p in proj_json.keys(): for i, (f, f_d) in enumerate(proj_json[p]['src_files'].items()): - f_read = read_file(join(self.projects_path, f)) - if len(f_read) != 0: + print(f"Adding types to file {f} from project {proj_json_path}") + total_no_types += f_d['no_types_annot']['I'] + f_d['no_types_annot']['D'] + if f_d['no_types_annot']['I'] + f_d['no_types_annot']['D'] > 0: + f_read = read_file(join(self.projects_path, f)) try: f_parsed = cst.parse_module(f_read) try: - f_parsed = cst.metadata.MetadataWrapper(f_parsed).visit(TypeApplier(f_d, self.apply_nlp)) - write_file(join(self.projects_path, f), f_parsed.code) + ta = TypeApplier(f_d, self.apply_nlp) + f_parsed = cst.metadata.MetadataWrapper(f_parsed).visit(ta) + if not self.dry_run: + write_file(join(self.projects_path, f), f_parsed.code) + total_added_types += ta.no_applied_types + print(f"Applied {ta.no_applied_types} types to file {f} from project {proj_json_path}") + assert f_d['no_types_annot']['I'] + f_d['no_types_annot']['D'] <= self.__get_no_applied_types(f_parsed.code) + ta.no_failed_applied_types except KeyError as ke: print(f"A variable not found | project {proj_json_path} | file {f}", ke) traceback.print_exc() except TypeError as te: print(f"Project {proj_json_path} | file {f}", te) traceback.print_exc() + except AssertionError as te: + print(f"[AssertionError] Project {proj_json_path} | file {f}", te) except cst._exceptions.ParserSyntaxError as pse: print(f"Can't parsed file {f} in project {proj_json_path}", pse) + return total_added_types, total_no_types + def run(self, jobs: int): - proj_jsons = list_files(join(self.output_path, 'processed_projects'), '.json') + proj_jsons, _ = list_files(join(self.output_path, 'processed_projects'), '.json') proj_jsons.sort(key=lambda f: os.stat(f).st_size, reverse=True) - ParallelExecutor(n_jobs=jobs)(total=len(proj_jsons))(delayed(self.process_project)(p_j) for p_j in proj_jsons) + start_t = time.time() + proj_type_added = ParallelExecutor(n_jobs=jobs)(total=len(proj_jsons))(delayed(self.process_project)(p_j) \ + for p_j in proj_jsons) + print(f"Finished applying types in {str(timedelta(seconds=time.time() - start_t))}") + print(f"{sum([a for a, t in proj_type_added]):,}/{sum([t for a, t in proj_type_added]):,} types applied to the whole dataset") + + def __get_no_applied_types(self, code: str) -> int: + f_applied_p = cst.parse_module(code) + tac = TypeAnnotationCounter() + f_applied_p.visit(tac) + return tac.total_no_type_annot + +class TypeAnnotationsRemoval: + """ + Removes type annotations that cannot be type-checked by mypy + """ + + MAX_TYPE_ERRORS_PER_FILE = 500 + + def __init__(self, input_projects_path: str, output_projects_path: str, processed_projects_path: str, no_projects_limit: int = None, + dry_run: bool = False, apply_nlp: bool = True): + self.input_projects_path = input_projects_path + self.processed_projects_path = processed_projects_path + self.output_projects_path = output_projects_path + self.no_projects_limit = no_projects_limit + self.dry_run = dry_run + self.apply_nlp = apply_nlp + + #def process_file(self, f: str, f_d_repr: dict, tc_res: dict): + def process_file(self, q: Queue, is_f_loader_done, tc_res: dict, ignored_files: list): + # TODO: The initial type-checking should not be done after adding no. type errors to the representation later on. + # init_tc, init_no_tc_err = type_check_single_file(join(self.projects_path, f), + # MypyManager('mypy', MAX_TC_TIME)) + + # if init_tc == False and init_no_tc_err is None: + # return + # else: + # Only files with type annotations + while not is_f_loader_done.value or q.qsize() != 0: + try: + f, f_d_repr = q.get(True, 1) + if f_d_repr['no_types_annot']['I'] + f_d_repr['no_types_annot']['D'] > 0: + try: + #tmp_f = create_tmp_file(".py") + f_read = read_file(join(self.output_projects_path, f)) + _, tc_errs, type_annot_r, tc_errors = self.remove_unchecked_type_annot(join(self.output_projects_path, f), + f_read, f_d_repr, f_d_repr['tc'][1]) + print(f"F: {f} | init_tc_errors: {f_d_repr['tc'][1]} | tc_errors: {tc_errs} | ta_r: {type_annot_r} | \ + total_ta: {f_d_repr['no_types_annot']['I'] + f_d_repr['no_types_annot']['D']} | Queue size: {q.qsize()}") + tc_res[f] = {"init_tc_errs": f_d_repr['tc'][1], "curr_tc_errs": tc_errs, "ta_rem": type_annot_r, + "total_ta": f_d_repr["no_types_annot"]['I'] + f_d_repr["no_types_annot"]['D'], + "errors": tc_errors} + # Path(join(self.output_path, Path(f).parent)).mkdir(parents=True, exist_ok=True) + if tc_errs == 0: + if self.dry_run: + write_file(join(self.output_projects_path, f), f_read) + else: + write_file(join(self.output_projects_path, f), f_read) + ignored_files.append(f) + except Exception as e: + print(f"F: {f} | e: {e}") + traceback.print_exc() + # finally: + # delete_tmp_file(tmp_f) + else: + print(f"F: {f} | init_tc_errors: {f_d_repr['tc'][1]} | total_ta: {f_d_repr['no_types_annot']['I'] + f_d_repr['no_types_annot']['D']} | Queue size: {q.qsize()}") + tc_res[f] = {"init_tc_errs": f_d_repr['tc'][1], "curr_tc_errs": f_d_repr['tc'][1], "ta_rem": None, + "total_ta": f_d_repr["no_types_annot"]['I'] + f_d_repr["no_types_annot"]['D'], + "errors": None} + ignored_files.append(f) + except queue.Empty as e: + print(f"Worker {os.getpid()} finished! Queue's empty!") + print(f"File loader working {is_f_loader_done.value} and queue size {q.qsize()}") + + def run(self, jobs: int): + manager = Manager() + q = manager.Queue() + is_f_loader_done = manager.Value('i', False) + ignored_files_a = manager.list() + type_checked_files = manager.list() + + file_loader = Process(target=self.__load_projects_files, args=(q, is_f_loader_done, ignored_files_a, + type_checked_files)) + file_loader.start() + #file_loader.join() + + print("File loader started!") + + # merged_projects = load_json(join(self.processed_projects_path, "merged_all_projects.json")) + # not_tced_src_f: List[Tuple[str, dict]] = [] + # for p, p_v in list(merged_projects['projects'].items()): + # for f, f_v in p_v['src_files'].items(): + # if not f_v['tc'][0] and f_v['tc'] != [False, None]: + # not_tced_src_f.append((f, f_v)) + + # del merged_projects + # # not_tced_src_f = not_tced_src_f[:250] + # # print("L:", len(not_tced_src_f)) + # manager = Manager() + time.sleep(5) + start_t = time.time() + tc_res = manager.dict() + ignored_files_b = manager.list() + file_processors = [] + for j in range(jobs): + p = Process(target=self.process_file, args=(q, is_f_loader_done, tc_res, ignored_files_b)) + p.daemon = True + file_processors.append(p) + p.start() + + for p in file_processors: + p.join() + file_loader.join() + # ParallelExecutor(n_jobs=jobs)(total=0)(delayed(self.process_file)(f, f_d, tc_res) \ + # for f, f_d in not_tced_src_f) + print(f"Finished fixing invalid types in {str(timedelta(seconds=time.time() - start_t))}") + save_json(join(self.processed_projects_path, "tc_ta_results_new.json"), tc_res.copy()) + write_file(join(self.processed_projects_path, 'ignored_files.txt'), '\n'.join(list(ignored_files_a) + list(ignored_files_b))) + write_file(join(self.processed_projects_path, 'tced_files.txt'), '\n'.join(list(type_checked_files))) + + def __load_projects_files(self, q: Queue, is_done, ignored_files: list, type_checked_files: list): + proj_jsons, _ = list_files(join(self.processed_projects_path, 'processed_projects'), '.json') + proj_jsons = proj_jsons[:self.no_projects_limit] if self.no_projects_limit is not None else proj_jsons + f_loaded = 0 + for p_j in proj_jsons: + proj_json = load_json(p_j) + for _, p_v in proj_json.items(): + for f, f_v in p_v['src_files'].items(): + if not f_v['tc'][0]: + if f_v['tc'] != [False, None, None]: + if f_v['tc'][1] <= TypeAnnotationsRemoval.MAX_TYPE_ERRORS_PER_FILE: + mk_dir_cp_file(join('/home/amir/data/MT4Py-pyre-apply', f), join(self.output_projects_path, f)) + q.put((f, f_v)) + f_loaded += 1 + print(f"Added file {f} to the analysis queue") + else: + ignored_files.append(f) + else: + ignored_files.append(f) + else: + type_checked_files.append(f) + + #print("Adding files to Queue...") + is_done.value = True + print(f"Loaded {f_loaded} Python files") + + for f in type_checked_files: + mk_dir_cp_file(join(self.input_projects_path, f), join(self.output_projects_path, f)) + print(f"Copied type-checked file: {f}") + + def remove_unchecked_type_annot(self, f_path: str, f_read: str, f_d_repr: dict, + init_no_tc_err: int) -> Tuple[str, int, List[str]]: + + type_annots_removed: List[str] = [] + no_try = 0 + MAX_TRY = 10 + + def type_check_ta(curr_no_tc_err: int, org_gt): + tc, no_tc_err, f_code, tc_errors = self.__type_check_type_annotation(f_path, f_read, f_d_repr) + nonlocal no_try + if no_tc_err is not None: + if tc: + type_annots_removed.append(org_gt) + elif no_tc_err < curr_no_tc_err: + curr_no_tc_err = no_tc_err + type_annots_removed.append(org_gt) + else: + no_try += 1 + else: + no_try += 1 + + return tc, no_tc_err, f_code, tc_errors + + out_f_code: str = "" + tc_errors = None + for m_v, m_v_t in f_d_repr['variables'].items(): + if m_v_t != "": + print(f"Type-checking module-level variable {m_v} with annotation {m_v_t}") + f_d_repr['variables'][m_v] = "" + # tc, no_tc_err, f_code = self.__type_check_type_annotation(f_read, f_d_repr, f_out_temp) + # if tc: + # type_annots_removed.append(m_v_t) + # return f_code, no_tc_err, type_annots_removed + # elif no_tc_err < init_no_tc_err: + # out_f_code = f_code + # init_no_tc_err = no_tc_err + # type_annots_removed.append(m_v_t) + # elif no_tc_err == init_no_tc_err: + # f_d_repr['variables'][m_v] = m_v_t + tc, no_tc_err, out_f_code, tc_errors = type_check_ta(init_no_tc_err, m_v_t) + if tc or no_try > MAX_TRY: + return out_f_code, no_tc_err, type_annots_removed, tc_errors + else: + f_d_repr['variables'][m_v] = m_v_t + + for i, fn in enumerate(f_d_repr['funcs']): + for p_n, p_t in fn['params'].items(): + if p_t != "": + print(f"Type-checking function parameter {p_n} with annotation {p_t}") + f_d_repr['funcs'][i]['params'][p_n] = "" + # tc, no_tc_err, f_code = self.__type_check_type_annotation(f_read, f_d_repr, f_out_temp) + # if tc: + # type_annots_removed.append(p_t) + # return f_code, no_tc_err, type_annots_removed + # elif no_tc_err < init_no_tc_err: + # out_f_code = f_code + # init_no_tc_err = no_tc_err + # type_annots_removed.append(p_t) + # elif no_tc_err == init_no_tc_err: + # f_d_repr['funcs'][i]['params'][p_n] = p_t + tc, no_tc_err, out_f_code, tc_errors = type_check_ta(init_no_tc_err, p_t) + if tc or no_try > MAX_TRY: + return out_f_code, no_tc_err, type_annots_removed, tc_errors + else: + f_d_repr['funcs'][i]['params'][p_n] = p_t + + for fn_v, fn_v_t in fn['variables'].items(): + if fn_v_t != "": + print(f"Type-checking function variable {fn_v} with annotation {fn_v_t}") + f_d_repr['funcs'][i]['variables'][fn_v] = "" + # tc, no_tc_err, f_code = self.__type_check_type_annotation(f_read, f_d_repr, f_out_temp) + # if tc: + # type_annots_removed.append(fn_v_t) + # return f_code, no_tc_err, type_annots_removed + # elif no_tc_err < init_no_tc_err: + # out_f_code = f_code + # init_no_tc_err = no_tc_err + # type_annots_removed.append(fn_v_t) + # elif no_tc_err == init_no_tc_err: + # f_d_repr['funcs'][i]['variables'][fn_v] = fn_v_t + tc, no_tc_err, out_f_code, tc_errors = type_check_ta(init_no_tc_err, fn_v_t) + if tc or no_try > MAX_TRY: + return out_f_code, no_tc_err, type_annots_removed, tc_errors + else: + f_d_repr['funcs'][i]['variables'][fn_v] = fn_v_t + + # The return type for module-level functions + if f_d_repr['funcs'][i]['ret_type'] != "": + org_t = f_d_repr['funcs'][i]['ret_type'] + print(f"Type-checking function {f_d_repr['funcs'][i]['name']} return with {org_t}") + f_d_repr['funcs'][i]['ret_type'] = "" + # tc, no_tc_err, f_code = self.__type_check_type_annotation(f_read, f_d_repr, f_out_temp) + # if tc: + # type_annots_removed.append(org_t) + # return f_code, no_tc_err, type_annots_removed + # elif no_tc_err < init_no_tc_err: + # out_f_code = f_code + # init_no_tc_err = no_tc_err + # type_annots_removed.append(org_t) + # elif no_tc_err == init_no_tc_err: + # f_d_repr['funcs'][i]['ret_type'] = org_t + tc, no_tc_err, out_f_code, tc_errors = type_check_ta(init_no_tc_err, org_t) + if tc or no_try > MAX_TRY: + return out_f_code, no_tc_err, type_annots_removed, tc_errors + else: + f_d_repr['funcs'][i]['ret_type'] = org_t + + # The type of class-level vars + for c_i, c in enumerate(f_d_repr['classes']): + for c_v, c_v_t in c['variables'].items(): + if c_v_t != "": + print(f"Type checking class variable {c_v} with annotation {c_v_t}") + f_d_repr['classes'][c_i]['variables'][c_v] = "" + # tc, no_tc_err, f_code = self.__type_check_type_annotation(f_read, f_d_repr, f_out_temp) + # if tc: + # type_annots_removed.append(c_v_t) + # return f_code, no_tc_err, type_annots_removed + # elif no_tc_err < init_no_tc_err: + # out_f_code = f_code + # init_no_tc_err = no_tc_err + # type_annots_removed.append(c_v_t) + # elif no_tc_err == init_no_tc_err: + # f_d_repr['classes'][c_i]['variables'][c_v] = c_v_t + tc, no_tc_err, out_f_code, tc_errors = type_check_ta(init_no_tc_err, c_v_t) + if tc or no_try > MAX_TRY: + return out_f_code, no_tc_err, type_annots_removed, tc_errors + else: + f_d_repr['classes'][c_i]['variables'][c_v] = c_v_t + + # The type of arguments for class-level functions + for fn_i, fn in enumerate(c['funcs']): + for p_n, p_t in fn["params"].items(): + if p_t != "": + print(f"Type-checking function parameter {p_n} with annotation {p_t}") + f_d_repr['classes'][c_i]['funcs'][fn_i]['params'][p_n] = "" + # tc, no_tc_err, f_code = self.__type_check_type_annotation(f_read, f_d_repr, f_out_temp) + # if tc: + # type_annots_removed.append(p_t) + # return f_code, no_tc_err, type_annots_removed + # elif no_tc_err < init_no_tc_err: + # out_f_code = f_code + # init_no_tc_err = no_tc_err + # type_annots_removed.append(p_t) + # elif no_tc_err == init_no_tc_err: + # f_d_repr['classes'][c_i]['funcs'][fn_i]['params'][p_n] = p_t + tc, no_tc_err, out_f_code, tc_errors = type_check_ta(init_no_tc_err, p_t) + if tc or no_try > MAX_TRY: + return out_f_code, no_tc_err, type_annots_removed, tc_errors + else: + f_d_repr['classes'][c_i]['funcs'][fn_i]['params'][p_n] = p_t + + # The type of local variables for class-level functions + for fn_v, fn_v_t in fn['variables'].items(): + if fn_v_t != "": + print(f"Type-checking function variable {fn_v} with annotation {fn_v_t}") + f_d_repr['classes'][c_i]['funcs'][fn_i]['variables'][fn_v] = "" + # tc, no_tc_err, f_code = self.__type_check_type_annotation(f_read, f_d_repr, f_out_temp) + # if tc: + # type_annots_removed.append(fn_v_t) + # return f_code, no_tc_err, type_annots_removed + # elif no_tc_err < init_no_tc_err: + # out_f_code = f_code + # init_no_tc_err = no_tc_err + # type_annots_removed.append(fn_v_t) + # elif no_tc_err == init_no_tc_err: + # f_d_repr['classes'][c_i]['funcs'][fn_i]['variables'][fn_v] = fn_v_t + tc, no_tc_err, out_f_code, tc_errors = type_check_ta(init_no_tc_err, fn_v_t) + if tc or no_try > MAX_TRY: + return out_f_code, no_tc_err, type_annots_removed, tc_errors + else: + f_d_repr['classes'][c_i]['funcs'][fn_i]['variables'][fn_v] = fn_v_t + + # The return type for class-level functions + if f_d_repr['classes'][c_i]['funcs'][fn_i]['ret_type'] != "": + org_t = f_d_repr['classes'][c_i]['funcs'][fn_i]['ret_type'] + print( + f"Annotating function {f_d_repr['classes'][c_i]['funcs'][fn_i]['name']} return with type {org_t}") + f_d_repr['classes'][c_i]['funcs'][fn_i]['ret_type'] = "" + # tc, no_tc_err, f_code = self.__type_check_type_annotation(f_read, f_d_repr, f_out_temp) + # if tc: + # type_annots_removed.append(org_t) + # return f_code, no_tc_err, type_annots_removed + # elif no_tc_err < init_no_tc_err: + # out_f_code = f_code + # init_no_tc_err = no_tc_err + # type_annots_removed.append(org_t) + # elif no_tc_err == init_no_tc_err: + # f_d_repr['classes'][c_i]['funcs'][fn_i]['ret_type'] = org_t + tc, no_tc_err, out_f_code, tc_errors = type_check_ta(init_no_tc_err, org_t) + if tc or no_try > MAX_TRY: + return out_f_code, no_tc_err, type_annots_removed, tc_errors + else: + f_d_repr['classes'][c_i]['funcs'][fn_i]['ret_type'] = org_t + + return out_f_code, init_no_tc_err, type_annots_removed, tc_errors + + def __type_check_type_annotation(self, f_path: str, f_read: str, f_d_repr: dict): + f_t_applied = cst.metadata.MetadataWrapper(cst.parse_module(f_read)).visit(TypeApplier(f_d_repr, + apply_nlp=self.apply_nlp)) + + # Writing applied code to temp files has an advantage which isolates the file and as a result, + # type-checking may be successful for some failed cases with the original file + # tmp_f = create_tmp_file(".py") + # write_to_tmp_file(tmp_f, f_t_applied.code) + write_file(f_path, f_t_applied.code) + tc, no_tc_err, tc_errors = type_check_single_file(f_path, MypyManager('mypy', MAX_TC_TIME)) + #delete_tmp_file(tmp_f) + return tc, no_tc_err, f_t_applied.code, tc_errors diff --git a/libsa4py/cst_transformers.py b/libsa4py/cst_transformers.py index f115b91..af078a7 100644 --- a/libsa4py/cst_transformers.py +++ b/libsa4py/cst_transformers.py @@ -853,13 +853,15 @@ def leave_SubscriptElement(self, original_node, updated_node): return updated_node +# TODO: Write two separate CSTTransformers for applying and removing type annotations class TypeApplier(cst.CSTTransformer): """ It applies (inferred) type annotations to a source code file. Specifically, it applies the type of arguments, return types, and variables' type. """ - METADATA_DEPENDENCIES = (cst.metadata.ScopeProvider, cst.metadata.QualifiedNameProvider) + METADATA_DEPENDENCIES = (cst.metadata.ScopeProvider, cst.metadata.QualifiedNameProvider, + cst.metadata.PositionProvider) def __init__(self, f_processeed_dict: dict, apply_nlp: bool=True): self.f_processed_dict = f_processeed_dict @@ -871,6 +873,10 @@ def __init__(self, f_processeed_dict: dict, apply_nlp: bool=True): self.lambda_d = 0 self.all_applied_types = set() + self.no_applied_types = 0 + self.no_failed_applied_types = 0 + + self.imported_names: List[str] = [] if apply_nlp: self.nlp_p = NLPreprocessor().process_identifier @@ -883,9 +889,15 @@ def __get_fn(self, f_node: cst.FunctionDef) -> dict: else: fns = self.f_processed_dict['funcs'] + qn = self.__get_qualified_name(f_node.name) + fn_params = set(self.__get_fn_params(f_node.params)) + fn_lc = self.__get_line_column_no(f_node) for fn in fns: - if fn['q_name'] == self.__get_qualified_name(f_node.name) and \ - set(list(fn['params'].keys())) == set(self.__get_fn_params(f_node.params)): + if (fn['fn_lc'][0][0], fn['fn_lc'][1][0]) == fn_lc: + return fn + + for fn in fns: + if fn['q_name'] == qn and set(list(fn['params'].keys())) == fn_params: return fn def __get_fn_param_type(self, param_name: str): @@ -895,11 +907,20 @@ def __get_fn_param_type(self, param_name: str): fn_param_type = self.__name2annotation(fn_param_type_resolved) if fn_param_type is not None: self.all_applied_types.add((fn_param_type_resolved, fn_param_type)) + self.no_applied_types += 1 return fn_param_type + else: + self.no_failed_applied_types += 1 def __get_cls(self, cls: cst.ClassDef) -> dict: + cls_lc = self.__get_line_column_no(cls) + cls_qn = self.__get_qualified_name(cls.name) for c in self.f_processed_dict['classes']: - if c['q_name'] == self.__get_qualified_name(cls.name): + if (c['cls_lc'][0][0], c['cls_lc'][1][0]) == cls_lc: + return c + + for c in self.f_processed_dict['classes']: + if c['q_name'] == cls_qn: return c def __get_fn_vars(self, var_name: str) -> dict: @@ -923,24 +944,27 @@ def __get_cls_vars(self, var_name: str) -> dict: def __get_mod_vars(self): return self.f_processed_dict['variables'] - def __get_var_type_assign_t(self, var_name: str): + def __get_var_type_assign_t(self, var_name: str, var_node): t: str = None + var_line_no = self.__get_line_column_no(var_node) if len(self.cls_visited) != 0: if len(self.fn_visited) != 0: # A class method's variable - if self.fn_visited[-1][1][var_name] == self.last_visited_assign_t_count: + if self.fn_visited[-1][0]['fn_var_ln'][var_name][0][0] == var_line_no[0]: t = self.__get_fn_vars(self.nlp_p(var_name)) else: # A class variable - if self.cls_visited[-1][1][var_name] == self.last_visited_assign_t_count: + if self.cls_visited[-1][0]["cls_var_ln"][var_name][0][0] == var_line_no[0]: t = self.__get_cls_vars(self.nlp_p(var_name)) elif len(self.fn_visited) != 0: # A module function's variable - if self.fn_visited[-1][1][var_name] == self.last_visited_assign_t_count: + #if self.fn_visited[-1][1][var_name] == self.last_visited_assign_t_count: + if self.fn_visited[-1][0]['fn_var_ln'][var_name][0][0] == var_line_no[0]: t = self.__get_fn_vars(self.nlp_p(var_name)) else: # A module's variables - t = self.__get_mod_vars()[self.nlp_p(var_name)] + if self.f_processed_dict['mod_var_ln'][var_name][0][0] == var_line_no[0]: + t = self.__get_mod_vars()[self.nlp_p(var_name)] return t def __get_var_type_an_assign(self, var_name: str): @@ -962,9 +986,17 @@ def __get_var_type_an_assign(self, var_name: str): def __get_var_names_counter(self, node, scope): vars_name = match.extractall(node, match.OneOf(match.AssignTarget(target=match.SaveMatchedNode( match.Name(value=match.DoNotCare()), "name")), match.AnnAssign(target=match.SaveMatchedNode( - match.Name(value=match.DoNotCare()), "name")))) + match.Name(value=match.DoNotCare()), "name")) + )) + attr_name = match.extractall(node, match.OneOf(match.AssignTarget( + target=match.SaveMatchedNode(match.Attribute(value=match.Name(value=match.DoNotCare()), attr= + match.Name(value=match.DoNotCare())), "attr")), + match.AnnAssign(target=match.SaveMatchedNode(match.Attribute(value=match.Name(value=match.DoNotCare()), attr= + match.Name(value=match.DoNotCare())), "attr")))) return Counter([n['name'].value for n in vars_name if isinstance(self.get_metadata(cst.metadata.ScopeProvider, - n['name']), scope)]) + n['name']), scope)] + + [n['attr'].attr.value for n in attr_name if isinstance(self.get_metadata(cst.metadata.ScopeProvider, + n['attr']), scope)]) def visit_ClassDef(self, node: cst.ClassDef): self.cls_visited.append((self.__get_cls(node), @@ -986,7 +1018,12 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu fn_ret_type = self.__name2annotation(fn_ret_type_resolved) if fn_ret_type is not None: self.all_applied_types.add((fn_ret_type_resolved, fn_ret_type)) + self.no_applied_types += 1 return updated_node.with_changes(returns=fn_ret_type) + else: + self.no_failed_applied_types += 1 + else: + return updated_node.with_changes(returns=None) return updated_node @@ -999,15 +1036,26 @@ def leave_Lambda(self, original_node: cst.Lambda, updated_node: cst.Lambda): def leave_SimpleStatementLine(self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine): + + # Untyped variables + t = None if match.matches(original_node, match.SimpleStatementLine(body=[match.Assign(targets=[match.AssignTarget( + target=match.DoNotCare())])])): + if match.matches(original_node, match.SimpleStatementLine(body=[match.Assign(targets=[match.AssignTarget( target=match.Name(value=match.DoNotCare()))])])): - t = self.__get_var_type_assign_t(original_node.body[0].targets[0].target.value) + t = self.__get_var_type_assign_t(original_node.body[0].targets[0].target.value, + original_node.body[0].targets[0].target) + elif match.matches(original_node, match.SimpleStatementLine(body=[match.Assign(targets=[match.AssignTarget( + target=match.Attribute(value=match.Name(value=match.DoNotCare()), attr=match.Name(value=match.DoNotCare())))])])): + t = self.__get_var_type_assign_t(original_node.body[0].targets[0].target.attr.value, + original_node.body[0].targets[0].target) if t is not None: t_annot_node_resolved = self.resolve_type_alias(t) t_annot_node = self.__name2annotation(t_annot_node_resolved) if t_annot_node is not None: self.all_applied_types.add((t_annot_node_resolved, t_annot_node)) + self.no_applied_types += 1 return updated_node.with_changes(body=[cst.AnnAssign( target=original_node.body[0].targets[0].target, value=original_node.body[0].value, @@ -1015,18 +1063,35 @@ def leave_SimpleStatementLine(self, original_node: cst.SimpleStatementLine, equal=cst.AssignEqual(whitespace_after=original_node.body[0].targets[0].whitespace_after_equal, whitespace_before=original_node.body[0].targets[0].whitespace_before_equal))] ) - elif match.matches(original_node, match.SimpleStatementLine(body=[match.AnnAssign(target=match.Name(value=match.DoNotCare()))])): - t = self.__get_var_type_an_assign(original_node.body[0].target.value) - if t is not None: + else: + self.no_failed_applied_types += 1 + # Typed variables + elif match.matches(original_node, match.SimpleStatementLine(body=[match.AnnAssign(target=match.DoNotCare(), + value=match.MatchIfTrue(lambda v: v is not None))])): + if match.matches(original_node, match.SimpleStatementLine(body=[match.AnnAssign(target=match.Name(value=match.DoNotCare()))])): + t = self.__get_var_type_an_assign(original_node.body[0].target.value) + elif match.matches(original_node, match.SimpleStatementLine(body=[match.AnnAssign(target=match.Attribute(value=match.Name(value=match.DoNotCare()), + attr=match.Name(value=match.DoNotCare())))])): + t = self.__get_var_type_an_assign(original_node.body[0].target.attr.value) + if t: t_annot_node_resolved = self.resolve_type_alias(t) t_annot_node = self.__name2annotation(t_annot_node_resolved) if t_annot_node is not None: self.all_applied_types.add((t_annot_node_resolved, t_annot_node)) + self.no_applied_types += 1 return updated_node.with_changes(body=[cst.AnnAssign( target=original_node.body[0].target, value=original_node.body[0].value, annotation=t_annot_node, equal=original_node.body[0].equal)]) + else: + self.no_failed_applied_types += 1 + else: + return updated_node.with_changes(body=[cst.Assign(targets=[cst.AssignTarget(target=original_node.body[0].target, + whitespace_before_equal=original_node.body[0].equal.whitespace_before, + whitespace_after_equal=original_node.body[0].equal.whitespace_after)], + value=original_node.body[0].value)]) + return original_node @@ -1035,6 +1100,8 @@ def leave_Param(self, original_node: cst.Param, updated_node: cst.Param): fn_param_type = self.__get_fn_param_type(original_node.name.value) if fn_param_type is not None: return updated_node.with_changes(annotation=fn_param_type) + else: + return updated_node.with_changes(annotation=None) return original_node @@ -1051,20 +1118,29 @@ def visit_AssignTarget(self, node: cst.AssignTarget): def leave_Module(self, original_node: cst.Module, updated_node: cst.Module): return updated_node.with_changes(body=self.__get_required_imports() + list(updated_node.body)) + def visit_ImportAlias(self, node: cst.ImportAlias): + self.imported_names.extend([n.value for n in match.findall(node.name, match.Name(value=match.DoNotCare()))]) + # TODO: Check the imported modules before adding new ones def __get_required_imports(self): - def find_required_modules(all_types): + def find_required_modules(all_types, imported_names): req_mod = set() for _, a_node in all_types: m = match.findall(a_node.annotation, match.Attribute(value=match.DoNotCare(), attr=match.DoNotCare())) if len(m) != 0: for i in m: - req_mod.add([n.value for n in match.findall(i, match.Name(value=match.DoNotCare()))][0]) + mod_imp = [n.value for n in match.findall(i, match.Name(value=match.DoNotCare()))][0] + if mod_imp not in imported_names: + req_mod.add(mod_imp) + # if n.value not in imported_names + print(req_mod) return req_mod req_imports = [] - all_req_mods = find_required_modules(self.all_applied_types) + self.imported_names = set(self.imported_names) + all_req_mods = find_required_modules(self.all_applied_types, self.imported_names) all_type_names = set(chain.from_iterable(map(lambda t: regex.findall(r"\w+", t[0]), self.all_applied_types))) + all_type_names = all_type_names - self.imported_names typing_imports = PY_TYPING_MOD & all_type_names collection_imports = PY_COLLECTION_MOD & all_type_names @@ -1098,6 +1174,10 @@ def __get_qualified_name(self, node) -> Optional[str]: q_name = list(self.get_metadata(cst.metadata.QualifiedNameProvider, node)) return q_name[0].name if len(q_name) != 0 else None + def __get_line_column_no(self, node) -> Tuple[int, int]: + lc = self.get_metadata(cst.metadata.PositionProvider, node) + return lc.start.line, lc.end.line + def resolve_type_alias(self, t: str): type_aliases = {'^{}$|^Dict$|(?<=.*)Dict\[\](?<=.*)|(?<=.*)Dict\[Any, *?Any\](?=.*)|^Dict\[unknown, *Any\]$': 'dict', '^Set$|(?<=.*)Set\[\](?<=.*)|^Set\[Any\]$': 'set', diff --git a/libsa4py/merge.py b/libsa4py/merge.py index c5932de..8f0e64c 100644 --- a/libsa4py/merge.py +++ b/libsa4py/merge.py @@ -137,6 +137,6 @@ def merge_projects(args): """ Saves merged projects into a single JSON file and a Dataframe """ - merged_jsons = merge_jsons_to_dict(list_files(join(args.o, 'processed_projects'), ".json"), args.l) + merged_jsons = merge_jsons_to_dict(list_files(join(args.o, 'processed_projects'), ".json")[0], args.l) save_json(join(args.o, 'merged_%s_projects.json' % (str(args.l) if args.l is not None else 'all')), merged_jsons) create_dataframe_fns(args.o, merged_jsons) diff --git a/libsa4py/type_check.py b/libsa4py/type_check.py index abcfe28..4aeaebe 100644 --- a/libsa4py/type_check.py +++ b/libsa4py/type_check.py @@ -68,10 +68,12 @@ def _build_tc_cmd(self, fpath): def _type_check(self, fpath): try: - cwd = os.getcwd() - os.chdir(dirname(fpath)) + # cwd = os.getcwd() + # os.chdir(dirname(fpath)) + # Runs mypy with the file's absolute path + # It may improve detection of type erorrs in some cases! result = subprocess.run( - self._build_tc_cmd(basename(fpath)), + self._build_tc_cmd(fpath), # basename(fpath) capture_output=True, text=True, timeout=self._timeout, @@ -81,8 +83,8 @@ def _type_check(self, fpath): return retcode, outlines except subprocess.TimeoutExpired: raise TypeCheckingTooLong - finally: - os.chdir(cwd) + # finally: + # os.chdir(cwd) @abstractmethod def _check_tc_outcome(self, returncode, outlines): @@ -124,7 +126,8 @@ def heavy_assess(self, fpath): class MypyManager(TCManager): def _build_tc_cmd(self, fpath): # Mypy needs a flag to display the error codes - return ["mypy", "--show-error-codes", "--no-incremental", "--cache-dir=/dev/null", fpath] + return ["mypy", "--show-error-codes", "--no-incremental", "--cache-dir=/dev/null", + "--follow-imports=silent", "--ignore-missing-imports", fpath] def _check_tc_outcome(self, _, outlines): if any(l.endswith(err) for l in outlines for err in self._inc_errcodes): @@ -165,13 +168,13 @@ def _report_errors(self, parsed_result): print(f"Error breaking down: {parsed_result.err_breakdown}.") -def type_check_single_file(f_path: str, tc: TCManager) -> Tuple[bool, Union[int, None]]: +def type_check_single_file(f_path: str, tc: TCManager) -> Tuple[bool, Union[int, None], Union[dict, None]]: try: no_t_err = tc.heavy_assess(f_path) if no_t_err is not None: - return (True, 0) if no_t_err.no_type_errs == 0 else (False, no_t_err.no_type_errs) + return (True, 0, no_t_err.err_breakdown) if no_t_err.no_type_errs == 0 else (False, no_t_err.no_type_errs, no_t_err.err_breakdown) else: - return False, None + return False, None, None except IndexError: print(f"f: {f_path} - No output from Mypy!") - return False, None + return False, None, None diff --git a/libsa4py/utils.py b/libsa4py/utils.py index c247c40..569f9f5 100644 --- a/libsa4py/utils.py +++ b/libsa4py/utils.py @@ -1,7 +1,9 @@ -from typing import List +import shutil +from typing import List, Tuple from tqdm import tqdm from joblib import Parallel from os.path import join, isdir +from tempfile import NamedTemporaryFile from pathlib import Path import time import os @@ -54,18 +56,20 @@ def tmp(op_iter): # return directory -def list_files(directory: str, file_ext: str = ".py") -> list: +def list_files(directory: str, file_ext: str = ".py") -> Tuple[list, int]: """ List all files in the given directory (recursively) """ filenames = [] + dir_size = 0 for root, dirs, files in os.walk(directory): for filename in files: if filename.endswith(file_ext): filenames.append(os.path.join(root, filename)) + dir_size += Path(os.path.join(root, filename)).stat().st_size - return filenames + return filenames, dir_size def read_file(filename: str) -> str: @@ -79,6 +83,13 @@ def write_file(filename: str, content: str): with open(filename, 'w') as file: file.write(content) +def mk_dir_cp_file(src_path: str, dest_path: str): + """ + Creates directories in the destination if not exists and copy the given file + """ + os.makedirs(os.path.dirname(dest_path), exist_ok=True) + shutil.copy(src_path, dest_path) + def save_json(filename: str, dict_obj: dict): """ Dumps a dict object into a JSON file @@ -113,3 +124,24 @@ def find_repos_list(projects_path: str) -> List[dict]: def mk_dir_not_exist(path: str): if not isdir(path): os.mkdir(path) + + +def create_tmp_file(suffix: str): + """ + It creates a temporary file. + NOTE: the temp file should be deleted manually after creation. + """ + return NamedTemporaryFile(mode="w", delete=False, suffix=suffix) + + +def delete_tmp_file(tmp_f: NamedTemporaryFile): + try: + os.unlink(tmp_f.name) + except TypeError: + print("Couldn't delete ", tmp_f.name) + + +def write_to_tmp_file(tmp_f: NamedTemporaryFile, text: str): + tmp_f.write(text) + #tmp_f.close() + return tmp_f diff --git a/tests/examples/type_apply_ex.json b/tests/examples/type_apply_ex.json index 39e4227..608d44e 100644 --- a/tests/examples/type_apply_ex.json +++ b/tests/examples/type_apply_ex.json @@ -1,417 +1,605 @@ { - "tests/examples": { - "src_files": { - "type_apply.py": { - "untyped_seq": "from typing import Tuple , Dict , List , Literal [EOL] from collections import defaultdict [EOL] import pandas [EOL] import pathlib [EOL] import builtins [EOL] import collections [EOL] import typing [EOL] from pathlib import Path [EOL] x = [number] [EOL] l = [ ( [number] , [number] ) ] [EOL] c = defaultdict ( int ) [EOL] df = pd . DataFrame ( [ [number] , [number] ] ) [EOL] dff = pd . DataFrame ( [ [number] , [number] ] ) [EOL] lit = [string] [EOL] class Foo : [EOL] foo_v = [string] [EOL] class Delta : [EOL] foo_d = [string] [EOL] foo_p = Path ( [string] ) [EOL] def __init__ ( ) : [EOL] def foo_inner ( c , d ) : [EOL] pass [EOL] def foo_fn ( self , y ) : [EOL] def foo_inner ( a , b , c , d ) : [EOL] pass [EOL] d = { [string] : True } [EOL] return d [EOL] @ event . getter def get_e ( self ) : [EOL] return Foo . foo_v [EOL] @ event . setter def get_e ( self , y ) : [EOL] Foo . foo_v = y [EOL] return Foo . foo_v [EOL] foo_v = [string] [EOL] def Bar ( x = [ [string] , [string] ] ) : [EOL] v = x [EOL] l = lambda e : e + [number] [EOL] return v [EOL]", - "typed_seq": "0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 $builtins.int$ 0 0 0 $typing.List[typing.Tuple[builtins.int,builtins.int]]$ 0 0 0 0 0 0 0 0 0 $collections.defaultdict$ 0 0 0 0 0 0 $pandas.DataFrame$ 0 0 0 0 0 0 0 0 0 0 0 0 $typing.List[pandas.arrays.PandasArray]$ 0 0 0 0 0 0 0 0 0 0 0 0 $typing.Literal$ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 $pathlib.Path$ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 $typing.Dict[builtins.str,builtins.bool]$ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 $typing.Dict[builtins.str,builtins.bool]$ 0 0 0 0 0 0 0 0 $typing.Dict[builtins.str,builtins.bool]$ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 $builtins.str$ 0 0 0 0 0 0 0 $builtins.str$ 0 0 0 0 0 0 0 0 0 0 0 $typing.List[builtins.str]$ 0 $builtins.int$ 0 0 0 0 0 0 0 0 0 0 0 $builtins.int$ 0 $typing.List[typing.Tuple[builtins.int,builtins.int]]$ 0 0 0 0 0 0 0 0 0 0 0", - "imports": [ - "Tuple", - "Dict", - "List", - "Literal", - "defaultdict", - "pandas", - "pathlib", - "builtins", - "collections", - "typing", - "Path" - ], - "variables": { - "x": "builtins.int", - "l": "typing.List[typing.Tuple[builtins.int, builtins.int]]", - "c": "collections.defaultdict", - "df": "pandas.DataFrame", - "dff": "typing.List[pandas.arrays.PandasArray]", - "lit": "typing.Literal" - }, - "mod_var_occur": { - "x": [ - [ - "v", - "typing", - "List", - "builtins", - "str", - "x" - ] - ], - "l": [ - [ - "l", - "e", - "e" - ] - ], - "c": [], - "df": [], - "dff": [], - "lit": [] - }, - "classes": [ - { - "name": "Delta", - "q_name": "Foo.Delta", - "variables": { - "foo_d": "" - }, - "cls_var_occur": { - "foo_d": [] - }, - "funcs": [] - }, - { - "name": "Foo", - "q_name": "Foo", - "variables": { - "foo_v": "", - "foo_p": "pathlib.Path" - }, - "cls_var_occur": { - "foo_v": [ - [ - "Foo", - "foo_v" - ], - [ - "Foo", - "foo_v", - "y" - ], - [ - "Foo", - "foo_v" - ] + "tests/examples": { + "src_files": { + "type_apply.py": { + "untyped_seq": "", + "typed_seq": "", + "imports": [ + "pathlib", + "Path", + "pandas" ], - "foo_p": [] - }, - "funcs": [ - { - "name": "foo_inner", - "q_name": "Foo.__init__..foo_inner", - "fn_lc": [ + "variables": { + "x": "builtins.int", + "l": "typing.List[typing.Tuple[builtins.int, builtins.int]]", + "c": "collections.defaultdict", + "df": "pandas.DataFrame", + "dff": "typing.List[pandas.arrays.PandasArray]", + "lit": "typing.Literal" + }, + "mod_var_occur": { + "x": [ [ - 21, - 8 - ], + "v", + "typing", + "List", + "builtins", + "str", + "x" + ] + ], + "l": [ [ - 22, - 16 + "l", + "e", + "e" ] ], - "params": { - "c": "str", - "d": "" - }, - "ret_exprs": [], - "params_occur": { - "c": [], - "d": [] - }, - "ret_type": "", - "variables": {}, - "fn_var_occur": {}, - "params_descr": { - "c": "", - "d": "" - }, - "docstring": { - "func": null, - "ret": null, - "long_descr": null - } + "c": [], + "df": [], + "dff": [], + "lit": [] }, - { - "name": "__init__", - "q_name": "Foo.__init__", - "fn_lc": [ + "mod_var_ln": { + "x": [ [ - 20, - 4 + 3, + 0 ], [ - 22, - 16 + 3, + 1 ] ], - "params": {}, - "ret_exprs": [], - "params_occur": {}, - "ret_type": "", - "variables": {}, - "fn_var_occur": {}, - "params_descr": {}, - "docstring": { - "func": null, - "ret": null, - "long_descr": null - } - }, - { - "name": "foo_inner", - "q_name": "Foo.foo_fn..foo_inner", - "fn_lc": [ + "l": [ [ - 24, - 8 + 4, + 0 ], [ - 25, - 16 + 4, + 1 ] ], - "params": { - "a": "", - "b": "", - "c": "", - "d": "", - "args": "", - "kwargs": "" - }, - "ret_exprs": [], - "params_occur": { - "a": [], - "b": [], - "c": [], - "d": [] - }, - "ret_type": "", - "variables": {}, - "fn_var_occur": {}, - "params_descr": { - "a": "", - "b": "", - "c": "", - "d": "" - }, - "docstring": { - "func": null, - "ret": null, - "long_descr": null - } - }, - { - "name": "foo_fn", - "q_name": "Foo.foo_fn", - "fn_lc": [ + "c": [ [ - 23, - 4 + 5, + 0 ], [ - 27, - 16 + 5, + 1 ] ], - "params": { - "self": "", - "y": "" - }, - "ret_exprs": [ - "return d" - ], - "params_occur": { - "self": [], - "y": [] - }, - "ret_type": "typing.Dict[builtins.str, builtins.bool]", - "variables": { - "d": "typing.Dict[builtins.str, builtins.bool]" - }, - "fn_var_occur": { - "d": [ - [ - "d", - "typing", - "Dict", - "builtins", - "str", - "builtins", - "bool", - "True" - ] - ] - }, - "params_descr": { - "self": "", - "y": "" - }, - "docstring": { - "func": null, - "ret": null, - "long_descr": null - } - }, - { - "name": "get_e", - "q_name": "Foo.get_e", - "fn_lc": [ + "df": [ [ - 29, - 4 + 6, + 0 ], [ - 30, - 24 + 6, + 2 ] ], - "params": { - "self": "" - }, - "ret_exprs": [ - "return Foo.foo_v" - ], - "params_occur": { - "self": [] - }, - "ret_type": "", - "variables": {}, - "fn_var_occur": {}, - "params_descr": { - "self": "" - }, - "docstring": { - "func": null, - "ret": null, - "long_descr": null - } - }, - { - "name": "get_e", - "q_name": "Foo.get_e", - "fn_lc": [ + "dff": [ [ - 32, - 4 + 7, + 0 ], [ - 34, - 24 + 7, + 3 ] ], - "params": { - "self": "", - "y": "builtins.str" - }, - "ret_exprs": [ - "return Foo.foo_v" - ], - "params_occur": { - "self": [], - "y": [ + "lit": [ + [ + 8, + 0 + ], + [ + 8, + 3 + ] + ] + }, + "classes": [ + { + "name": "Delta", + "q_name": "Foo.Delta", + "cls_lc": [ + [ + 11, + 4 + ], [ - "Foo", - "foo_v", - "y" + 12, + 31 ] - ] - }, - "ret_type": "", - "variables": { - "foo_v": "" + ], + "variables": { + "foo_d": "" + }, + "cls_var_occur": { + "foo_d": [] + }, + "cls_var_ln": { + "foo_d": [ + [ + 12, + 8 + ], + [ + 12, + 13 + ] + ] + }, + "funcs": [] }, - "fn_var_occur": { - "foo_v": [ + { + "name": "Foo", + "q_name": "Foo", + "cls_lc": [ [ - "Foo", - "foo_v", - "y" + 9, + 0 ], [ - "Foo", - "foo_v" + 30, + 30 ] + ], + "variables": { + "foo_v": "builtins.str", + "foo_p": "pathlib.Path" + }, + "cls_var_occur": { + "foo_v": [ + [ + "Foo", + "foo_v" + ], + [ + "Foo", + "foo_v", + "y" + ], + [ + "Foo", + "foo_v" + ] + ], + "foo_p": [] + }, + "cls_var_ln": { + "foo_v": [ + [ + 30, + 4 + ], + [ + 30, + 9 + ] + ], + "foo_p": [ + [ + 13, + 4 + ], + [ + 13, + 9 + ] + ] + }, + "funcs": [ + { + "name": "foo_inner", + "q_name": "Foo.__init__..foo_inner", + "fn_lc": [ + [ + 16, + 8 + ], + [ + 17, + 16 + ] + ], + "params": { + "c": "builtins.str", + "d": "" + }, + "ret_exprs": [], + "params_occur": { + "c": [], + "d": [] + }, + "ret_type": "", + "variables": {}, + "fn_var_occur": {}, + "fn_var_ln": {}, + "params_descr": { + "c": "", + "d": "" + }, + "docstring": { + "func": null, + "ret": null, + "long_descr": null + } + }, + { + "name": "__init__", + "q_name": "Foo.__init__", + "fn_lc": [ + [ + 14, + 4 + ], + [ + 17, + 16 + ] + ], + "params": { + "self": "" + }, + "ret_exprs": [], + "params_occur": { + "self": [] + }, + "ret_type": "", + "variables": { + "i": "builtins.int" + }, + "fn_var_occur": { + "i": [] + }, + "fn_var_ln": { + "i": [ + [ + 15, + 8 + ], + [ + 15, + 14 + ] + ] + }, + "params_descr": { + "self": "" + }, + "docstring": { + "func": null, + "ret": null, + "long_descr": null + } + }, + { + "name": "foo_inner", + "q_name": "Foo.foo_fn..foo_inner", + "fn_lc": [ + [ + 19, + 8 + ], + [ + 20, + 16 + ] + ], + "params": { + "a": "", + "b": "", + "c": "", + "d": "", + "args": "", + "kwargs": "" + }, + "ret_exprs": [], + "params_occur": { + "a": [], + "b": [], + "c": [], + "d": [], + "args": [], + "kwargs": [] + }, + "ret_type": "", + "variables": {}, + "fn_var_occur": {}, + "fn_var_ln": {}, + "params_descr": { + "a": "", + "b": "", + "c": "", + "d": "", + "args": "", + "kwargs": "" + }, + "docstring": { + "func": null, + "ret": null, + "long_descr": null + } + }, + { + "name": "foo_fn", + "q_name": "Foo.foo_fn", + "fn_lc": [ + [ + 18, + 4 + ], + [ + 22, + 16 + ] + ], + "params": { + "self": "", + "y": "" + }, + "ret_exprs": [ + "return d" + ], + "params_occur": { + "self": [], + "y": [] + }, + "ret_type": "typing.Dict[builtins.str, builtins.bool]", + "variables": { + "d": "typing.Dict[builtins.str, builtins.bool]" + }, + "fn_var_occur": { + "d": [ + [ + "d", + "typing", + "Dict", + "builtins", + "str", + "builtins", + "bool", + "True" + ] + ] + }, + "fn_var_ln": { + "d": [ + [ + 21, + 8 + ], + [ + 21, + 9 + ] + ] + }, + "params_descr": { + "self": "", + "y": "" + }, + "docstring": { + "func": null, + "ret": null, + "long_descr": null + } + }, + { + "name": "get_e", + "q_name": "Foo.get_e", + "fn_lc": [ + [ + 24, + 4 + ], + [ + 25, + 24 + ] + ], + "params": { + "self": "" + }, + "ret_exprs": [ + "return Foo.foo_v" + ], + "params_occur": { + "self": [] + }, + "ret_type": "", + "variables": {}, + "fn_var_occur": {}, + "fn_var_ln": {}, + "params_descr": { + "self": "" + }, + "docstring": { + "func": null, + "ret": null, + "long_descr": null + } + }, + { + "name": "get_e", + "q_name": "Foo.get_e", + "fn_lc": [ + [ + 27, + 4 + ], + [ + 29, + 24 + ] + ], + "params": { + "self": "", + "y": "builtins.str" + }, + "ret_exprs": [ + "return Foo.foo_v" + ], + "params_occur": { + "self": [], + "y": [ + [ + "Foo", + "foo_v", + "y" + ] + ] + }, + "ret_type": "", + "variables": { + "foo_v": "" + }, + "fn_var_occur": { + "foo_v": [ + [ + "Foo", + "foo_v", + "y" + ], + [ + "Foo", + "foo_v" + ] + ] + }, + "fn_var_ln": { + "foo_v": [ + [ + 28, + 8 + ], + [ + 28, + 17 + ] + ] + }, + "params_descr": { + "self": "", + "y": "" + }, + "docstring": { + "func": null, + "ret": null, + "long_descr": null + } + } ] - }, - "params_descr": { - "self": "", - "y": "" - }, - "docstring": { - "func": null, - "ret": null, - "long_descr": null } - } - ] - } - ], - "funcs": [ - { - "name": "Bar", - "q_name": "Bar", - "fn_lc": [ - [ - 36, - 0 ], - [ - 39, - 12 - ] - ], - "params": { - "x": "typing.List[builtins.str]", - "c": "" - }, - "ret_exprs": [ - "return v" - ], - "params_occur": { - "x": [ - [ - "v", - "typing", - "List", - "builtins", - "str", - "x" - ] - ] - }, - "ret_type": "typing.List[builtins.str]", - "variables": { - "v": "typing.List[builtins.str]", - "l": "" - }, - "fn_var_occur": { - "v": [ - [ - "v", - "typing", - "List", - "builtins", - "str", - "x" - ] + "funcs": [ + { + "name": "Bar", + "q_name": "Bar", + "fn_lc": [ + [ + 31, + 0 + ], + [ + 34, + 12 + ] + ], + "params": { + "x": "typing.List[builtins.str]", + "c": "" + }, + "ret_exprs": [ + "return v" + ], + "params_occur": { + "x": [ + [ + "v", + "typing", + "List", + "builtins", + "str", + "x" + ] + ], + "c": [] + }, + "ret_type": "typing.List[builtins.str]", + "variables": { + "v": "typing.List[builtins.str]", + "l": "" + }, + "fn_var_occur": { + "v": [ + [ + "v", + "typing", + "List", + "builtins", + "str", + "x" + ] + ], + "l": [ + [ + "l", + "e", + "e" + ] + ] + }, + "fn_var_ln": { + "v": [ + [ + 32, + 4 + ], + [ + 32, + 5 + ] + ], + "l": [ + [ + 33, + 4 + ], + [ + 33, + 5 + ] + ] + }, + "params_descr": { + "x": "", + "c": "" + }, + "docstring": { + "func": null, + "ret": null, + "long_descr": null + } + } ], - "l": [ - [ - "l", - "e", - "e" - ] - ] - }, - "params_descr": { - "x": "" - }, - "docstring": { - "func": null, - "ret": null, - "long_descr": null + "set": null, + "tc": [ + false, + null + ], + "no_types_annot": { + "U": 14, + "D": 15, + "I": 0 + }, + "type_annot_cove": 0.52 } } - ], - "set": null, - "tc": false, - "no_types_annot": { - "U": 12, - "D": 13, - "I": 0 - }, - "type_annot_cove": 0.52 -} } - } } \ No newline at end of file diff --git a/tests/examples/type_apply_typed_ex.json b/tests/examples/type_apply_typed_ex.json new file mode 100644 index 0000000..4ce162c --- /dev/null +++ b/tests/examples/type_apply_typed_ex.json @@ -0,0 +1,325 @@ +{ + "tests/examples": { + "src_files": { + "type_apply_typed.py": { + "untyped_seq": "", + "typed_seq": "", + "imports": [], + "variables": { + "a": "", + "l": "", + "c": "", + "h": "builtins.dict" + }, + "mod_var_occur": { + "a": [ + [ + "self", + "a", + "a" + ] + ], + "l": [], + "c": [], + "h": [] + }, + "mod_var_ln": { + "a": [ + [ + 1, + 0 + ], + [ + 1, + 1 + ] + ], + "l": [ + [ + 2, + 0 + ], + [ + 2, + 1 + ] + ], + "c": [ + [ + 3, + 0 + ], + [ + 3, + 1 + ] + ], + "h": [ + [ + 4, + 0 + ], + [ + 4, + 1 + ] + ] + }, + "classes": [ + { + "name": "Bar", + "q_name": "Bar", + "cls_lc": [ + [ + 8, + 0 + ], + [ + 15, + 25 + ] + ], + "variables": { + "bar_var1": "", + "bar_var2": "" + }, + "cls_var_occur": { + "bar_var1": [], + "bar_var2": [] + }, + "cls_var_ln": { + "bar_var1": [ + [ + 9, + 4 + ], + [ + 9, + 12 + ] + ], + "bar_var2": [ + [ + 10, + 4 + ], + [ + 10, + 12 + ] + ] + }, + "funcs": [ + { + "name": "__init__", + "q_name": "Bar.__init__", + "fn_lc": [ + [ + 11, + 4 + ], + [ + 13, + 18 + ] + ], + "params": { + "a": "", + "b": "" + }, + "ret_exprs": [], + "params_occur": { + "a": [ + [ + "self", + "a", + "a" + ] + ], + "b": [ + [ + "self", + "b", + "b" + ] + ] + }, + "ret_type": "", + "variables": { + "a": "", + "b": "" + }, + "fn_var_occur": { + "a": [ + [ + "self", + "a", + "a" + ] + ], + "b": [ + [ + "self", + "b", + "b" + ] + ] + }, + "fn_var_ln": { + "a": [ + [ + 12, + 8 + ], + [ + 12, + 14 + ] + ], + "b": [ + [ + 13, + 8 + ], + [ + 13, + 14 + ] + ] + }, + "params_descr": { + "a": "", + "b": "" + }, + "docstring": { + "func": null, + "ret": null, + "long_descr": null + } + }, + { + "name": "delta", + "q_name": "Bar.delta", + "fn_lc": [ + [ + 14, + 4 + ], + [ + 15, + 25 + ] + ], + "params": { + "n": "" + }, + "ret_exprs": [ + "return [2.17] * p" + ], + "params_occur": { + "n": [] + }, + "ret_type": "", + "variables": {}, + "fn_var_occur": {}, + "fn_var_ln": {}, + "params_descr": { + "n": "" + }, + "docstring": { + "func": null, + "ret": null, + "long_descr": null + } + } + ] + } + ], + "funcs": [ + { + "name": "foo", + "q_name": "foo", + "fn_lc": [ + [ + 5, + 0 + ], + [ + 7, + 12 + ] + ], + "params": { + "x": "", + "y": "" + }, + "ret_exprs": [ + "return z" + ], + "params_occur": { + "x": [ + [ + "z", + "x", + "y" + ] + ], + "y": [ + [ + "z", + "x", + "y" + ] + ] + }, + "ret_type": "", + "variables": { + "z": "" + }, + "fn_var_occur": { + "z": [ + [ + "z", + "x", + "y" + ] + ] + }, + "fn_var_ln": { + "z": [ + [ + 6, + 4 + ], + [ + 6, + 5 + ] + ] + }, + "params_descr": { + "x": "", + "y": "" + }, + "docstring": { + "func": null, + "ret": null, + "long_descr": null + } + } + ], + "set": null, + "tc": [ + false, + null + ], + "no_types_annot": { + "U": 14, + "D": 1, + "I": 0 + }, + "type_annot_cove": 0.07 + } + } + } +} \ No newline at end of file diff --git a/tests/test_type_apply.py b/tests/test_type_apply.py index 5fb08c4..8607de1 100644 --- a/tests/test_type_apply.py +++ b/tests/test_type_apply.py @@ -5,6 +5,7 @@ import shutil test_file = """from pathlib import Path +import pandas x: int = 12 l = [(1, 2)] c = defaultdict(int) @@ -12,11 +13,12 @@ dff = pd.DataFrame([1,2]) lit = "Hello!" class Foo: - foo_v: str = 'Hello, Foo!' + foo_v = 'Hello, Foo!' class Delta: foo_d = 'Hello, Delta!' foo_p = Path('/home/foo/bar') - def __init__(): + def __init__(self): + self.i = 10 def foo_inner(c, d=lambda a,b: a == b): pass def foo_fn(self, y): @@ -40,12 +42,12 @@ def Bar(x=['apple', 'orange'], *, c): test_file_exp = """from typing import Tuple, Dict, List, Literal from collections import defaultdict -import pandas import pathlib import builtins import collections import typing from pathlib import Path +import pandas x: builtins.int = 12 l: typing.List[typing.Tuple[builtins.int, builtins.int]] = [(1, 2)] c: collections.defaultdict = defaultdict(int) @@ -53,12 +55,13 @@ def Bar(x=['apple', 'orange'], *, c): dff: typing.List[pandas.arrays.PandasArray] = pd.DataFrame([1,2]) lit: typing.Literal = "Hello!" class Foo: - foo_v: str = 'Hello, Foo!' + foo_v = 'Hello, Foo!' class Delta: foo_d = 'Hello, Delta!' foo_p: pathlib.Path = Path('/home/foo/bar') - def __init__(): - def foo_inner(c: str, d=lambda a,b: a == b): + def __init__(self): + self.i: builtins.int = 10 + def foo_inner(c: builtins.str, d=lambda a,b: a == b): pass def foo_fn(self, y)-> typing.Dict[builtins.str, builtins.bool]: def foo_inner(a, b, c, d, *args, **kwargs): @@ -72,13 +75,47 @@ def get_e(self): def get_e(self, y: builtins.str): Foo.foo_v = y return Foo.foo_v - foo_v = "No" + foo_v: builtins.str = "No" def Bar(x: typing.List[builtins.str]=['apple', 'orange'], *, c)-> typing.List[builtins.str]: v: typing.List[builtins.str] = x l = lambda e: e+1 return v """ +test_file_typed = """a: int = 12 +l: List[int] = [1,2,3] +c = 2.71 +h: dict +def foo(x: int, y: int) -> int: + z: int = x + y + return z +class Bar: + bar_var1: str = "Hello, Bar!" + bar_var2: float = 3.14 + def __init__(a: int, b): + self.a: int = a + self.b = b + def delta(n: int) -> List[float]: + return [2.17] * p +""" + +test_file_typed_exp = """a = 12 +l = [1,2,3] +c = 2.71 +h: dict +def foo(x, y): + z = x + y + return z +class Bar: + bar_var1 = "Hello, Bar!" + bar_var2 = 3.14 + def __init__(a, b): + self.a = a + self.b = b + def delta(n): + return [2.17] * p +""" + class TestTypeAnnotatingProjects(unittest.TestCase): """ @@ -92,12 +129,17 @@ def __init__(self, *args, **kwargs): def setUpClass(cls): mk_dir_not_exist('./tmp_ta') write_file('./tmp_ta/type_apply.py', test_file) + write_file('./tmp_ta/type_apply_typed.py', test_file_typed) + # from libsa4py.cst_extractor import Extractor - # save_json('./tmp_ta/type_apply_ex.json', Extractor.extract(read_file('./tmp_ta/type_apply.py')).to_dict()) + # save_json('./tmp_ta/type_apply_ex.json', {"tests/examples": {"src_files": {"type_apply.py": + # Extractor.extract(read_file('./tmp_ta/type_apply.py'), include_seq2seq=False).to_dict()}}}) + # save_json('./tmp_ta/type_apply_typed_ex.json', {"tests/examples": {"src_files": {"type_apply_typed.py": + # Extractor.extract(read_file('./tmp_ta/type_apply_typed.py'), include_seq2seq=False).to_dict()}}}) def test_type_apply_pipeline(self): ta = TypeAnnotatingProjects('./tmp_ta', None, apply_nlp=False) - ta.process_project('./examples/type_apply_ex.json') + total_no_added_types = ta.process_project('./examples/type_apply_ex.json') exp_split = test_file_exp.splitlines() out_split = read_file('./tmp_ta/type_apply.py').splitlines() @@ -106,9 +148,20 @@ def test_type_apply_pipeline(self): out = """{}""".format("\n".join(out_split[7:])) self.assertEqual(exp, out) + self.assertEqual(total_no_added_types[0], 16) + # The imported types from typing self.assertEqual(Counter(" ".join(exp_split[0:7])), Counter(" ".join(out_split[0:7]))) + def test_type_apply_remove_annot(self): + """ + Tests the removal of type annotations if not present in the JSON output + """ + ta = TypeAnnotatingProjects('./tmp_ta', None, apply_nlp=False) + ta.process_project('./examples/type_apply_typed_ex.json') + + self.assertEqual(test_file_typed_exp, read_file('./tmp_ta/type_apply_typed.py')) + @classmethod def tearDownClass(cls): shutil.rmtree("./tmp_ta/")