-
Notifications
You must be signed in to change notification settings - Fork 5
/
run_random_search.py
48 lines (42 loc) · 1.46 KB
/
run_random_search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#!/usr/bin/env python
# coding=utf-8
from __future__ import division, print_function, unicode_literals
from sacred.observers import MongoObserver
from dae import ex
nr_runs_per_dataset = 100
datasets = {
'bars': 12,
'corners': 5,
'shapes': 3,
'multi_mnist': 3,
'mnist_shape': 2,
'simple_superpos':2
}
db_name = 'binding_via_rc'
# Random search
ex.observers = [MongoObserver.create(db_name=db_name, prefix='random_search')]
for ds, k in datasets.items():
for i in range(nr_runs_per_dataset):
ex.run(config_updates={'dataset.name': ds, 'verbose': False, 'em.k': k},
named_configs=['random_search'])
# Multi-Train Runs
ex.observers = [MongoObserver.create(db_name=db_name, prefix='train_multi')]
for ds, k in datasets.items():
if ds == "simple_superpos": continue
for i in range(nr_runs_per_dataset):
ex.run(config_updates={
'dataset.name': ds,
'dataset.train_set': 'train_multi',
'em.k': k,
'em.e_step': 'max',
'verbose': False}, named_configs=['random_search'])
# MSE-Likelihood Runs
ex.observers = [MongoObserver.create(db_name=db_name, prefix='mse_likelihood')]
for ds, k in datasets.items():
for i in range(nr_runs_per_dataset):
ex.run(config_updates={
'dataset.name': ds,
'dataset.salt_n_pepper': 0.3,
'network_spec': 'Fr250',
'em.k': k,
'verbose': False}, named_configs=['random_search'])