diff --git a/coroio/epoll.hpp b/coroio/epoll.hpp index 74a9beb..07c8b26 100644 --- a/coroio/epoll.hpp +++ b/coroio/epoll.hpp @@ -12,6 +12,7 @@ namespace NNet { class TEPoll: public TPollerBase { public: using TSocket = NNet::TSocket; + using TFileHandle = NNet::TFileHandle; TEPoll() : Fd_(epoll_create1(EPOLL_CLOEXEC)) diff --git a/coroio/kqueue.hpp b/coroio/kqueue.hpp index 0515ed3..e512c47 100644 --- a/coroio/kqueue.hpp +++ b/coroio/kqueue.hpp @@ -12,6 +12,7 @@ namespace NNet { class TKqueue: public TPollerBase { public: using TSocket = NNet::TSocket; + using TFileHandle = NNet::TFileHandle; TKqueue() : Fd_(kqueue()) diff --git a/coroio/poll.hpp b/coroio/poll.hpp index d932b4e..1636ce4 100644 --- a/coroio/poll.hpp +++ b/coroio/poll.hpp @@ -14,6 +14,7 @@ namespace NNet { class TPoll: public TPollerBase { public: using TSocket = NNet::TSocket; + using TFileHandle = NNet::TFileHandle; void Poll() { auto deadline = Timers_.empty() ? TTime::max() : Timers_.top().Deadline; diff --git a/coroio/select.hpp b/coroio/select.hpp index 74fb19c..c5f092e 100644 --- a/coroio/select.hpp +++ b/coroio/select.hpp @@ -12,6 +12,7 @@ namespace NNet { class TSelect: public TPollerBase { public: using TSocket = NNet::TSocket; + using TFileHandle = NNet::TFileHandle; void Poll() { auto deadline = Timers_.empty() ? TTime::max() : Timers_.top().Deadline; diff --git a/coroio/socket.hpp b/coroio/socket.hpp index d454266..bf0f4c0 100644 --- a/coroio/socket.hpp +++ b/coroio/socket.hpp @@ -37,94 +37,57 @@ class TAddress { struct sockaddr_in Addr_; }; -class TSocket { -public: - TSocket(TAddress&& addr, TPollerBase& poller) - : Poller_(&poller) - , Addr_(std::move(addr)) - , Fd_(Create()) - { } - - TSocket(const TAddress& addr, int fd, TPollerBase& poller) - : Poller_(&poller) - , Addr_(addr) - , Fd_(fd) - { - Setup(Fd_); +template +struct TSockOpAwaitable { + bool await_ready() { + SafeRun(); + return (ready = (ret >= 0)); } - TSocket(const TAddress& addr, TPollerBase& poller) - : Poller_(&poller) - , Addr_(addr) - , Fd_(Create()) - { } - - TSocket(TSocket&& other) - { - *this = std::move(other); - } - - ~TSocket() - { - if (Fd_ >= 0) { - close(Fd_); - Poller_->RemoveEvent(Fd_); + int await_resume() { + if (!ready) { + SafeRun(); } + return ret; } - TSocket() = default; - TSocket(const TSocket& other) = delete; - TSocket& operator=(TSocket& other) const = delete; - - TSocket& operator=(TSocket&& other) { - if (this != &other) { - Poller_ = other.Poller_; - Addr_ = other.Addr_; - Fd_ = other.Fd_; - other.Fd_ = -1; + void SafeRun() { + ((T*)this)->run(); + if (ret < 0 && !(errno==EINTR||errno==EAGAIN||errno==EINPROGRESS)) { + throw std::system_error(errno, std::generic_category()); } - return *this; } - auto Connect(TTime deadline = TTime::max()) { - struct TAwaitable { - bool await_ready() { - int ret = connect(fd, (struct sockaddr*) &addr, sizeof(addr)); - if (ret < 0 && !(errno == EINTR||errno==EAGAIN||errno==EINPROGRESS)) { - throw std::system_error(errno, std::generic_category(), "connect"); - } - return ret >= 0; - } + TPollerBase* poller; + int fd; + char* b; size_t s; + int ret; + bool ready; +}; - void await_suspend(std::coroutine_handle<> h) { - poller->AddWrite(fd, h); - if (deadline != TTime::max()) { - poller->AddTimer(fd, deadline, h); - } - } +template +class TSocketBase { +public: + TSocketBase(TPollerBase& poller) + : Poller_(&poller) + , Fd_(Create()) + { } - void await_resume() { - if (deadline != TTime::max() && poller->RemoveTimer(fd, deadline)) { - throw std::system_error(std::make_error_code(std::errc::timed_out)); - } - } + TSocketBase(int fd, TPollerBase& poller) + : Poller_(&poller) + , Fd_(Setup(fd)) + { } - TPollerBase* poller; - int fd; - sockaddr_in addr; - TTime deadline; - }; - return TAwaitable{Poller_, Fd_, Addr_.Addr(), deadline}; - } + TSocketBase() = default; auto ReadSome(char* buf, size_t size) { - struct TAwaitableRead: public TAwaitable { + struct TAwaitableRead: public TSockOpAwaitable { void run() { - ret = read(fd, b, s); + this->ret = TSockOps::read(this->fd, this->b, this->s); } void await_suspend(std::coroutine_handle<> h) { - poller->AddRead(fd, h); + this->poller->AddRead(this->fd, h); } }; return TAwaitableRead{Poller_,Fd_,buf,size}; @@ -134,15 +97,15 @@ class TSocket { auto ReadSomeYield(char* buf, size_t size) { struct TAwaitableRead: public TAwaitable { bool await_ready() { - return (ready = false); + return (this->ready = false); } void run() { - ret = read(fd, b, s); + this->ret = TSockOps::read(this->fd, this->b, this->s); } void await_suspend(std::coroutine_handle<> h) { - poller->AddRead(fd, h); + this->poller->AddRead(this->fd, h); } }; return TAwaitableRead{Poller_,Fd_,buf,size}; @@ -151,11 +114,11 @@ class TSocket { auto WriteSome(char* buf, size_t size) { struct TAwaitableWrite: public TAwaitable { void run() { - ret = write(fd, b, s); + this->ret = TSockOps::write(this->fd, this->b, this->s); } void await_suspend(std::coroutine_handle<> h) { - poller->AddWrite(fd, h); + this->poller->AddWrite(this->fd, h); } }; return TAwaitableWrite{Poller_,Fd_,buf,size}; @@ -164,77 +127,30 @@ class TSocket { auto WriteSomeYield(char* buf, size_t size) { struct TAwaitableWrite: public TAwaitable { bool await_ready() { - return (ready = false); + return (this->ready = false); } void run() { - ret = write(fd, b, s); + this->ret = TSockOps::write(this->fd, this->b, this->s); } void await_suspend(std::coroutine_handle<> h) { - poller->AddWrite(fd, h); + this->poller->AddWrite(this->fd, h); } }; return TAwaitableWrite{Poller_,Fd_,buf,size}; } - auto Accept() { - struct TAwaitable { - bool await_ready() const { return false; } - void await_suspend(std::coroutine_handle<> h) { - poller->AddRead(fd, h); - } - TSocket await_resume() { - sockaddr_in clientaddr; - socklen_t len = sizeof(clientaddr); - - int clientfd = accept(fd, (sockaddr*)&clientaddr, &len); - if (clientfd < 0) { - throw std::system_error(errno, std::generic_category(), "accept"); - } - - return TSocket{clientaddr, clientfd, *poller}; - } - - TPollerBase* poller; - int fd; - }; - - return TAwaitable{Poller_, Fd_}; - } - - void Bind() { - auto addr = Addr_.Addr(); - if (bind(Fd_, (struct sockaddr*)&addr, sizeof(addr)) < 0) { - throw std::system_error(errno, std::generic_category(), "bind"); - } - } - - void Listen(int backlog = 128) { - if (listen(Fd_, backlog) < 0) { - throw std::system_error(errno, std::generic_category(), "listen"); - } - } - - const TAddress& Addr() const { - return Addr_; - } - - int Fd() const { - return Fd_; - } - -protected: // TODO: XXX +protected: int Create() { auto s = socket(PF_INET, SOCK_STREAM, 0); if (s < 0) { throw std::system_error(errno, std::generic_category(), "socket"); } - Setup(s); - return s; + return Setup(s); } - void Setup(int s) { + int Setup(int s) { struct stat statbuf; fstat(s, &statbuf); if (S_ISSOCK(statbuf.st_mode)) { @@ -255,6 +171,7 @@ class TSocket { if (fcntl(s, F_SETFL, flags | O_NONBLOCK) < 0) { throw std::system_error(errno, std::generic_category(), "fcntl"); } + return s; } template @@ -285,8 +202,163 @@ class TSocket { bool ready; }; - int Fd_ = -1; TPollerBase* Poller_ = nullptr; + int Fd_ = -1; +}; + +class TFileOps { +public: + static auto read(int fd, void* buf, size_t count) { + return ::read(fd, buf, count); + } + + static auto write(int fd, const void* buf, size_t count) { + return ::write(fd, buf, count); + } +}; + +class TFileHandle: public TSocketBase { +public: + TFileHandle(int fd, TPollerBase& poller) + : TSocketBase(fd, poller) + { } + + TFileHandle() = default; +}; + +class TSockOps { +public: + static auto read(int fd, void* buf, size_t count) { + return ::recv(fd, buf, count, 0); + } + + static auto write(int fd, const void* buf, size_t count) { + return ::send(fd, buf, count, 0); + } +}; + +class TSocket: public TSocketBase { +public: + TSocket(TAddress&& addr, TPollerBase& poller) + : TSocketBase(poller) + , Addr_(std::move(addr)) + { } + + TSocket(const TAddress& addr, int fd, TPollerBase& poller) + : TSocketBase(fd, poller) + , Addr_(addr) + { } + + TSocket(const TAddress& addr, TPollerBase& poller) + : TSocketBase(poller) + , Addr_(addr) + { } + + TSocket(TSocket&& other) + { + *this = std::move(other); + } + + ~TSocket() + { + if (Fd_ >= 0) { + close(Fd_); + Poller_->RemoveEvent(Fd_); + } + } + + TSocket() = default; + TSocket(const TSocket& other) = delete; + TSocket& operator=(TSocket& other) const = delete; + + TSocket& operator=(TSocket&& other) { + if (this != &other) { + Poller_ = other.Poller_; + Addr_ = other.Addr_; + Fd_ = other.Fd_; + other.Fd_ = -1; + } + return *this; + } + + auto Connect(TTime deadline = TTime::max()) { + struct TAwaitable { + bool await_ready() { + int ret = connect(fd, (struct sockaddr*) &addr, sizeof(addr)); + if (ret < 0 && !(errno == EINTR||errno==EAGAIN||errno==EINPROGRESS)) { + throw std::system_error(errno, std::generic_category(), "connect"); + } + return ret >= 0; + } + + void await_suspend(std::coroutine_handle<> h) { + poller->AddWrite(fd, h); + if (deadline != TTime::max()) { + poller->AddTimer(fd, deadline, h); + } + } + + void await_resume() { + if (deadline != TTime::max() && poller->RemoveTimer(fd, deadline)) { + throw std::system_error(std::make_error_code(std::errc::timed_out)); + } + } + + TPollerBase* poller; + int fd; + sockaddr_in addr; + TTime deadline; + }; + return TAwaitable{Poller_, Fd_, Addr_.Addr(), deadline}; + } + + auto Accept() { + struct TAwaitable { + bool await_ready() const { return false; } + void await_suspend(std::coroutine_handle<> h) { + poller->AddRead(fd, h); + } + TSocket await_resume() { + sockaddr_in clientaddr; + socklen_t len = sizeof(clientaddr); + + int clientfd = accept(fd, (sockaddr*)&clientaddr, &len); + if (clientfd < 0) { + throw std::system_error(errno, std::generic_category(), "accept"); + } + + return TSocket{clientaddr, clientfd, *poller}; + } + + TPollerBase* poller; + int fd; + }; + + return TAwaitable{Poller_, Fd_}; + } + + void Bind() { + auto addr = Addr_.Addr(); + if (bind(Fd_, (struct sockaddr*)&addr, sizeof(addr)) < 0) { + throw std::system_error(errno, std::generic_category(), "bind"); + } + } + + void Listen(int backlog = 128) { + if (listen(Fd_, backlog) < 0) { + throw std::system_error(errno, std::generic_category(), "listen"); + } + } + + const TAddress& Addr() const { + return Addr_; + } + + int Fd() const { + return Fd_; + } + +protected: TAddress Addr_; }; diff --git a/coroio/uring.hpp b/coroio/uring.hpp index 6895bfd..6fe7c2d 100644 --- a/coroio/uring.hpp +++ b/coroio/uring.hpp @@ -24,6 +24,7 @@ class TUringSocket; class TUring: public TPollerBase { public: using TSocket = NNet::TUringSocket; + using TFileHandle = NNet::TFileHandle; TUring(int queueSize = 256) : RingFd_(eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK)) @@ -242,6 +243,11 @@ class TUringSocket: public TSocket , Uring_(&poller) { } + TUringSocket(int fd, TUring& poller) + : TSocket({}, fd, poller) + , Uring_(&poller) + { } + TUringSocket() = default; auto Accept() { diff --git a/examples/bench.cpp b/examples/bench.cpp index b92e33e..7dbe9fa 100644 --- a/examples/bench.cpp +++ b/examples/bench.cpp @@ -69,9 +69,9 @@ TTestTask yield(TPollerBase& poller) { template std::chrono::microseconds run_one(int num_pipes, int num_writes, int num_active) { Stat s; - using TSocket = typename TPoller::TSocket; + using TFileHandle = typename TPoller::TFileHandle; TLoop loop; - vector pipes; + vector pipes; vector> handles; pipes.reserve(num_pipes*2); handles.reserve(num_pipes+num_writes); @@ -81,8 +81,8 @@ std::chrono::microseconds run_one(int num_pipes, int num_writes, int num_active) if (pipe(&p[0]) < 0) { throw std::system_error(errno, std::generic_category(), "pipe"); } - pipes.emplace_back(std::move(TSocket{{}, p[0], loop.Poller()})); - pipes.emplace_back(std::move(TSocket{{}, p[1], loop.Poller()})); + pipes.emplace_back(std::move(TFileHandle{p[0], loop.Poller()})); + pipes.emplace_back(std::move(TFileHandle{p[1], loop.Poller()})); } s.writes = num_writes; diff --git a/examples/echoclient.cpp b/examples/echoclient.cpp index 9b29c66..4569c58 100644 --- a/examples/echoclient.cpp +++ b/examples/echoclient.cpp @@ -21,12 +21,13 @@ template NNet::TTestTask client(TPoller& poller, TAddress addr, int buffer_size) { using TSocket = typename TPoller::TSocket; + using TFileHandle = typename TPoller::TFileHandle; std::vector out(buffer_size); std::vector in(buffer_size); ssize_t size = 1; try { - TSocket input{TAddress{}, 0, poller}; // stdin + TFileHandle input{0, poller}; // stdin TSocket socket{std::move(addr), poller}; co_await socket.Connect();