-
Notifications
You must be signed in to change notification settings - Fork 0
/
ml.html
1710 lines (1537 loc) · 116 KB
/
ml.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
<!DOCTYPE html>
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Machine Learning — nussl 1.0.0 documentation</title>
<script type="text/javascript" src="_static/js/modernizr.min.js"></script>
<script type="text/javascript" id="documentation_options" data-url_root="./" src="_static/documentation_options.js"></script>
<script src="_static/jquery.js"></script>
<script src="_static/underscore.js"></script>
<script src="_static/doctools.js"></script>
<script src="_static/language_data.js"></script>
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
<script type="text/javascript" src="_static/js/theme.js"></script>
<link rel="stylesheet" href="_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="_static/theme_overrides.css" type="text/css" />
<link rel="index" title="Index" href="genindex.html" />
<link rel="search" title="Search" href="search.html" />
<link rel="next" title="Separation algorithms" href="separation.html" />
<link rel="prev" title="Evaluation" href="evaluation.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="index.html" class="icon icon-home"> nussl
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="search.html" method="get">
<input type="text" name="q" placeholder="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<ul>
<li class="toctree-l1"><a class="reference internal" href="getting_started.html">Getting Started</a></li>
</ul>
<ul>
<li class="toctree-l1"><a class="reference internal" href="tutorials.html">Tutorials</a></li>
</ul>
<ul>
<li class="toctree-l1"><a class="reference internal" href="examples/examples.html">Examples</a></li>
</ul>
<ul>
<li class="toctree-l1"><a class="reference internal" href="recipes/recipes.html">Recipes</a></li>
</ul>
<ul class="current">
<li class="toctree-l1 current"><a class="reference internal" href="api.html">API Documentation</a><ul class="current">
<li class="toctree-l2"><a class="reference internal" href="core.html">Core</a></li>
<li class="toctree-l2"><a class="reference internal" href="datasets.html">Datasets</a></li>
<li class="toctree-l2"><a class="reference internal" href="evaluation.html">Evaluation</a></li>
<li class="toctree-l2 current"><a class="current reference internal" href="#">Machine Learning</a><ul>
<li class="toctree-l3"><a class="reference internal" href="#separationmodel">SeparationModel</a></li>
<li class="toctree-l3"><a class="reference internal" href="#module-nussl.ml.modules">Building blocks for SeparationModel</a></li>
<li class="toctree-l3"><a class="reference internal" href="#module-nussl.ml.networks.builders">Helpers for common deep networks</a></li>
<li class="toctree-l3"><a class="reference internal" href="#module-nussl.ml.confidence">Confidence measures</a></li>
<li class="toctree-l3"><a class="reference internal" href="#module-nussl.ml.train">Training</a><ul>
<li class="toctree-l4"><a class="reference internal" href="#id1">Training</a></li>
<li class="toctree-l4"><a class="reference internal" href="#module-nussl.ml.train.loss">Loss functions</a></li>
<li class="toctree-l4"><a class="reference internal" href="#module-nussl.ml.train.closures">Closures</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="separation.html">Separation algorithms</a></li>
</ul>
</li>
</ul>
<ul>
<li class="toctree-l1"><a class="reference internal" href="citation.html">Citing nussl</a></li>
</ul>
<ul>
<li class="toctree-l1"><a class="reference internal" href="contributing.html">Contribution Guide</a></li>
</ul>
<ul>
<li class="toctree-l1"><a class="reference internal" href="changelog.html">Changelog</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" aria-label="top navigation">
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="index.html">nussl</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="index.html">Docs</a> »</li>
<li><a href="api.html">API Documentation</a> »</li>
<li>Machine Learning</li>
<li class="wy-breadcrumbs-aside">
<a href="_sources/ml.rst.txt" rel="nofollow"> View page source</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<style>
/* CSS overrides for sphinx_rtd_theme */
/* 24px margin */
.nbinput.nblast.container,
.nboutput.nblast.container {
margin-bottom: 19px; /* padding has already 5px */
}
/* ... except between code cells! */
.nblast.container + .nbinput.container {
margin-top: -19px;
}
.admonition > p:before {
margin-right: 4px; /* make room for the exclamation icon */
}
/* Fix math alignment, see https://github.com/rtfd/sphinx_rtd_theme/pull/686 */
.math {
text-align: unset;
}
</style>
<span class="target" id="module-nussl.ml"></span><div class="section" id="machine-learning">
<h1>Machine Learning<a class="headerlink" href="#machine-learning" title="Permalink to this headline">¶</a></h1>
<div class="section" id="separationmodel">
<h2>SeparationModel<a class="headerlink" href="#separationmodel" title="Permalink to this headline">¶</a></h2>
<dl class="class">
<dt id="nussl.ml.SeparationModel">
<em class="property">class </em><code class="sig-prename descclassname">nussl.ml.</code><code class="sig-name descname">SeparationModel</code><span class="sig-paren">(</span><em class="sig-param">config</em>, <em class="sig-param">verbose=False</em><span class="sig-paren">)</span><a class="reference internal" href="_modules/nussl/ml/networks/separation_model.html#SeparationModel"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#nussl.ml.SeparationModel" title="Permalink to this definition">¶</a></dt>
<dd><p>SeparationModel takes a configuration file or dictionary that describes the model
structure, which is some combination of MelProjection, Embedding, RecurrentStack,
ConvolutionalStack, and other modules found in <code class="docutils literal notranslate"><span class="pre">nussl.ml.networks.modules</span></code>.</p>
<p class="rubric">References</p>
<p><strong>Methods</strong></p>
<table class="longtable docutils align-default">
<colgroup>
<col style="width: 10%" />
<col style="width: 90%" />
</colgroup>
<tbody>
<tr class="row-odd"><td><p><a class="reference internal" href="#nussl.ml.SeparationModel.forward" title="nussl.ml.SeparationModel.forward"><code class="xref py py-obj docutils literal notranslate"><span class="pre">forward</span></code></a>(data)</p></td>
<td><p><dl class="field-list simple">
<dt class="field-odd">param data</dt>
<dd class="field-odd"><p>(dict) a dictionary containing the input data for the model.</p>
</dd>
</dl>
</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="#nussl.ml.SeparationModel.save" title="nussl.ml.SeparationModel.save"><code class="xref py py-obj docutils literal notranslate"><span class="pre">save</span></code></a>(location[, metadata])</p></td>
<td><p>Saves a SeparationModel into a location into a dictionary with the weights and model configuration.</p></td>
</tr>
</tbody>
</table>
<p>Hershey, J. R., Chen, Z., Le Roux, J., & Watanabe, S. (2016, March).
Deep clustering: Discriminative embeddings for segmentation and separation.
In Acoustics, Speech and Signal Processing (ICASSP),
2016 IEEE International Conference on (pp. 31-35). IEEE.</p>
<p>Luo, Y., Chen, Z., Hershey, J. R., Le Roux, J., & Mesgarani, N. (2017, March).
Deep clustering and conventional networks for music separation: Stronger together.
In Acoustics, Speech and Signal Processing (ICASSP),
2017 IEEE International Conference on (pp. 61-65). IEEE.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><p><strong>config</strong> – (str, dict) Either a config dictionary that defines the model and its
connections, or the path to a json file containing the dictionary. If the
latter, the path will be loaded and used.</p>
</dd>
</dl>
<div class="admonition seealso">
<p class="admonition-title">See also</p>
<p>ml.register_module to register your custom modules with SeparationModel.</p>
</div>
<p class="rubric">Examples</p>
<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">config</span> <span class="o">=</span> <span class="n">nussl</span><span class="o">.</span><span class="n">ml</span><span class="o">.</span><span class="n">networks</span><span class="o">.</span><span class="n">builders</span><span class="o">.</span><span class="n">build_recurrent_dpcl</span><span class="p">(</span>
<span class="gp">>>> </span> <span class="n">num_features</span><span class="o">=</span><span class="mi">512</span><span class="p">,</span> <span class="n">hidden_size</span><span class="o">=</span><span class="mi">300</span><span class="p">,</span> <span class="n">num_layers</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">bidirectional</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="gp">>>> </span> <span class="n">dropout</span><span class="o">=</span><span class="mf">0.3</span><span class="p">,</span> <span class="n">embedding_size</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span>
<span class="gp">>>> </span> <span class="n">embedding_activation</span><span class="o">=</span><span class="p">[</span><span class="s1">'sigmoid'</span><span class="p">,</span> <span class="s1">'unit_norm'</span><span class="p">])</span>
<span class="go">>>></span>
<span class="gp">>>> </span><span class="n">model</span> <span class="o">=</span> <span class="n">SeparationModel</span><span class="p">(</span><span class="n">config</span><span class="p">)</span>
</pre></div>
</div>
<dl class="method">
<dt id="nussl.ml.SeparationModel.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param">data</em><span class="sig-paren">)</span><a class="reference internal" href="_modules/nussl/ml/networks/separation_model.html#SeparationModel.forward"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#nussl.ml.SeparationModel.forward" title="Permalink to this definition">¶</a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>data</strong> – (dict) a dictionary containing the input data for the model.</p></li>
<li><p><strong>match the input_keys in self.input.</strong> (<em>Should</em>) – </p></li>
</ul>
</dd>
</dl>
<p>Returns:</p>
</dd></dl>
<dl class="method">
<dt id="nussl.ml.SeparationModel.save">
<code class="sig-name descname">save</code><span class="sig-paren">(</span><em class="sig-param">location</em>, <em class="sig-param">metadata=None</em><span class="sig-paren">)</span><a class="reference internal" href="_modules/nussl/ml/networks/separation_model.html#SeparationModel.save"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#nussl.ml.SeparationModel.save" title="Permalink to this definition">¶</a></dt>
<dd><p>Saves a SeparationModel into a location into a dictionary with the
weights and model configuration.
:param location: (str) Where you want the model saved, as a path.</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
<dd class="field-odd"><p>where the model was saved.</p>
</dd>
<dt class="field-even">Return type</dt>
<dd class="field-even"><p>(str)</p>
</dd>
</dl>
</dd></dl>
</dd></dl>
</div>
<div class="section" id="module-nussl.ml.modules">
<span id="building-blocks-for-separationmodel"></span><h2>Building blocks for SeparationModel<a class="headerlink" href="#module-nussl.ml.modules" title="Permalink to this headline">¶</a></h2>
<span class="target" id="module-nussl.ml.cluster"></span></div>
<div class="section" id="module-nussl.ml.networks.builders">
<span id="helpers-for-common-deep-networks"></span><h2>Helpers for common deep networks<a class="headerlink" href="#module-nussl.ml.networks.builders" title="Permalink to this headline">¶</a></h2>
<p>Functions that make it easy to build commonly used source separation architectures.
Currently contains mask inference, deep clustering, and chimera networks that are
based on recurrent neural networks. These functions are a good place to start when
creating your own network toplogies. Since there can be dependencies between layers
depending on input size, it’s good to work this out in a function like those below.</p>
<p><strong>Functions</strong></p>
<table class="longtable docutils align-default">
<colgroup>
<col style="width: 10%" />
<col style="width: 90%" />
</colgroup>
<tbody>
<tr class="row-odd"><td><p><a class="reference internal" href="#nussl.ml.networks.builders.build_dual_path_recurrent_end_to_end" title="nussl.ml.networks.builders.build_dual_path_recurrent_end_to_end"><code class="xref py py-obj docutils literal notranslate"><span class="pre">build_dual_path_recurrent_end_to_end</span></code></a>(…[, …])</p></td>
<td><p>Builds a config for a dual path recurrent network that operates on the time-series.</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="#nussl.ml.networks.builders.build_open_unmix_like" title="nussl.ml.networks.builders.build_open_unmix_like"><code class="xref py py-obj docutils literal notranslate"><span class="pre">build_open_unmix_like</span></code></a>(num_features, …[, …])</p></td>
<td><p>This is a builder for an open-unmix LIKE (UMX) architecture for music source separation.</p></td>
</tr>
<tr class="row-odd"><td><p><a class="reference internal" href="#nussl.ml.networks.builders.build_recurrent_chimera" title="nussl.ml.networks.builders.build_recurrent_chimera"><code class="xref py py-obj docutils literal notranslate"><span class="pre">build_recurrent_chimera</span></code></a>(num_features, …[, …])</p></td>
<td><p>Builds a config for a Chimera network that can be passed to SeparationModel.</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="#nussl.ml.networks.builders.build_recurrent_dpcl" title="nussl.ml.networks.builders.build_recurrent_dpcl"><code class="xref py py-obj docutils literal notranslate"><span class="pre">build_recurrent_dpcl</span></code></a>(num_features, …[, …])</p></td>
<td><p>Builds a config for a deep clustering network that can be passed to SeparationModel.</p></td>
</tr>
<tr class="row-odd"><td><p><a class="reference internal" href="#nussl.ml.networks.builders.build_recurrent_end_to_end" title="nussl.ml.networks.builders.build_recurrent_end_to_end"><code class="xref py py-obj docutils literal notranslate"><span class="pre">build_recurrent_end_to_end</span></code></a>(num_filters, …)</p></td>
<td><p>Builds a config for a BLSTM-based network that operates on the time-series.</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="#nussl.ml.networks.builders.build_recurrent_mask_inference" title="nussl.ml.networks.builders.build_recurrent_mask_inference"><code class="xref py py-obj docutils literal notranslate"><span class="pre">build_recurrent_mask_inference</span></code></a>(num_features, …)</p></td>
<td><p>Builds a config for a mask inference network that can be passed to SeparationModel.</p></td>
</tr>
</tbody>
</table>
<dl class="function">
<dt id="nussl.ml.networks.builders.build_dual_path_recurrent_end_to_end">
<code class="sig-prename descclassname">nussl.ml.networks.builders.</code><code class="sig-name descname">build_dual_path_recurrent_end_to_end</code><span class="sig-paren">(</span><em class="sig-param">num_filters</em>, <em class="sig-param">filter_length</em>, <em class="sig-param">hop_length</em>, <em class="sig-param">chunk_size</em>, <em class="sig-param">hop_size</em>, <em class="sig-param">hidden_size</em>, <em class="sig-param">num_layers</em>, <em class="sig-param">bidirectional</em>, <em class="sig-param">bottleneck_size</em>, <em class="sig-param">num_sources</em>, <em class="sig-param">mask_activation</em>, <em class="sig-param">num_audio_channels=1</em>, <em class="sig-param">window_type='sqrt_hann'</em>, <em class="sig-param">skip_connection=False</em>, <em class="sig-param">rnn_type='lstm'</em>, <em class="sig-param">mix_key='mix_audio'</em><span class="sig-paren">)</span><a class="reference internal" href="_modules/nussl/ml/networks/builders.html#build_dual_path_recurrent_end_to_end"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#nussl.ml.networks.builders.build_dual_path_recurrent_end_to_end" title="Permalink to this definition">¶</a></dt>
<dd><p>Builds a config for a dual path recurrent network that operates on the
time-series. Uses a learned filterbank within the network.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>num_filters</strong> (<em>int</em>) – Number of learnable filters in the front end network.</p></li>
<li><p><strong>filter_length</strong> (<em>int</em>) – Length of the filters.</p></li>
<li><p><strong>hop_length</strong> (<em>int</em>) – Hop length between frames.</p></li>
<li><p><strong>window_type</strong> (<em>str</em>) – Type of windowing function on apply to each frame.</p></li>
<li><p><strong>hidden_size</strong> (<em>int</em>) – Hidden size of the RNN.</p></li>
<li><p><strong>num_layers</strong> (<em>int</em>) – Number of layers in the RNN.</p></li>
<li><p><strong>bidirectional</strong> (<em>int</em>) – Whether the RNN is bidirectional.</p></li>
<li><p><strong>dropout</strong> (<em>float</em>) – Amount of dropout to be used between layers of RNN.</p></li>
<li><p><strong>num_sources</strong> (<em>int</em>) – Number of sources to create masks for.</p></li>
<li><p><strong>mask_activation</strong> (<em>list of str</em>) – Activation of the mask (‘sigmoid’, ‘softmax’, etc.).
See <code class="docutils literal notranslate"><span class="pre">nussl.ml.networks.modules.Embedding</span></code>.</p></li>
<li><p><strong>num_audio_channels</strong> (<em>int</em>) – Number of audio channels in input (e.g. mono or stereo).
Defaults to 1.</p></li>
<li><p><strong>rnn_type</strong> (<em>str</em><em>, </em><em>optional</em>) – RNN type, either ‘lstm’ or ‘gru’. Defaults to ‘lstm’.</p></li>
<li><p><strong>normalization_class</strong> (<em>str</em><em>, </em><em>optional</em>) – Type of normalization to apply, either
‘InstanceNorm’ or ‘BatchNorm’. Defaults to ‘BatchNorm’.</p></li>
<li><p><strong>mix_key</strong> (<em>str</em><em>, </em><em>optional</em>) – The key to look for in the input dictionary that contains
the mixture spectrogram. Defaults to ‘mix_magnitude’.</p></li>
</ul>
</dd>
<dt class="field-even">Returns</dt>
<dd class="field-even"><p><dl class="simple">
<dt>A TASNet configuration that can be passed to</dt><dd><p>SeparationModel.</p>
</dd>
</dl>
</p>
</dd>
<dt class="field-odd">Return type</dt>
<dd class="field-odd"><p>dict</p>
</dd>
</dl>
</dd></dl>
<dl class="function">
<dt id="nussl.ml.networks.builders.build_open_unmix_like">
<code class="sig-prename descclassname">nussl.ml.networks.builders.</code><code class="sig-name descname">build_open_unmix_like</code><span class="sig-paren">(</span><em class="sig-param">num_features</em>, <em class="sig-param">hidden_size</em>, <em class="sig-param">num_layers</em>, <em class="sig-param">bidirectional</em>, <em class="sig-param">dropout</em>, <em class="sig-param">num_sources</em>, <em class="sig-param">num_audio_channels=1</em>, <em class="sig-param">add_embedding=False</em>, <em class="sig-param">embedding_size=20</em>, <em class="sig-param">embedding_activation='sigmoid'</em>, <em class="sig-param">rnn_type='lstm'</em>, <em class="sig-param">mix_key='mix_magnitude'</em><span class="sig-paren">)</span><a class="reference internal" href="_modules/nussl/ml/networks/builders.html#build_open_unmix_like"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#nussl.ml.networks.builders.build_open_unmix_like" title="Permalink to this definition">¶</a></dt>
<dd><p>This is a builder for an open-unmix LIKE (UMX) architecture for music source
separation.</p>
<p>The architecture is not exactly the same but is very similar for the
most part. This architecture also has the option of having an embedding space
attached to it, making it a UMX + Chimera network that you can regularize with
a deep clustering loss.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>num_features</strong> (<em>int</em>) – Number of features in the input spectrogram (usually means
window length of STFT // 2 + 1.)</p></li>
<li><p><strong>hidden_size</strong> (<em>int</em>) – Hidden size of the RNN. Will be hidden_size // 2 if bidirectional is True.</p></li>
<li><p><strong>num_layers</strong> (<em>int</em>) – Number of layers in the RNN.</p></li>
<li><p><strong>bidirectional</strong> (<em>int</em>) – Whether the RNN is bidirectional.</p></li>
<li><p><strong>dropout</strong> (<em>float</em>) – Amount of dropout to be used between layers of RNN.</p></li>
<li><p><strong>num_sources</strong> (<em>int</em>) – Number of sources to create masks for.</p></li>
<li><p><strong>num_audio_channels</strong> (<em>int</em>) – Number of audio channels in input (e.g. mono or stereo).
Defaults to 1.</p></li>
<li><p><strong>add_embedding</strong> (<em>bool</em>) – Whether or not to add an embedding layer to this to make this a
Chimera network. If True, then <code class="docutils literal notranslate"><span class="pre">embedding_size</span></code> and <code class="docutils literal notranslate"><span class="pre">embedding_activation</span></code> will
be used to define this.</p></li>
<li><p><strong>embedding_size</strong> (<em>int</em>) – Embedding dimensionality of the deep clustering network.</p></li>
<li><p><strong>embedding_activation</strong> (<em>list of str</em>) – Activation of the embedding (‘sigmoid’, ‘softmax’, etc.).
See <code class="docutils literal notranslate"><span class="pre">nussl.ml.networks.modules.Embedding</span></code>.</p></li>
<li><p><strong>rnn_type</strong> (<em>str</em><em>, </em><em>optional</em>) – RNN type, either ‘lstm’ or ‘gru’. Defaults to ‘lstm’.</p></li>
<li><p><strong>mix_key</strong> (<em>str</em><em>, </em><em>optional</em>) – The key to look for in the input dictionary that contains
the mixture spectrogram. Defaults to ‘mix_magnitude’.</p></li>
</ul>
</dd>
<dt class="field-even">Returns</dt>
<dd class="field-even"><p><dl class="simple">
<dt>An OpenUnmix-like configuration that can be passed to</dt><dd><p>SeparationModel.</p>
</dd>
</dl>
</p>
</dd>
<dt class="field-odd">Return type</dt>
<dd class="field-odd"><p>dict</p>
</dd>
</dl>
</dd></dl>
<dl class="function">
<dt id="nussl.ml.networks.builders.build_recurrent_chimera">
<code class="sig-prename descclassname">nussl.ml.networks.builders.</code><code class="sig-name descname">build_recurrent_chimera</code><span class="sig-paren">(</span><em class="sig-param">num_features</em>, <em class="sig-param">hidden_size</em>, <em class="sig-param">num_layers</em>, <em class="sig-param">bidirectional</em>, <em class="sig-param">dropout</em>, <em class="sig-param">embedding_size</em>, <em class="sig-param">embedding_activation</em>, <em class="sig-param">num_sources</em>, <em class="sig-param">mask_activation</em>, <em class="sig-param">num_audio_channels=1</em>, <em class="sig-param">rnn_type='lstm'</em>, <em class="sig-param">normalization_class='BatchNorm'</em>, <em class="sig-param">mix_key='mix_magnitude'</em><span class="sig-paren">)</span><a class="reference internal" href="_modules/nussl/ml/networks/builders.html#build_recurrent_chimera"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#nussl.ml.networks.builders.build_recurrent_chimera" title="Permalink to this definition">¶</a></dt>
<dd><p>Builds a config for a Chimera network that can be passed to SeparationModel.
Chimera networks are so-called because they have two “heads” which can be trained
via different loss functions. In traditional Chimera, one head is trained using a
deep clustering loss while the other is trained with a mask inference loss.
This Chimera network uses a recurrent neural network (RNN) to process the input
representation.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>num_features</strong> (<em>int</em>) – Number of features in the input spectrogram (usually means
window length of STFT // 2 + 1.)</p></li>
<li><p><strong>hidden_size</strong> (<em>int</em>) – Hidden size of the RNN.</p></li>
<li><p><strong>num_layers</strong> (<em>int</em>) – Number of layers in the RNN.</p></li>
<li><p><strong>bidirectional</strong> (<em>int</em>) – Whether the RNN is bidirectional.</p></li>
<li><p><strong>dropout</strong> (<em>float</em>) – Amount of dropout to be used between layers of RNN.</p></li>
<li><p><strong>embedding_size</strong> (<em>int</em>) – Embedding dimensionality of the deep clustering network.</p></li>
<li><p><strong>embedding_activation</strong> (<em>list of str</em>) – Activation of the embedding (‘sigmoid’, ‘softmax’, etc.).
See <code class="docutils literal notranslate"><span class="pre">nussl.ml.networks.modules.Embedding</span></code>.</p></li>
<li><p><strong>num_sources</strong> (<em>int</em>) – Number of sources to create masks for.</p></li>
<li><p><strong>mask_activation</strong> (<em>list of str</em>) – Activation of the mask (‘sigmoid’, ‘softmax’, etc.).
See <code class="docutils literal notranslate"><span class="pre">nussl.ml.networks.modules.Embedding</span></code>.</p></li>
<li><p><strong>num_audio_channels</strong> (<em>int</em>) – Number of audio channels in input (e.g. mono or stereo).
Defaults to 1.</p></li>
<li><p><strong>rnn_type</strong> (<em>str</em><em>, </em><em>optional</em>) – RNN type, either ‘lstm’ or ‘gru’. Defaults to ‘lstm’.
normalization_class (str, optional): Type of normalization to apply, either</p></li>
<li><p><strong>or 'BatchNorm'. Defaults to 'BatchNorm'.</strong> (<em>'InstanceNorm'</em>) – </p></li>
<li><p><strong>mix_key</strong> (<em>str</em><em>, </em><em>optional</em>) – The key to look for in the input dictionary that contains
the mixture spectrogram. Defaults to ‘mix_magnitude’.</p></li>
</ul>
</dd>
<dt class="field-even">Returns</dt>
<dd class="field-even"><p><dl class="simple">
<dt>A recurrent Chimera network configuration that can be passed to</dt><dd><p>SeparationModel.</p>
</dd>
</dl>
</p>
</dd>
<dt class="field-odd">Return type</dt>
<dd class="field-odd"><p>dict</p>
</dd>
</dl>
</dd></dl>
<dl class="function">
<dt id="nussl.ml.networks.builders.build_recurrent_dpcl">
<code class="sig-prename descclassname">nussl.ml.networks.builders.</code><code class="sig-name descname">build_recurrent_dpcl</code><span class="sig-paren">(</span><em class="sig-param">num_features</em>, <em class="sig-param">hidden_size</em>, <em class="sig-param">num_layers</em>, <em class="sig-param">bidirectional</em>, <em class="sig-param">dropout</em>, <em class="sig-param">embedding_size</em>, <em class="sig-param">embedding_activation</em>, <em class="sig-param">num_audio_channels=1</em>, <em class="sig-param">rnn_type='lstm'</em>, <em class="sig-param">normalization_class='BatchNorm'</em>, <em class="sig-param">mix_key='mix_magnitude'</em><span class="sig-paren">)</span><a class="reference internal" href="_modules/nussl/ml/networks/builders.html#build_recurrent_dpcl"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#nussl.ml.networks.builders.build_recurrent_dpcl" title="Permalink to this definition">¶</a></dt>
<dd><p>Builds a config for a deep clustering network that can be passed to
SeparationModel. This deep clustering network uses a recurrent neural network (RNN)
to process the input representation.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>num_features</strong> (<em>int</em>) – Number of features in the input spectrogram (usually means
window length of STFT // 2 + 1.)</p></li>
<li><p><strong>hidden_size</strong> (<em>int</em>) – Hidden size of the RNN.</p></li>
<li><p><strong>num_layers</strong> (<em>int</em>) – Number of layers in the RNN.</p></li>
<li><p><strong>bidirectional</strong> (<em>int</em>) – Whether the RNN is bidirectional.</p></li>
<li><p><strong>dropout</strong> (<em>float</em>) – Amount of dropout to be used between layers of RNN.</p></li>
<li><p><strong>embedding_size</strong> (<em>int</em>) – Embedding dimensionality of the deep clustering network.</p></li>
<li><p><strong>embedding_activation</strong> (<em>list of str</em>) – Activation of the embedding (‘sigmoid’, ‘softmax’, etc.).
See <code class="docutils literal notranslate"><span class="pre">nussl.ml.networks.modules.Embedding</span></code>.</p></li>
<li><p><strong>num_audio_channels</strong> (<em>int</em>) – Number of audio channels in input (e.g. mono or stereo).
Defaults to 1.</p></li>
<li><p><strong>rnn_type</strong> (<em>str</em><em>, </em><em>optional</em>) – RNN type, either ‘lstm’ or ‘gru’. Defaults to ‘lstm’.</p></li>
<li><p><strong>normalization_class</strong> (<em>str</em><em>, </em><em>optional</em>) – Type of normalization to apply, either
‘InstanceNorm’ or ‘BatchNorm’. Defaults to ‘BatchNorm’.</p></li>
<li><p><strong>mix_key</strong> (<em>str</em><em>, </em><em>optional</em>) – The key to look for in the input dictionary that contains
the mixture spectrogram. Defaults to ‘mix_magnitude’.</p></li>
</ul>
</dd>
<dt class="field-even">Returns</dt>
<dd class="field-even"><p><dl class="simple">
<dt>A recurrent deep clustering network configuration that can be passed to</dt><dd><p>SeparationModel.</p>
</dd>
</dl>
</p>
</dd>
<dt class="field-odd">Return type</dt>
<dd class="field-odd"><p>dict</p>
</dd>
</dl>
</dd></dl>
<dl class="function">
<dt id="nussl.ml.networks.builders.build_recurrent_end_to_end">
<code class="sig-prename descclassname">nussl.ml.networks.builders.</code><code class="sig-name descname">build_recurrent_end_to_end</code><span class="sig-paren">(</span><em class="sig-param">num_filters</em>, <em class="sig-param">filter_length</em>, <em class="sig-param">hop_length</em>, <em class="sig-param">window_type</em>, <em class="sig-param">hidden_size</em>, <em class="sig-param">num_layers</em>, <em class="sig-param">bidirectional</em>, <em class="sig-param">dropout</em>, <em class="sig-param">num_sources</em>, <em class="sig-param">mask_activation</em>, <em class="sig-param">num_audio_channels=1</em>, <em class="sig-param">mask_complex=False</em>, <em class="sig-param">trainable=False</em>, <em class="sig-param">rnn_type='lstm'</em>, <em class="sig-param">mix_key='mix_audio'</em>, <em class="sig-param">normalization_class='BatchNorm'</em><span class="sig-paren">)</span><a class="reference internal" href="_modules/nussl/ml/networks/builders.html#build_recurrent_end_to_end"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#nussl.ml.networks.builders.build_recurrent_end_to_end" title="Permalink to this definition">¶</a></dt>
<dd><p>Builds a config for a BLSTM-based network that operates on the time-series.
Uses an STFT within the network and can apply the mixture phase to
the estimate, or can learn a mask on the phase as well as the magnitude.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>num_filters</strong> (<em>int</em>) – Number of learnable filters in the front end network.</p></li>
<li><p><strong>filter_length</strong> (<em>int</em>) – Length of the filters.</p></li>
<li><p><strong>hop_length</strong> (<em>int</em>) – Hop length between frames.</p></li>
<li><p><strong>window_type</strong> (<em>str</em>) – Type of windowing function on apply to each frame.</p></li>
<li><p><strong>hidden_size</strong> (<em>int</em>) – Hidden size of the RNN.</p></li>
<li><p><strong>num_layers</strong> (<em>int</em>) – Number of layers in the RNN.</p></li>
<li><p><strong>bidirectional</strong> (<em>int</em>) – Whether the RNN is bidirectional.</p></li>
<li><p><strong>dropout</strong> (<em>float</em>) – Amount of dropout to be used between layers of RNN.</p></li>
<li><p><strong>num_sources</strong> (<em>int</em>) – Number of sources to create masks for.</p></li>
<li><p><strong>mask_activation</strong> (<em>list of str</em>) – Activation of the mask (‘sigmoid’, ‘softmax’, etc.).
See <code class="docutils literal notranslate"><span class="pre">nussl.ml.networks.modules.Embedding</span></code>.</p></li>
<li><p><strong>num_audio_channels</strong> (<em>int</em>) – Number of audio channels in input (e.g. mono or stereo).
Defaults to 1.</p></li>
<li><p><strong>mask_complex</strong> (<em>bool</em><em>, </em><em>optional</em>) – Whether to also place a mask on the complex part, or
whether to just use the mixture phase.</p></li>
<li><p><strong>trainable</strong> (<em>bool</em><em>, </em><em>optional</em>) – Whether to learn the filters, which start from a
Fourier basis.</p></li>
<li><p><strong>rnn_type</strong> (<em>str</em><em>, </em><em>optional</em>) – RNN type, either ‘lstm’ or ‘gru’. Defaults to ‘lstm’.</p></li>
<li><p><strong>normalization_class</strong> (<em>str</em><em>, </em><em>optional</em>) – Type of normalization to apply, either
‘InstanceNorm’ or ‘BatchNorm’. Defaults to ‘BatchNorm’.</p></li>
<li><p><strong>mix_key</strong> (<em>str</em><em>, </em><em>optional</em>) – The key to look for in the input dictionary that contains
the mixture spectrogram. Defaults to ‘mix_magnitude’.</p></li>
</ul>
</dd>
<dt class="field-even">Returns</dt>
<dd class="field-even"><p><dl class="simple">
<dt>A recurrent end-to-end network configuration that can be passed to</dt><dd><p>SeparationModel.</p>
</dd>
</dl>
</p>
</dd>
<dt class="field-odd">Return type</dt>
<dd class="field-odd"><p>dict</p>
</dd>
</dl>
</dd></dl>
<dl class="function">
<dt id="nussl.ml.networks.builders.build_recurrent_mask_inference">
<code class="sig-prename descclassname">nussl.ml.networks.builders.</code><code class="sig-name descname">build_recurrent_mask_inference</code><span class="sig-paren">(</span><em class="sig-param">num_features</em>, <em class="sig-param">hidden_size</em>, <em class="sig-param">num_layers</em>, <em class="sig-param">bidirectional</em>, <em class="sig-param">dropout</em>, <em class="sig-param">num_sources</em>, <em class="sig-param">mask_activation</em>, <em class="sig-param">num_audio_channels=1</em>, <em class="sig-param">rnn_type='lstm'</em>, <em class="sig-param">normalization_class='BatchNorm'</em>, <em class="sig-param">mix_key='mix_magnitude'</em><span class="sig-paren">)</span><a class="reference internal" href="_modules/nussl/ml/networks/builders.html#build_recurrent_mask_inference"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#nussl.ml.networks.builders.build_recurrent_mask_inference" title="Permalink to this definition">¶</a></dt>
<dd><p>Builds a config for a mask inference network that can be passed to
SeparationModel. This mask inference network uses a recurrent neural network (RNN)
to process the input representation.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>num_features</strong> (<em>int</em>) – Number of features in the input spectrogram (usually means
window length of STFT // 2 + 1.)</p></li>
<li><p><strong>hidden_size</strong> (<em>int</em>) – Hidden size of the RNN.</p></li>
<li><p><strong>num_layers</strong> (<em>int</em>) – Number of layers in the RNN.</p></li>
<li><p><strong>bidirectional</strong> (<em>int</em>) – Whether the RNN is bidirectional.</p></li>
<li><p><strong>dropout</strong> (<em>float</em>) – Amount of dropout to be used between layers of RNN.</p></li>
<li><p><strong>num_sources</strong> (<em>int</em>) – Number of sources to create masks for.</p></li>
<li><p><strong>mask_activation</strong> (<em>list of str</em>) – Activation of the mask (‘sigmoid’, ‘softmax’, etc.).
See <code class="docutils literal notranslate"><span class="pre">nussl.ml.networks.modules.Embedding</span></code>.</p></li>
<li><p><strong>num_audio_channels</strong> (<em>int</em>) – Number of audio channels in input (e.g. mono or stereo).
Defaults to 1.</p></li>
<li><p><strong>rnn_type</strong> (<em>str</em><em>, </em><em>optional</em>) – RNN type, either ‘lstm’ or ‘gru’. Defaults to ‘lstm’.</p></li>
<li><p><strong>normalization_class</strong> (<em>str</em><em>, </em><em>optional</em>) – Type of normalization to apply, either
‘InstanceNorm’ or ‘BatchNorm’. Defaults to ‘BatchNorm’.</p></li>
<li><p><strong>mix_key</strong> (<em>str</em><em>, </em><em>optional</em>) – The key to look for in the input dictionary that contains
the mixture spectrogram. Defaults to ‘mix_magnitude’.</p></li>
</ul>
</dd>
<dt class="field-even">Returns</dt>
<dd class="field-even"><p><dl class="simple">
<dt>A recurrent mask inference network configuration that can be passed to</dt><dd><p>SeparationModel.</p>
</dd>
</dl>
</p>
</dd>
<dt class="field-odd">Return type</dt>
<dd class="field-odd"><p>dict</p>
</dd>
</dl>
</dd></dl>
</div>
<div class="section" id="module-nussl.ml.confidence">
<span id="confidence-measures"></span><h2>Confidence measures<a class="headerlink" href="#module-nussl.ml.confidence" title="Permalink to this headline">¶</a></h2>
<p>There are ways to measure the quality of a separated source without
requiring ground truth. These functions operate on the output of
clustering-based separation algorithms and work by analyzing
the clusterability of the feature space used to generate the
separated sources.</p>
<p><strong>Functions</strong></p>
<table class="longtable docutils align-default">
<colgroup>
<col style="width: 10%" />
<col style="width: 90%" />
</colgroup>
<tbody>
<tr class="row-odd"><td><p><a class="reference internal" href="#nussl.ml.confidence.dpcl_classic_confidence" title="nussl.ml.confidence.dpcl_classic_confidence"><code class="xref py py-obj docutils literal notranslate"><span class="pre">dpcl_classic_confidence</span></code></a>(audio_signal, …[, …])</p></td>
<td><p>Computes the clusterability in two steps:</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="#nussl.ml.confidence.jensen_shannon_confidence" title="nussl.ml.confidence.jensen_shannon_confidence"><code class="xref py py-obj docutils literal notranslate"><span class="pre">jensen_shannon_confidence</span></code></a>(audio_signal, …)</p></td>
<td><p>Calculates the clusterability of a space by comparing a K-cluster GMM with a 1-cluster GMM on the same features.</p></td>
</tr>
<tr class="row-odd"><td><p><a class="reference internal" href="#nussl.ml.confidence.jensen_shannon_divergence" title="nussl.ml.confidence.jensen_shannon_divergence"><code class="xref py py-obj docutils literal notranslate"><span class="pre">jensen_shannon_divergence</span></code></a>(gmm_p, gmm_q[, …])</p></td>
<td><p>Compute Jensen-Shannon (JS) divergence between two Gaussian Mixture Models via sampling.</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="#nussl.ml.confidence.loudness_confidence" title="nussl.ml.confidence.loudness_confidence"><code class="xref py py-obj docutils literal notranslate"><span class="pre">loudness_confidence</span></code></a>(audio_signal, features, …)</p></td>
<td><p>Computes the clusterability of the feature space by comparing the absolute size of each cluster.</p></td>
</tr>
<tr class="row-odd"><td><p><a class="reference internal" href="#nussl.ml.confidence.posterior_confidence" title="nussl.ml.confidence.posterior_confidence"><code class="xref py py-obj docutils literal notranslate"><span class="pre">posterior_confidence</span></code></a>(audio_signal, features, …)</p></td>
<td><p>Calculates the clusterability of an embedding space by looking at the strength of the assignments of each point to a specific cluster.</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="#nussl.ml.confidence.silhouette_confidence" title="nussl.ml.confidence.silhouette_confidence"><code class="xref py py-obj docutils literal notranslate"><span class="pre">silhouette_confidence</span></code></a>(audio_signal, …[, …])</p></td>
<td><p>Uses the silhouette score to compute the clusterability of the feature space.</p></td>
</tr>
<tr class="row-odd"><td><p><a class="reference internal" href="#nussl.ml.confidence.whitened_kmeans_confidence" title="nussl.ml.confidence.whitened_kmeans_confidence"><code class="xref py py-obj docutils literal notranslate"><span class="pre">whitened_kmeans_confidence</span></code></a>(audio_signal, …)</p></td>
<td><p>Computes the clusterability in two steps:</p></td>
</tr>
</tbody>
</table>
<dl class="function">
<dt id="nussl.ml.confidence.dpcl_classic_confidence">
<code class="sig-prename descclassname">nussl.ml.confidence.</code><code class="sig-name descname">dpcl_classic_confidence</code><span class="sig-paren">(</span><em class="sig-param">audio_signal</em>, <em class="sig-param">features</em>, <em class="sig-param">num_sources</em>, <em class="sig-param">threshold=95</em>, <em class="sig-param">**kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="_modules/nussl/ml/confidence.html#dpcl_classic_confidence"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#nussl.ml.confidence.dpcl_classic_confidence" title="Permalink to this definition">¶</a></dt>
<dd><p>Computes the clusterability in two steps:</p>
<ol class="arabic simple">
<li><p>Cluster the feature space using KMeans into assignments</p></li>
<li><p>Compute the classic deep clustering loss between the features and the assignments.</p></li>
</ol>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>audio_signal</strong> (<a class="reference internal" href="core.html#nussl.core.AudioSignal" title="nussl.core.AudioSignal"><em>AudioSignal</em></a>) – AudioSignal object which will be used to compute
the mask over which to compute the confidence measure. This can be None, if
and only if <code class="docutils literal notranslate"><span class="pre">representation</span></code> is passed as a keyword argument to this
function.</p></li>
<li><p><strong>features</strong> (<em>np.ndarray</em>) – Numpy array containing the features to be clustered.
Should have the same dimensions as the representation.</p></li>
<li><p><strong>n_sources</strong> (<em>int</em>) – Number of sources to cluster the features into.</p></li>
<li><p><strong>threshold</strong> (<em>int</em><em>, </em><em>optional</em>) – Threshold by loudness. Points below the threshold are
excluded from being used in the confidence measure. Defaults to 95.</p></li>
<li><p><strong>kwargs</strong> – Keyword arguments to <cite>_get_loud_bins_mask</cite>. Namely, representation can
go here as a keyword argument.</p></li>
</ul>
</dd>
<dt class="field-even">Returns</dt>
<dd class="field-even"><p>Confidence given by deep clustering loss.</p>
</dd>
<dt class="field-odd">Return type</dt>
<dd class="field-odd"><p>float</p>
</dd>
</dl>
</dd></dl>
<dl class="function">
<dt id="nussl.ml.confidence.jensen_shannon_confidence">
<code class="sig-prename descclassname">nussl.ml.confidence.</code><code class="sig-name descname">jensen_shannon_confidence</code><span class="sig-paren">(</span><em class="sig-param">audio_signal</em>, <em class="sig-param">features</em>, <em class="sig-param">num_sources</em>, <em class="sig-param">threshold=95</em>, <em class="sig-param">n_samples=100000</em>, <em class="sig-param">**kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="_modules/nussl/ml/confidence.html#jensen_shannon_confidence"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#nussl.ml.confidence.jensen_shannon_confidence" title="Permalink to this definition">¶</a></dt>
<dd><p>Calculates the clusterability of a space by comparing a K-cluster GMM
with a 1-cluster GMM on the same features. This function fits two
GMMs to all of the points that are above the specified threshold (defaults
to 95: 95th percentile of all the data). This saves on computation time and
also allows one to have the confidence measure only focus on the louder
more perceptually important points.</p>
<p>References:</p>
<p>Seetharaman, Prem, Gordon Wichern, Jonathan Le Roux, and Bryan Pardo.
“Bootstrapping Single-Channel Source Separation via Unsupervised Spatial
Clustering on Stereo Mixtures”. 44th International Conference on Acoustics,
Speech, and Signal Processing, Brighton, UK, May, 2019</p>
<p>Seetharaman, Prem. Bootstrapping the Learning Process for Computer Audition.
Diss. Northwestern University, 2019.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>audio_signal</strong> (<a class="reference internal" href="core.html#nussl.core.AudioSignal" title="nussl.core.AudioSignal"><em>AudioSignal</em></a>) – AudioSignal object which will be used to compute
the mask over which to compute the confidence measure. This can be None, if
and only if <code class="docutils literal notranslate"><span class="pre">representation</span></code> is passed as a keyword argument to this
function.</p></li>
<li><p><strong>features</strong> (<em>np.ndarray</em>) – Numpy array containing the features to be clustered.
Should have the same dimensions as the representation.</p></li>
<li><p><strong>n_sources</strong> (<em>int</em>) – Number of sources to cluster the features into.</p></li>
<li><p><strong>threshold</strong> (<em>int</em><em>, </em><em>optional</em>) – Threshold by loudness. Points below the threshold are
excluded from being used in the confidence measure. Defaults to 95.</p></li>
<li><p><strong>kwargs</strong> – Keyword arguments to <cite>_get_loud_bins_mask</cite>. Namely, representation can
go here as a keyword argument.</p></li>
</ul>
</dd>
<dt class="field-even">Returns</dt>
<dd class="field-even"><p>Confidence given by Jensen-Shannon divergence.</p>
</dd>
<dt class="field-odd">Return type</dt>
<dd class="field-odd"><p>float</p>
</dd>
</dl>
</dd></dl>
<dl class="function">
<dt id="nussl.ml.confidence.jensen_shannon_divergence">
<code class="sig-prename descclassname">nussl.ml.confidence.</code><code class="sig-name descname">jensen_shannon_divergence</code><span class="sig-paren">(</span><em class="sig-param">gmm_p</em>, <em class="sig-param">gmm_q</em>, <em class="sig-param">n_samples=100000</em><span class="sig-paren">)</span><a class="reference internal" href="_modules/nussl/ml/confidence.html#jensen_shannon_divergence"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#nussl.ml.confidence.jensen_shannon_divergence" title="Permalink to this definition">¶</a></dt>
<dd><p>Compute Jensen-Shannon (JS) divergence between two Gaussian Mixture Models via
sampling. JS divergence is also known as symmetric Kullback-Leibler divergence.
JS divergence has no closed form in general for GMMs, thus we use sampling to
compute it.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>gmm_p</strong> (<em>GaussianMixture</em>) – A GaussianMixture class fit to some data.</p></li>
<li><p><strong>gmm_q</strong> (<em>GaussianMixture</em>) – Another GaussianMixture class fit to some data.</p></li>
<li><p><strong>n_samples</strong> (<em>int</em>) – Number of samples to use to estimate JS divergence.</p></li>
</ul>
</dd>
<dt class="field-even">Returns</dt>
<dd class="field-even"><p>JS divergence between gmm_p and gmm_q</p>
</dd>
</dl>
</dd></dl>
<dl class="function">
<dt id="nussl.ml.confidence.loudness_confidence">
<code class="sig-prename descclassname">nussl.ml.confidence.</code><code class="sig-name descname">loudness_confidence</code><span class="sig-paren">(</span><em class="sig-param">audio_signal</em>, <em class="sig-param">features</em>, <em class="sig-param">num_sources</em>, <em class="sig-param">threshold=95</em>, <em class="sig-param">**kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="_modules/nussl/ml/confidence.html#loudness_confidence"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#nussl.ml.confidence.loudness_confidence" title="Permalink to this definition">¶</a></dt>
<dd><p>Computes the clusterability of the feature space by comparing the absolute
size of each cluster.</p>
<p>References:</p>
<p>Seetharaman, Prem, Gordon Wichern, Jonathan Le Roux, and Bryan Pardo.
“Bootstrapping Single-Channel Source Separation via Unsupervised Spatial
Clustering on Stereo Mixtures”. 44th International Conference on Acoustics,
Speech, and Signal Processing, Brighton, UK, May, 2019</p>
<p>Seetharaman, Prem. Bootstrapping the Learning Process for Computer Audition.
Diss. Northwestern University, 2019.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>audio_signal</strong> (<a class="reference internal" href="core.html#nussl.core.AudioSignal" title="nussl.core.AudioSignal"><em>AudioSignal</em></a>) – AudioSignal object which will be used to compute
the mask over which to compute the confidence measure. This can be None, if
and only if <code class="docutils literal notranslate"><span class="pre">representation</span></code> is passed as a keyword argument to this
function.</p></li>
<li><p><strong>features</strong> (<em>np.ndarray</em>) – Numpy array containing the features to be clustered.
Should have the same dimensions as the representation.</p></li>
<li><p><strong>n_sources</strong> (<em>int</em>) – Number of sources to cluster the features into.</p></li>
<li><p><strong>threshold</strong> (<em>int</em><em>, </em><em>optional</em>) – Threshold by loudness. Points below the threshold are
excluded from being used in the confidence measure. Defaults to 95.</p></li>
<li><p><strong>kwargs</strong> – Keyword arguments to <cite>_get_loud_bins_mask</cite>. Namely, representation can
go here as a keyword argument.</p></li>
</ul>
</dd>
<dt class="field-even">Returns</dt>
<dd class="field-even"><p>Confidence given by size of smallest cluster.</p>
</dd>
<dt class="field-odd">Return type</dt>
<dd class="field-odd"><p>float</p>
</dd>
</dl>
</dd></dl>
<dl class="function">
<dt id="nussl.ml.confidence.posterior_confidence">
<code class="sig-prename descclassname">nussl.ml.confidence.</code><code class="sig-name descname">posterior_confidence</code><span class="sig-paren">(</span><em class="sig-param">audio_signal</em>, <em class="sig-param">features</em>, <em class="sig-param">num_sources</em>, <em class="sig-param">threshold=95</em>, <em class="sig-param">**kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="_modules/nussl/ml/confidence.html#posterior_confidence"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#nussl.ml.confidence.posterior_confidence" title="Permalink to this definition">¶</a></dt>
<dd><p>Calculates the clusterability of an embedding space by looking at the
strength of the assignments of each point to a specific cluster. The
more points that are “in between” clusters (e.g. no strong assignmment),
the lower the clusterability.</p>
<p>References:</p>
<p>Seetharaman, Prem, Gordon Wichern, Jonathan Le Roux, and Bryan Pardo.
“Bootstrapping Single-Channel Source Separation via Unsupervised Spatial
Clustering on Stereo Mixtures”. 44th International Conference on Acoustics,
Speech, and Signal Processing, Brighton, UK, May, 2019</p>
<p>Seetharaman, Prem. Bootstrapping the Learning Process for Computer Audition.
Diss. Northwestern University, 2019.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>audio_signal</strong> (<a class="reference internal" href="core.html#nussl.core.AudioSignal" title="nussl.core.AudioSignal"><em>AudioSignal</em></a>) – AudioSignal object which will be used to compute
the mask over which to compute the confidence measure. This can be None, if
and only if <code class="docutils literal notranslate"><span class="pre">representation</span></code> is passed as a keyword argument to this
function.</p></li>
<li><p><strong>features</strong> (<em>np.ndarray</em>) – Numpy array containing the features to be clustered.
Should have the same dimensions as the representation.</p></li>
<li><p><strong>n_sources</strong> (<em>int</em>) – Number of sources to cluster the features into.</p></li>
<li><p><strong>threshold</strong> (<em>int</em><em>, </em><em>optional</em>) – Threshold by loudness. Points below the threshold are
excluded from being used in the confidence measure. Defaults to 95.</p></li>
<li><p><strong>kwargs</strong> – Keyword arguments to <cite>_get_loud_bins_mask</cite>. Namely, representation can
go here as a keyword argument.</p></li>
</ul>
</dd>
<dt class="field-even">Returns</dt>
<dd class="field-even"><p>Confidence given by posteriors.</p>
</dd>
<dt class="field-odd">Return type</dt>
<dd class="field-odd"><p>float</p>
</dd>
</dl>
</dd></dl>
<dl class="function">
<dt id="nussl.ml.confidence.silhouette_confidence">
<code class="sig-prename descclassname">nussl.ml.confidence.</code><code class="sig-name descname">silhouette_confidence</code><span class="sig-paren">(</span><em class="sig-param">audio_signal</em>, <em class="sig-param">features</em>, <em class="sig-param">num_sources</em>, <em class="sig-param">threshold=95</em>, <em class="sig-param">max_points=1000</em>, <em class="sig-param">**kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="_modules/nussl/ml/confidence.html#silhouette_confidence"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#nussl.ml.confidence.silhouette_confidence" title="Permalink to this definition">¶</a></dt>
<dd><p>Uses the silhouette score to compute the clusterability of the feature space.</p>
<p>The Silhouette Coefficient is calculated using the
mean intra-cluster distance (a) and the mean nearest-cluster distance (b)
for each sample. The Silhouette Coefficient for a sample is (b - a) / max(a, b).
To clarify, b is the distance between a sample and the nearest cluster
that the sample is not a part of. Note that Silhouette Coefficient is
only defined if number of labels is 2 <= n_labels <= n_samples - 1.</p>
<p>References:</p>
<p>Seetharaman, Prem. Bootstrapping the Learning Process for Computer Audition.
Diss. Northwestern University, 2019.</p>
<p>Peter J. Rousseeuw (1987). “Silhouettes: a Graphical Aid to the
Interpretation and Validation of Cluster Analysis”. Computational and
Applied Mathematics 20: 53-65.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>audio_signal</strong> (<a class="reference internal" href="core.html#nussl.core.AudioSignal" title="nussl.core.AudioSignal"><em>AudioSignal</em></a>) – AudioSignal object which will be used to compute
the mask over which to compute the confidence measure. This can be None, if
and only if <code class="docutils literal notranslate"><span class="pre">representation</span></code> is passed as a keyword argument to this
function.</p></li>
<li><p><strong>features</strong> (<em>np.ndarray</em>) – Numpy array containing the features to be clustered.
Should have the same dimensions as the representation.</p></li>
<li><p><strong>n_sources</strong> (<em>int</em>) – Number of sources to cluster the features into.</p></li>
<li><p><strong>threshold</strong> (<em>int</em><em>, </em><em>optional</em>) – Threshold by loudness. Points below the threshold are
excluded from being used in the confidence measure. Defaults to 95.</p></li>
<li><p><strong>kwargs</strong> – Keyword arguments to <cite>_get_loud_bins_mask</cite>. Namely, representation can
go here as a keyword argument.</p></li>
<li><p><strong>max_points</strong> (<em>int</em><em>, </em><em>optional</em>) – Maximum number of points to compute the Silhouette
score for. Silhouette score is a costly operation. Defaults to 1000.</p></li>
</ul>
</dd>
<dt class="field-even">Returns</dt>
<dd class="field-even"><p>Confidence given by Silhouette score.</p>
</dd>
<dt class="field-odd">Return type</dt>
<dd class="field-odd"><p>float</p>
</dd>
</dl>
</dd></dl>
<dl class="function">
<dt id="nussl.ml.confidence.whitened_kmeans_confidence">
<code class="sig-prename descclassname">nussl.ml.confidence.</code><code class="sig-name descname">whitened_kmeans_confidence</code><span class="sig-paren">(</span><em class="sig-param">audio_signal</em>, <em class="sig-param">features</em>, <em class="sig-param">num_sources</em>, <em class="sig-param">threshold=95</em>, <em class="sig-param">**kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="_modules/nussl/ml/confidence.html#whitened_kmeans_confidence"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#nussl.ml.confidence.whitened_kmeans_confidence" title="Permalink to this definition">¶</a></dt>
<dd><p>Computes the clusterability in two steps:</p>
<ol class="arabic simple">
<li><p>Cluster the feature space using KMeans into assignments</p></li>
<li><p>Compute the Whitened K-Means loss between the features and the assignments.</p></li>
</ol>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>audio_signal</strong> (<a class="reference internal" href="core.html#nussl.core.AudioSignal" title="nussl.core.AudioSignal"><em>AudioSignal</em></a>) – AudioSignal object which will be used to compute
the mask over which to compute the confidence measure. This can be None, if
and only if <code class="docutils literal notranslate"><span class="pre">representation</span></code> is passed as a keyword argument to this
function.</p></li>
<li><p><strong>features</strong> (<em>np.ndarray</em>) – Numpy array containing the features to be clustered.
Should have the same dimensions as the representation.</p></li>
<li><p><strong>n_sources</strong> (<em>int</em>) – Number of sources to cluster the features into.</p></li>
<li><p><strong>threshold</strong> (<em>int</em><em>, </em><em>optional</em>) – Threshold by loudness. Points below the threshold are
excluded from being used in the confidence measure. Defaults to 95.</p></li>
<li><p><strong>kwargs</strong> – Keyword arguments to <cite>_get_loud_bins_mask</cite>. Namely, representation can
go here as a keyword argument.</p></li>
</ul>
</dd>
<dt class="field-even">Returns</dt>
<dd class="field-even"><p>Confidence given by whitened k-means loss.</p>
</dd>
<dt class="field-odd">Return type</dt>
<dd class="field-odd"><p>float</p>
</dd>
</dl>
</dd></dl>
</div>
<div class="section" id="module-nussl.ml.train">
<span id="training"></span><h2>Training<a class="headerlink" href="#module-nussl.ml.train" title="Permalink to this headline">¶</a></h2>
<div class="section" id="id1">
<h3>Training<a class="headerlink" href="#id1" title="Permalink to this headline">¶</a></h3>
<dl class="function">
<dt id="nussl.ml.train.create_train_and_validation_engines">
<code class="sig-prename descclassname">nussl.ml.train.</code><code class="sig-name descname">create_train_and_validation_engines</code><span class="sig-paren">(</span><em class="sig-param">train_func</em>, <em class="sig-param">val_func=None</em>, <em class="sig-param">device='cpu'</em><span class="sig-paren">)</span><a class="reference internal" href="_modules/nussl/ml/train/trainer.html#create_train_and_validation_engines"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#nussl.ml.train.create_train_and_validation_engines" title="Permalink to this definition">¶</a></dt>
<dd><p>Helper function for creating an ignite Engine object with helpful defaults.
This sets up an Engine that has four handlers attached to it:</p>
<ul class="simple">
<li><p>prepare_batch: before a batch is passed to train_func or val_func, this
function runs, moving every item in the batch (which is a dictionary) to
the appropriate device (‘cpu’ or ‘cuda’).</p></li>
<li><p>book_keeping: sets up some dictionaries that are used for bookkeeping so one
can easily track the epoch and iteration losses for both training and
validation.</p></li>
<li><p>add_to_iter_history: records the iteration, epoch, and past iteration losses
into the dictionaries set up by book_keeping.</p></li>
<li><p>clear_iter_history: resets the current iteration history of losses after moving
the current iteration history into past iteration history.</p></li>
</ul>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>train_func</strong> (<em>func</em>) – Function that provides the closure for training for
a single batch.</p></li>
<li><p><strong>val_func</strong> (<em>func</em><em>, </em><em>optional</em>) – Function that provides the closure for
validating a single batch. Defaults to None.</p></li>
<li><p><strong>device</strong> (<em>str</em><em>, </em><em>optional</em>) – Device to move tensors to. Defaults to ‘cpu’.</p></li>
</ul>
</dd>
</dl>
</dd></dl>
<dl class="function">
<dt id="nussl.ml.train.add_tensorboard_handler">
<code class="sig-prename descclassname">nussl.ml.train.</code><code class="sig-name descname">add_tensorboard_handler</code><span class="sig-paren">(</span><em class="sig-param">tensorboard_folder</em>, <em class="sig-param">engine</em>, <em class="sig-param">every_iteration=False</em><span class="sig-paren">)</span><a class="reference internal" href="_modules/nussl/ml/train/trainer.html#add_tensorboard_handler"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#nussl.ml.train.add_tensorboard_handler" title="Permalink to this definition">¶</a></dt>
<dd><p>Every key in engine.state.epoch_history[-1] is logged to TensorBoard.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>tensorboard_folder</strong> (<em>str</em>) – Where the tensorboard logs should go.</p></li>
<li><p><strong>trainer</strong> (<em>ignite.Engine</em>) – The engine to log.</p></li>
<li><p><strong>every_iteration</strong> (<em>bool</em><em>, </em><em>optional</em>) – Whether to also log the values at every
iteration.</p></li>
</ul>
</dd>
</dl>
</dd></dl>
<dl class="function">
<dt id="nussl.ml.train.cache_dataset">
<code class="sig-prename descclassname">nussl.ml.train.</code><code class="sig-name descname">cache_dataset</code><span class="sig-paren">(</span><em class="sig-param">dataset</em><span class="sig-paren">)</span><a class="reference internal" href="_modules/nussl/ml/train/trainer.html#cache_dataset"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#nussl.ml.train.cache_dataset" title="Permalink to this definition">¶</a></dt>
<dd><p>Runs through an entire dataset and caches it if there nussl.datasets.transforms.Cache
is in dataset.transform. If there is no caching, or dataset.cache_populated = True,
then this function just iterates through the dataset and does nothing.</p>
<p>This function can also take a <cite>torch.util.data.DataLoader</cite> object wrapped around
a <cite>nussl.datasets.BaseDataset</cite> object.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><p><strong>dataset</strong> (<a class="reference internal" href="datasets.html#nussl.datasets.BaseDataset" title="nussl.datasets.BaseDataset"><em>nussl.datasets.BaseDataset</em></a>) – Must be a subclass of
<cite>nussl.datasets.BaseDataset</cite>.</p>
</dd>
</dl>
</dd></dl>
<dl class="function">
<dt id="nussl.ml.train.add_validate_and_checkpoint">
<code class="sig-prename descclassname">nussl.ml.train.</code><code class="sig-name descname">add_validate_and_checkpoint</code><span class="sig-paren">(</span><em class="sig-param">output_folder</em>, <em class="sig-param">model</em>, <em class="sig-param">optimizer</em>, <em class="sig-param">train_data</em>, <em class="sig-param">trainer</em>, <em class="sig-param">val_data=None</em>, <em class="sig-param">validator=None</em><span class="sig-paren">)</span><a class="reference internal" href="_modules/nussl/ml/train/trainer.html#add_validate_and_checkpoint"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#nussl.ml.train.add_validate_and_checkpoint" title="Permalink to this definition">¶</a></dt>
<dd><p>This adds the following handler to the trainer:</p>
<ul class="simple">
<li><p>validate_and_checkpoint: this runs the validator on the validation dataset
(<code class="docutils literal notranslate"><span class="pre">val_data</span></code>) using a defined validation process function <code class="docutils literal notranslate"><span class="pre">val_func</span></code>.
These are optional. If these are not provided, then no validator is run
and the model is simply checkpointed. The model is always saved to
<code class="docutils literal notranslate"><span class="pre">{output_folder}/checkpoints/latest.model.pth</span></code>. If the model is also the
one with the lowest validation loss, then it is <em>also</em> saved to
<code class="docutils literal notranslate"><span class="pre">{output_folder}/checkpoints/best.model.pth.</span> <span class="pre">This</span> <span class="pre">is</span> <span class="pre">attached</span> <span class="pre">to</span>
<span class="pre">``Events.EPOCH_COMPLETED</span></code> on the trainer. After completion, it fires a
<code class="docutils literal notranslate"><span class="pre">ValidationEvents.VALIDATION_COMPLETED</span></code> event.</p></li>
</ul>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>model</strong> (<em>torch.nn.Module</em>) – Model that is being trained (typically a SeparationModel).
optimizer (torch.optim.Optimizer): Optimizer being used to train.</p></li>
<li><p><strong>train_data</strong> (<a class="reference internal" href="datasets.html#nussl.datasets.BaseDataset" title="nussl.datasets.BaseDataset"><em>BaseDataset</em></a>) – dataset that is being used to train the model. This is to
save additional metadata information alongside the model checkpoint such as the
STFTParams, dataset folder, length, list of transforms, etc.</p></li>
<li><p><strong>trainer</strong> (<em>ignite.Engine</em>) – Engine for trainer</p></li>
<li><p><strong>validator</strong> (<em>ignite.Engine</em><em>, </em><em>optional</em>) – Engine for validation.
Defaults to None.</p></li>
<li><p><strong>val_data</strong> (<em>torch.utils.data.Dataset</em><em>, </em><em>optional</em>) – The validation data.
Defaults to None.</p></li>
</ul>
</dd>
</dl>
</dd></dl>
<dl class="function">
<dt id="nussl.ml.train.add_stdout_handler">
<code class="sig-prename descclassname">nussl.ml.train.</code><code class="sig-name descname">add_stdout_handler</code><span class="sig-paren">(</span><em class="sig-param">trainer</em>, <em class="sig-param">validator=None</em><span class="sig-paren">)</span><a class="reference internal" href="_modules/nussl/ml/train/trainer.html#add_stdout_handler"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#nussl.ml.train.add_stdout_handler" title="Permalink to this definition">¶</a></dt>
<dd><p>This adds the following handler to the trainer engine, and also sets up
Timers:</p>
<ul>
<li><p>log_epoch_to_stdout: This logs the results of a model after it has trained
for a single epoch on both the training and validation set. The output typically
looks like this:</p>
<div class="highlight-none notranslate"><div class="highlight"><pre><span></span>EPOCH SUMMARY
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- Epoch number: 0010 / 0010