在这里插入图片描述

一、poll函数

#include <poll.h>

int poll(struct pollfd *fds, nfds_t nfds, int timeout);

功能:poll 是 Linux 系统中的一种 I/O 多路复用机制,主要用于同时监控多个文件描述符的状态

参数

  • fds:指向 struct pollfd 数组的指针,每个元素指定一个要监控的文件描述符及其关注的事件
  • nfds:数组中元素的数量(即监控的文件描述符总数)
  • timeout:超时时间(毫秒):
    • -1:永久阻塞,直到有事件发生
    • 0:立即返回(非阻塞模式)
    • >0:指定超时时间,超时后返回

返回值

  • 正数:表示就绪的文件描述符总数
  • 0:表示超时(无文件描述符就绪)
  • -1:表示错误,并设置 errno

struct pollfd 结构

struct pollfd {
    int fd;         // 文件描述符
    short events;   // 关注的事件(输入掩码,如 POLLIN、POLLOUT)
    short revents;  // 实际发生的事件(输出掩码,由内核填充)
};

在这里插入图片描述

poll 函数支持的标准事件类型,本质上是宏,只有一个比特位为1,通过与events和revents异或分为两种情况:

  • 调用时:用户告诉内核,需要关注文件描述符中的events事件
  • 返回时:内核告诉用户,用户关注的文件描述符,有revents中的事件准备就绪
事件 描述 是否可以作为输入 是否可以作为输出
POLLIN 有普通数据或优先数据可读
POLLRDNORM 有普通数据可读
POLLRDBAND 有优先级带数据可读
POLLPRI 有高优先级带数据可读
POLLOUT 有普通数据或优先数据可写
POLLWRNORM 有普通数据可写
POLLWRBAND 有优先级带数据可写
POLLRDHUP TCP连接的对端关闭连接,或关闭了写操作
POLLHUP 挂起
POLLERR 错误
POLLNVAL 文件描述符未打开

二、poll的优缺点

  • 优点
    1. poll 只负责等待,可以等待多个文件描述符,在IO的时候效率会比较高
    2. 输入和输出参数进行分离,events和revents,不需要再对poll的参数进行频繁的重置了
    3. poll使用了动态数组,所以 poll 能够检测文件描述符的个数也是没有有限的
  • 缺点
    1. 用户和内核之间,需要一直进行数据拷贝
    2. 在编写代码的时候,需要遍历动态数组,可能会影响select的效率
    3. poll 会让操作系统在底层遍历要关心的所有文件描述符,会导致效率降低

三、实现poll服务器(只关心读事件)

3.1 Log.hpp(日志)

#pragma once

#include "LockGuard.hpp"
#include <iostream>
#include <string>
#include <stdarg.h>
#include <stdio.h>
#include <pthread.h>
#include <time.h>
#include <unistd.h>
#include <sys/types.h>
#include <fcntl.h>
#include <sys/stat.h>

using namespace std;

// 日志等级
enum
{
    Debug = 0, // 调试
    Info,      // 正常
    Warning,   // 警告
    Error,     // 错误,但程序并未直接退出
    Fatal      // 程序直接挂掉
};

enum
{
    Screen = 10, // 打印到显示器上
    OneFile,     // 打印到一个文件中
    ClassFile    // 按照日志等级打印到不同的文件中
};

string LevelToString(int level)
{
    switch (level)
    {
    case Debug:
        return "Debug";
    case Info:
        return "Info";
    case Warning:
        return "Warning";
    case Error:
        return "Error";
    case Fatal:
        return "Fatal";
    default:
        return "Unknow";
    }
}

const char *default_filename = "log.";
const int default_style = Screen;
const char *defaultdir = "log";

class Log
{
public:
    Log()
        : style(default_style), filename(default_filename)
    {
        // mkdir(defaultdir,0775);
        pthread_mutex_init(&_log_mutex, nullptr);
    }

    void SwitchStyle(int sty)
    {
        style = sty;
    }

    void WriteLogToOneFile(const string &logname, const string &logmessage)
    {
        int fd = open(logname.c_str(), O_CREAT | O_WRONLY | O_APPEND, 0666);
        if (fd == -1)
            return;

        {
            LockGuard lockguard(&_log_mutex);
            write(fd, logmessage.c_str(), logmessage.size());
        }
        close(fd);
    }

