-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
330 lines (261 loc) · 8.47 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
""" This module contains utility functions of general interest. """
from typing import Any
import subprocess as sp
import multiprocessing
import shutil
from pathlib import Path, PosixPath
import logging
import yaml
import numpy as np
from dunedn.configdn import PACKAGE, get_dunedn_search_path
def check(check_instance: Any, check_list: list[Any]):
"""
Checks that check_list contains check_instance object. If not, raises
NotImplementedError.
Parameters
----------
check_instance: Any
Object to check.
check_list: list[Any]
Available options.
Raises
------
NotImplementedError
If ``check_instance`` is not in ``check_list``.
"""
if not check_instance in check_list:
raise NotImplementedError("Operation not implemented")
def smooth(smoothed: list[float], scalars: list[float], weight: float) -> list[float]:
"""Computes the next element of the moving average.
In-place appending of the next element of the moving average to ``smoothed``.
Parameters
----------
smoothed: list[float]
The list of smoothed scalar quantities.
scalars: list[float]
The list of scalar quantities to be smoothed.
weight: float
The weighting factor in the (0,1) range.
Returns
-------
smoothed: list[float]
The extended list of computed smoothed scalar quantities.
Raises
------
AssertionError
If ``scalars`` does not have one element more that ``smoothed``.
"""
assert len(scalars) - len(smoothed) == 1
if len(scalars) == 1:
smoothed.append(scalars[0])
else:
smoothed.append(weight * smoothed[-1] + (1 - weight) * scalars[-1])
return smoothed
def moving_average(scalars: list[float], weight: float) -> list[float]:
"""Computes the moving avarage from a list of scalar quantities.
Parameters
----------
scalars: list[float]
List of scalar quantities to be smoothed.
weight: float
The weighting factor in the (0,1) range. Higher values provide more
smoothing power.
Returns
-------
smoothed: list[float]
The list of smoothed scalar quantities.
"""
smoothed = []
for i in range(len(scalars)):
smooth(smoothed, scalars[: i + 1], weight)
return smoothed
def median_subtraction(planes: np.ndarray) -> np.ndarray:
"""Computes median subtraction to input planes.
Parameters
----------
planes: np.ndarray
The data to be normalized, of shape=(N,C,H,W).
Returns
-------
output: np.ndarray
The median subtracted data, of shape=(N,C,H,W).
"""
medians = np.median(planes, axis=[1, 2, 3], keepdims=True)
output = planes - medians
return output
def confusion_matrix(hit, no_hit, t=0.5):
"""
Return confusion matrix elements from arrays of scores and threshold value.
Parameters:
hit: np.array, scores of real hits
no_hit: np.array, scores of real no-hits
t: float, threshold
Returns:
tp, fp, fn, tn
"""
tp = np.count_nonzero(hit > t)
fn = np.size(hit) - tp
tn = np.count_nonzero(no_hit < t)
fp = np.size(no_hit) - tn
return tp, fp, fn, tn
def add_info_columns(evt: np.ndarray) -> np.ndarray:
"""Adds event identifier and channel number columns to event.
Events come with additional information placed in the two first comlumns of
the 2D array. These must be removed to make the computation as they are not
informative.
When saving back the event, the information must be added again.
Parameters
----------
evt: np.ndarray
The event w/o additional information, of shape=(nb channels, nb tdc ticks).
Returns
-------
The event w additional information, of shape=(nb channels, 2 + nb tdc ticks).
"""
nb_channels, _ = evt.shape
channels_col = np.arange(nb_channels).reshape([-1, 1])
event_col = np.zeros_like(channels_col)
evt_with_info = np.concatenate([event_col, channels_col, evt], axis=1)
return evt_with_info
# instantiate logger
logger = logging.getLogger(PACKAGE + ".train")
def path_constructor(loader, node):
"""PyYaml utility function."""
value = loader.construct_scalar(node)
return Path(value)
def load_runcard(runcard_file: Path) -> dict:
"""Load runcard from yaml file.
Parameters
----------
runcard_file: Path
The yaml to dump the dictionary.
Returns
-------
runcard: dict
The loaded settings dictionary.
Note
----
The pathlib.Path objects are automatically loaded if they are encoded
with the following syntax:
```
path: !Path 'path/to/file'
```
"""
if not isinstance(runcard_file, Path):
runcard_file = Path(runcard_file)
yaml.add_constructor("!Path", path_constructor)
with open(runcard_file, "r") as stream:
runcard = yaml.load(stream, Loader=yaml.FullLoader)
return runcard
def path_representer(dumper, data):
"""PyYaml utility function."""
return dumper.represent_scalar("!Path", "%s" % data)
def save_runcard(fname: Path, setup: dict):
"""Save runcard to yaml file.
Parameters
----------
fname: Path
The yaml output file.
setup: Path
The settings dictionary to be dumped.
Note
----
pathlib.PosixPath objects are automatically loaded.
"""
yaml.add_representer(PosixPath, path_representer)
with open(fname, "w") as f:
yaml.dump(setup, f, indent=4)
def check_in_folder(folder: Path, should_force: bool):
"""Creates the query folder.
The ``should_force`` parameters controls the function behavior in case
``folder`` exists. If true, it overwrites the existent directory, otherwise
exits.
Parameters
----------
folder: Path
The directory to be checked.
should_force: bool
Wether to replace the already existing directory.
Raises
------
FileExistsError
If output folder exists and ``should_force`` is False.
"""
try:
folder.mkdir()
except FileExistsError as error:
if should_force:
logger.warning(f"Overwriting output directory at {folder}")
shutil.rmtree(folder)
folder.mkdir()
else:
logger.error('Delete or run with "--force" to overwrite.')
raise error
else:
logger.info(f"Creating output directory at {folder}")
def initialize_output_folder(output: Path, should_force: bool):
"""Creates the output directory structure.
Parameters
----------
output: Path
The output directory.
should_force: bool
Wether to replace the already existing output directory.
"""
check_in_folder(output, should_force)
output.joinpath("cards").mkdir()
output.joinpath("models").mkdir()
def get_configcard_path(fname):
"""Retrieves the configcard path.
.. deprecated:: 2.0.0
this function is not used anymore.
If the supplied path is not a valid file, looks recursively into directories
from DUNEDN_SEARCH_PATH environment variable to find the first match.
Parameters
----------
fname: Path
Path to configcard yaml file.
Returns
-------
Path, the retrieved configcard path
Raises
------
FileNotFoundError, if fname is not found.
"""
if fname.is_file():
return fname
# get list of directories from DUNEDN_SEARCH_PATH env variable
search_path = get_dunedn_search_path()
# recursively look in search directories
for base in search_path:
candidate = base / fname.name
if candidate.is_file():
return candidate
raise FileNotFoundError(
f"Configcard {fname} not found. Please, update DUNEDN_SEARCH_PATH variable."
)
def get_cpu_info() -> dict:
"""Parses ``lscpu`` command to dictionary.
Returns
-------
cpu_info: dict
The parsed command output.
"""
output = sp.check_output("lscpu", shell=True).decode("utf-8")
cpu_info = {}
for line in output.split("\n"):
line = line.strip()
if line:
splits = line.split(":")
key = splits[0]
value = ":".join(splits[1:])
cpu_info[key.strip().lower()] = value.strip()
return cpu_info
def get_nb_cpu_cores() -> int:
"""Returns the number of available cpus for the current process.
Returns
-------
nb_cpus: int
The number of available cpus for the current process.
"""
return multiprocessing.cpu_count()