-
Notifications
You must be signed in to change notification settings - Fork 11
/
merge_dnsmos.py
160 lines (131 loc) · 4.61 KB
/
merge_dnsmos.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# 2023 (c) LINE Corporation
# Authors: Robin Scheibler
# MIT License
import argparse
import csv
import json
from collections import defaultdict
from pathlib import Path
from evaluate_mp import summarize
fieldnames = [
"",
"filename",
"len_in_sec",
"sr",
"num_hops",
"OVRL_raw",
"SIG_raw",
"BAK_raw",
"OVRL",
"SIG",
"BAK",
]
types = {
"filename": Path,
"len_in_sec": float,
"sr": int,
"num_hops": int,
"OVRL_raw": float,
"SIG_raw": float,
"BAK_raw": float,
"OVRL": float,
"SIG": float,
"BAK": float,
}
def get_results_filepath(exp_path, split):
filepath = exp_path / f"{split}.json"
if not filepath.exists():
return False
else:
return filepath
def get_dnsmos_filepath(exp_path, split):
filepath = exp_path / f"{split}_dnsmos.csv"
if not filepath.exists():
return False
else:
return filepath
def parse_dnsmos_csv(filepath):
dnsmos = defaultdict(lambda: {})
with open(filepath, newline="") as csvfile:
dnsmos_reader = csv.reader(csvfile, delimiter=",")
for idx, row in enumerate(dnsmos_reader):
if idx == 0:
# check that this is a valid DNSMOS output file
for f1, f2 in zip(row, fieldnames):
if f1 != f2:
raise ValueError(
f"There might be an error in the DNSMOS file ({f1} != {f2})"
)
else:
sample_idx, channel_idx = Path(row[1]).stem.split(".")
sample_idx = int(sample_idx)
channel_idx = int(channel_idx[3:])
dnsmos_res = {}
for key, val in zip(fieldnames[2:], row[2:]):
dnsmos_res[key] = types[key](val)
dnsmos[sample_idx][channel_idx] = dnsmos_res
if len(dnsmos) == 0:
raise ValueError("Empty DNSMOS file")
# run some checks on the dictionary
num_chan = len(dnsmos[sample_idx])
errors = {}
for sample_idx, res in dnsmos.items():
if num_chan != len(res):
errors[sample_idx] = len(res)
if len(errors) > 0:
print(f"Found {len(errors)} errors")
for sample_idx, num_el in errors.items():
print(f" - sample {sample_idx} has only {num_el} channels")
# convert to desired output format
dnsmos_output = {}
for sample_idx, res in dnsmos.items():
dnsmos_output[sample_idx] = {}
for key in fieldnames[2:]:
dnsmos_output[sample_idx][key] = []
for idx in range(num_chan):
dnsmos_output[sample_idx][key].append(dnsmos[sample_idx][idx][key])
return dnsmos_output
def get_results_file(filepath):
with open(filepath, "r") as f:
res = json.load(f)
return res
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Merge DNSMOS evaluation results into main result file"
)
parser.add_argument("results_path", type=Path, help="Path to result folder")
parser.add_argument(
"--overwrite-results", action="store_true", help="Path to result folder"
)
args = parser.parse_args()
for split in ["val", "test", "libri-clean", "libri-noisy"]:
if not (results_path := get_results_filepath(args.results_path, split)):
print(f"Seems evaluate.py has not been run for {split}. Skip.")
continue
else:
print(f"{split}: found results file")
if not (dnsmos_path := get_dnsmos_filepath(args.results_path, split)):
print(f"Seems DNSMOS evaluation has not been run for {split}. Skip.")
continue
else:
print(f"{split}: found DNSMOS file")
dnsmos = parse_dnsmos_csv(dnsmos_path)
results = get_results_file(results_path)
for idx, metrics in results.items():
idx = int(idx)
if idx not in dnsmos:
breakpoint()
raise ValueError(f"Sample {idx} not found in DNSMOS file")
metrics.update(dnsmos[idx])
summary = summarize(results, ignore_inf=False)
if args.overwrite_results:
output_results = results_path
output_summary = args.results_path / f"{split}_summary.json"
else:
output_results = args.results_path / f"{split}_with_dnsmos.json"
output_summary = args.results_path / f"{split}_summary_with_dnsmos.json"
with open(output_results, "w") as f:
json.dump(results, f, indent=2)
with open(output_summary, "w") as f:
json.dump(summary, f, indent=2)
print(summary)