    void WriteLogToClassFile(const string &levelstr, const string &logmessage)
    {
        mkdir(defaultdir, 0775);

        string name = defaultdir;
        name += "/";
        name += filename;
        name += levelstr;

        WriteLogToOneFile(name, logmessage);
    }

    void WriteLog(int level, const string &logmessage)
    {
        switch (style)
        {
        case Screen:
        {
            LockGuard lockguard(&_log_mutex);
            cout << logmessage;
        }
        break;
        case OneFile:
            WriteLogToClassFile("All", logmessage);
            break;
        case ClassFile:
            WriteLogToClassFile(LevelToString(level), logmessage);
            break;
        default:
            break;
        }
    }

    string GetTime()
    {
        time_t CurrentTime = time(nullptr);

        struct tm *curtime = localtime(&CurrentTime);
        char time[128];

        // localtime 的年是从1900开始的,所以要加1900, 月是从0开始的所以加1
        snprintf(time, sizeof(time), "%d-%d-%d %d:%d:%d",
                 curtime->tm_year + 1900, curtime->tm_mon + 1, curtime->tm_mday,
                 curtime->tm_hour, curtime->tm_min, curtime->tm_sec);

        return time;
        return "";
    }

    void LogMessage(int level, const char *format, ...)
    {
        char left[1024];
        string Levelstr = LevelToString(level).c_str();
        string Timestr = GetTime().c_str();
        string Idstr = to_string(getpid());
        snprintf(left, sizeof(left), "[%s][%s][%s] ",
                 Levelstr.c_str(), Timestr.c_str(), Idstr.c_str());

        va_list args;
        va_start(args, format);
        char right[1024];
        vsnprintf(right, sizeof(right), format, args);

        string logmessage = left;
        logmessage += right;

        WriteLog(level, logmessage);

        va_end(args);
    }

    ~Log()
    {
        pthread_mutex_destroy(&_log_mutex);
    };

private:
    int style;
    string filename;

    pthread_mutex_t _log_mutex;
};

Log lg;

class Conf
{
public:
    Conf()
    {
        lg.SwitchStyle(Screen);
    }
    ~Conf()
    {
    }
};

Conf conf;

3.2 Lockguard.hpp(自动管理锁)

#pragma once

#include <iostream>

class Mutex
{
public:
    Mutex(pthread_mutex_t* lock)
        :pmutex(lock)
    {}

    void Lock()
    {
        pthread_mutex_lock(pmutex);
    }

    void Unlock()
    {
        pthread_mutex_unlock(pmutex);
    }

    ~Mutex()
    {}
public:
    pthread_mutex_t* pmutex;
};

class LockGuard
{
public:
    LockGuard(pthread_mutex_t* lock)
        :mutex(lock)
    {
        mutex.Lock();
    }

    ~LockGuard()
    {
        mutex.Unlock();
    }
public:
    Mutex mutex;
};


3.3 Socket.hpp(封装套接字)

#pragma once

#include <iostream>
#include <string>
#include <string.h>
#include <unistd.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <errno.h>

#define CONV(addrptr) (struct sockaddr*)addrptr

enum{
    Socket_err = 1,
    Bind_err,
    Listen_err
};

const static int defalutsockfd = -1;
const int defalutbacklog = 5;

class Socket
{
public:
    virtual ~Socket(){};
    virtual void CreateSocketOrDie() = 0;
    virtual void BindSocketOrDie(uint16_t port) = 0;
    virtual void ListenSocketOrDie(int backlog) = 0;
    virtual int AcceptConnection(std::string* ip , uint16_t* port) = 0;
    virtual bool ConnectServer(const std::string& serverip , uint16_t serverport) = 0;
    virtual int GetSockFd() = 0;
    virtual void SetSockFd(int sockfd) = 0;
    virtual void CloseSockFd() = 0;
    virtual bool Recv(std::string& buffer,int size) = 0;
    virtual void Send(const std::string& send_string) = 0;

public:
    void BuildListenSocketMethod(uint16_t port,int backlog = defalutbacklog)
    {
        CreateSocketOrDie();
        BindSocketOrDie(port);
        ListenSocketOrDie(backlog);
    }
    bool BuildConnectSocketMethod(const std::string& serverip , uint16_t serverport)
    {
        CreateSocketOrDie();
        return ConnectServer(serverip,serverport);
    }
    void BuildNormalSocketMethod(int sockfd)
    {
        SetSockFd(sockfd);
    }
};

