简介

本文应用 C++20 引入的协程来编写一个 Linux epoll 程序。在此实现中,用户应用异步操作时再也无需提供本人的回调函数。以此处实现的 asyncRead() 为例:

  • 应用 asyncRead() 所需的参数和 read() 大致相同,无需传入回调;
  • asyncRead() 的外部会向 epoll 注册要监听的文件描述符、感兴趣的事件和要执行的回调(由实现提供,而无需使用者传入);
  • 当事件未就绪时,co_await asyncRead() 会挂起以后协程;
  • 当事件就绪时,epoll 循环中会执行具体的 I/O 操作(此处将其提交到 I/O 线程池中执行),当 I/O 操作实现时,复原协程的运行。

1. ThreadPool

此处应用了两个线程池:

  • I/O 线程池:用于执行 I/O 操作;
  • 工作线程池:用于解决客户端连贯(此处以 tcp 回显程序为例)。

此处应用的是本人实现的线程池,具体实现见 https://segmentfault.com/a/11...。

2. IOContext

IOContext 类对 Linux epoll 做了简略的封装。

io_context.h:

#ifndef IOCONTEXT_H#define IOCONTEXT_H#include <sys/epoll.h>#include <unistd.h>#include <unordered_map>#include <functional>#include "thread_pool.h"using callback_t = std::function<void()>;struct Args{    callback_t m_cb;};class IOContext{public:    IOContext(int nIOThreads=2, int nJobThreads=2);    // 监听文件描述符 fd,感兴趣事件为 events,args 中蕴含要执行的回调    bool post(int fd, int events, const Args& args);    // 提交工作至工作线程池    bool post(const Task& task);    // 不再关注文件描述符 fd,并移除相应的回调    void remove(int fd);    // 继续监听、期待事件就绪    void run();private:    int m_fd;    std::unordered_map<int, Args> m_args;    std::mutex m_lock;    ThreadPool m_ioPool;     // I/O 线程池    ThreadPool m_jobPool;    // 工作线程池};#endif

io_context.cpp:

#include "io_context.h"#include <cstring>#include <errno.h>#include <iostream>IOContext::IOContext(int nIOThreads, int nJobThreads)    : m_ioPool(nIOThreads), m_jobPool(nJobThreads){    m_fd = epoll_create(1024);}bool IOContext::post(int fd, int events, const Args& args){    struct epoll_event event;    event.events = events;    event.data.fd = fd;    std::lock_guard<std::mutex> lock(m_lock);    int err = epoll_ctl(m_fd, EPOLL_CTL_ADD, fd, &event);    if (err == 0)    {        m_args[fd] = args;    }    return err == 0;}bool IOContext::post(const Task& task){    return m_jobPool.submitTask(task);}void IOContext::remove(int fd){    std::lock_guard<std::mutex> lock(m_lock);    int err = epoll_ctl(m_fd, EPOLL_CTL_DEL, fd, nullptr);    if (err == 0)    {        m_args.erase(fd);    }    else    {        std::cout << "remove: " << strerror(errno) << "\n";    }}void IOContext::run(){    int timeout = -1;    size_t nEvents = 32;    struct epoll_event* eventList = new struct epoll_event[nEvents];    while (true)    {        int nReady = epoll_wait(m_fd, eventList, nEvents, timeout);        if (nReady < 0)        {            delete []eventList;            return;        }        for (int i = 0; i < nReady; i++)        {            int fd = eventList[i].data.fd;            m_lock.lock();            auto cb = m_args[fd].m_cb;            m_lock.unlock();            remove(fd);            m_ioPool.submitTask([=]()            {                cb();            });        }    }}

3. Awaitable

实现 C++ 协程所需的类型,具体解释见 https://segmentfault.com/a/11...。

awaitable.h:

#ifndef AWAITABLE_H#define AWAITABLE_H#include <coroutine>#include <sys/socket.h>#include <sys/types.h>#include <iostream>#include "io_context.h"// 回调须要执行的操作类型:读、写、承受客户端连贯enum class HandlerType{    Read, Write, Accept,};class Awaitable{public:    Awaitable(IOContext* ctx, int fd, int events, void* buf, size_t n, HandlerType ht);    bool await_ready();    void await_suspend(std::coroutine_handle<> handle);    int await_resume();private:    IOContext* m_ctx;    int m_fd;    int m_events;    void* m_buf;    size_t m_n;    int m_result;    HandlerType m_ht;};struct CoroRetType{public:    struct promise_type    {        CoroRetType get_return_object();        std::suspend_never initial_suspend();        std::suspend_never final_suspend() noexcept;        void return_void();        void unhandled_exception();    };};#endif

awaitable.cpp:

#include <cstring>#include "awaitable.h"Awaitable::Awaitable(IOContext *ctx, int fd, int events, void* buf, size_t n, HandlerType ht)        : m_ctx(ctx), m_fd(fd), m_events(events), m_buf(buf), m_n(n), m_ht(ht){}bool Awaitable::await_ready(){    return false;}int Awaitable::await_resume(){    return m_result;}// 注册要监听的文件描述符、感兴趣的事件及要执行的回调void Awaitable::await_suspend(std::coroutine_handle<> handle){    auto cb = [handle, this]() mutable    {        switch (m_ht)        {            case HandlerType::Read:                m_result = read(m_fd, m_buf, m_n);                break;            case HandlerType::Write:                m_result = write(m_fd, m_buf, m_n);                break;            case HandlerType::Accept:                m_result = accept(m_fd, nullptr, nullptr);                break;        }                handle.resume();    };    Args args{cb};    m_ctx->post(m_fd, m_events, args);}CoroRetType CoroRetType::promise_type::get_return_object(){    return CoroRetType();}std::suspend_never CoroRetType::promise_type::initial_suspend(){    return std::suspend_never{};}std::suspend_never CoroRetType::promise_type::final_suspend() noexcept{    return std::suspend_never{};}void CoroRetType::promise_type::return_void(){}void CoroRetType::promise_type::unhandled_exception(){    std::terminate();}

4. 异步操作

以下操作向 IOContext 注册要监听的文件描述符及感兴趣的事件,而后立刻返回。

io_util.h:

#ifndef IO_UTIL_H#define IO_UTIL_H#include "io_context.h"#include "awaitable.h"Awaitable asyncRead(IOContext* ctx, int fd, void* buf, size_t n);Awaitable asyncWrite(IOContext* ctx, int fd, void* buf, size_t n);Awaitable asyncAccept(IOContext* ctx, int fd);#endif

io_util.cpp:

#include "io_util.h"Awaitable asyncRead(IOContext* ctx, int fd, void* buf, size_t n){    return Awaitable(ctx, fd, EPOLLIN, buf, n, HandlerType::Read);}Awaitable asyncWrite(IOContext* ctx, int fd, void* buf, size_t n){    return Awaitable(ctx, fd, EPOLLOUT, buf, n, HandlerType::Write);}Awaitable asyncAccept(IOContext* ctx, int fd){    return Awaitable(ctx, fd, EPOLLIN, nullptr, 0, HandlerType::Accept);}

5. 例子

main.cpp:

#include <thread>#include <coroutine>#include <iostream>#include <netinet/in.h>#include <arpa/inet.h>#include <cstring>#include <errno.h>#include "io_util.h"static std::mutex ioLock;static uint16_t port = 6666;static int backlog = 32;static const char* Msg = "hello, cpp!";static const size_t MsgLen = 11;static IOContext ioContext;CoroRetType handleConnection(int fd){    char buf[MsgLen+1] = {0};    int n;    n = co_await asyncRead(&ioContext, fd, buf, MsgLen);    buf[n+1] = '\0';    co_await asyncWrite(&ioContext, fd, buf, n);    close(fd);}CoroRetType serverThread(){    int listenSock = socket(AF_INET, SOCK_STREAM, 0);    int value = 1;    setsockopt(listenSock, SOL_SOCKET, SO_REUSEADDR, &value, sizeof(int));    struct sockaddr_in addr;    memset(&addr, 0, sizeof(addr));    addr.sin_port = htons(port);    addr.sin_family = AF_INET;    addr.sin_addr.s_addr = htonl(INADDR_ANY);    int err = bind(listenSock, (const struct sockaddr*)&addr, sizeof(addr));    listen(listenSock, backlog);    while (true)    {        int clientSock = co_await asyncAccept(&ioContext, listenSock);        auto h = [=]()        {            handleConnection(clientSock);        };        ioContext.post(h);    }}void clientThread(){    using namespace std::literals;    std::this_thread::sleep_for(1s);    int sock = socket(AF_INET, SOCK_STREAM, 0);    struct sockaddr_in addr;    memset(&addr, 0, sizeof(addr));    addr.sin_port = htons(port);    addr.sin_family = AF_INET;    inet_pton(AF_INET, "127.0.0.1", &addr.sin_addr);    connect(sock, (const struct sockaddr*)&addr, sizeof(addr));    char buf[MsgLen+1] = {0};    ssize_t n = write(sock, Msg, MsgLen);    read(sock, buf, n);    buf[n+1] = '\0';    std::lock_guard<std::mutex> lock(ioLock);    std::cout << "clientThread: " << buf << '\n';    close(sock);}int main(){    serverThread();    constexpr int N = 10;    for (int i = 0; i < N; i++)    {        std::thread t(clientThread);        t.detach();    }    ioContext.run();}
clientThread: hello, cpp!clientThread: hello, cpp!clientThread: hello, cpp!clientThread: hello, cpp!clientThread: hello, cpp!clientThread: hello, cpp!clientThread: hello, cpp!clientThread: hello, cpp!clientThread: hello, cpp!clientThread: hello, cpp!