关于c++:Linux-epoll-与-C-协程

7次阅读

共计 6436 个字符,预计需要花费 17 分钟才能阅读完成。

简介

本文应用 C++20 引入的协程来编写一个 Linux epoll 程序。在此实现中,用户应用异步操作时再也无需提供本人的回调函数。以此处实现的 asyncRead() 为例:

  • 应用 asyncRead() 所需的参数和 read() 大致相同,无需传入回调;
  • asyncRead() 的外部会向 epoll 注册要监听的文件描述符、感兴趣的事件和要执行的回调(由实现提供,而无需使用者传入);
  • 当事件未就绪时,co_await asyncRead() 会挂起以后协程;
  • 当事件就绪时,epoll 循环中会执行具体的 I/O 操作(此处将其提交到 I/O 线程池中执行),当 I/O 操作实现时,复原协程的运行。

1. ThreadPool

此处应用了两个线程池:

  • I/O 线程池:用于执行 I/O 操作;
  • 工作线程池:用于解决客户端连贯(此处以 tcp 回显程序为例)。

此处应用的是本人实现的线程池,具体实现见 https://segmentfault.com/a/11…。

2. IOContext

IOContext 类对 Linux epoll 做了简略的封装。

io_context.h:

#ifndef IOCONTEXT_H
#define IOCONTEXT_H

#include <sys/epoll.h>
#include <unistd.h>
#include <unordered_map>
#include <functional>
#include "thread_pool.h"

using callback_t = std::function<void()>;

struct Args
{callback_t m_cb;};

class IOContext
{
public:
    IOContext(int nIOThreads=2, int nJobThreads=2);

    // 监听文件描述符 fd,感兴趣事件为 events,args 中蕴含要执行的回调
    bool post(int fd, int events, const Args& args);

    // 提交工作至工作线程池
    bool post(const Task& task);

    // 不再关注文件描述符 fd,并移除相应的回调
    void remove(int fd);

    // 继续监听、期待事件就绪
    void run();

private:
    int m_fd;
    std::unordered_map<int, Args> m_args;
    std::mutex m_lock;
    ThreadPool m_ioPool;     // I/O 线程池
    ThreadPool m_jobPool;    // 工作线程池
};

#endif

io_context.cpp:

#include "io_context.h"
#include <cstring>
#include <errno.h>
#include <iostream>

IOContext::IOContext(int nIOThreads, int nJobThreads)
    : m_ioPool(nIOThreads), m_jobPool(nJobThreads)
{m_fd = epoll_create(1024);
}

bool IOContext::post(int fd, int events, const Args& args)
{
    struct epoll_event event;
    event.events = events;
    event.data.fd = fd;

    std::lock_guard<std::mutex> lock(m_lock);

    int err = epoll_ctl(m_fd, EPOLL_CTL_ADD, fd, &event);
    if (err == 0)
    {m_args[fd] = args;
    }

    return err == 0;
}

bool IOContext::post(const Task& task)
{return m_jobPool.submitTask(task);
}

void IOContext::remove(int fd)
{std::lock_guard<std::mutex> lock(m_lock);
    int err = epoll_ctl(m_fd, EPOLL_CTL_DEL, fd, nullptr);
    if (err == 0)
    {m_args.erase(fd);
    }
    else
    {std::cout << "remove:" << strerror(errno) << "\n";
    }
}

void IOContext::run()
{
    int timeout = -1;
    size_t nEvents = 32;
    struct epoll_event* eventList = new struct epoll_event[nEvents];

    while (true)
    {int nReady = epoll_wait(m_fd, eventList, nEvents, timeout);
        if (nReady < 0)
        {delete []eventList;
            return;
        }

        for (int i = 0; i < nReady; i++)
        {int fd = eventList[i].data.fd;

            m_lock.lock();
            auto cb = m_args[fd].m_cb;
            m_lock.unlock();

            remove(fd);

            m_ioPool.submitTask([=]()
            {cb();
            });
        }
    }
}

3. Awaitable

实现 C++ 协程所需的类型,具体解释见 https://segmentfault.com/a/11…。

awaitable.h:

#ifndef AWAITABLE_H
#define AWAITABLE_H

#include <coroutine>
#include <sys/socket.h>
#include <sys/types.h>
#include <iostream>
#include "io_context.h"

// 回调须要执行的操作类型:读、写、承受客户端连贯
enum class HandlerType
{Read, Write, Accept,};

class Awaitable
{
public:
    Awaitable(IOContext* ctx, int fd, int events, void* buf, size_t n, HandlerType ht);

    bool await_ready();
    void await_suspend(std::coroutine_handle<> handle);
    int await_resume();

private:
    IOContext* m_ctx;
    int m_fd;
    int m_events;
    void* m_buf;
    size_t m_n;
    int m_result;
    HandlerType m_ht;
};


struct CoroRetType
{
public:
    struct promise_type
    {CoroRetType get_return_object();
        std::suspend_never initial_suspend();
        std::suspend_never final_suspend() noexcept;
        void return_void();
        void unhandled_exception();};
};
#endif

awaitable.cpp:

#include <cstring>
#include "awaitable.h"

