forked from raimis/AToM-OpenMM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
local_openmm_transport.py
274 lines (230 loc) · 9.98 KB
/
local_openmm_transport.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
from __future__ import print_function
from __future__ import division
"""
Multiprocessing job transport for AsyncRE/OpenMM
"""
import os, re, sys, time, shutil, copy, random, signal
import multiprocessing as mp
#from multiprocessing import Process, Queue, Event
import logging
from simtk import openmm as mm
from simtk.openmm.app import *
from simtk.openmm import *
from simtk.unit import *
from datetime import datetime
from ommreplica import *
from ommsystem import *
from ommworker import *
from contextlib import contextmanager
from transport import Transport
class LocalOpenMMTransport(Transport):
"""
Class to launch and monitor jobs on a set of local GPUs
"""
def __init__(self, jobname, openmm_workers, openmm_replicas):
# jobname: identifies current asyncRE job
Transport.__init__(self)
self.logger = logging.getLogger("async_re.local_openmm_transport")
# openmm contexts
self.openmm_workers = openmm_workers
self.nprocs = len(self.openmm_workers)
# record replica OpenMM objects
self.openmm_replicas = openmm_replicas
# device status = None if idle
# Otherwise a structure containing:
# replica id being executed
# ...
self.node_status = [ None for k in range(self.nprocs)]
# contains information about the device etc. running a replica
# None = no information about where the replica is running
self.replica_to_job = [ None for k in range(len(openmm_replicas)) ]
# implements a queue of jobs from which to draw the next job
# to launch
ctx = mp.get_context('spawn')
self.jobqueue = ctx.Queue()
self.ncrashes = [ 0 for k in range(self.nprocs)]
self.disabled = [ False for k in range(self.nprocs)]
self.maxcrashes = 4
def _clear_resource(self, replica):
# frees up the node running a replica identified by replica id
job = {}
try:
job = self.replica_to_job[replica]
if job is None:
return None
except:
self.logger.warning("clear_resource(): unknown replica id %d",
replica)
if 'nodeid' not in job:
return None
else:
nodeid = job['nodeid']
try:
if self.node_status[nodeid] is not None and self.node_status[nodeid] >= 0: #-1 signals a crashed node that should be left alone
self.node_status[nodeid] = None
except:
self.logger.warning("clear_resource(): unable to query nodeid %d", nodeid)
return None
return nodeid
def numNodesAlive(self):
alive = [node for node in range(self.nprocs)
if self.node_status[node] is None or self.node_status[node] >= 0 ]
return len(alive)
def _fixnodes(self):
for nodeid in range(self.nprocs):
if self.node_status[nodeid] is not None and self.node_status[nodeid] < 0 and not self.disabled[nodeid]:
if self.ncrashes[nodeid] <= self.maxcrashes:
self.ncrashes[nodeid] += 1
self.logger.warning("fixnodes(): attempting to restart nodeid %d", nodeid)
res = self.openmm_workers[nodeid].start_worker()
if res is not None:
self.node_status[nodeid] = None
else:
self.logger.warning("fixnodes(): node %d has crashed too many times; it will not be restarted.", nodeid)
self.disabled[nodeid] = True
def _availableNode(self):
#returns a node at random among available nodes
available = [node for node in range(self.nprocs)
if self.node_status[node] is None]
if available == None or len(available) == 0:
return None
random.shuffle(available)
return available[0]
def launchJob(self, replica, job_info):
#Enqueues a replica for running based on provided job info.
job = job_info
job['replica'] = replica
job['start_time'] = 0
self.replica_to_job[replica] = job
self.jobqueue.put(replica)
return self.jobqueue.qsize()
def LaunchReplica(self, worker, replica, cycle, nsteps,
nheating = 0, ncooling = 0, hightemp = 0.0):
(stateid, par) = replica.get_state()
worker.set_posvel(replica.positions, replica.velocities)
worker.set_state(par)
worker.run(nsteps, nheating, ncooling, hightemp)
def ProcessJobQueue(self, mintime, maxtime):
#Launches jobs waiting in the queue.
#It will scan free devices and job queue up to maxtime.
#If the queue becomes empty, it will still block until maxtime is elapsed.
njobs_launched = 0
nreplicas = len(self.replica_to_job)
when_started = time.time()
while time.time() < when_started + maxtime:
# find an available node
node = self._availableNode()
while (not self.jobqueue.empty()) and (not node == None):
# grabs job on top of the queue
replica = self.jobqueue.get()
job = self.replica_to_job[replica]
# assign job to available node
job['nodeid'] = node
job['openmm_replica'] = self.openmm_replicas[replica]
job['openmm_worker'] = self.openmm_workers[node]
job['start_time'] = time.time()
# connects node to replica
self.replica_to_job[replica] = job
self.node_status[node] = replica
if 'nheating' in job:
nheating = job['nheating']
ncooling = job['ncooling']
hightemp = job['hightemp']
else:
nheating = 0
ncooling = 0
hightemp = 0.0
self.LaunchReplica(job['openmm_worker'], job['openmm_replica'], job['cycle'],
job['nsteps'], nheating, ncooling, hightemp)
# updates number of jobs launched
njobs_launched += 1
node = self._availableNode()
# waits mintime second and rescans job queue
time.sleep(mintime)
# updates set of free nodes by checking for replicas that have exited
for repl in range(nreplicas):
self.isDone(repl,0)
#restarts crashed nodes if any
self._fixnodes()
return njobs_launched
def DrainJobQueue(self):
#clear the job queue
while not self.jobqueue.empty():
# grabs job on top of the queue
replica = self.jobqueue.get()
self._clear_resource(replica)
self.replica_to_job[replica] = None
def _update_replica(self, job):
#update replica cycle, mdsteps, write out, etc. from worker
ommreplica = job['openmm_replica']
if job['openmm_worker'].has_crashed(): #refuses to update replica from a crashed worker
return None
(pos,vel) = job['openmm_worker'].get_posvel()
pot = job['openmm_worker'].get_energy()
if pos is None or vel is None or pot is None:
return None
for value in pot.values():
if math.isnan(value._value):
return None
for p in pos:
if math.isnan(p.x) or math.isnan(p.y) or math.isnan(p.z):
return None
for v in vel:
if math.isnan(v.x) or math.isnan(v.y) or math.isnan(v.z):
return None
cycle = ommreplica.get_cycle() + 1
ommreplica.set_cycle(cycle)
mdsteps = ommreplica.get_mdsteps() + job['nsteps']
ommreplica.set_mdsteps(mdsteps)
#update positions and velocities of openmm replica
ommreplica.set_posvel(pos,vel)
#TODO: should also update boxsize
#update energies of openmm replica
ommreplica.set_energy(pot)
#output data and trajectory file update
if mdsteps % job['nprnt'] == 0:
ommreplica.save_out()
if mdsteps % job['ntrj'] == 0:
ommreplica.save_dcd()
return 0
def isDone(self,replica,cycle):
"""
Checks if a replica completed a run.
If a replica is done it clears the corresponding node.
Note that cycle is ignored by job transport. It is assumed that it is
the latest cycle. it's kept for argument compatibility with
hasCompleted() elsewhere.
"""
job = self.replica_to_job[replica]
if job == None:
# if job has been removed we assume that the replica is done
return True
else:
try:
openmm_worker = job['openmm_worker']
except:
#job is in the queue but not yet launched
return False
if openmm_worker.has_crashed():
self.logger.warning("isDone(): replica %d has crashed", replica)
openmm_worker.finish(wait = False)
self.node_status[job['nodeid']] = -1 #signals dead context
self._clear_resource(replica)
self.replica_to_job[replica] = None
return True
if not openmm_worker.is_started():
done = False
else:
done = openmm_worker.is_running() and openmm_worker.is_done()
if done:
#update replica info
openmm_worker._runningSignal.clear()
retcode = self._update_replica(job)
if retcode is None:
self.logger.warning("isDone(): replica %d has completed with errors", replica)
self.node_status[job['nodeid']] = -1 #signals dead context
# disconnects replica from job and node
self._clear_resource(replica)
#flag replica as not linked to a job
self.replica_to_job[replica] = None
return done