Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update filtering and decontamination #58

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 95 additions & 38 deletions preprocessing/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
ALL_FILTERS = ["basic", "basic_per_extension", "stars", "comments", "fertility", "xml", "html", "large_and_small_files"]
THRESHOLDS_FERTILITY = {"python": 2.5, "java": 2.9, "javascript": 2.6}


LANG = "language"
class MultiChoice:
def __init__(self, choices):
self.choices = choices
Expand Down Expand Up @@ -63,7 +63,7 @@ def parse_args():
def get_comments_ratio(examples):
"""Get ratio of comments to code in each example. Requires a language argument"""
ratio_list = []
for code, language in zip(examples["content"], examples["lang"]):
for code, language in zip(examples["content"], examples[LANG]):
ratio_list.append(get_nl_ratio(code, language.lower()))
return {"nl_ratio": ratio_list}

Expand All @@ -89,6 +89,17 @@ def basic_filters(example):
return False
return True

def add_stats(example):
"""Add extra stats:
- size of text, mean and max line length of file
- % alphanumeric characters
- extracts file extension"""
size = len(example["content"])
line_lengths = [len(line) for line in example["content"].splitlines()]
alpha_frac = np.mean([c.isalnum() for c in example["content"]])
ext = example["path"].split(".")[-1]
return {"size": size, "avg_line_length": np.mean(line_lengths), "max_line_length": max(line_lengths), "alphanum_fraction": alpha_frac, "ext": ext}


