From 8c7a441c8a46fd0156dd9edb615bfb8d42120eec Mon Sep 17 00:00:00 2001 From: Alan Frindell Date: Thu, 14 Nov 2024 02:14:17 -0800 Subject: [PATCH 1/2] Implement FETCH streams (#8) Summary: Adds support to MoQFramer, MoQCodec and MoQSession to send and receive fetch streams Differential Revision: D65532676 --- moxygen/MoQCodec.cpp | 18 +++- moxygen/MoQCodec.h | 2 + moxygen/MoQFramer.cpp | 45 +++++++++- moxygen/MoQFramer.h | 6 +- moxygen/MoQServer.cpp | 8 -- moxygen/MoQServer.h | 2 - moxygen/MoQSession.cpp | 143 +++++++++++++++++++++++++++++--- moxygen/MoQSession.h | 56 ++++++++++--- moxygen/test/MoQSessionTest.cpp | 12 +++ moxygen/test/Mocks.h | 1 + 10 files changed, 253 insertions(+), 40 deletions(-) diff --git a/moxygen/MoQCodec.cpp b/moxygen/MoQCodec.cpp index f1e5556..568667e 100644 --- a/moxygen/MoQCodec.cpp +++ b/moxygen/MoQCodec.cpp @@ -128,7 +128,10 @@ void MoQObjectStreamCodec::onIngress( case StreamType::STREAM_HEADER_SUBGROUP: parseState_ = ParseState::OBJECT_STREAM; break; - // CONTROL doesn't have a wire type yet. + case StreamType::FETCH_HEADER: + parseState_ = ParseState::FETCH_HEADER; + break; + // CONTROL doesn't have a wire type yet. default: XLOG(DBG4) << "Stream not allowed: 0x" << std::setfill('0') << std::setw(sizeof(uint64_t) * 2) << std::hex @@ -154,6 +157,19 @@ void MoQObjectStreamCodec::onIngress( } break; } + case ParseState::FETCH_HEADER: { + auto newCursor = cursor; + auto res = parseFetchHeader(newCursor); + if (res.hasError()) { + XLOG(DBG6) << __func__ << " " << uint32_t(res.error()); + connError_ = res.error(); + break; + } + curObjectHeader_.trackIdentifier = SubscribeID(res.value()); + parseState_ = ParseState::MULTI_OBJECT_HEADER; + cursor = newCursor; + break; + } case ParseState::OBJECT_STREAM: { auto newCursor = cursor; auto res = parseStreamHeader(newCursor, streamType_); diff --git a/moxygen/MoQCodec.h b/moxygen/MoQCodec.h index 54b02fe..8e7016f 100644 --- a/moxygen/MoQCodec.h +++ b/moxygen/MoQCodec.h @@ -139,6 +139,7 @@ class MoQObjectStreamCodec : public MoQCodec { public: ~ObjectCallback() override = default; + virtual void onFetchHeader(uint64_t subscribeID) = 0; virtual void onObjectHeader(ObjectHeader objectHeader) = 0; virtual void onObjectPayload( @@ -166,6 +167,7 @@ class MoQObjectStreamCodec : public MoQCodec { STREAM_HEADER_TYPE, DATAGRAM, OBJECT_STREAM, + FETCH_HEADER, MULTI_OBJECT_HEADER, OBJECT_PAYLOAD, // OBJECT_PAYLOAD_NO_LENGTH diff --git a/moxygen/MoQFramer.cpp b/moxygen/MoQFramer.cpp index f25b0c7..0e870a2 100644 --- a/moxygen/MoQFramer.cpp +++ b/moxygen/MoQFramer.cpp @@ -188,6 +188,15 @@ folly::Expected parseServerSetup( return serverSetup; } +folly::Expected parseFetchHeader( + folly::io::Cursor& cursor) noexcept { + auto subscribeID = quic::decodeQuicInteger(cursor); + if (!subscribeID) { + return folly::makeUnexpected(ErrorCode::PARSE_UNDERFLOW); + } + return subscribeID->first; +} + folly::Expected parseObjectHeader( folly::io::Cursor& cursor, size_t length) noexcept { @@ -288,10 +297,13 @@ folly::Expected parseMultiObjectHeader( const ObjectHeader& headerTemplate) noexcept { DCHECK( streamType == StreamType::STREAM_HEADER_TRACK || - streamType == StreamType::STREAM_HEADER_SUBGROUP); + streamType == StreamType::STREAM_HEADER_SUBGROUP || + streamType == StreamType::FETCH_HEADER); + // TODO get rid of this auto length = cursor.totalLength(); ObjectHeader objectHeader = headerTemplate; - if (streamType == StreamType::STREAM_HEADER_TRACK) { + if (streamType == StreamType::STREAM_HEADER_TRACK || + streamType == StreamType::FETCH_HEADER) { auto group = quic::decodeQuicInteger(cursor, length); if (!group) { return folly::makeUnexpected(ErrorCode::PARSE_UNDERFLOW); @@ -302,12 +314,28 @@ folly::Expected parseMultiObjectHeader( } else { objectHeader.forwardPreference = ForwardPreference::Subgroup; } + if (streamType == StreamType::FETCH_HEADER) { + objectHeader.forwardPreference = ForwardPreference::Fetch; + auto subgroup = quic::decodeQuicInteger(cursor, length); + if (!subgroup) { + return folly::makeUnexpected(ErrorCode::PARSE_UNDERFLOW); + } + length -= subgroup->second; + objectHeader.subgroup = subgroup->first; + } auto id = quic::decodeQuicInteger(cursor, length); if (!id) { return folly::makeUnexpected(ErrorCode::PARSE_UNDERFLOW); } length -= id->second; objectHeader.id = id->first; + if (streamType == StreamType::FETCH_HEADER) { + if (length < 2) { + return folly::makeUnexpected(ErrorCode::PARSE_UNDERFLOW); + } + objectHeader.priority = cursor.readBE(); + length--; + } auto payloadLength = quic::decodeQuicInteger(cursor, length); if (!payloadLength) { return folly::makeUnexpected(ErrorCode::PARSE_UNDERFLOW); @@ -1154,6 +1182,9 @@ WriteResult writeStreamHeader( folly::to_underlying(StreamType::STREAM_HEADER_SUBGROUP), size, error); + } else if (objectHeader.forwardPreference == ForwardPreference::Fetch) { + writeVarint( + writeBuf, folly::to_underlying(StreamType::FETCH_HEADER), size, error); } else { LOG(FATAL) << "Unsupported forward preference to stream header"; } @@ -1162,7 +1193,9 @@ WriteResult writeStreamHeader( writeVarint(writeBuf, objectHeader.group, size, error); writeVarint(writeBuf, objectHeader.subgroup, size, error); } - writeVarint(writeBuf, objectHeader.priority, size, error); + if (objectHeader.forwardPreference != ForwardPreference::Fetch) { + writeVarint(writeBuf, objectHeader.priority, size, error); + } if (error) { return folly::makeUnexpected(quic::TransportErrorCode::INTERNAL_ERROR); } @@ -1199,12 +1232,16 @@ WriteResult writeObject( if (objectHeader.forwardPreference != ForwardPreference::Subgroup) { writeVarint(writeBuf, objectHeader.group, size, error); } + if (objectHeader.forwardPreference == ForwardPreference::Fetch) { + writeVarint(writeBuf, objectHeader.subgroup, size, error); + } writeVarint(writeBuf, objectHeader.id, size, error); CHECK( objectHeader.status != ObjectStatus::NORMAL || (objectHeader.length && *objectHeader.length > 0)) << "Normal objects require non-zero length"; - if (objectHeader.forwardPreference == ForwardPreference::Datagram) { + if (objectHeader.forwardPreference == ForwardPreference::Datagram || + objectHeader.forwardPreference == ForwardPreference::Fetch) { writeBuf.append(&objectHeader.priority, 1); size += 1; } diff --git a/moxygen/MoQFramer.h b/moxygen/MoQFramer.h index b133b3d..f196988 100644 --- a/moxygen/MoQFramer.h +++ b/moxygen/MoQFramer.h @@ -112,6 +112,7 @@ enum class StreamType : uint64_t { OBJECT_DATAGRAM = 1, STREAM_HEADER_TRACK = 0x2, STREAM_HEADER_SUBGROUP = 0x4, + FETCH_HEADER = 0x5, CONTROL = 100000000 }; @@ -168,7 +169,7 @@ folly::Expected parseServerSetup( folly::io::Cursor& cursor, size_t length) noexcept; -enum class ForwardPreference : uint8_t { Track, Subgroup, Datagram }; +enum class ForwardPreference : uint8_t { Track, Subgroup, Datagram, Fetch }; enum class ObjectStatus : uint64_t { NORMAL = 0, @@ -264,6 +265,9 @@ folly::Expected parseObjectHeader( folly::io::Cursor& cursor, size_t length) noexcept; +folly::Expected parseFetchHeader( + folly::io::Cursor& cursor) noexcept; + folly::Expected parseStreamHeader( folly::io::Cursor& cursor, StreamType streamType) noexcept; diff --git a/moxygen/MoQServer.cpp b/moxygen/MoQServer.cpp index 20aafd4..054ed5d 100644 --- a/moxygen/MoQServer.cpp +++ b/moxygen/MoQServer.cpp @@ -109,14 +109,6 @@ void MoQServer::ControlVisitor::operator()(FetchCancel fetchCancel) const { XLOG(INFO) << "FetchCancel id=" << fetchCancel.subscribeID; } -void MoQServer::ControlVisitor::operator()(FetchOk fetchOk) const { - XLOG(INFO) << "FetchOk id=" << fetchOk.subscribeID; -} - -void MoQServer::ControlVisitor::operator()(FetchError fetchError) const { - XLOG(INFO) << "FetchError id=" << fetchError.subscribeID; -} - void MoQServer::ControlVisitor::operator()(SubscribeDone subscribeDone) const { XLOG(INFO) << "SubscribeDone id=" << subscribeDone.subscribeID << " code=" << folly::to_underlying(subscribeDone.statusCode) diff --git a/moxygen/MoQServer.h b/moxygen/MoQServer.h index b189bd2..66b604b 100644 --- a/moxygen/MoQServer.h +++ b/moxygen/MoQServer.h @@ -39,8 +39,6 @@ class MoQServer { void operator()(MaxSubscribeId maxSubscribeId) const override; void operator()(Fetch fetch) const override; void operator()(FetchCancel fetchCancel) const override; - void operator()(FetchOk fetchOk) const override; - void operator()(FetchError fetchError) const override; void operator()(Unannounce unannounce) const override; void operator()(AnnounceCancel announceCancel) const override; void operator()(SubscribeAnnounces subscribeAnnounces) const override; diff --git a/moxygen/MoQSession.cpp b/moxygen/MoQSession.cpp index 6c65135..58e12fc 100644 --- a/moxygen/MoQSession.cpp +++ b/moxygen/MoQSession.cpp @@ -24,6 +24,10 @@ MoQSession::~MoQSession() { subTrack.second->subscribeError( {/*TrackHandle fills in subId*/ 0, 500, "session closed", folly::none}); } + for (auto& fetch : fetches_) { + fetch.second->fetchError( + {/*TrackHandle fills in subId*/ 0, 500, "session closed"}); + } for (auto& pendingAnn : pendingAnnounce_) { pendingAnn.second.setValue(folly::makeUnexpected( AnnounceError({pendingAnn.first, 500, "session closed"}))); @@ -168,10 +172,13 @@ folly::coro::Task MoQSession::readLoop( proxygen::WebTransport::StreamReadHandle* readHandle) { XLOG(DBG1) << __func__ << " sess=" << this; std::unique_ptr codec; + MoQObjectStreamCodec* objCodec = nullptr; if (streamType == StreamType::CONTROL) { codec = std::make_unique(dir_, this); } else { - codec = std::make_unique(this); + auto res = std::make_unique(this); + objCodec = res.get(); + codec = std::move(res); } auto id = readHandle->getID(); codec->setStreamId(id); @@ -191,6 +198,16 @@ folly::coro::Task MoQSession::readLoop( } fin = streamData->fin; XLOG_IF(DBG3, fin) << "End of stream id=" << id << " sess=" << this; + if (fin && objCodec) { + auto id = objCodec->getTrackIdentifier(); + if (auto subscribeID = std::get_if(&id)) { // it's fetch + auto track = getTrack(id); + if (track) { + track->fin(); + fetches_.erase(*subscribeID); + } + } + } } } } @@ -243,7 +260,15 @@ std::shared_ptr MoQSession::getTrack( } track = trackIt->second; } else { - // TODO - handle subscribe ID + auto subscribeID = std::get(trackIdentifier); + XLOG(DBG3) << "getTrack subID=" << subscribeID; + auto trackIt = fetches_.find(subscribeID); + if (trackIt == fetches_.end()) { + // received an object for unknown subscribe ID + XLOG(ERR) << "unknown subscribe ID=" << subscribeID << " sess=" << this; + return nullptr; + } + track = trackIt->second; } return track; } @@ -420,19 +445,37 @@ void MoQSession::onMaxSubscribeId(MaxSubscribeId maxSubscribeId) { } void MoQSession::onFetch(Fetch fetch) { - XLOG(ERR) << "Not implemented yet"; + XLOG(DBG1) << __func__ << " sess=" << this; + controlMessages_.enqueue(std::move(fetch)); } void MoQSession::onFetchCancel(FetchCancel fetchCancel) { - XLOG(ERR) << "Not implemented yet"; + XLOG(DBG1) << __func__ << " sess=" << this; + controlMessages_.enqueue(std::move(fetchCancel)); } void MoQSession::onFetchOk(FetchOk fetchOk) { - XLOG(ERR) << "Not implemented yet"; + XLOG(DBG1) << __func__ << " sess=" << this; + auto fetchIt = fetches_.find(fetchOk.subscribeID); + if (fetchIt == fetches_.end()) { + XLOG(ERR) << "No matching subscribe ID=" << fetchOk.subscribeID + << " sess=" << this; + return; + } + auto trackHandle = fetchIt->second; + trackHandle->fetchOK(trackHandle); } void MoQSession::onFetchError(FetchError fetchError) { - XLOG(ERR) << "Not implemented yet"; + XLOG(DBG1) << __func__ << " sess=" << this; + auto fetchIt = fetches_.find(fetchError.subscribeID); + if (fetchIt == fetches_.end()) { + XLOG(ERR) << "No matching subscribe ID=" << fetchError.subscribeID + << " sess=" << this; + return; + } + fetchIt->second->fetchError(fetchError); + fetches_.erase(fetchIt); } void MoQSession::onAnnounce(Announce ann) { @@ -747,6 +790,60 @@ void MoQSession::subscribeUpdate(SubscribeUpdate subUpdate) { controlWriteEvent_.signal(); } +folly::coro::Task< + folly::Expected, FetchError>> +MoQSession::fetch(Fetch fetch) { + XLOG(DBG1) << __func__ << " sess=" << this; + auto g = + folly::makeGuard([func = __func__] { XLOG(DBG1) << "exit " << func; }); + auto fullTrackName = fetch.fullTrackName; + if (nextSubscribeID_ >= peerMaxSubscribeID_) { + XLOG(WARN) << "Issuing fetch that will fail; nextSubscribeID_=" + << nextSubscribeID_ + << " peerMaxSubscribeid_=" << peerMaxSubscribeID_ + << " sess=" << this; + } + auto subID = nextSubscribeID_++; + fetch.subscribeID = subID; + auto wres = writeFetch(controlWriteBuf_, std::move(fetch)); + if (!wres) { + XLOG(ERR) << "writeFetch failed" << " sess=" << this; + co_return folly::makeUnexpected( + FetchError({subID, 500, "local write failed"})); + } + controlWriteEvent_.signal(); + auto subTrack = fetches_.emplace( + std::piecewise_construct, + std::forward_as_tuple(subID), + std::forward_as_tuple(std::make_shared( + fullTrackName, subID, cancellationSource_.getToken()))); + + auto res = co_await subTrack.first->second->fetchReady(); + XLOG(DBG1) << __func__ + << " fetchReady trackHandle=" << subTrack.first->second; + co_return res; +} + +void MoQSession::fetchOk(FetchOk fetchOk) { + XLOG(DBG1) << __func__ << " sess=" << this; + auto res = writeFetchOk(controlWriteBuf_, fetchOk); + if (!res) { + XLOG(ERR) << "writeFetchOk failed" << " sess=" << this; + return; + } + controlWriteEvent_.signal(); +} + +void MoQSession::fetchError(FetchError subErr) { + XLOG(DBG1) << __func__ << " sess=" << this; + auto res = writeFetchError(controlWriteBuf_, std::move(subErr)); + if (!res) { + XLOG(ERR) << "writeFetchError failed" << " sess=" << this; + return; + } + controlWriteEvent_.signal(); +} + namespace { constexpr uint32_t IdMask = 0x1FFFFF; uint64_t groupOrder(GroupOrder groupOrder, uint64_t group) { @@ -808,6 +905,29 @@ folly::SemiFuture MoQSession::publishStatus( return publishImpl(objHeader, subscribeID, 0, nullptr, true, false); } +void MoQSession::closeFetchStream(SubscribeID subID) { + PublishKey publishKey{ + TrackIdentifier(subID), 0, 0, ForwardPreference::Fetch, 0}; + auto pubDataIt = publishDataMap_.find(publishKey); + if (pubDataIt == publishDataMap_.end()) { + XLOG(ERR) << "Invalid subscribeID to closeFetchStream=" << subID.value; + return; + } + if (pubDataIt->second.objectLength && *pubDataIt->second.objectLength > 0) { + XLOG(ERR) << "Non-zero length remaining in previous obj id=" << subID.value; + return; + } + XLOG(DBG1) << "Closing fetch stream=" << pubDataIt->second.streamID; + auto writeRes = + wt_->writeStreamData(pubDataIt->second.streamID, nullptr, true); + if (!writeRes) { + XLOG(ERR) << "Failed to close fetch stream sess=" << this + << " error=" << static_cast(writeRes.error()); + return; + } + publishDataMap_.erase(pubDataIt); +} + folly::SemiFuture MoQSession::publishImpl( const ObjectHeader& objHeader, SubscribeID subscribeID, @@ -840,10 +960,9 @@ folly::SemiFuture MoQSession::publishImpl( // - Next portion of the object calls this function again with payloadOffset // > 0 if (payloadOffset != 0) { - XLOG(WARN) - << __func__ - << " Can't start publishing in the middle. Disgregard data for this new obj with payloadOffset = " - << payloadOffset << " sess=" << this; + XLOG(WARN) << __func__ << " Can't start publishing in the middle. " + << "Disgregard data for this new obj with payloadOffset = " + << payloadOffset << " sess=" << this; return folly::makeSemiFuture(folly::exception_wrapper( std::runtime_error("Can't start publishing in the middle."))); } @@ -880,6 +999,7 @@ folly::SemiFuture MoQSession::publishImpl( pubDataIt = res.first; // Serialize multi-object stream header if (objHeader.forwardPreference == ForwardPreference::Track || + objHeader.forwardPreference == ForwardPreference::Fetch || objHeader.forwardPreference == ForwardPreference::Subgroup) { writeStreamHeader(writeBuf, objHeader); } @@ -892,7 +1012,8 @@ folly::SemiFuture MoQSession::publishImpl( // new object // validate group and object are moving in the right direction bool multiObject = false; - if (objHeader.forwardPreference == ForwardPreference::Track) { + if (objHeader.forwardPreference == ForwardPreference::Track || + objHeader.forwardPreference == ForwardPreference::Fetch) { if (objHeader.group < pubDataIt->second.group) { XLOG(ERR) << "Decreasing group in Track" << " sess=" << this; return folly::makeSemiFuture(folly::exception_wrapper( diff --git a/moxygen/MoQSession.h b/moxygen/MoQSession.h index 8124c0b..5e25e3b 100644 --- a/moxygen/MoQSession.h +++ b/moxygen/MoQSession.h @@ -56,6 +56,8 @@ class MoQSession : public MoQControlCodec::ControlCallback, SubscribeUpdate, Unsubscribe, SubscribeDone, + Fetch, + FetchCancel, MaxSubscribeId, TrackStatusRequest, TrackStatus, @@ -126,12 +128,6 @@ class MoQSession : public MoQControlCodec::ControlCallback, virtual void operator()(FetchCancel fetchCancel) const { XLOG(INFO) << "FetchCancel subID=" << fetchCancel.subscribeID; } - virtual void operator()(FetchOk fetchOk) const { - XLOG(INFO) << "FetchOk subID=" << fetchOk.subscribeID; - } - virtual void operator()(FetchError fetchError) const { - XLOG(INFO) << "FetchError subID=" << fetchError.subscribeID; - } virtual void operator()(TrackStatusRequest trackStatusRequest) const { XLOG(INFO) << "Subscribe ftn=" << trackStatusRequest.fullTrackName.trackNamespace @@ -182,6 +178,10 @@ class MoQSession : public MoQControlCodec::ControlCallback, folly::Expected, SubscribeError>>(); promise_ = std::move(contract.first); future_ = std::move(contract.second); + auto contract2 = folly::coro::makePromiseContract< + folly::Expected, FetchError>>(); + fetchPromise_ = std::move(contract2.first); + fetchFuture_ = std::move(contract2.second); } void setTrackName(FullTrackName trackName) { @@ -207,7 +207,6 @@ class MoQSession : public MoQControlCodec::ControlCallback, ready() { co_return co_await std::move(future_); } - void subscribeOK( std::shared_ptr self, GroupOrder order, @@ -224,6 +223,22 @@ class MoQSession : public MoQControlCodec::ControlCallback, } } + folly::coro::Task, FetchError>> + fetchReady() { + co_return co_await std::move(fetchFuture_); + } + void fetchOK(std::shared_ptr self) { + XCHECK_EQ(self.get(), this); + XLOG(DBG1) << __func__ << " trackHandle=" << this; + fetchPromise_.setValue(std::move(self)); + } + void fetchError(FetchError fetchErr) { + if (!promise_.isFulfilled()) { + fetchErr.subscribeID = subscribeID_; + fetchPromise_.setValue(folly::makeUnexpected(std::move(fetchErr))); + } + } + struct ObjectSource { ObjectHeader header; FullTrackName fullTrackName; @@ -268,12 +283,14 @@ class MoQSession : public MoQControlCodec::ControlCallback, private: FullTrackName fullTrackName_; SubscribeID subscribeID_; - folly::coro::Promise< - folly::Expected, SubscribeError>> - promise_; - folly::coro::Future< - folly::Expected, SubscribeError>> - future_; + using SubscribeResult = + folly::Expected, SubscribeError>; + folly::coro::Promise promise_; + folly::coro::Future future_; + using FetchResult = + folly::Expected, FetchError>; + folly::coro::Promise fetchPromise_; + folly::coro::Future fetchFuture_; folly:: F14FastMap, std::shared_ptr> objects_; @@ -293,6 +310,12 @@ class MoQSession : public MoQControlCodec::ControlCallback, void subscribeDone(SubscribeDone subDone); void subscribeUpdate(SubscribeUpdate subUpdate); + folly::coro::Task, FetchError>> + fetch(Fetch fetch); + void fetchOk(FetchOk fetchOk); + void fetchError(FetchError fetchError); + void fetchCancel(FetchCancel fetchCancel); + class WebTransportException : public std::runtime_error { public: explicit WebTransportException( @@ -319,6 +342,7 @@ class MoQSession : public MoQControlCodec::ControlCallback, folly::SemiFuture publishStatus( const ObjectHeader& objHeader, SubscribeID subscribeID); + void closeFetchStream(SubscribeID subID); void onNewUniStream(proxygen::WebTransport::StreamReadHandle* rh) override; void onNewBidiStream(proxygen::WebTransport::BidiStreamHandle bh) override; @@ -348,6 +372,7 @@ class MoQSession : public MoQControlCodec::ControlCallback, uint64_t id, std::unique_ptr payload, bool eom) override; + void onFetchHeader(uint64_t subscribeID) override {} void onSubscribe(SubscribeRequest subscribeRequest) override; void onSubscribeUpdate(SubscribeUpdate subscribeUpdate) override; void onSubscribeOk(SubscribeOk subscribeOk) override; @@ -403,6 +428,8 @@ class MoQSession : public MoQControlCodec::ControlCallback, return group == other.group && subgroup == other.subgroup; } else if (pref == ForwardPreference::Track) { return true; + } else if (pref == ForwardPreference::Fetch) { + return true; } return false; } @@ -459,6 +486,9 @@ class MoQSession : public MoQControlCodec::ControlCallback, // Track Alias -> Track Handle folly::F14FastMap, TrackAlias::hash> subTracks_; + folly:: + F14FastMap, SubscribeID::hash> + fetches_; folly::F14FastMap subIdToTrackAlias_; diff --git a/moxygen/test/MoQSessionTest.cpp b/moxygen/test/MoQSessionTest.cpp index cf4cac2..07c3342 100644 --- a/moxygen/test/MoQSessionTest.cpp +++ b/moxygen/test/MoQSessionTest.cpp @@ -24,6 +24,8 @@ class MockControlVisitorBase { virtual void onSubscribeDone(SubscribeDone subscribeDone) const = 0; virtual void onMaxSubscribeId(MaxSubscribeId maxSubscribeId) const = 0; virtual void onUnsubscribe(Unsubscribe unsubscribe) const = 0; + virtual void onFetch(Fetch fetch) const = 0; + virtual void onFetchCancel(FetchCancel fetchCancel) const = 0; virtual void onAnnounce(Announce announce) const = 0; virtual void onUnannounce(Unannounce unannounce) const = 0; virtual void onAnnounceCancel(AnnounceCancel announceCancel) const = 0; @@ -95,6 +97,16 @@ class MockControlVisitor : public MoQSession::ControlVisitor, onUnsubscribe(unsubscribe); } + MOCK_METHOD(void, onFetch, (Fetch), (const)); + void operator()(Fetch fetch) const override { + onFetch(fetch); + } + + MOCK_METHOD(void, onFetchCancel, (FetchCancel), (const)); + void operator()(FetchCancel fetchCancel) const override { + onFetchCancel(fetchCancel); + } + MOCK_METHOD(void, onMaxSubscribeId, (MaxSubscribeId), (const)); void operator()(MaxSubscribeId maxSubscribeId) const override { onMaxSubscribeId(maxSubscribeId); diff --git a/moxygen/test/Mocks.h b/moxygen/test/Mocks.h index 6591282..5b145db 100644 --- a/moxygen/test/Mocks.h +++ b/moxygen/test/Mocks.h @@ -22,6 +22,7 @@ class MockMoQCodecCallback : public MoQControlCodec::ControlCallback, uint64_t id, std::unique_ptr payload, bool eom)); + MOCK_METHOD(void, onFetchHeader, (uint64_t subscribeID)); MOCK_METHOD(void, onSubscribe, (SubscribeRequest subscribeRequest)); MOCK_METHOD(void, onSubscribeUpdate, (SubscribeUpdate subscribeUpdate)); MOCK_METHOD(void, onSubscribeOk, (SubscribeOk subscribeOk)); From b495e4e4bbe403947ba693028933555f3eb3d03c Mon Sep 17 00:00:00 2001 From: Alan Frindell Date: Thu, 14 Nov 2024 02:14:17 -0800 Subject: [PATCH 2/2] Add MoQSession::drain API Summary: This will close the session when all fetches and subscribes are complete. This required changing the objects() and payload() read loops a bit. Since the sources are shared pointers, the callees can continue reading from them after the session is gone. We use try_dequeue() to drain the queue after the session's token has been cancelled. Reviewed By: NEUDitao Differential Revision: D65532678 --- moxygen/MoQSession.cpp | 35 ++++++++++++++++++++++++++++++++--- moxygen/MoQSession.h | 22 ++++++++++++++++++---- 2 files changed, 50 insertions(+), 7 deletions(-) diff --git a/moxygen/MoQSession.cpp b/moxygen/MoQSession.cpp index 58e12fc..186146f 100644 --- a/moxygen/MoQSession.cpp +++ b/moxygen/MoQSession.cpp @@ -63,6 +63,18 @@ void MoQSession::start() { } } +void MoQSession::drain() { + XLOG(DBG1) << __func__ << " sess=" << this; + draining_ = true; + checkForCloseOnDrain(); +} + +void MoQSession::checkForCloseOnDrain() { + if (draining_ && fetches_.empty() && subTracks_.empty()) { + close(); + } +} + void MoQSession::close(folly::Optional error) { XLOG(DBG1) << __func__ << " sess=" << this; if (wt_) { @@ -205,6 +217,7 @@ folly::coro::Task MoQSession::readLoop( if (track) { track->fin(); fetches_.erase(*subscribeID); + checkForCloseOnDrain(); } } } @@ -408,6 +421,7 @@ void MoQSession::onSubscribeError(SubscribeError subErr) { subTracks_[trackAliasIt->second]->subscribeError(std::move(subErr)); subTracks_.erase(trackAliasIt->second); subIdToTrackAlias_.erase(trackAliasIt); + checkForCloseOnDrain(); } void MoQSession::onSubscribeDone(SubscribeDone subscribeDone) { @@ -421,7 +435,13 @@ void MoQSession::onSubscribeDone(SubscribeDone subscribeDone) { } // TODO: handle final object and status code + // TODO: there could still be objects in flight. Removing from maps now + // will prevent their delivery. I think the only way to handle this is with + // timeouts. subTracks_[trackAliasIt->second]->fin(); + subTracks_.erase(trackAliasIt->second); + subIdToTrackAlias_.erase(trackAliasIt); + checkForCloseOnDrain(); controlMessages_.enqueue(std::move(subscribeDone)); } @@ -476,6 +496,7 @@ void MoQSession::onFetchError(FetchError fetchError) { } fetchIt->second->fetchError(fetchError); fetches_.erase(fetchIt); + checkForCloseOnDrain(); } void MoQSession::onAnnounce(Announce ann) { @@ -677,9 +698,15 @@ MoQSession::TrackHandle::objects() { folly::makeGuard([func = __func__] { XLOG(DBG1) << "exit " << func; }); auto cancelToken = co_await folly::coro::co_current_cancellation_token; auto mergeToken = folly::CancellationToken::merge(cancelToken, cancelToken_); - while (!mergeToken.isCancellationRequested()) { - auto obj = co_await folly::coro::co_withCancellation( - mergeToken, newObjects_.dequeue()); + while (!cancelToken.isCancellationRequested()) { + auto optionalObj = newObjects_.try_dequeue(); + std::shared_ptr obj; + if (optionalObj) { + obj = *optionalObj; + } else { + obj = co_await folly::coro::co_withCancellation( + mergeToken, newObjects_.dequeue()); + } if (!obj) { XLOG(DBG3) << "Out of objects for trackHandle=" << this << " id=" << subscribeID_; @@ -753,6 +780,8 @@ void MoQSession::unsubscribe(Unsubscribe unsubscribe) { XLOG(ERR) << "writeUnsubscribe failed" << " sess=" << this; return; } + // we rely on receiving subscribeDone after unsubscribe to remove from + // subTracks_ controlWriteEvent_.signal(); } diff --git a/moxygen/MoQSession.h b/moxygen/MoQSession.h index 5e25e3b..cda426f 100644 --- a/moxygen/MoQSession.h +++ b/moxygen/MoQSession.h @@ -39,6 +39,7 @@ class MoQSession : public MoQControlCodec::ControlCallback, ~MoQSession() override; void start(); + void drain(); void close(folly::Optional error = folly::none); void setup(ClientSetup setup); @@ -252,14 +253,25 @@ class MoQSession : public MoQControlCodec::ControlCallback, co_return nullptr; } folly::IOBufQueue payloadBuf{folly::IOBufQueue::cacheChainLength()}; - while (true) { - auto buf = co_await folly::coro::co_withCancellation( - cancelToken, payloadQueue.dequeue()); + auto curCancelToken = + co_await folly::coro::co_current_cancellation_token; + auto mergeToken = + folly::CancellationToken::merge(curCancelToken, cancelToken); + while (!curCancelToken.isCancellationRequested()) { + std::unique_ptr buf; + auto optionalBuf = payloadQueue.try_dequeue(); + if (optionalBuf) { + buf = std::move(*optionalBuf); + } else { + buf = co_await folly::coro::co_withCancellation( + cancelToken, payloadQueue.dequeue()); + } if (!buf) { - co_return payloadBuf.move(); + break; } payloadBuf.append(std::move(buf)); } + co_return payloadBuf.move(); } }; @@ -400,6 +412,7 @@ class MoQSession : public MoQControlCodec::ControlCallback, void onTrackStatus(TrackStatus trackStatus) override; void onGoaway(Goaway goaway) override; void onConnectionError(ErrorCode error) override; + void checkForCloseOnDrain(); folly::SemiFuture publishImpl( const ObjectHeader& objHeader, @@ -524,6 +537,7 @@ class MoQSession : public MoQControlCodec::ControlCallback, moxygen::TimedBaton sentSetup_; moxygen::TimedBaton receivedSetup_; bool setupComplete_{false}; + bool draining_{false}; folly::CancellationSource cancellationSource_; // SubscribeID must be a unique monotonically increasing number that is