在这里插入图片描述

一、epoll相关接口

1.1 epoll_create函数

#include <sys/epoll.h>

int epoll_create(int size);

功能:创建一个 epoll 实例(文件描述符),用于管理被监控的文件描述符。

参数

  • size:在 Linux 2.6.8 之后被忽略,但需传入大于 0 的值(历史遗留参数)

返回值:成功返回 epoll 实例的文件描述符,失败返回 -1。


1.2 epoll_ctl函数

#include <sys/epoll.h>

int epoll_ctl(int epfd, int op, int fd, struct epoll_event *event);

功能:注册、修改或删除对指定文件描述符的监控。

参数

  • epfd:epoll 实例的文件描述符(由 epoll_create 返回)
  • op:操作类型:
    • EPOLL_CTL_ADD:注册监控
    • EPOLL_CTL_MOD:修改已注册的监控事件
    • EPOLL_CTL_DEL:删除监控(此时 event 可为 NULL)
  • fd:要监控的文件描述符
  • event:指向 struct epoll_event 的指针,指定监控的事件类型和用户数据

struct epoll_event 结构

struct epoll_event {
    uint32_t     events;    // 事件掩码(如 EPOLLIN、EPOLLOUT)
    epoll_data_t data;      // 用户数据(可存储 fd 或指针)
};

typedef union epoll_data {
    void        *ptr;
    int          fd;
    uint32_t     u32;
    uint64_t     u64;
} epoll_data_t;

返回值

  • 0:操作成功
  • -1:操作失败,并设置 errno 以指示具体错误类型

常用事件标志

事件 描述
EPOLLIN 对应的文件描述符可以读
EPOLLOUT 对应的文件描述符可以写
EPOLLPRI 对应的文件描述符有紧急的数据可读
EPOLLERR 对应的文件描述符发生错误
EPOLLHUP 将EPOLL设为边缘触发(Edge Triggered)模式
EPOLLET 将EPOLL设为水平触发(Level Triggered)模式
EPOLLONESHOT 只监听一次事件,当监听完这次事件之后,如果还需要继续监听这个socket的话,需要再次把这个socket加入到EPOLL队列里

1.3 epoll_wait函数

#include <sys/epoll.h>

int epoll_wait(int epfd, struct epoll_event *events, int maxevents, int timeout);

功能:等待 epoll 实例中监控的文件描述符上有事件发生。

参数

  • epfd:epoll 实例的文件描述符
  • events:用于存储就绪事件的数组(由用户分配)
  • maxevents:数组大小(必须大于 0)
  • timeout:超时时间(毫秒):
    • -1:永久阻塞,直到有事件发生
    • 0:立即返回,不阻塞
    • >0:指定超时时间

返回值

  • 正数:就绪事件的数量
  • 0:超时
  • -1:错误(如 epfd 无效)

二、epoll的原理 + 相关接口

我们知道select和poll需要一直遍历所有的文件描述符才能确认是否有事件就绪,这势必会导致效率低下的问题。

epoll不需要一直遍历就可以知道哪些文件描述符的哪些事件已经就绪了,它是如何做到的呢?

之前的文章中讲到过,操作系统并不需要轮询遍历硬件,就能知道硬件上是否有数据就绪,这是通过硬件中断实现的,再通过中断号,执行对应中断向量表中的方法,就可以将硬件上的数据拿到内存中了。

epoll就是通过硬件中断+回调函数的方法,实现不需要一直遍历所以的文件描述符就能知道哪些文件描述符的哪些事件已经就绪了。


下面我就详细的讲解一下 epoll 的原理和 epoll 相关接口。

进程在运行的时候,操作系统会为其创建一个 task_struct,操作系统还会为每个进程创建一个 files_struct,files_struct 中有一个进程文件描述符表,文件描述符表中有一个数组用来存储被打开文件的结构体对象的地址。

进程在调用 epoll_create 函数的时候,操作系统会为进程创建一个 epoll模型,并返回一个文件描述符,操作系统还会创建一个 struct file 结构体,通过文件描述符指向 struct file 结构体,而 struct file 结构体中有一个字段会指向 epoll 模型。

epoll 模型中有两个重要的部分就是红黑树就绪队列,红黑树类似于 select 和 poll 中需要用户手动维护的数组,而这里的红黑树是由内核维护的。

当用户需要增加对某个文件描述符上的某个事件关注时,可以调用 epoll_ctl 函数,操作系统会创建一个 epitem 节点(结构体) ,节点中包含文件描述符、对应关注事件等数据,然后该节点会被链接到红黑树中,并且操作系统会在网卡的驱动中添加回调函数

epoll_ctl 函数的作用是注册、修改或删除对指定文件描述符的关注,同理修改对某个文件描述符上的某个事件关注,就是修改红黑树中对应的 epitem 节点,删除对某个文件描述符上的某个事件关注,就是删除红黑树中对应的 epitem 节点。

当网卡中有数据时,网卡会发送中断信号,操作系统就知道网卡中有事件就绪,此时会调用回调函数,判断该事件是否与用户关注的文件描述符中的事件有关,有关则将相关在红黑树中的 epitem 节点 链入到就绪队列中,此时节点还在红黑树中

用户通过调用 epoll_wait 函数等待 epoll 实例中监控的文件描述符上有事件发生,epoll_wait 函数只需要通过就绪队列中是否有节点,就能判断是否有事件就绪,时间复杂度为O(1),而 epoll_wait 函数获取所有就绪节点,需要遍历就绪队列,时间复杂度为O(N),这是无法优化的。

epoll_wait 函数获取就绪节点的时候,会严格按照就绪队列中节点的顺序,将数据放在用户传入的 struct epoll_event 数组中,并返回就绪事件的个数。如果说就绪队列中的节点数量大于用户传入的 maxevents 时,epoll_wait 函数存放 maxevents 个数据后就返回,其余的节点会保留在就绪队列中,方便用户下一次读取。

