简介
本文应用 C++20 引入的协程来编写一个 Linux epoll 程序。在此实现中,用户应用异步操作时再也无需提供本人的回调函数。以此处实现的 asyncRead()
为例:
- 应用
asyncRead()
所需的参数和read()
大致相同,无需传入回调; asyncRead()
的外部会向 epoll 注册要监听的文件描述符、感兴趣的事件和要执行的回调(由实现提供,而无需使用者传入);- 当事件未就绪时,
co_await asyncRead()
会挂起以后协程; - 当事件就绪时,epoll 循环中会执行具体的 I/O 操作(此处将其提交到 I/O 线程池中执行),当 I/O 操作实现时,复原协程的运行。
1. ThreadPool
此处应用了两个线程池:
- I/O 线程池:用于执行 I/O 操作;
- 工作线程池:用于解决客户端连贯(此处以 tcp 回显程序为例)。
此处应用的是本人实现的线程池,具体实现见 https://segmentfault.com/a/11…。
2. IOContext
IOContext
类对 Linux epoll 做了简略的封装。
io_context.h:
#ifndef IOCONTEXT_H
#define IOCONTEXT_H
#include <sys/epoll.h>
#include <unistd.h>
#include <unordered_map>
#include <functional>
#include "thread_pool.h"
using callback_t = std::function<void()>;
struct Args
{callback_t m_cb;};
class IOContext
{
public:
IOContext(int nIOThreads=2, int nJobThreads=2);
// 监听文件描述符 fd,感兴趣事件为 events,args 中蕴含要执行的回调
bool post(int fd, int events, const Args& args);
// 提交工作至工作线程池
bool post(const Task& task);
// 不再关注文件描述符 fd,并移除相应的回调
void remove(int fd);
// 继续监听、期待事件就绪
void run();
private:
int m_fd;
std::unordered_map<int, Args> m_args;
std::mutex m_lock;
ThreadPool m_ioPool; // I/O 线程池
ThreadPool m_jobPool; // 工作线程池
};
#endif
io_context.cpp:
#include "io_context.h"
#include <cstring>
#include <errno.h>
#include <iostream>
IOContext::IOContext(int nIOThreads, int nJobThreads)
: m_ioPool(nIOThreads), m_jobPool(nJobThreads)
{m_fd = epoll_create(1024);
}
bool IOContext::post(int fd, int events, const Args& args)
{
struct epoll_event event;
event.events = events;
event.data.fd = fd;
std::lock_guard<std::mutex> lock(m_lock);
int err = epoll_ctl(m_fd, EPOLL_CTL_ADD, fd, &event);
if (err == 0)
{m_args[fd] = args;
}
return err == 0;
}
bool IOContext::post(const Task& task)
{return m_jobPool.submitTask(task);
}
void IOContext::remove(int fd)
{std::lock_guard<std::mutex> lock(m_lock);
int err = epoll_ctl(m_fd, EPOLL_CTL_DEL, fd, nullptr);
if (err == 0)
{m_args.erase(fd);
}
else
{std::cout << "remove:" << strerror(errno) << "\n";
}
}
void IOContext::run()
{
int timeout = -1;
size_t nEvents = 32;
struct epoll_event* eventList = new struct epoll_event[nEvents];
while (true)
{int nReady = epoll_wait(m_fd, eventList, nEvents, timeout);
if (nReady < 0)
{delete []eventList;
return;
}
for (int i = 0; i < nReady; i++)
{int fd = eventList[i].data.fd;
m_lock.lock();
auto cb = m_args[fd].m_cb;
m_lock.unlock();
remove(fd);
m_ioPool.submitTask([=]()
{cb();
});
}
}
}
3. Awaitable
实现 C++ 协程所需的类型,具体解释见 https://segmentfault.com/a/11…。
awaitable.h:
#ifndef AWAITABLE_H
#define AWAITABLE_H
#include <coroutine>
#include <sys/socket.h>
#include <sys/types.h>
#include <iostream>
#include "io_context.h"
// 回调须要执行的操作类型:读、写、承受客户端连贯
enum class HandlerType
{Read, Write, Accept,};
class Awaitable
{
public:
Awaitable(IOContext* ctx, int fd, int events, void* buf, size_t n, HandlerType ht);
bool await_ready();
void await_suspend(std::coroutine_handle<> handle);
int await_resume();
private:
IOContext* m_ctx;
int m_fd;
int m_events;
void* m_buf;
size_t m_n;
int m_result;
HandlerType m_ht;
};
struct CoroRetType
{
public:
struct promise_type
{CoroRetType get_return_object();
std::suspend_never initial_suspend();
std::suspend_never final_suspend() noexcept;
void return_void();
void unhandled_exception();};
};
#endif
awaitable.cpp:
#include <cstring>
#include "awaitable.h"
Awaitable::Awaitable(IOContext *ctx, int fd, int events, void* buf, size_t n, HandlerType ht)
: m_ctx(ctx), m_fd(fd), m_events(events), m_buf(buf), m_n(n), m_ht(ht)
{}
bool Awaitable::await_ready()
{return false;}
int Awaitable::await_resume()
{return m_result;}
// 注册要监听的文件描述符、感兴趣的事件及要执行的回调
void Awaitable::await_suspend(std::coroutine_handle<> handle)
{auto cb = [handle, this]() mutable
{switch (m_ht)
{
case HandlerType::Read:
m_result = read(m_fd, m_buf, m_n);
break;
case HandlerType::Write:
m_result = write(m_fd, m_buf, m_n);
break;
case HandlerType::Accept:
m_result = accept(m_fd, nullptr, nullptr);
break;
}
handle.resume();};
Args args{cb};
m_ctx->post(m_fd, m_events, args);
}
CoroRetType CoroRetType::promise_type::get_return_object()
{return CoroRetType();
}
std::suspend_never CoroRetType::promise_type::initial_suspend()
{return std::suspend_never{};
}
std::suspend_never CoroRetType::promise_type::final_suspend() noexcept
{return std::suspend_never{};
}
void CoroRetType::promise_type::return_void()
{}
void CoroRetType::promise_type::unhandled_exception()
{std::terminate();
}
4. 异步操作
以下操作向 IOContext 注册要监听的文件描述符及感兴趣的事件,而后立刻返回。
io_util.h:
#ifndef IO_UTIL_H
#define IO_UTIL_H
#include "io_context.h"
#include "awaitable.h"
Awaitable asyncRead(IOContext* ctx, int fd, void* buf, size_t n);
Awaitable asyncWrite(IOContext* ctx, int fd, void* buf, size_t n);
Awaitable asyncAccept(IOContext* ctx, int fd);
#endif
io_util.cpp:
#include "io_util.h"
Awaitable asyncRead(IOContext* ctx, int fd, void* buf, size_t n)
{return Awaitable(ctx, fd, EPOLLIN, buf, n, HandlerType::Read);
}
Awaitable asyncWrite(IOContext* ctx, int fd, void* buf, size_t n)
{return Awaitable(ctx, fd, EPOLLOUT, buf, n, HandlerType::Write);
}
Awaitable asyncAccept(IOContext* ctx, int fd)
{return Awaitable(ctx, fd, EPOLLIN, nullptr, 0, HandlerType::Accept);
}
5. 例子
main.cpp:
#include <thread>
#include <coroutine>
#include <iostream>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <cstring>
#include <errno.h>
#include "io_util.h"
static std::mutex ioLock;
static uint16_t port = 6666;
static int backlog = 32;
static const char* Msg = "hello, cpp!";
static const size_t MsgLen = 11;
static IOContext ioContext;
CoroRetType handleConnection(int fd)
{char buf[MsgLen+1] = {0};
int n;
n = co_await asyncRead(&ioContext, fd, buf, MsgLen);
buf[n+1] = '\0';
co_await asyncWrite(&ioContext, fd, buf, n);
close(fd);
}
CoroRetType serverThread()
{int listenSock = socket(AF_INET, SOCK_STREAM, 0);
int value = 1;
setsockopt(listenSock, SOL_SOCKET, SO_REUSEADDR, &value, sizeof(int));
struct sockaddr_in addr;
memset(&addr, 0, sizeof(addr));
addr.sin_port = htons(port);
addr.sin_family = AF_INET;
addr.sin_addr.s_addr = htonl(INADDR_ANY);
int err = bind(listenSock, (const struct sockaddr*)&addr, sizeof(addr));
listen(listenSock, backlog);
while (true)
{int clientSock = co_await asyncAccept(&ioContext, listenSock);
auto h = [=]()
{handleConnection(clientSock);
};
ioContext.post(h);
}
}
void clientThread()
{
using namespace std::literals;
std::this_thread::sleep_for(1s);
int sock = socket(AF_INET, SOCK_STREAM, 0);
struct sockaddr_in addr;
memset(&addr, 0, sizeof(addr));
addr.sin_port = htons(port);
addr.sin_family = AF_INET;
inet_pton(AF_INET, "127.0.0.1", &addr.sin_addr);
connect(sock, (const struct sockaddr*)&addr, sizeof(addr));
char buf[MsgLen+1] = {0};
ssize_t n = write(sock, Msg, MsgLen);
read(sock, buf, n);
buf[n+1] = '\0';
std::lock_guard<std::mutex> lock(ioLock);
std::cout << "clientThread:" << buf << '\n';
close(sock);
}
int main()
{serverThread();
constexpr int N = 10;
for (int i = 0; i < N; i++)
{std::thread t(clientThread);
t.detach();}
ioContext.run();}
clientThread: hello, cpp!
clientThread: hello, cpp!
clientThread: hello, cpp!
clientThread: hello, cpp!
clientThread: hello, cpp!
clientThread: hello, cpp!
clientThread: hello, cpp!
clientThread: hello, cpp!
clientThread: hello, cpp!
clientThread: hello, cpp!