class TcpSocket : public Socket
{
public:
    TcpSocket(int sockfd = defalutsockfd)
        :_sockfd(sockfd)
    {}
    ~TcpSocket(){};

    void CreateSocketOrDie() override
    {
        _sockfd = ::socket(AF_INET,SOCK_STREAM,0);
        if(_sockfd < 0) exit(Socket_err);
    }

    void BindSocketOrDie(uint16_t port) override
    {
        struct sockaddr_in addr;
        memset(&addr,0,sizeof(addr));

        addr.sin_family = AF_INET;
        addr.sin_addr.s_addr = INADDR_ANY;
        addr.sin_port = htons(port);
        socklen_t len = sizeof(addr);

        int n = ::bind(_sockfd,CONV(&addr),len);
        if(n < 0) exit(Bind_err);
    }

    void ListenSocketOrDie(int backlog) override
    {
        int n = ::listen(_sockfd,backlog);
        if(n < 0) exit(Listen_err);
    }

    int AcceptConnection(std::string* clientip , uint16_t* clientport) override
    {
        struct sockaddr_in client;
        memset(&client,0,sizeof(client));
        socklen_t len = sizeof(client);
        int fd = ::accept(_sockfd,CONV(&client),&len);

        if(fd < 0) return -1;

        char buffer[64];
        inet_ntop(AF_INET,&client.sin_addr,buffer,len);
        *clientip = buffer;
        *clientport = ntohs(client.sin_port);

        return fd;
    }   

    bool ConnectServer(const std::string& serverip , uint16_t serverport) override
    {
        struct sockaddr_in server;
        memset(&server,0,sizeof(server));
        server.sin_family = AF_INET;
        // server.sin_addr.s_addr =  inet_addr(serverip.c_str());
        inet_pton(AF_INET,serverip.c_str(),&server.sin_addr);
        server.sin_port = htons(serverport);
        socklen_t len = sizeof(server);
        
        int n = connect(_sockfd,CONV(&server),len);
        if(n < 0) return false;
        else return true;
    }
    
    int GetSockFd() override
    {
        return _sockfd;
    }
    void SetSockFd(int sockfd) override
    {
        _sockfd = sockfd;
    }

    void CloseSockFd() override
    {
        if(_sockfd > defalutsockfd)
        {
            close(_sockfd);
        }
    }

    bool Recv(std::string& buffer , int size)override
    {
        char inbuffer[size];
        int n = recv(_sockfd,inbuffer,sizeof(inbuffer)-1,0);
        if(n > 0)
        {
            inbuffer[n] = 0;
        }
        else
        {
            return false;
        }

        buffer += inbuffer;

        return true;
    }

    void Send(const std::string& send_string)
    {
        send(_sockfd,send_string.c_str(),send_string.size(),0);
    }


private:
    int _sockfd;
};

3.4 PollServer.hpp(服务端封装)

#pragma once

#include <iostream>
#include <string>
#include <poll.h>
#include "Socket.hpp"
#include "Log.hpp"
#include <memory>

using namespace std;

const static uint16_t defalutport = 8888;
const static int gbacklog = 8;
const static int num = 1024;