进程调用 epoll_wait 函数完后,存放 struct epoll_event 数组中的都是有效的文件描述符和对应就绪事件,用户根据函数的返回值,从头开始遍历,中途不会出现无效文件描述符和没有就绪的时间,也就是说用户不会访问到无效信息

操作系统中可以有多个进程,一个进程中可以有多个 epoll模型,一个 epoll模型中又有红黑树、就绪队列等相关数据,操作系统就需要对 epoll模型进行管理先描述再组织,将红黑树、就绪队列等数据统一保存到结构体 eventpoll 中,再通过链表的方式将所有的 eventpoll 节点进行链接,对epoll模型的管理,就转变为了对链表的增删查改

struct file结构体中有一个字段会指向 epoll模型,实际上指的就是 eventpoll 结构体,也就是说通过struct file结构体也可以管理epoll模型。之前的文章中讲到过,struct file结构体可以指向各种本地设备,例如键盘、磁盘、网卡等,struct file结构体还可以指向网络中的套接字,我们还讲过Linux操作系统下,一切皆文件,所以在操作系统中只需要将struct file结构体管理好,就可以将操作系统中的绝大部分资源管理好

在这里插入图片描述


三、epoll 工作方式

3.1 理解水平触发和边缘触发

epoll 有2种工作方式:水平触发(LT)和边缘触发(ET)

这里以故事的方式帮助大家理解两种工作模式:

前几天小Z在网络平台上购买了五个快递,此时小Z正在家中参加学校的线上会议。小A和小B是快递站的两名员工,刚好今天五个快递都到了快递站,小A就带着小Z的三个快递和其他人的快递来到了小Z的小区,小B就带着小Z的两个快递和其他人的快递来到了小Z的小区。

当小A到达了小Z所在的这栋楼时,小A就给小Z打电话,告诉小Z你的快递到了,下来拿一下。小Z就告诉小A自己正在开会,让小A等一下,小A就说行吧,然后就去派送其他人的快递了,过了一段时间,小A又给小Z打电话,小Z还是在开会没时间,又让小A等一下,小A又去派送别人的快递,重复了几个几次后,小A再次给小Z打电话,此时小Z的会议刚好开完,就下楼去取快递了,但是快递太大了,小Z只能拿两个上去,此时小A还有一个小Z的快递,小Z上楼后,学校有突发事件,要求小Z再次参加会议,小A在楼下等了一段时间后,又给小Z打电话,说小Z还有一个快递未取,小Z又说正在开会,让小A等一下,重复几次后,小B也来到了小Z所在的这栋楼,小B看见小A就打招呼,问他在干什么,小A说正在派送小Z的快递,小B一看自己也有两个小Z的快递,就让小A一起派送了,此时小A手上就有三个小Z的快递,然后小A又给小Z打电话,告诉小Z又到了两个快递,让小Z下来取一下,小Z还是再开会,让小A等一下,重复几次后,会议开完了,小A打电话给小Z,小Z这时候就下楼将所有的快递全部拿上去了。

拿到快递后,小Z又买了五个快递,过了几天,小Z的五个快递到了快递站,又是小A和小B送快递,小B就带着小Z的三个快递和其他人的快递来到了小Z的小区,小A就带着小Z的两个快递和其他人的快递来到了小Z的小区。

当小B到达了小Z所在的这栋楼时,此时小Z正在参加学校的线上会议,小B就给小Z打电话,告诉小Z你的快递到了,下来拿一下,还告诉了小Z,他就只打这一次电话,如果小Z不下来拿,他就再也不打电话了。小Z一想小B只打一次电话,就有点害怕今天拿不到快递了,就先不管会议,下楼拿快递了,但是快递太大了,小Z只能拿两个快递,此时小B手上还有一个小Z的快递,小Z上去以后就没有动静了,但是小B也不会再给小Z打电话了,此时小A也来到了小Z所在的这栋楼,小A看见小B就打招呼,问小B在干什么,小B就说在派送小A的快递,小A一看自己刚好有两个小Z的快递,就让小B一起送了,小B此时手上就有三个小Z的快递,由于新到了两个快递,小B就给小Z打电话,你又到了两个快递,你下来拿一下快递,还告诉了小Z,他就只打这一次电话,如果小Z不下来拿,他就再也不打电话了,小Z一看小B就不好惹,就下楼将所有的快递都拿上楼了。

在上面的故事中讲到了小A和小B为小Z送快递的事情

  • 小A只要手上有小Z的快递,就会一直给小Z打电话
  • 小B只有第一次送快递和后序有新快递到来的时候才会给小Z打电话

小A送快递的方式就是采用了水平触发策略,小B送快递的方式就是采用了边缘触发策略

  • 水平触发策略:底层只要有数据,epoll就会一直通知上层取数据
  • 边缘触发策略:底层有数据了,epoll只通知一次,让上层取数据,后序不再通知,直到底层收到了新的数据,epoll才会进行下一次通知

3.2 对比水平触发和边缘触发

水平触发是 epoll 的默认行为,那水平触发和边缘触发哪个效率更高呢?

显然是边缘触发(ET)的效率更高,在ET策略下,没有无效的通知,全部都是有效的,ET策略下 epoll 只会通知上层一次,倒逼上层要取数据,并且要将本轮数据取完。

数据取完后,底层的接收缓冲区的空间更大,给对方发送的窗口大小也就更大,对方的滑动窗口也就更大了,从概率上就提高了双方的通信效率。

但是应该如何保证上层将本轮数据取完呢?那就只能让上层循环读取,但是上层不知道底层是否有数据,必定会导致阻塞,如何让上层不阻塞呢?使用非阻塞IO的方式进行读取。

上层使用非阻塞IO的方式循环读取底层的数据,就能保证上层将本轮数据取完。


3.3 理解ET模式和非阻塞文件描述符

使用 ET 模式的 epoll,需要将文件描述设置为非阻塞,这个不是接口上的要求,而是 “工程实践” 上的要求。

