-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add KernelHive manager code, remove KernelHive dependency
- Loading branch information
Showing
17 changed files
with
342 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
Scripts for running DeepSpeech training with Distributed TensorFlow and [KernelHive manager](https://github.com/roscisz/KernelHive/tree/manager/hive-manager). | ||
Scripts for running DeepSpeech training with Distributed TensorFlow and TensorHive. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
Scripts for running [inception](https://github.com/tensorflow/models/tree/master/inception) training with Distributed TensorFlow and [KernelHive manager](https://github.com/roscisz/KernelHive/tree/manager/hive-manager). | ||
Scripts for running [inception](https://github.com/tensorflow/models/tree/master/inception) training with Distributed TensorFlow and TensorHive. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import paramiko | ||
|
||
|
||
# TODO: consider channel limit on the connections | ||
class ConnectionManager: | ||
def __init__(self): | ||
self.connections = dict() | ||
|
||
def ensure_connection(self, node): | ||
if node not in self.connections.keys(): | ||
self.connections[node] = self.setup_ssh_client(node) | ||
return self.connections[node] | ||
|
||
def run_command(self, node, command): | ||
client = self.ensure_connection(node) | ||
client.exec_command(command) | ||
|
||
def shutdown_connections(self): | ||
for node in self.connections.keys(): | ||
self.connections[node].close() | ||
|
||
def setup_ssh_client(self, node): | ||
client = paramiko.SSHClient() | ||
client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) | ||
client.connect(node) | ||
return client | ||
|
||
def shutdown(self): | ||
print('Shutting down node connections...') | ||
for node in self.connections.keys(): | ||
self.connections[node].close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from threading import Thread | ||
|
||
from monitoring import * | ||
from connectivity import ConnectionManager | ||
from serving import HTTPJSONRPCServer | ||
|
||
|
||
class Manager(Thread): | ||
def __init__(self, hostname, port, monitors, handlers, landing_page='index.html'): | ||
Thread.__init__(self) | ||
self.monitors = monitors | ||
self.handlers = handlers | ||
|
||
self.server = HTTPJSONRPCServer(hostname, port, self.get_module_name(), landing_page) | ||
self.configure_services() | ||
self.configure_handlers() | ||
self.configure_monitors() | ||
|
||
self.connection_manager = ConnectionManager() | ||
print('hello?') | ||
self.monitoring_service = MonitoringService(self.monitors, self.handlers, self.connection_manager) | ||
|
||
def configure_monitors(self): | ||
pass | ||
|
||
def configure_handlers(self): | ||
pass | ||
|
||
def configure_services(self): | ||
self.add_service(self.add_node) | ||
self.add_service(self.get_infrastructure) | ||
|
||
def get_module_name(self): | ||
return 'manager' | ||
|
||
def add_service(self, method): | ||
self.server.add_service(method) | ||
|
||
def add_node(self, node_hostname): | ||
self.monitoring_service.add_node(node_hostname) | ||
|
||
def get_infrastructure(self): | ||
return self.monitoring_service.infrastructure | ||
|
||
def shutdown(self): | ||
print('Shutting down the manager...') | ||
self.monitoring_service.shutdown() | ||
self.monitoring_service.join() | ||
self.connection_manager.shutdown() | ||
self.server.shutdown() | ||
|
||
def run(self): | ||
self.monitoring_service.start() | ||
self.server.server_forever() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import time | ||
from utils import StoppableThread | ||
|
||
|
||
class MonitoringHandler: | ||
def handle_monitoring(self, infrastructure): | ||
raise NotImplementedError() | ||
|
||
|
||
class Monitor: | ||
def get_key(self): | ||
raise NotImplementedError() | ||
|
||
def discover(self, client): | ||
raise NotImplementedError() | ||
|
||
def monitor(self, client, output): | ||
raise NotImplementedError() | ||
|
||
|
||
class MonitoringWorker(StoppableThread): | ||
def __init__(self, client, node_data, monitors): | ||
StoppableThread.__init__(self) | ||
self.client = client | ||
self.node_data = node_data | ||
self.monitors = monitors | ||
|
||
def monitor(self, client, node_data): | ||
for monitor in self.monitors: | ||
node_data[monitor.get_key()] = monitor.monitor(client, node_data[monitor.get_key()]) | ||
|
||
def do_run(self): | ||
self.monitor(self.client, self.node_data) | ||
# TODO: sleep period as arg | ||
time.sleep(1) | ||
|
||
def finalize(self): | ||
self.client.close() | ||
|
||
|
||
class MonitoringService(StoppableThread): | ||
def __init__(self, monitors, handlers, connection_manager): | ||
StoppableThread.__init__(self) | ||
print('Starting the monitoring service...') | ||
self.monitors = monitors | ||
self.handlers = handlers | ||
self.connection_manager = connection_manager | ||
|
||
self.infrastructure = dict() | ||
self.workers = [] | ||
|
||
def discover_node(self, client, monitors): | ||
return {monitor.get_key(): monitor.discover(client) for monitor in monitors} | ||
|
||
def add_node(self, node): | ||
if node not in self.infrastructure.keys(): | ||
connection = self.connection_manager.ensure_connection(node) | ||
self.infrastructure[node] = self.discover_node(connection, self.monitors) | ||
worker = MonitoringWorker(connection, self.infrastructure[node], self.monitors) | ||
self.workers.append(worker) | ||
worker.start() | ||
|
||
def add_handler(self, handler): | ||
self.handlers.append(handler) | ||
|
||
def do_run(self): | ||
for handler in self.handlers: | ||
handler.handle_monitoring(self.infrastructure) | ||
# TODO: sleep period as arg | ||
time.sleep(5) | ||
|
||
def finalize(self): | ||
for worker in self.workers: | ||
worker.join() | ||
|
||
def shutdown(self): | ||
print('Shutting down monitoring workers...') | ||
for worker in self.workers: | ||
worker.shutdown() | ||
StoppableThread.shutdown(self) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from tensorhive.monitoring import MonitoringHandler | ||
|
||
|
||
class PrintingHandler(MonitoringHandler): | ||
def handle_monitoring(self, infrastructure): | ||
print(str(infrastructure)) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from tensorhive.monitoring import Monitor | ||
|
||
timeout_prefix = 'timeout 2 ' | ||
|
||
|
||
class GPUMonitor(Monitor): | ||
def get_key(self): | ||
return 'gpu' | ||
|
||
def discover(self, client): | ||
gpus = {} | ||
_, stdout, _ = client.exec_command('nvidia-smi -L') | ||
gpu_descrs = stdout.read().split('\n')[:-1] | ||
for gpu_descr in gpu_descrs: | ||
name, model, uuid = gpu_descr.split(': ') | ||
model = model[:-6] | ||
uuid = uuid[:-1] | ||
gpus[uuid] = {} | ||
gpus[uuid]['name'] = name | ||
gpus[uuid]['model'] = model | ||
return gpus | ||
|
||
def check_process_owner(self, client, pid): | ||
_, stdout, _ = client.exec_command('ps -o user %s' % pid) | ||
return stdout.read().split('\n')[1] | ||
|
||
def monitor_processes(self, client, uuid): | ||
processes = [] | ||
_, stdout, _ = client.exec_command('%s nvidia-smi pmon -c 1 -i %s' % (timeout_prefix, uuid)) | ||
outputs = stdout.read().split('\n')[2:-1] | ||
for output in outputs: | ||
values = output.split() | ||
if values[1] is not '-': | ||
processes.append({'pid': values[1], 'owner': self.check_process_owner(client, values[1])}) | ||
return processes | ||
|
||
def monitor_utilization(self, client, uuid): | ||
_, stdout, _ = client.exec_command('%s nvidia-smi --query-gpu=utilization.gpu --format=csv,noheader -i %s' % (timeout_prefix, uuid)) | ||
output = stdout.read().split() | ||
if not len(output): | ||
return 0 | ||
return output[0] | ||
|
||
def monitor(self, client, gpus): | ||
for uuid in gpus.keys(): | ||
gpus[uuid]['processes'] = self.monitor_processes(client, uuid) | ||
gpus[uuid]['utilization'] = self.monitor_utilization(client, uuid) | ||
return gpus | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from tensorhive.monitoring import Monitor | ||
|
||
|
||
class ProcessMonitor(Monitor): | ||
def __init__(self): | ||
self.processes = [] | ||
|
||
def get_key(self): | ||
return 'processes' | ||
|
||
def discover(self, client): | ||
process_map = dict() | ||
for process in self.processes: | ||
_, stdout, _ = client.exec_command('pgrep -f "%s"' % process) | ||
process_map[process] = stdout.read().split('\n')[:-1] | ||
return process_map | ||
|
||
def monitor(self, client, output): | ||
return self.discover(client) | ||
|
||
def add_process(self, process): | ||
self.processes.append(process) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import os | ||
import importlib | ||
from werkzeug.wrappers import Request, Response | ||
from werkzeug.wsgi import SharedDataMiddleware | ||
from werkzeug.serving import make_server | ||
from werkzeug.utils import redirect | ||
from jsonrpc import JSONRPCResponseManager, dispatcher | ||
|
||
|
||
class HTTPJSONRPCServer: | ||
def __init__(self, hostname, port, name, landing_page): | ||
self.landing_page = landing_page | ||
|
||
try: | ||
module = importlib.import_module(name) | ||
static_path = os.path.dirname(module.__file__) + '/static' | ||
except ImportError: | ||
static_path = os.path.join(os.getcwd(), 'static') | ||
|
||
self.dynamic_path = '/tmp/%s' % name | ||
if not os.path.exists(self.dynamic_path): | ||
os.mkdir(self.dynamic_path) | ||
|
||
application = SharedDataMiddleware(self.application, {'/': static_path}) | ||
self.srv = make_server(hostname, port, application) | ||
|
||
@Request.application | ||
def application(self, request): | ||
if request.path == '/' and request.method == 'GET': | ||
return redirect(self.landing_page) | ||
|
||
path = request.path.split('/') | ||
if len(path) > 2 and path[1] == 'dynamic': | ||
filename = '/'.join([self.dynamic_path, path[2]]) | ||
if os.path.isfile(filename): | ||
with open(filename, 'rb') as f: | ||
data = f.read() | ||
else: | ||
data = None | ||
response = Response(data, mimetype='application/octet-stream') | ||
response.headers['Cache-Control'] = 'no-cache, no-store, must-revalidate' | ||
response.headers['Pragma'] = 'no-cache' | ||
response.headers['Expires'] = '0' | ||
return response | ||
response = JSONRPCResponseManager.handle(request.data, dispatcher) | ||
return Response(response.json, mimetype='application/json') | ||
|
||
def add_service(self, method): | ||
dispatcher.add_method(method) | ||
|
||
def server_forever(self): | ||
print('Starting the HTTPJSONServer at http://%s:%d' % (self.srv.host, self.srv.port)) | ||
self.srv.serve_forever() | ||
|
||
def shutdown(self): | ||
print('Shutting down the HTTPJSONSever...') | ||
self.srv.shutdown() |
Oops, something went wrong.