def basic_filters_per_extension(example, ext_to_filter):
"""Filter files based on line length and % alphanumeric characters.
Expand All @@ -97,7 +108,7 @@ def basic_filters_per_extension(example, ext_to_filter):
# extension `None` is an empty string in the csv
try:
(include, line_max, line_mean, alphanum_frac, alphabetic_frac) = ext_to_filter[(language_format_from_dataset(
example["lang"]), example["ext"] if example["ext"] is not None else ""
example[LANG]), example["ext"] if example["ext"] is not None else ""
)]
except KeyError as e:
# Some extensions are not in the csv. This happens for dockerfiles.
Expand Down Expand Up @@ -187,7 +198,7 @@ def char_token_ratio(examples, tokenizer):
def filter_tokenizer(examples):
"""Filter files based on char to token ratio"""
values = []
for ratio, lang in zip(examples["fertility_ratio"], examples["lang"]):
for ratio, lang in zip(examples["fertility_ratio"], examples[LANG]):
if ratio < THRESHOLDS_FERTILITY[lang.lower()]:
values.append(False)
else:
Expand All @@ -202,7 +213,7 @@ def filter_xml(example):

def filter_html(example):
"""Filter HTML files based on displayed text VS code ratio"""
assert example["lang"] == "HTML", "Filter is only for html examples"
assert example[LANG] == "HTML", "Filter is only for html examples"
html = example["content"]
try:
soup = BeautifulSoup(html, features="html.parser")
Expand All @@ -226,6 +237,8 @@ def filter_large_and_small_files(example):
def get_size_text(example):
return {"size": len(example["content"])}

def get_ext(example):
return {"ext": example["path"].split(".")[-1]}

LICENSE_COLUMNS = ['max_stars_repo_licenses', 'max_issues_repo_licenses', 'max_forks_repo_licenses']
def fix_license_cols(example):
Expand All @@ -234,6 +247,7 @@ def fix_license_cols(example):
return example



if __name__ == "__main__":
args = parse_args()
print(f"Selected filters: {args.filters}")
Expand All @@ -258,20 +272,29 @@ def fix_license_cols(example):
# Load dataset
t_start = time.time()
logger.info(f" ===== Loading {args.dataset_name} and subset {args.subset}=====")
# assert out_path/data doesn't exists
import os
if os.path.exists(f"{args.out_path}/data"):
raise ValueError(f"Output path already exists: {args.out_path}/data delete if before filtering")

dataset = load_dataset(
args.dataset_name, split=args.split, data_dir=args.subset, use_auth_token=True, num_proc=args.num_workers
args.dataset_name, split=args.split, use_auth_token=True, num_proc=rgs.num_workers
)
logger.info(f"Dataset loaded in {time.time() - t_start:.2f} seconds")
logger.info(f"Dataset: {dataset}")
if "size" not in dataset.column_names:
logger.info("Add text size column")
dataset = dataset.map(get_size_text)
logger.info("Add text size column, ext and line stats")
dataset = dataset.map(add_stats, num_proc=args.num_workers)
if args.fix_license_columns:
dataset = dataset.map(fix_license_cols, num_proc=args.num_workers)
logger.info(
f"Dataset size before any filtering: {len(dataset)} examples, {sum(dataset['size']) / 1e9:.2f} GB"
f"Dataset size before any filtering: {len(dataset)} examples, {sum(dataset['size']) / 1e9:.2f} GB and columns: {dataset.column_names}"
)
# filter non permissive data
dataset = dataset.filter(lambda x: x["license_type"] != "non_permissive")
logger.info(
f"Dataset size after non permissive filtering: {len(dataset)} examples, {sum(dataset['size']) / 1e9:.2f} GB"
)

# Run pre-processing if needed
if "stars" in filters:
logger.info(f"===== Processing dataset to add proper stars column=====")
Expand Down Expand Up @@ -335,6 +358,8 @@ def fix_license_cols(example):
elif filter == "basic_per_extension":
assert args.per_extension_filter_csv is not None
language = language_format_from_data_dir(args.subset.split("/")[-1]) if args.subset is not None else None
language = "python"
logger.info("selected language: ", language)
logger.info(
f"===== Language: {language}. Basic filtering with line_max, avg_line, alphanum_frac and alphabetic_frac given by : {args.per_extension_filter_csv} ====="
)
Expand Down Expand Up @@ -536,6 +561,65 @@ def fix_license_cols(example):
)
dataset = ds


# Run decontamination
if args.run_decontamination:
logger.info(
f"===== Running decontamination ====="
)
import sys
import os
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir))
from decontamination.benchmark_data import FILTER_OUT

FILTER_OUT.pop('apps_docstrings', None)
FILTER_OUT.pop('gsm8k_questions', None)
logger.info(f"FILTER OUT Benchmarks: {FILTER_OUT.keys()}")
def decontaminate(samples, filter_out=FILTER_OUT):
"""
filter_out: Dict[str, List[str]] mapping from benchmark name to list of strings that need to be
filtered-out.
Return a list where each element is True if the corresponding file should be included in the dataset.
Otherwise, the element is False.
"""
output = []

for content in samples["content"]:
content = content.lower()
matched = False
for benchmark, substrings in filter_out.items():
for substring in substrings:
if substring.lower() in content:
matched = True
break
if matched:
break
# we keep files that are not matched
output.append(not matched)

return output

old_size = len(dataset)
old_size_gb = sum(dataset["size"])
dataset = dataset.filter(decontaminate, batched=True, batch_size=10_000, num_proc=64)
filtered_size_gb = sum(dataset["size"])
logger.info(
f"Removed {old_size - len(dataset)} files from {old_size} (i.e {(old_size - len(dataset)) * 100 / old_size}%)"
)
logger.info(
f"Dataset size after decontamination: {len(dataset)} examples, {filtered_size_gb / 1e9:.2f} GB"
)

if args.add_metadata:
from add_content_with_meta import content_with_meta

logger.info("===== Adding content with metadata =====")
dataset = dataset.map(
content_with_meta,
remove_columns=["content"],
num_proc=args.num_workers,
)

# Save dataset
logger.info(
f"Final dataset has {len(dataset)} samples and {sum(dataset['size']) / 1e9:.2f} GB of code"
Expand All @@ -548,7 +632,7 @@ def fix_license_cols(example):
dataset.push_to_hub(args.remote_repo)
else:
print(
f"Saving the dataset in manual shards in a clone of {args.hub_username + args.remote_repo}"
f"Saving the dataset in manual shards in a clone of {args.hub_username}/{args.remote_repo}"
)
try:
save_manual_shards(
Expand All @@ -557,30 +641,3 @@ def fix_license_cols(example):
logger.info(f"Dataset successfully saved at {args.out_path}/{args.subset} in {time.time() - t_start:.2f} seconds")
except FileExistsError:
logger.warning(f"Output dir already exists at {args.out_path}/{args.subset}. Will not save filtered data")

# Run decontamination
if args.run_decontamination:
import sys
import os
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir))
from decontamination.find_substrings import SubstringFilterer

output_dir_decontaminated = f"{args.out_path}_decontaminate/{args.subset}"

filterer = SubstringFilterer(
output_dir=output_dir_decontaminated,
cached_decontamination_dir=None, # no previous cached run
split_languages=False,
cache_retrieval_key="",
data_dir=output_dir_decontaminated
)

filtered = filterer.run(dataset, args.num_workers, args.batch_size)

filtered_size_gb = sum(filtered["size"])
logger.info(
f"Removed {len(dataset) - len(filtered)} / {len(dataset)} files"
)
logger.info(
f"Dataset size after decontamination: {len(filtered)} examples, {filtered_size_gb / 1e9:.2f} GB"
)