假设这样的场景:服务器接受到一个10k的请求,会向客户端返回一个应答数据,如果客户端收不到应答,不会发送第二个10k请求。

如果服务端写的代码是阻塞式的read,并且一次只 read 1k 数据的话(read不能保证一次就把所有的数据都读出来,参考 man 手册的说明,可能被信号打断),剩下的9k数据就会待在缓冲区中。

此时由于 epoll 是ET模式,并不会认为文件描述符读就绪,epoll_wait 就不会再次返回,剩下的 9k 数据会一直在缓冲区中,直到下一次客户端再给服务器写数据,epoll_wait 才能返回。

但是问题来了,服务器只读到1k个数据,要10k读完才会给客户端返回响应数据,客户端要读到服务器的响应,客户端发送了下一个请求,epoll_wait 才会返回,才能去读缓冲区中剩余的数据。

所以,为了解决上述问题(阻塞read不一定能一下把完整的请求读完),于是就可以使用非阻塞轮训的方式来读缓冲区,保证一定能把完整的请求都读出来。而如果是LT没这个问题,只要缓冲区中的数据没读完,就能够让 epoll_wait 返回文件描述符读就绪。


四、epoll的优点(和 select 的缺点对应)

  • 接口使用方便:虽然拆分成了三个函数,但是反而使用起来更方便高效,不需要每次循环都设置关注的文件描述符,也做到了输入输出参数分离开
  • 数据拷贝轻量:只在合适的时候调用 EPOLL_CTL_ADD 将文件描述符结构拷贝到内核中,这个操作并不频繁(而select/poll都是每次循环都要进行拷贝)
  • 事件回调机制:避免使用遍历,而是使用回调函数的方式,将就绪的文件描述符结构加入到就绪队列中,epoll_wait 返回直接访问就绪队列就知道哪些文件描述符就绪,这个操作时间复杂度O(1),即使文件描述符数目很多,效率也不会受到影响
  • 没有数量限制:文件描述符数目无上限

注意!!
网上有些博客说,epoll中使用了内存映射机制

  • 内存映射机制:内核直接将就绪队列通过mmap的方式映射到用户态,避免了拷贝内存这样的额外性能开销

这种说法是不准确的,我们定义的struct epoll_event是我们在用户空间中分配好的内存,势必还是需要将内核的数据拷贝到这个用户空间的内存中的。


五、epoll的使用场景

epoll的高性能,是有一定的特定场景的,如果场景选择的不适宜,epoll的性能可能适得其反。

对于多连接,且多连接中只有一部分连接比较活跃时,比较适合使用epoll。
例如,典型的一个需要处理上万个客户端的服务器,例如各种互联网APP的入口服务器,这样的服务器就很适合epoll。

如果只是系统内部,服务器和服务器之间进行通信,只有少数的几个连接,这种情况下用epoll就并不合适。具体要根据需求和场景特点来决定使用哪种IO模型。


六、实现epoll服务器

6.1 Log.hpp(日志)

#pragma once

#include "LockGuard.hpp"
#include <iostream>
#include <string>
#include <stdarg.h>
#include <stdio.h>
#include <pthread.h>
#include <time.h>
#include <unistd.h>
#include <sys/types.h>
#include <fcntl.h>
#include <sys/stat.h>

using namespace std;

// 日志等级
enum
{
    Debug = 0, // 调试
    Info,      // 正常
    Warning,   // 警告
    Error,     // 错误,但程序并未直接退出
    Fatal      // 程序直接挂掉
};

enum
{
    Screen = 10, // 打印到显示器上
    OneFile,     // 打印到一个文件中
    ClassFile    // 按照日志等级打印到不同的文件中
};

string LevelToString(int level)
{
    switch (level)
    {
    case Debug:
        return "Debug";
    case Info:
        return "Info";
    case Warning:
        return "Warning";
    case Error:
        return "Error";
    case Fatal:
        return "Fatal";
    default:
        return "Unknow";
    }
}

const char *default_filename = "log.";
const int default_style = Screen;
const char *defaultdir = "log";

class Log
{
public:
    Log()
        : style(default_style), filename(default_filename)
    {
        // mkdir(defaultdir,0775);
        pthread_mutex_init(&_log_mutex, nullptr);
    }

    void SwitchStyle(int sty)
    {
        style = sty;
    }

    void WriteLogToOneFile(const string &logname, const string &logmessage)
    {
        int fd = open(logname.c_str(), O_CREAT | O_WRONLY | O_APPEND, 0666);
        if (fd == -1)
            return;

        {
            LockGuard lockguard(&_log_mutex);
            write(fd, logmessage.c_str(), logmessage.size());
        }
        close(fd);
    }

    void WriteLogToClassFile(const string &levelstr, const string &logmessage)
    {
        mkdir(defaultdir, 0775);

        string name = defaultdir;
        name += "/";
        name += filename;
        name += levelstr;

        WriteLogToOneFile(name, logmessage);
    }

    void WriteLog(int level, const string &logmessage)
    {
        switch (style)
        {
        case Screen:
        {
            LockGuard lockguard(&_log_mutex);
            cout << logmessage;
        }
        break;
        case OneFile:
            WriteLogToClassFile("All", logmessage);
            break;
        case ClassFile:
            WriteLogToClassFile(LevelToString(level), logmessage);
            break;
        default:
            break;
        }
    }

    string GetTime()
    {
        time_t CurrentTime = time(nullptr);

        struct tm *curtime = localtime(&CurrentTime);
        char time[128];

        // localtime 的年是从1900开始的,所以要加1900, 月是从0开始的所以加1
        snprintf(time, sizeof(time), "%d-%d-%d %d:%d:%d",
                 curtime->tm_year + 1900, curtime->tm_mon + 1, curtime->tm_mday,
                 curtime->tm_hour, curtime->tm_min, curtime->tm_sec);

        return time;
        return "";
    }