class PollServer
{
private:
    void HandlerEvent()
    {
        for (int i = 0; i < _num; i++)
        {
            // 是否监控
            if (_rfds[i].fd == -1)
                continue;

            // 是否就绪
            int fd = _rfds[i].fd;

            if (_rfds[i].revents & POLLIN)
            {
                // 是新连接到来,还是新数据到来
                // 新连接到来
                if (fd == _listensock->GetSockFd())
                {
                    lg.LogMessage(Info, "get a new link\n");

                    string clientip;
                    uint16_t cilentport;
                    // 由于select已经检测到listensock已经就绪了,这里不会阻塞
                    int sockfd = _listensock->AcceptConnection(&clientip, &cilentport);
                    if (sockfd == -1)
                    {
                        lg.LogMessage(Error, "accept error\n");
                        continue;
                    }
                    lg.LogMessage(Info, "get a client , client info# %s %d , fd:%d\n", clientip.c_str(), cilentport, sockfd);

                    // 这里已经获取连接成功,由于底层数据不一定就绪
                    // 所以这里需要将新连接的文件描述符交给poll托管
                    // 只需将文件描述符加入到_rfds即可
                    int pos = 0;
                    for (; pos < _num; pos++)
                    {
                        if (_rfds[pos].fd == -1)
                        {
                            _rfds[pos].fd = sockfd;
                            _rfds[pos].events |= POLLIN;
                            break;
                        }
                    }

                    // 当存储上限时,可以选择扩容,由于poll并不是很重要,这里我为了方便就直接关闭文件描述符
                    if(pos == _num)
                    {
                        close(sockfd);
                        lg.LogMessage(Warning, "server is full...!\n");
                    }
                }
                else
                {  // 是新数据来了
                    // 这里读是有问题的
                    char buffer[1024];
                    bool flag = recv(fd,buffer,1024,0);
                    if(flag)  // 读取成功
                    {
                        lg.LogMessage(Info,"client say# %s\n",buffer);
                    }
                    else  // 读取失败
                    {
                        lg.LogMessage(Warning,"cilent quit !! close fd : %d\n",fd);
                        close(fd);
                        _rfds[i].fd = -1;
                        _rfds[i].events = 0;
                        _rfds[i].revents = 0;
                    }
                }
            }
        }
    }

public:
    PollServer(uint16_t port = defalutport)
        : _port(port), _listensock(new TcpSocket()), _isrunning(false), _num(num),_rfds(new pollfd[_num])
    {
    }

    void Init()
    {
        _listensock->BuildListenSocketMethod(_port, gbacklog);
        for (int i = 0; i < _num; i++)
        {
            _rfds[i].fd = -1;
            _rfds[i].events = 0;
            _rfds[i].revents = 0;
        }
        _rfds[0].fd = _listensock.get()->GetSockFd();
        _rfds[0].events |= POLLIN;
    }

    void Loop()
    {
        _isrunning = true;
        while (_isrunning)
        {
            PrintDebug();

            int timeout = 1000;
            ssize_t n = poll(_rfds,_num,timeout);

            switch (n)
            {
            case -1:
            {
                lg.LogMessage(Fatal, "select Error\n");
                break;
            }
            case 0:
            {
                lg.LogMessage(Info, "select timeout...");
                break;
            }
            default:
            {
                lg.LogMessage(Info, "select success , begin handler event\n");
                HandlerEvent();
                break;
            }
            }
        }
        _isrunning = false;
    }

    void Stop()
    {
        _isrunning = false;
    }

    // 查看当前哪些文件描述符需要被监控
    void PrintDebug()
    {
        std::cout << "current select rfds list is : ";
        for (int i = 0; i < _num; i++)
        {
            if (_rfds[i].fd == -1)
                continue;
            else
                std::cout << _rfds[i].fd << " ";
        }
        std::cout << std::endl;
    }

    ~PollServer() 
    {
        delete[] _rfds;
    }

private:
    unique_ptr<Socket> _listensock;
    uint16_t _port;
    bool _isrunning;

    int _num;
    struct pollfd* _rfds;
};

3.5 Main.cpp(服务端)

#include <iostream>
#include <memory>
#include "PollServer.hpp"

using namespace std;

// ./pollServer port
int main(int argc , char* argv[])
{
    if(argc != 2)
    {
        cout << "Usage : " << argv[0] << " port" << endl;
        exit(0); 
    }

    uint16_t localport = stoi(argv[1]);
    unique_ptr<PollServer> svr = make_unique<PollServer>(localport);

    svr->Init();
    svr->Loop();

    return 0;
}


结尾

如果有什么建议和疑问,或是有什么错误,大家可以在评论区中提出。
希望大家以后也能和我一起进步!!🌹🌹
如果这篇文章对你有用的话,希望大家给一个三连支持一下!!🌹🌹

在这里插入图片描述

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