-
Notifications
You must be signed in to change notification settings - Fork 5
/
collect_mixture_data.py
58 lines (51 loc) · 1.95 KB
/
collect_mixture_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import pandas as pd
from copy import copy
import yaml
import os
import argparse
def read_config(config_file):
# read the yaml config
with open(config_file, "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
new_config = {}
train_keys = list(config["train"].keys())
for key in train_keys:
# remove train_doremi_sample prefix
if key.startswith("train_"):
new_config[key] = config["train"][key]
flatten_dict = {}
for key, value in new_config.items():
if type(value) == float:
flatten_dict[key] = round(value, 5)
if type(value) == int:
flatten_dict[key] = value
return flatten_dict
def gather_mixture_data(write_file_path, config_folder):
# read all files in the config folder
output_dict = {}
for file_path in os.listdir(config_folder):
# only read yaml files
if not file_path.endswith(".yaml"):
print("skip", file_path)
continue
full_path = os.path.join(config_folder, file_path)
# index name is the file path remove the prefix "n"
index_name = int(file_path.split(".")[0].replace("n", ""))
config = read_config(full_path)
# only the train part is valid
output_dict[index_name] = config
# convert the dict to dataframe
df = pd.DataFrame(output_dict).T
# the index column is the index name
df.index.name = "index"
# order by index name
df = df.sort_index()
df.to_csv(write_file_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--write_file_path", type=str, default="train_mixture_1m.csv")
parser.add_argument("--config_folder", type=str, default="../mixture_config/config_1m")
args = parser.parse_args()
write_file_path = args.write_file_path
config_folder = args.config_folder
gather_mixture_data(write_file_path, config_folder)