    void LogMessage(int level, const char *format, ...)
    {
        char left[1024];
        string Levelstr = LevelToString(level).c_str();
        string Timestr = GetTime().c_str();
        string Idstr = to_string(getpid());
        snprintf(left, sizeof(left), "[%s][%s][%s] ",
                 Levelstr.c_str(), Timestr.c_str(), Idstr.c_str());

        va_list args;
        va_start(args, format);
        char right[1024];
        vsnprintf(right, sizeof(right), format, args);

        string logmessage = left;
        logmessage += right;

        WriteLog(level, logmessage);

        va_end(args);
    }

    ~Log()
    {
        pthread_mutex_destroy(&_log_mutex);
    };

private:
    int style;
    string filename;

    pthread_mutex_t _log_mutex;
};

Log lg;

class Conf
{
public:
    Conf()
    {
        lg.SwitchStyle(Screen);
    }
    ~Conf()
    {
    }
};

Conf conf;

6.2 Lockguard.hpp(自动管理锁)

#pragma once

#include <iostream>

class Mutex
{
public:
    Mutex(pthread_mutex_t* lock)
        :pmutex(lock)
    {}

    void Lock()
    {
        pthread_mutex_lock(pmutex);
    }

    void Unlock()
    {
        pthread_mutex_unlock(pmutex);
    }

    ~Mutex()
    {}
public:
    pthread_mutex_t* pmutex;
};

class LockGuard
{
public:
    LockGuard(pthread_mutex_t* lock)
        :mutex(lock)
    {
        mutex.Lock();
    }

    ~LockGuard()
    {
        mutex.Unlock();
    }
public:
    Mutex mutex;
};

6.3 Socket.hpp(封装套接字)

#pragma once

#include <iostream>
#include <string>
#include <string.h>
#include <unistd.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <errno.h>

#define CONV(addrptr) (struct sockaddr*)addrptr

enum{
    Socket_err = 1,
    Bind_err,
    Listen_err
};

const static int defalutsockfd = -1;
const int defalutbacklog = 5;

class Socket
{
public:
    virtual ~Socket(){};
    virtual void CreateSocketOrDie() = 0;
    virtual void BindSocketOrDie(uint16_t port) = 0;
    virtual void ListenSocketOrDie(int backlog) = 0;
    virtual int AcceptConnection(std::string* ip , uint16_t* port) = 0;
    virtual bool ConnectServer(const std::string& serverip , uint16_t serverport) = 0;
    virtual int GetSockFd() = 0;
    virtual void SetSockFd(int sockfd) = 0;
    virtual void CloseSockFd() = 0;
    virtual bool Recv(std::string& buffer,int size) = 0;
    virtual void Send(const std::string& send_string) = 0;

public:
    void BuildListenSocketMethod(uint16_t port,int backlog = defalutbacklog)
    {
        CreateSocketOrDie();
        BindSocketOrDie(port);
        ListenSocketOrDie(backlog);
    }
    bool BuildConnectSocketMethod(const std::string& serverip , uint16_t serverport)
    {
        CreateSocketOrDie();
        return ConnectServer(serverip,serverport);
    }
    void BuildNormalSocketMethod(int sockfd)
    {
        SetSockFd(sockfd);
    }
};

class TcpSocket : public Socket
{
public:
    TcpSocket(int sockfd = defalutsockfd)
        :_sockfd(sockfd)
    {}
    ~TcpSocket(){};

    void CreateSocketOrDie() override
    {
        _sockfd = ::socket(AF_INET,SOCK_STREAM,0);
        if(_sockfd < 0) exit(Socket_err);
    }

    void BindSocketOrDie(uint16_t port) override
    {
        struct sockaddr_in addr;
        memset(&addr,0,sizeof(addr));

        addr.sin_family = AF_INET;
        addr.sin_addr.s_addr = INADDR_ANY;
        addr.sin_port = htons(port);
        socklen_t len = sizeof(addr);

        int n = ::bind(_sockfd,CONV(&addr),len);
        if(n < 0) exit(Bind_err);
    }

    void ListenSocketOrDie(int backlog) override
    {
        int n = ::listen(_sockfd,backlog);
        if(n < 0) exit(Listen_err);
    }

    int AcceptConnection(std::string* clientip , uint16_t* clientport) override
    {
        struct sockaddr_in client;
        memset(&client,0,sizeof(client));
        socklen_t len = sizeof(client);
        int fd = ::accept(_sockfd,CONV(&client),&len);

        if(fd < 0) return -1;

        char buffer[64];
        inet_ntop(AF_INET,&client.sin_addr,buffer,len);
        *clientip = buffer;
        *clientport = ntohs(client.sin_port);

        return fd;
    }   

    bool ConnectServer(const std::string& serverip , uint16_t serverport) override
    {
        struct sockaddr_in server;
        memset(&server,0,sizeof(server));
        server.sin_family = AF_INET;
        // server.sin_addr.s_addr =  inet_addr(serverip.c_str());
        inet_pton(AF_INET,serverip.c_str(),&server.sin_addr);
        server.sin_port = htons(serverport);
        socklen_t len = sizeof(server);
        
        int n = connect(_sockfd,CONV(&server),len);
        if(n < 0) return false;
        else return true;
    }
    
    int GetSockFd() override
    {
        return _sockfd;
    }
    void SetSockFd(int sockfd) override
    {
        _sockfd = sockfd;
    }

    void CloseSockFd() override
    {
        if(_sockfd > defalutsockfd)
        {
            close(_sockfd);
        }
    }

    bool Recv(std::string& buffer , int size)override
    {
        char inbuffer[size];
        int n = recv(_sockfd,inbuffer,sizeof(inbuffer)-1,0);
        if(n > 0)
        {
            inbuffer[n] = 0;
        }
        else
        {
            return false;
        }

        buffer += inbuffer;

        return true;
    }

    void Send(const std::string& send_string)
    {
        send(_sockfd,send_string.c_str(),send_string.size(),0);
    }


private:
    int _sockfd;
};

