diff --git a/src/engine/Join.cpp b/src/engine/Join.cpp index 6d2376bb9..967ebfe55 100644 --- a/src/engine/Join.cpp +++ b/src/engine/Join.cpp @@ -435,43 +435,64 @@ Result::Generator Join::runLazyJoinAndConvertToGenerator( Result::IdTableVocabPair, std::function> auto action, OptionalPermutation permutation) const { - std::atomic_flag write = true; + // TODO This heavily mixes a synchronization algorithm with + // the actual logic. This should be refactored. + std::mutex mutex; + std::condition_variable cv; + enum struct State { Inner, Outer, Finished }; + State state = State::Inner; + struct CancelException : public std::exception {}; std::variant storage; - ad_utility::JThread thread{[&write, &storage, &action, &permutation]() { - auto writeValue = [&write, &storage](auto value) noexcept { - storage = std::move(value); - write.clear(); - write.notify_one(); - }; - auto writeValueAndWait = [&permutation, &write, - &writeValue](Result::IdTableVocabPair value) { - AD_CORRECTNESS_CHECK(write.test()); - applyPermutation(value.idTable_, permutation); - writeValue(std::move(value)); - // Wait until we are allowed to write again. - write.wait(false); - }; - auto addValue = [&writeValueAndWait](IdTable& idTable, - LocalVocab& localVocab) { - if (idTable.size() < CHUNK_SIZE) { - return; - } - writeValueAndWait({std::move(idTable), std::move(localVocab)}); - }; - try { - auto finalValue = action(addValue); - if (!finalValue.idTable_.empty()) { - writeValueAndWait(std::move(finalValue)); - } - writeValue(std::monostate{}); - } catch (...) { - writeValue(std::current_exception()); - } + ad_utility::JThread thread{ + [&mutex, &cv, &state, &storage, &action, &permutation]() { + std::unique_lock lock(mutex); + auto wait = [&]() { + cv.wait(lock, [&]() { return state != State::Outer; }); + if (state == State::Finished) { + throw CancelException{}; + } + }; + + auto writeValue = [&cv, &state, &storage](auto value) noexcept { + storage = std::move(value); + state = State::Outer; + cv.notify_one(); + }; + auto writeValueAndWait = [&state, &permutation, &writeValue, + &wait](Result::IdTableVocabPair value) { + AD_CORRECTNESS_CHECK(state == State::Inner); + applyPermutation(value.idTable_, permutation); + writeValue(std::move(value)); + wait(); + }; + auto addValue = [&writeValueAndWait](IdTable& idTable, + LocalVocab& localVocab) { + if (idTable.size() < CHUNK_SIZE) { + return; + } + writeValueAndWait({std::move(idTable), std::move(localVocab)}); + }; + try { + auto finalValue = action(addValue); + if (!finalValue.idTable_.empty()) { + writeValueAndWait(std::move(finalValue)); + } + writeValue(std::monostate{}); + } catch (...) { + writeValue(std::current_exception()); + } + }}; + std::unique_lock lock{mutex}; + cv.wait(lock, [&state]() { return state == State::Outer; }); + auto cleanup = absl::Cleanup{[&cv, &state, &lock]() { + state = State::Finished; + lock.unlock(); + cv.notify_one(); }}; while (true) { // Wait for read phase. - write.wait(true); + cv.wait(lock, [&state] { return state == State::Outer; }); if (std::holds_alternative(storage)) { break; } @@ -479,9 +500,10 @@ Result::Generator Join::runLazyJoinAndConvertToGenerator( std::rethrow_exception(std::get(storage)); } co_yield std::get(storage); - // Initiate write phase. - write.test_and_set(); - write.notify_one(); + state = State::Inner; + lock.unlock(); + cv.notify_one(); + lock.lock(); } }