简介
本文应用 C++20 引入的协程来编写一个 Linux epoll 程序。在此实现中,用户应用异步操作时再也无需提供本人的回调函数。以此处实现的 asyncRead()
为例:
- 应用
asyncRead()
所需的参数和read()
大致相同,无需传入回调; asyncRead()
的外部会向 epoll 注册要监听的文件描述符、感兴趣的事件和要执行的回调(由实现提供,而无需使用者传入);- 当事件未就绪时,
co_await asyncRead()
会挂起以后协程; - 当事件就绪时,epoll 循环中会执行具体的 I/O 操作(此处将其提交到 I/O 线程池中执行),当 I/O 操作实现时,复原协程的运行。
1. ThreadPool
此处应用了两个线程池:
- I/O 线程池:用于执行 I/O 操作;
- 工作线程池:用于解决客户端连贯(此处以 tcp 回显程序为例)。
此处应用的是本人实现的线程池,具体实现见 https://segmentfault.com/a/11...。
2. IOContext
IOContext
类对 Linux epoll 做了简略的封装。
io_context.h:
#ifndef IOCONTEXT_H#define IOCONTEXT_H#include <sys/epoll.h>#include <unistd.h>#include <unordered_map>#include <functional>#include "thread_pool.h"using callback_t = std::function<void()>;struct Args{ callback_t m_cb;};class IOContext{public: IOContext(int nIOThreads=2, int nJobThreads=2); // 监听文件描述符 fd,感兴趣事件为 events,args 中蕴含要执行的回调 bool post(int fd, int events, const Args& args); // 提交工作至工作线程池 bool post(const Task& task); // 不再关注文件描述符 fd,并移除相应的回调 void remove(int fd); // 继续监听、期待事件就绪 void run();private: int m_fd; std::unordered_map<int, Args> m_args; std::mutex m_lock; ThreadPool m_ioPool; // I/O 线程池 ThreadPool m_jobPool; // 工作线程池};#endif
io_context.cpp:
#include "io_context.h"#include <cstring>#include <errno.h>#include <iostream>IOContext::IOContext(int nIOThreads, int nJobThreads) : m_ioPool(nIOThreads), m_jobPool(nJobThreads){ m_fd = epoll_create(1024);}bool IOContext::post(int fd, int events, const Args& args){ struct epoll_event event; event.events = events; event.data.fd = fd; std::lock_guard<std::mutex> lock(m_lock); int err = epoll_ctl(m_fd, EPOLL_CTL_ADD, fd, &event); if (err == 0) { m_args[fd] = args; } return err == 0;}bool IOContext::post(const Task& task){ return m_jobPool.submitTask(task);}void IOContext::remove(int fd){ std::lock_guard<std::mutex> lock(m_lock); int err = epoll_ctl(m_fd, EPOLL_CTL_DEL, fd, nullptr); if (err == 0) { m_args.erase(fd); } else { std::cout << "remove: " << strerror(errno) << "\n"; }}void IOContext::run(){ int timeout = -1; size_t nEvents = 32; struct epoll_event* eventList = new struct epoll_event[nEvents]; while (true) { int nReady = epoll_wait(m_fd, eventList, nEvents, timeout); if (nReady < 0) { delete []eventList; return; } for (int i = 0; i < nReady; i++) { int fd = eventList[i].data.fd; m_lock.lock(); auto cb = m_args[fd].m_cb; m_lock.unlock(); remove(fd); m_ioPool.submitTask([=]() { cb(); }); } }}
3. Awaitable
实现 C++ 协程所需的类型,具体解释见 https://segmentfault.com/a/11...。
awaitable.h:
#ifndef AWAITABLE_H#define AWAITABLE_H#include <coroutine>#include <sys/socket.h>#include <sys/types.h>#include <iostream>#include "io_context.h"// 回调须要执行的操作类型:读、写、承受客户端连贯enum class HandlerType{ Read, Write, Accept,};class Awaitable{public: Awaitable(IOContext* ctx, int fd, int events, void* buf, size_t n, HandlerType ht); bool await_ready(); void await_suspend(std::coroutine_handle<> handle); int await_resume();private: IOContext* m_ctx; int m_fd; int m_events; void* m_buf; size_t m_n; int m_result; HandlerType m_ht;};struct CoroRetType{public: struct promise_type { CoroRetType get_return_object(); std::suspend_never initial_suspend(); std::suspend_never final_suspend() noexcept; void return_void(); void unhandled_exception(); };};#endif
awaitable.cpp:
#include <cstring>#include "awaitable.h"Awaitable::Awaitable(IOContext *ctx, int fd, int events, void* buf, size_t n, HandlerType ht) : m_ctx(ctx), m_fd(fd), m_events(events), m_buf(buf), m_n(n), m_ht(ht){}bool Awaitable::await_ready(){ return false;}int Awaitable::await_resume(){ return m_result;}// 注册要监听的文件描述符、感兴趣的事件及要执行的回调void Awaitable::await_suspend(std::coroutine_handle<> handle){ auto cb = [handle, this]() mutable { switch (m_ht) { case HandlerType::Read: m_result = read(m_fd, m_buf, m_n); break; case HandlerType::Write: m_result = write(m_fd, m_buf, m_n); break; case HandlerType::Accept: m_result = accept(m_fd, nullptr, nullptr); break; } handle.resume(); }; Args args{cb}; m_ctx->post(m_fd, m_events, args);}CoroRetType CoroRetType::promise_type::get_return_object(){ return CoroRetType();}std::suspend_never CoroRetType::promise_type::initial_suspend(){ return std::suspend_never{};}std::suspend_never CoroRetType::promise_type::final_suspend() noexcept{ return std::suspend_never{};}void CoroRetType::promise_type::return_void(){}void CoroRetType::promise_type::unhandled_exception(){ std::terminate();}
4. 异步操作
以下操作向 IOContext 注册要监听的文件描述符及感兴趣的事件,而后立刻返回。
io_util.h:
#ifndef IO_UTIL_H#define IO_UTIL_H#include "io_context.h"#include "awaitable.h"Awaitable asyncRead(IOContext* ctx, int fd, void* buf, size_t n);Awaitable asyncWrite(IOContext* ctx, int fd, void* buf, size_t n);Awaitable asyncAccept(IOContext* ctx, int fd);#endif
io_util.cpp:
#include "io_util.h"Awaitable asyncRead(IOContext* ctx, int fd, void* buf, size_t n){ return Awaitable(ctx, fd, EPOLLIN, buf, n, HandlerType::Read);}Awaitable asyncWrite(IOContext* ctx, int fd, void* buf, size_t n){ return Awaitable(ctx, fd, EPOLLOUT, buf, n, HandlerType::Write);}Awaitable asyncAccept(IOContext* ctx, int fd){ return Awaitable(ctx, fd, EPOLLIN, nullptr, 0, HandlerType::Accept);}
5. 例子
main.cpp:
#include <thread>#include <coroutine>#include <iostream>#include <netinet/in.h>#include <arpa/inet.h>#include <cstring>#include <errno.h>#include "io_util.h"static std::mutex ioLock;static uint16_t port = 6666;static int backlog = 32;static const char* Msg = "hello, cpp!";static const size_t MsgLen = 11;static IOContext ioContext;CoroRetType handleConnection(int fd){ char buf[MsgLen+1] = {0}; int n; n = co_await asyncRead(&ioContext, fd, buf, MsgLen); buf[n+1] = '\0'; co_await asyncWrite(&ioContext, fd, buf, n); close(fd);}CoroRetType serverThread(){ int listenSock = socket(AF_INET, SOCK_STREAM, 0); int value = 1; setsockopt(listenSock, SOL_SOCKET, SO_REUSEADDR, &value, sizeof(int)); struct sockaddr_in addr; memset(&addr, 0, sizeof(addr)); addr.sin_port = htons(port); addr.sin_family = AF_INET; addr.sin_addr.s_addr = htonl(INADDR_ANY); int err = bind(listenSock, (const struct sockaddr*)&addr, sizeof(addr)); listen(listenSock, backlog); while (true) { int clientSock = co_await asyncAccept(&ioContext, listenSock); auto h = [=]() { handleConnection(clientSock); }; ioContext.post(h); }}void clientThread(){ using namespace std::literals; std::this_thread::sleep_for(1s); int sock = socket(AF_INET, SOCK_STREAM, 0); struct sockaddr_in addr; memset(&addr, 0, sizeof(addr)); addr.sin_port = htons(port); addr.sin_family = AF_INET; inet_pton(AF_INET, "127.0.0.1", &addr.sin_addr); connect(sock, (const struct sockaddr*)&addr, sizeof(addr)); char buf[MsgLen+1] = {0}; ssize_t n = write(sock, Msg, MsgLen); read(sock, buf, n); buf[n+1] = '\0'; std::lock_guard<std::mutex> lock(ioLock); std::cout << "clientThread: " << buf << '\n'; close(sock);}int main(){ serverThread(); constexpr int N = 10; for (int i = 0; i < N; i++) { std::thread t(clientThread); t.detach(); } ioContext.run();}
clientThread: hello, cpp!clientThread: hello, cpp!clientThread: hello, cpp!clientThread: hello, cpp!clientThread: hello, cpp!clientThread: hello, cpp!clientThread: hello, cpp!clientThread: hello, cpp!clientThread: hello, cpp!clientThread: hello, cpp!