6.4 Calculator.hpp

#pragma once

#include <memory>
#include "Protocol.hpp"

enum
{
    Success = 0,
    DivZeroErr,
    ModZeroErr,
    Unknown
};

class Calculator
{
public:
    Calculator()
    {}

    std::shared_ptr<Response> Cal(std::shared_ptr<Request> req)
    {
        std::shared_ptr<Response> resp = factory.BuildResponse();
        switch(req->GetOp())
        {
            case '+':
            {
                resp->SetResult(req->GetX()+req->GetY());
                break;
            }
            case '-':
            {
                resp->SetResult(req->GetX()-req->GetY());
                break;
            }
            case '*':
            {
                resp->SetResult(req->GetX()*req->GetY());
                break;
            }
            case '/':
            {
                if(req->GetY() == 0)
                {
                    resp->SetCode(DivZeroErr);
                }
                else
                {
                    resp->SetResult(req->GetX()/req->GetY());
                }
                break;
            }
            case '%':
            {
                if(req->GetY() == 0)
                {
                    resp->SetCode(ModZeroErr);
                }
                else
                {
                    resp->SetResult(req->GetX()%req->GetY());
                }
                break;
            }
            default:
            {
                resp->SetCode(Unknown);
                break;
            }
        }
        return resp;
    }

    ~Calculator(){}
public:
    Factory factory;
};

6.5 Comm.hpp

#pragma once

#include <iostream>
#include <fcntl.h>
#include <unistd.h>
#include <errno.h>

using namespace std;

// 将对应文件描述符设置为非阻塞IO
void SetNonBlock(int fd)
{
    int mode = fcntl(fd, F_GETFL);
    if (mode < 0)
    {
        perror("fcntl");
        return;
    }

    fcntl(fd, F_SETFL, mode | O_NONBLOCK);
}

6.6 Protocol.hpp

#pragma once

#include <memory>
#include <iostream>
#include <string>
#include <jsoncpp/json/json.h>

#define SelfDefine 1

const std::string ProtSep = " ";
const std::string LineBreakSep = "\r\n";  // 这里使用telnet模拟客户端,所以行分割符\r\n

// "len\nx op y\n"
// "len\nresult code\n"
// 添加自描述报头 len代表报文的长度,不包含后面的\n
// 解决用户区分报文边界问题
std::string EnCode(const std::string &info)
{
    std::string message = std::to_string(info.size()) + LineBreakSep + info + LineBreakSep;
    return message;
}

// "l"
// "len"
// "len\n"
// "len\nx"
// "len\nx op "
// "len\nx op y"
// "len\nx op y\n"
// "len\nx op y\n""le"
// "len\nx op y\n""len\nx"
// "len\nx op y\n""len\nx op y\n"

// "len\nresult code\n""len\nresult code\n"
// 取出报文
bool DeCode(std::string &message, std::string *info)
{
    // 读到len
    auto pos = message.find(LineBreakSep);
    if (pos == std::string::npos)
        return false;

    std::string len = message.substr(0, pos);
    int messagelen = stoi(len);

    // 保证读到完整的报文
    int total = len.size() + messagelen + 2 * LineBreakSep.size();
    if (message.size() < total)
        return false;

    *info = message.substr(pos + LineBreakSep.size(),messagelen);
    
    // 对已经读完的报文,再message中删除
    message.erase(0, total);

    return true;
}

// 请求
class Request
{
public:
    Request()
    {
    }
    Request(int data_x, int data_y, char op)
        : _data_x(data_x), _data_y(data_y), _oper(op)
    {
    }

    void Debug()
    {
        std::cout << _data_x << " " << _oper << " " << _data_y << std::endl;
    }

    void Test()
    {
        _data_x++;
        _data_y++;
    }

    // x op y
    // 序列化
    bool Serialize(std::string *out)
    {
#ifdef SelfDefine
        // 自己设计的反序列化方案
        *out = std::to_string(_data_x) + ProtSep + _oper + ProtSep + std::to_string(_data_y);
        return true;
#else
        // 成熟的Json序列化方案
        Json::Value root;
        root["data_x"] = _data_x;
        root["data_y"] = _data_y;
        root["oper"] = _oper;

        Json::FastWriter writer;
        *out = writer.write(root);

        return true;
#endif
    }
    // x op y
    // 反序列化
    bool Deserialize(const std::string &in)
    {
#ifdef SelfDefine
        // 自己设计的反序列化方案
        auto pos = in.find(ProtSep);
        if (pos == std::string::npos)
            return false;
        auto rpos = in.rfind(ProtSep);
        if (rpos == std::string::npos)
            return false;

        _data_x = stoi(in.substr(0, pos));
        _data_y = stoi(in.substr(rpos + ProtSep.size()));

        std::string op = in.substr(pos + ProtSep.size(), rpos - (pos + ProtSep.size()));
        if (op.size() != 1)
            return false;

        _oper = op[0];
        return true;
#else
        // 成熟的Json反序列化方案
        Json::Value root;
        Json::Reader reader;
        reader.parse(in, root);

        _data_x = root["data_x"].asInt();
        _data_y = root["data_y"].asInt();
        _oper = root["oper"].asInt();

        return true;
#endif
    }

    int GetX()
    {
        return _data_x;
    }

    int GetY()
    {
        return _data_y;
    }

    char GetOp()
    {
        return _oper;
    }

private:
    int _data_x; // 第一个参数
    int _data_y; // 第一个参数
    char _oper;  // 操作符
};

// 响应
class Response
{
public:
    Response()
        : _result(0), _code(0)
    {
    }
    Response(int result, int code)
        : _result(result), _code(code)
    {
    }

