Skip to content

Commit

Permalink
bugfix: race-condition on python-level task.state
Browse files Browse the repository at this point in the history
  • Loading branch information
wlruys committed Feb 16, 2024
1 parent 6bd3166 commit 20e6321
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 84 deletions.
11 changes: 9 additions & 2 deletions src/c/backend/include/runtime.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,8 @@ class InnerTask {
If this task's access permission to the parray is read-only, it pulls
this list of the dependencies to this map.
*/
std::unordered_map<uint64_t, std::vector<InnerTask*>> parray_dependencies_map;
std::unordered_map<uint64_t, std::vector<InnerTask *>>
parray_dependencies_map;

InnerTask();
InnerTask(long long int id, void *py_task);
Expand Down Expand Up @@ -600,6 +601,12 @@ class InnerTask {
/* Set the task state */
TaskState set_state(TaskState state);

/* Get the task state */
int get_state_int() const {
const TaskState state = this->state.load();
return static_cast<int>(state);
}

/* Get the task state */
TaskState get_state() const {
const TaskState state = this->state.load();
Expand Down Expand Up @@ -632,7 +639,7 @@ class InnerTask {
void begin_multidev_req_addition();
void end_multidev_req_addition();

std::vector<InnerTask*>& get_parray_dependencies(uint64_t parray_parent_id) {
std::vector<InnerTask *> &get_parray_dependencies(uint64_t parray_parent_id) {
return this->parray_dependencies_map[parray_parent_id];
}

Expand Down
9 changes: 7 additions & 2 deletions src/c/backend/scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ void InnerWorker::wait() {
std::unique_lock<std::mutex> lck(mtx);
// std::cout << "Waiting for task (C++) " << this->thread_idx << std::endl;
cv.wait(lck, [this] { return this->notified; });
// std::cout << "Task assigned (C++) " << this->thread_idx << " "
// << this->ready << std::endl;
// std::cout << "Task assigned (C++) " << this->thread_idx << " " <<
// this->ready
// << std::endl;
// std::cout << "Task assigned (C++) " << this->task->get_name() << ": "
// << this->task->instance << std::endl;
this->scheduler->increase_num_notified_workers();
}

Expand Down Expand Up @@ -209,6 +212,8 @@ void InnerScheduler::spawn_task(InnerTask *task) {
void InnerScheduler::enqueue_task(InnerTask *task, TaskStatusFlags status) {
// TODO: Change this to appropriate phase as it becomes implemented
LOG_INFO(SCHEDULER, "Enqueing task: {}, Status: {}", task, status);
// std::cout << "Enqueing task: " << task->get_name()
// << " Instance: " << task->instance << std::endl;
if (status.mappable && (task->get_state() < TaskState::MAPPED)) {
LOG_INFO(SCHEDULER, "Enqueing task: {} to mapper", task);
task->set_status(TaskStatus::MAPPABLE);
Expand Down
1 change: 1 addition & 0 deletions src/python/parla/cython/core.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ cdef extern from "include/runtime.hpp" nogil:

string get_name()

int get_state_int()
int set_state(int state)
void add_device_req(void* dev_ptr, long mem_sz, int num_vcus)
void begin_arch_req_addition()
Expand Down
4 changes: 4 additions & 0 deletions src/python/parla/cython/core.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,10 @@ cdef class PyInnerTask:
status = c_self.notify_dependents_wrapper()
return status

cpdef get_state_int(self):
cdef InnerTask* c_self = self.c_task
return c_self.get_state_int()

cpdef set_state(self, int state):
cdef InnerTask* c_self = self.c_task
return c_self.set_state(state)
Expand Down
83 changes: 31 additions & 52 deletions src/python/parla/cython/scheduler.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,9 @@ class WorkerThread(ControllableThread, SchedulerContext):

while self._should_run:
self.status = "Waiting"

nvtx.push_range(message="worker::wait", domain="Python Runtime", color="blue")
self.inner_worker.wait_for_task()

self.task = self.inner_worker.get_task()

if isinstance(self.task, core.DataMovementTaskAttributes):
self.task_attrs = self.task
self.task = DataMovementTask()
Expand All @@ -220,8 +218,6 @@ class WorkerThread(ControllableThread, SchedulerContext):
# comment(wlr): Need this is all cases currently. FIXME: Add stream/event creation in C++ so python isn't the owner.
_global_data_tasks[id(self.task)] = self.task

nvtx.pop_range(domain="Python Runtime")

self.status = "Running"

if isinstance(self.task, Task):
Expand All @@ -242,8 +238,6 @@ class WorkerThread(ControllableThread, SchedulerContext):
else:
active_task.handle_runahead_dependencies()

nvtx.push_range(message="worker::run", domain="Python Runtime", color="blue")

# Push the task to the thread local stack
Locals.push_task(active_task)

Expand All @@ -257,90 +251,63 @@ class WorkerThread(ControllableThread, SchedulerContext):
parray_target_id = device_manager.globalid_to_parrayid(global_target_id)
parray._auto_move(parray_target_id, True)

core.binlog_2("Worker", "Running task: ", active_task.inner_task, " on worker: ", self.inner_worker)
# Run the task body (this may complete the task or return a continuation)
# The body may return asynchronusly before kernels have completed, in which case the task will be marked as runahead
active_task.run()
state = active_task.state

# Pop the task from the thread local stack
Locals.pop_task()

# Log events on all 'task default' streams
device_context.record_events()

nvtx.pop_range(domain="Python Runtime")

nvtx.push_range(message="worker::cleanup", domain="Python Runtime", color="blue")

final_state = active_task.state

# FIXME: This can be cleaned up and hidden from this function with a better interface...
if active_task.runahead == SyncType.NONE:
device_context.finalize()

# TODO(wlr): Add better exception handling
if isinstance(final_state, tasks.TaskException):
raise TaskBodyException(active_task.state.exception)

elif isinstance(final_state, tasks.TaskRunning):
nvtx.push_range(message="worker::continuation", domain="Python Runtime", color="red")
# print("CONTINUATION: ", active_task.taskid.full_name, active_task.state.dependencies, flush=True)
active_task.dependencies = active_task.state.dependencies
active_task.func = active_task.state.func
active_task.args = active_task.state.args
if isinstance(state, tasks.TaskRunning):

active_task.dependencies = state.dependencies
active_task.func = state.func
active_task.args = state.args

active_task.inner_task.clear_dependencies()
active_task.add_dependencies(active_task.dependencies, process=False)
nvtx.pop_range(domain="Python Runtime")

elif isinstance(final_state, tasks.TaskRunahead):
core.binlog_2("Worker", "Runahead task: ", active_task.inner_task, " on worker: ", self.inner_worker)

# print("Cleaning up Task", active_task, flush=True)

if USE_PYTHON_RUNAHEAD:
# Handle synchronization in Python (for debugging, works!)
self.scheduler.inner_scheduler.task_cleanup_presync(self.inner_worker, active_task.inner_task, active_task.state.value)
self.scheduler.inner_scheduler.task_cleanup_presync(self.inner_worker, active_task.inner_task, state.value)
if active_task.runahead != SyncType.NONE:
device_context.synchronize(events=True)
self.scheduler.inner_scheduler.task_cleanup_postsync(self.inner_worker, active_task.inner_task, active_task.state.value)
self.scheduler.inner_scheduler.task_cleanup_postsync(self.inner_worker, active_task.inner_task, state.value)
else:
# Handle synchronization in C++
self.scheduler.inner_scheduler.task_cleanup(self.inner_worker, active_task.inner_task, active_task.state.value)
self.scheduler.inner_scheduler.task_cleanup(self.inner_worker, active_task.inner_task, state.value)

if active_task.runahead != SyncType.NONE:
device_context.return_streams()

if isinstance(final_state, tasks.TaskRunahead):
final_state = tasks.TaskCompleted(final_state.return_value)
if active_task.is_completed():
active_task.cleanup()
active_task.state = tasks.TaskCompleted(active_task.result)

core.binlog_2("Worker", "Completed task: ", active_task.inner_task, " on worker: ", self.inner_worker)

active_task.state = final_state
self.task = None

nvtx.pop_range(domain="Python Runtime")
elif self._should_run:
raise WorkerThreadException("%r Worker: Woke without a task", self.index)
else:
break

except Exception as e:
tb = traceback.format_exc()
print("Exception in Worker Thread ", self, ": ", e, tb, flush=True)
print("Exception in Worker Thread ", self, " during handling of ", self.task.name, ": ", e, tb, flush=True)

self.scheduler.exception_stack.append(e)
self.scheduler.stop()

if isinstance(e, TaskBodyException):
raise WorkerThreadException(f"Unhandled Exception in Task: {self.task.get_name()}") from e
if isinstance(e, KeyboardInterrupt):
print("You pressed Ctrl+C! In a worker!", flush=True)
raise e
else:
raise WorkerThreadException("Unhandled Exception on "+str(self))

def stop(self):
super().stop()
self.inner_worker.stop()
Expand Down Expand Up @@ -398,12 +365,15 @@ class Scheduler(ControllableThread, SchedulerContext):

for t in self.worker_threads:
t.join()

#print("Exiting Scheduler", flush=True)

except Exception as e:
self.exception_stack.append(e)

finally:
#print(self.exception_stack, flush=True)
if len(self.exception_stack) > 0:
raise self.exception_stack[0]
finally:
pass

def run(self):
Expand Down Expand Up @@ -468,7 +438,6 @@ class Scheduler(ControllableThread, SchedulerContext):
"""
return self.inner_scheduler.get_parray_state(global_dev_id, parray_parent_id)


def _task_callback(task, body):
"""
A function which forwards to a python function in the appropriate device context.
Expand All @@ -482,26 +451,36 @@ def _task_callback(task, body):
if inspect.iscoroutine(body):
try:
in_value_task = getattr(task, "value_task", None)
in_value = in_value_task and in_value_task.result

if in_value_task is not None:
in_value = in_value_task.result
#print(in_value_task.state)
#print(f"Task invalue1", task, in_value_task, body, in_value, in_value_task.state, in_value_task.result, type(in_value_task), flush=True)
else:
in_value = None
#print(f"Task invalue2", task, in_value_task, body, in_value, type(task), type(in_value_task), flush=True)

new_task_info = body.send(in_value)
#print(f"Task new_task_info", task, new_task_info, body, flush=True)
task.value_task = None
if not isinstance(new_task_info, tasks.TaskAwaitTasks):
raise TypeError(
"Parla coroutine tasks must yield a TaskAwaitTasks")
dependencies = new_task_info.dependencies
value_task = new_task_info.value_task
#print(dependencies)
if value_task:
assert isinstance(value_task, Task)
task.value_task = value_task
return tasks.TaskRunning(_task_callback, (body,), dependencies)
return tasks.TaskRunning(_task_callback, (body,), dependencies, id=task.name)
except StopIteration as e:
#print(f"Task StopIteration", task, e, e.args, flush=True)
result = None
if e.args:
(result,) = e.args
return tasks.TaskRunahead(result)
else:
result = body()
#print(f"Task body", task, body, result, flush=True)
return tasks.TaskRunahead(result)
finally:
pass
48 changes: 20 additions & 28 deletions src/python/parla/cython/tasks.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class TaskRunning(TaskState):
@brief This state specifies that a task is executing in a stream.
"""

__slots__ = ["func", "args", "dependencies"]
__slots__ = ["func", "args", "dependencies", "id"]

@property
def value(self):
Expand All @@ -138,16 +138,22 @@ class TaskRunning(TaskState):
def is_terminal(self):
return False

@property
def return_value(self):
print("TaskRunning has no return value", self, "task: ", self.id, flush=True)
raise NotImplementedError()

# The argument dependencies intentially has no type hint.
# Callers can pass None if they want to pass empty dependencies.
def __init__(self, func, args, dependencies: Optional[Iterable] = None):
def __init__(self, func, args, dependencies: Optional[Iterable] = None, id=None):
if dependencies is not None:
self.dependencies = dependencies
else:
self.dependencies = []

self.args = args
self.func = func
self.id = id

def clear_dependencies(self):
self.dependencies = []
Expand Down Expand Up @@ -222,7 +228,7 @@ class TaskException(TaskState):
self.traceback = tb

def __repr__(self):
return "TaskException({})".format(self.exception)
return f"TaskException({self.exception}, {self.traceback})"


TaskAwaitTasks = namedtuple("AwaitTasks", ["dependencies", "value_task"])
Expand Down Expand Up @@ -445,7 +451,7 @@ class Task:
@return The return value of the task body or an exception if the task threw an exception. Returns None if the task has not completed.
"""

if isinstance(self.state, TaskCompleted):
if isinstance(self.state, TaskCompleted) or isinstance(self.state, TaskRunahead):
return self.state.return_value
elif isinstance(self.state, TaskException):
return self.state.exception
Expand All @@ -464,30 +470,10 @@ class Task:
"""!
@brief Run the task body.
"""
#if not isinstance(self.state, TaskRunning):
# self.state = TaskRunning(self.func, self.args, id=self.name)

task_state = None
self.state = TaskRunning(self.func, self.args)
try:

task_state = self._execute_task()

task_state = task_state or TaskRunahead(None)

except Exception as e:
tb = traceback.format_exc()
task_state = TaskException(e, tb)
self.state = task_state

print("Exception in Task ", self, ": ", e, tb, flush=True)

if isinstance(e, KeyboardInterrupt):
print("You pressed Ctrl+C! In a Task!", flush=True)
raise e
# print("Task {} failed with exception: {} \n {}".format(self.name, e, tb), flush=True)

finally:
assert(task_state is not None)
self.state = task_state
self.state = self._execute_task()

def __await__(self):
return (yield TaskAwaitTasks([self], self))
Expand Down Expand Up @@ -560,7 +546,13 @@ class Task:
"""!
@brief Get the state of the task (from the C++ runtime)
"""
return self.inner_task.get_state()
return self.inner_task.get_state_int()

def is_completed(self):
"""!
@brief Get the completion status of the task.
"""
return self.get_state() == 7

def set_complete(self):
self.inner_task.set_complete()
Expand Down

0 comments on commit 20e6321

Please sign in to comment.