Awaitable::Awaitable(IOContext *ctx, int fd, int events, void* buf, size_t n, HandlerType ht)
        : m_ctx(ctx), m_fd(fd), m_events(events), m_buf(buf), m_n(n), m_ht(ht)
{}

bool Awaitable::await_ready()
{return false;}

int Awaitable::await_resume()
{return m_result;}

// 注册要监听的文件描述符、感兴趣的事件及要执行的回调
void Awaitable::await_suspend(std::coroutine_handle<> handle)
{auto cb = [handle, this]() mutable
    {switch (m_ht)
        {
            case HandlerType::Read:
                m_result = read(m_fd, m_buf, m_n);
                break;
            case HandlerType::Write:
                m_result = write(m_fd, m_buf, m_n);
                break;
            case HandlerType::Accept:
                m_result = accept(m_fd, nullptr, nullptr);
                break;
        }
        
        handle.resume();};

    Args args{cb};
    m_ctx->post(m_fd, m_events, args);
}


CoroRetType CoroRetType::promise_type::get_return_object()
{return CoroRetType();
}

std::suspend_never CoroRetType::promise_type::initial_suspend()
{return std::suspend_never{};
}

std::suspend_never CoroRetType::promise_type::final_suspend() noexcept
{return std::suspend_never{};
}

void CoroRetType::promise_type::return_void()
{}

void CoroRetType::promise_type::unhandled_exception()
{std::terminate();
}

4. 异步操作

以下操作向 IOContext 注册要监听的文件描述符及感兴趣的事件,而后立刻返回。

io_util.h:

#ifndef IO_UTIL_H
#define IO_UTIL_H

#include "io_context.h"
#include "awaitable.h"

Awaitable asyncRead(IOContext* ctx, int fd, void* buf, size_t n);
Awaitable asyncWrite(IOContext* ctx, int fd, void* buf, size_t n);
Awaitable asyncAccept(IOContext* ctx, int fd);

#endif

io_util.cpp:

#include "io_util.h"

Awaitable asyncRead(IOContext* ctx, int fd, void* buf, size_t n)
{return Awaitable(ctx, fd, EPOLLIN, buf, n, HandlerType::Read);
}

Awaitable asyncWrite(IOContext* ctx, int fd, void* buf, size_t n)
{return Awaitable(ctx, fd, EPOLLOUT, buf, n, HandlerType::Write);
}

Awaitable asyncAccept(IOContext* ctx, int fd)
{return Awaitable(ctx, fd, EPOLLIN, nullptr, 0, HandlerType::Accept);
}

5. 例子

main.cpp:

#include <thread>
#include <coroutine>
#include <iostream>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <cstring>
#include <errno.h>
#include "io_util.h"

static std::mutex ioLock;
static uint16_t port = 6666;
static int backlog = 32;
static const char* Msg = "hello, cpp!";
static const size_t MsgLen = 11;
static IOContext ioContext;

CoroRetType handleConnection(int fd)
{char buf[MsgLen+1] = {0};
    int n;

    n = co_await asyncRead(&ioContext, fd, buf, MsgLen);
    buf[n+1] = '\0';

    co_await asyncWrite(&ioContext, fd, buf, n);
    close(fd);
}

CoroRetType serverThread()
{int listenSock = socket(AF_INET, SOCK_STREAM, 0);

    int value = 1;
    setsockopt(listenSock, SOL_SOCKET, SO_REUSEADDR, &value, sizeof(int));

    struct sockaddr_in addr;
    memset(&addr, 0, sizeof(addr));
    addr.sin_port = htons(port);
    addr.sin_family = AF_INET;
    addr.sin_addr.s_addr = htonl(INADDR_ANY);

    int err = bind(listenSock, (const struct sockaddr*)&addr, sizeof(addr));
    listen(listenSock, backlog);

    while (true)
    {int clientSock = co_await asyncAccept(&ioContext, listenSock);

        auto h = [=]()
        {handleConnection(clientSock);
        };

        ioContext.post(h);
    }
}

void clientThread()
{
    using namespace std::literals;

    std::this_thread::sleep_for(1s);

    int sock = socket(AF_INET, SOCK_STREAM, 0);

    struct sockaddr_in addr;
    memset(&addr, 0, sizeof(addr));
    addr.sin_port = htons(port);
    addr.sin_family = AF_INET;
    inet_pton(AF_INET, "127.0.0.1", &addr.sin_addr);

    connect(sock, (const struct sockaddr*)&addr, sizeof(addr));

    char buf[MsgLen+1] = {0};

    ssize_t n = write(sock, Msg, MsgLen);

    read(sock, buf, n);
    buf[n+1] = '\0';

    std::lock_guard<std::mutex> lock(ioLock);
    std::cout << "clientThread:" << buf << '\n';

    close(sock);
}

int main()
{serverThread();

    constexpr int N = 10;
    for (int i = 0; i < N; i++)
    {std::thread t(clientThread);
        t.detach();}

    ioContext.run();}
clientThread: hello, cpp!
clientThread: hello, cpp!
clientThread: hello, cpp!
clientThread: hello, cpp!
clientThread: hello, cpp!
clientThread: hello, cpp!
clientThread: hello, cpp!
clientThread: hello, cpp!
clientThread: hello, cpp!
clientThread: hello, cpp!
正文完
 0