    // result code
    // 序列化
    bool Serialize(std::string *out)
    {
#ifdef SelfDefine
        // 自己设计的序列化方案
        *out = std::to_string(_result) + ProtSep + std::to_string(_code);
        return true;

#else
        // 成熟的Json序列化方案
        Json::Value root;
        root["result"] = _result;
        root["code"] = _code;

        Json::FastWriter writer;
        *out = writer.write(root);

        return true;
#endif
    }
    // result code
    // 反序列化
    bool Deserialize(const std::string &in)
    {
#ifdef SelfDefine
        // 自己设计的反序列化方案
        auto pos = in.find(ProtSep);
        if (pos == std::string::npos)
            return false;

        _result = stoi(in.substr(0, pos));
        _code = stoi(in.substr(pos + ProtSep.size()));

        return true;
#else
        // 成熟的Json反序列化方案
        Json::Value root;
        Json::Reader reader;
        reader.parse(in, root);

        _result = root["result"].asInt();
        _code = root["code"].asInt();

        return true;
#endif
    }

    void SetResult(int reslut)
    {
        _result = reslut;
    }

    void SetCode(int code)
    {
        _code = code;
    }

    int GetResult()
    {
        return _result;
    }

    int GetCode()
    {
        return _code;
    }

private:
    int _result; // 答案
    int _code;   // 答案是否有效
};

// 工厂模式,建造类设计模式
class Factory
{
public:
    // 使用智能指针创建Request对象
    std::shared_ptr<Request> BuildRequest()
    {
        std::shared_ptr<Request> req = std::make_shared<Request>();

        return req;
    }
    std::shared_ptr<Request> BuildRequest(int data_x, int data_y, char op)
    {
        std::shared_ptr<Request> req = std::make_shared<Request>(data_x, data_y, op);

        return req;
    }

    // 使用智能指针创建Response对象
    std::shared_ptr<Response> BuildResponse()
    {
        std::shared_ptr<Response> resp = std::make_shared<Response>();

        return resp;
    }
    std::shared_ptr<Response> BuildResponse(int result, int code)
    {
        std::shared_ptr<Response> resp = std::make_shared<Response>(result, code);

        return resp;
    }
};

6.7 Accepter.hpp

#pragma once

#include "Connection.hpp"
#include "HandlerConnection.hpp"
#include "Log.hpp"

class Accepter // 连接管理器
{
public:
    Accepter() {}

    // ET模式下,我们并不能确认只有一个连接到来,所以这里需要循环判断
    void AccepterConnection(Connection *conn)
    {
        errno = 0;
        
        while (true)
        {
            struct sockaddr_in client;
            socklen_t len = sizeof(client);

            int sockfd = accept(conn->GetSockfd(), (sockaddr *)&client, &len);
            if (sockfd >= 0)
            {
                lg.LogMessage(Info, "get a new link , sockfd : %d\n", sockfd);

                auto reader = bind(&HandlerConnection::Reader,placeholders::_1);
                auto writer = bind(&HandlerConnection::Writer,placeholders::_1);
                auto excepter = bind(&HandlerConnection::Excepter,placeholders::_1);

                Connection *new_conn = ConnectionFactory::BuildNormalConnection(sockfd, EPOLLIN | EPOLLET, reader, writer, excepter,conn->_R);
                conn->_R->AddConnection(new_conn);
            }
            else
            {
                if (errno == EAGAIN)
                {
                    break;
                }
                else if (errno == EINTR)
                {
                    continue;
                }
                else
                {
                    lg.LogMessage(Warning, "get a new link error\n");
                    break;
                }
            }
        }
    }

    ~Accepter()
    {
    }
};

6.8 Connection.hpp

#pragma once
#include <iostream>
#include <functional>
#include <unistd.h>
#include "TcpServer.hpp"

using namespace std;

class Connection;
class TcpServer;

using func_t = function<void(Connection *)>;

class Connection
{
public:
    Connection(int sockfd, uint32_t events, TcpServer *R)
        : _sockfd(sockfd), _events(events), _R(R)
    {
    }

    // 注册连接三个回调处理函数
    void RegisterCallback(func_t reader, func_t writer, func_t excepter)
    {
        _reader = reader;
        _writer = writer;
        _excepter = excepter;
    }

    // 判断输出缓冲区是否为空
    bool OutBufferEmpty()
    {
        return _outbuffer.empty();
    }

    // 向输入缓冲区中追加数据
    void AddInBuffer(string buffer)
    {
        _inbuffer += buffer;
    }

    // 向输出缓冲区中追加数据
    void AddOutBuffer(string buffer)
    {
        _outbuffer += buffer;
    }

    // 获取连接的文件描述符
    int GetSockfd()
    {
        return _sockfd;
    }

    // 获取输入缓冲区的引用
    string& GetInbuffer()
    {
        return _inbuffer;
    }

    // 获取输出缓冲区的引用
    string& GetOutbuffer()
    {
        return _outbuffer;
    }

    // 获取连接对应文件描述符关注事件
    uint32_t GetEvents()
    {
        return _events;
    }

    // 设置连接对应文件描述符关注事件
    void SetEvents(uint32_t events)
    {
        _events = events;
    }

    // 关闭文件描述符
    void Close(int sockfd)
    {
        close(sockfd);
    }

    ~Connection()
    {
    }

private:
    int _sockfd;      // 文件描述符
    uint32_t _events; // 事件

    string _inbuffer;  // 输入缓冲区
    string _outbuffer; // 输入缓冲区

public:
    func_t _reader;   // 读方法
    func_t _writer;   // 写方法
    func_t _excepter; // 异常处理方法

    TcpServer *_R; // 回指指针
};

class ConnectionFactory
{
public:
    // 构建Listen文件描述符的连接
    static Connection *BuildListenConnection(int sockfd, uint32_t events, func_t reader, TcpServer *R)
    {
        Connection *conn = new Connection(sockfd,events,R);
        conn->RegisterCallback(reader, nullptr, nullptr);

        return conn;
    }
    
    // 构建普通文件描述符的连接
    static Connection *BuildNormalConnection(int sockfd, uint32_t events, func_t reader, func_t writer, func_t excepter, TcpServer *R)
    {
        Connection *conn = new Connection(sockfd,events,R);
        conn->RegisterCallback(reader, writer, excepter);

        return conn;
    }
};

6.9 Epoller.hpp

#pragma once

#include <sys/epoll.h>
#include <string.h>
#include <set>
#include "Log.hpp"

using namespace std;

const static int defaultepfd = -1;
const static int defaultsize = 1024;

class Epoller
{
public:
    Epoller()
        : _epfd(defaultepfd)
    {
    }

    // 初始化Epoller
    void InitEpoller()
    {
        _epfd = epoll_create(defaultsize);

        if (_epfd == defaultepfd)
        {
            lg.LogMessage(Fatal, "epoll_create fail , error: %s , errno : %d\n", strerror(errno), errno);
            exit(-1);
        }
        else
        {
            lg.LogMessage(Info, "epoll_create success , epfd : %d\n", _epfd);
        }
    }

    // 向对应文件描述符中添加事件关心
    int AddEvent(int sockfd, uint32_t events)
    {
        fd_list.insert(sockfd); // for test

        struct epoll_event event;

        event.events = events;
        event.data.fd = sockfd;

        int n = epoll_ctl(_epfd, EPOLL_CTL_ADD, sockfd, &event);

        if (n == -1)
        {
            lg.LogMessage(Fatal, "add %d event fail , error: %s , errno : %d\n", sockfd, strerror(errno), errno);
        }
        else
        {
            lg.LogMessage(Info, "add %d event success\n", sockfd);
        }

        return n;
    }

    // 等待事件就绪
    int Wait(struct epoll_event *events, int maxevents, int timeout)
    {
        int n = epoll_wait(_epfd, events, maxevents, timeout);

        if (n == -1)
        {
            lg.LogMessage(Error, "wait %d fail , error: %s , errno : %d\n", events->data.fd, strerror(errno), errno);
        }
        else
        {
            lg.LogMessage(Info, "wait %d success , events %d \n", events->data.fd, events->events);
        }

        return n;
    }

    // 用于测试,输出目前所有连接的文件描述符
    void DebugFdList() // for test
    {
        std::cout << "fd list is : ";
        for (auto &fd : fd_list)
        {
            std::cout << fd << " ";
        }
        std::cout << std::endl;
    }

    // 修改对应文件描述符的事件关心
    void ModEvents(int sockfd, uint32_t events)
    {
        struct epoll_event event;

        event.events = events;
        event.data.fd = sockfd;

        int n = epoll_ctl(_epfd, EPOLL_CTL_MOD, sockfd, &event);
        if (n < 0)
        {
            lg.LogMessage(Error, "epoll_ctl mod fail, error : %s , errno : %d\n", strerror(errno), errno);
        }
    }

    // 删除对应文件描述符的事件关心
    void DelEvents(int sockfd)
    {
        fd_list.erase(sockfd); // for test

        int n = epoll_ctl(_epfd, EPOLL_CTL_DEL, sockfd, nullptr);
        if (n < 0)
        {
            lg.LogMessage(Error, "epoll_ctl del fail, error : %s , errno : %d\n", strerror(errno), errno);
        }
    }

    ~Epoller()
    {
        if (_epfd >= 0)
        {
            close(_epfd);
        }
        lg.LogMessage(Info, "epoll close success\n");
    }

private:
    int _epfd;
    set<int> fd_list;
};

6.10 HandlerConnection.hpp

#pragma once

#include <iostream>
#include <cerrno>
#include "Connection.hpp"
#include "Log.hpp"
#include "Protocol.hpp"
#include "Calculator.hpp"

const static int buffer_size = 1024;

class HandlerConnection
{
public:
    // 处理请求
    static void HandlerRequest(Connection *conn)
    {
        string &inbuffer = conn->GetInbuffer();
        string messages; // 从inbuffer中截取的一个完整报文
        Factory factory;

        auto req = factory.BuildRequest();

        Calculator cal; // 负责业务

        // 明确报文边界,解决粘包问题
        while (DeCode(inbuffer, &messages))
        {
            // 反序列化,得到完整报文再进行后序操作
            if (!req->Deserialize(messages))
                continue;

            // 业务处理
            auto resp = cal.Cal(req);

            // 序列化
            string info;
            resp->Serialize(&info);

            // 封装完整报文
            string outmessage = EnCode(info);

            // 将报文加入输出缓冲区
            conn->AddOutBuffer(outmessage);
        }

        if (!conn->OutBufferEmpty())
            conn->_writer(conn);
    }

    // 读
    static void Reader(Connection *conn)
    {
        errno = 0;

        string &inbuffer = conn->GetInbuffer();
        // 不能保证一次性读完,需要重复读
        while (true)
        {
            char buffer[1024];
            ssize_t n = recv(conn->GetSockfd(), buffer, sizeof(buffer) - 1, 0);

            if (n > 0)
            {
                buffer[n] = 0;
                conn->AddInBuffer(buffer);
            }
            else
            {
                if (errno == EAGAIN)
                {
                    break;
                }
                else if (errno == EINTR)
                {
                    continue;
                }
                else
                {
                    lg.LogMessage(Info, "reader error , becasue : %s , errno : ", strerror(errno), errno);
                    conn->_excepter(conn);
                    return;
                }
            }
            // 读取完毕后,可能出现完整报文,尝试处理报文
            HandlerRequest(conn);
        }
    }

    // 写
    static void Writer(Connection *conn)
    {
        errno = 0;

        string &outbuffer = conn->GetOutbuffer();

        // 不能保证一次性写完,需要重复写
        while (true)
        {
            ssize_t n = send(conn->GetSockfd(), outbuffer.c_str(), outbuffer.size(), 0);

            if (n >= 0)
            {
                // 发送完毕后,将数据从输出缓冲区中删除
                outbuffer.erase(0, n);

                if (conn->OutBufferEmpty())
                    break;
            }
            else
            {
                if (errno == EAGAIN)
                {
                    break;
                }
                else if (errno == EINTR)
                {
                    continue;
                }
                else
                {
                    lg.LogMessage(Info, "writer error , becasue : %s , errno : ", strerror(errno), errno);
                    conn->_excepter(conn);
                    return;
                }
            }

            if (!conn->OutBufferEmpty())
                conn->_R->EnableReadWrite(conn, true, true);
            else
                conn->_R->EnableReadWrite(conn, true, false);
        }
    }

    // 异常处理
    static void Excepter(Connection *conn)
    {
        errno = 0;

        lg.LogMessage(Info, "connection erase, sockfd: %d\n", conn->GetSockfd());
        // 从unordered_map中删除该连接
        conn->_R->DelConnection(conn->GetSockfd());
        conn->Close(conn->GetSockfd());

        delete conn;
    }
};

6.11 Makefile

epoll_server:Main.cpp
	g++ $^ -o $@ -std=c++14
.PHONY:clean
clean:
	rm -f epoll_server

6.12 TcpServer.hpp

#pragma once

#include <iostream>
#include <unordered_map>
#include <sys/epoll.h>
#include <sys/types.h>
#include <sys/socket.h>

#include "Connection.hpp"
#include "Epoller.hpp"
#include "Comm.hpp"
#include "Log.hpp"

#define EPOLL_NOBLOCK  -1

class TcpServer
{
    const static int gmaxevents = 64;

public:
    TcpServer()
        : _isrunning(false), _timeout(EPOLL_NOBLOCK)
    {
        _epoller.InitEpoller();
    }

    // 判断连接是否存在
    bool IsConnectionExists(int sockfd)
    {
        auto iter = _connections.find(sockfd);
        if (iter == _connections.end())
            return false;
        else
            return true;
    }

    // 修改连接文件描述符是否关系读写时间
    void EnableReadWrite(Connection *conn, bool readable, bool writeable)
    {
        if (IsConnectionExists(conn->GetSockfd()))
            return;

        uint32_t events = (readable ? EPOLLIN : 0) | (writeable ? EPOLLOUT : 0);

        conn->SetEvents(events);
    }

    // 创建新连接,并加入到unordered_map中
    void AddConnection(Connection *conn)
    {
        if (IsConnectionExists(conn->GetSockfd()))
            return;

        lg.LogMessage(Info, "get a new connection , sockfd : %d \n", conn->GetSockfd());
        _connections[conn->GetSockfd()] = conn;
        _epoller.AddEvent(conn->GetSockfd(), conn->GetEvents());

        SetNonBlock(conn->GetSockfd()); // 将文件描述符设置为非阻塞
        lg.LogMessage(Info,"set %d non block\n",conn->GetSockfd());
    }

    // 从unordered_ma中删除连接
    void DelConnection(int sockfd)
    {
        if(!IsConnectionExists(sockfd))
            return;
        // 从unordered_map中删除该连接
        _connections.erase(sockfd);
        // 从Epoller中删除对应文件描述符
        _epoller.DelEvents(sockfd);
    }

    // 单次循环
    void LoopOnce(int timeout)
    {
        int n = _epoller.Wait(_revs, gmaxevents, timeout);
        for (int i = 0; i < n; i++)
        {
            // 将所有的异常,转换为读写异常
            if (_revs[i].events & EPOLLERR || _revs[i].events & EPOLLHUP)
            {
                _revs[i].events |= (EPOLLIN | EPOLLOUT);
            }

            int sockfd = _revs[i].data.fd;
            if (_revs[i].events & EPOLLIN)
            {
                if (_connections[sockfd]->_reader)
                    _connections[sockfd]->_reader(_connections[sockfd]);
            }

            if (_revs[i].events & EPOLLOUT)
            {
                if (_connections[sockfd]->_writer)
                    _connections[sockfd]->_writer(_connections[sockfd]);
            }
        }
    }

    void Dispatcher() // 事件派发器
    {
        _isrunning = true;
        while (_isrunning)
        {
            _epoller.DebugFdList(); // for test
            LoopOnce(_timeout);
        }
    }

    ~TcpServer()
    {
    }

private:
    unordered_map<int, Connection *> _connections;
    Epoller _epoller;

    struct epoll_event _revs[gmaxevents];
    int _timeout;

    bool _isrunning;
};


6.13 Main.cpp

#include <iostream>
#include <string>
#include <memory>

#include "TcpServer.hpp"
#include "Connection.hpp"
#include "Socket.hpp"
#include "Accepter.hpp"

const static int gbacklog = 8;

using namespace std;

void Usage(const string &proc)
{
    cout << "Usage : " << proc << " port" << endl;
}

int main(int argc, char *argv[])
{
    if (argc != 2)
    {
        Usage(argv[0]);
        exit(-1);
    }

    uint16_t port = stoi(argv[1]);

    // 1. 创建listensock
    unique_ptr<Socket> listen_socket = make_unique<TcpSocket>();
    listen_socket->BuildListenSocketMethod(port, gbacklog);

    // 2. 创建tcpserver
    unique_ptr<TcpServer> srv = make_unique<TcpServer>();
    unique_ptr<Accepter> accepter = make_unique<Accepter>();
    auto listen_reader = bind(&Accepter::AccepterConnection,accepter.get(),placeholders::_1);

    // 3. 构建listen对应的connection,添加到tcpserver中
    Connection *listen_connection = ConnectionFactory::BuildListenConnection(listen_socket->GetSockFd(), EPOLLIN | EPOLLET, listen_reader,srv.get());
    srv->AddConnection(listen_connection);

    // 4. 开始事件派发
    srv->Dispatcher();

    return 0;
}

结尾

如果有什么建议和疑问,或是有什么错误,大家可以在评论区中提出。
希望大家以后也能和我一起进步!!🌹🌹
如果这篇文章对你有用的话,希望大家给一个三连支持一下!!🌹🌹

在这里插入图片描述

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