【.Net】TCP级别的反向代理、Socket连接池和数据包解析器

2/11/2016来源:C#应用人气:4106

背景

最近特别忙,博客久未更新。
回顾了一下2010-2011年的一些.Net项目代码,觉得对初学者可能有一定参考作用,这里share一下。
主要包括:

  1. TCP反向代理
  2. Socket连接池
  3. 数据包解析

反向代理

一般的Web反向代理大家很熟悉了,主要是通过在客户端和服务端之间架设一层代理服务器,转发客户端的请求至服务端或数据库,并将结果回复给客户端。
其特点主要有:

1、缓存一些数据库I/O过重、却更新不频繁的数据,或者静态数据,如文件、图片等。
2、隔离客户端(公网)和服务端(windows服务、Web服务、文件服务),仅将反向代理服务器的ip、名称、host和端口等暴露给公网。
3、基于第2点,其应该是轻量的、可随时重启的,这在服务端自身所在的服务器重启代价较高或不能忍受重启的条件下,极为有用。

比如服务端本身需要处理大量业务逻辑,可能涉及重计算(cpu和内存要求高)、重I/O(磁盘和网络要求高)或者混合类型,那么服务端的机器成本就很高,因为需要更强力的cpu,更大容量的内存,和更快的磁盘和网络,如果还需要考虑DDOS和CC防御能力,服务端的机器成本将急剧上升。

此时,可考虑反向代理技术,选择带硬防、配置比服务端低很多的廉价机器,来作为反向代理服务器(比如阿x云的云主机,带5G硬防,但其非SSD的云磁盘I/O能力很差,此时不能作为业务服务端的宿主机器,但可以作为反向代理服务器),来组成反向代理分布集群。

DDOS攻击,流量需要聚合到一个峰值,才会打死带防机器,而根据DDOS攻击者所具备的流量打压机器和网络条件的不同,这通常需要一段时间,通过反向代理分布集群,一台反向代理被打死,死亡或黑洞窗口通常在半小时至数小时内,如果能保证有相对充裕的反向代理储备,使得整个集群阵亡前,能有跳出黑洞复生的代理机重新加入集群为客户端提供服务,那么就可以形成对抗。即使储备不足,至少可以为受到攻击时的决策赢得更多时间。

综上所述,反向代理技术通过增加额外的网络传输时间,却获得了很多客户端与服务端直接连接所不具备的优势。

通常的web应用服务器,如nginx都可提供反向代理能力。

但tcp级别的反向代理还是比较少的。

当时的项目倒逼出了这么一个需求,其中用到一些基础组件如下。

连接池

如果客户端使用tcp协议和反向代理服务器通讯,比如常见的桌面客户端,那么可以考虑单个长连接 + 异步的方式连接至代理服务器。

而多台代理服务器和真正的业务服务端之间,由于代理和服务端之间多为同步通讯,为了效率,可考虑使用多连接 + 池化的技术,让连接介于长、短之间,综合两者的长处。

下面给出项目中真实使用过的连接池代码,实现中参考了当时MongoDB的C#驱动部分:

    /// <summary>
    /// 连接池
    /// 特性及更新:
    /// 1:从单个移除不可用连接,变为批量移除
    /// 2:移除连接不再防止,暴露重连风暴风险
    ///   目的:尽快尽多发现不可用连接,防止请求失败
    ///   考虑:一般只开放200个连接,没什么大问题.
    /// 5:增大排队线程数和排队超时时间,考虑:网络抖动和业务层慢操作
    /// 6:增大连接最大存活和空闲时间,考虑:网络抖动和业务层慢操作
    /// 7:尽最大可能负载请求并减轻 某一瞬间 传递给主力的请求和连接数目
    /// </summary>
    public class sessionPool
    {
        PRivate object _poolLock = new object();
        public int PoolSize { get; set; }

        public IList<SyncTcpSession> AvaliableSessions
        {
            get { return _avaliableSessions; }
        }

        public ILog Logger;
        private int _waitQueueSize;
        private bool _inMaintainPoolSize;
        private bool _inEnsureMinConnectionPoolSizeWorkItem;
        private IList<SyncTcpSession> _avaliableSessions = new List<SyncTcpSession>();
        public int MaxWaitQueueSize { get; set; }
        public int MaxConnectionPoolSize { get; set; }
        public int MinConnectionPoolSize { get; set; }

        public TimeSpan WaitQueueTimeout { get; set; }
        /// <summary>
        /// 连接最大存活时间(分)
        /// </summary>
        public TimeSpan MaxConnectionLifeTime { get; set; }
        /// <summary>
        /// 连接最大空闲时间(秒)
        /// </summary>
        public TimeSpan MaxConnectionIdleTime { get; set; }
        public IPEndPoint RemoteAddr { get; set; }

        public SessionPool(ILog log)
        {
            Logger = log;
        }

        /// <summary>
        /// 获取可用连接
        /// </summary>
        /// <returns></returns>
        public SyncTcpSession GetAvaliableSession()
        {
            lock (_poolLock)
            {
                //等待获取连接的线程发生了严重积压
                //说明连接数量不足以应付业务,或者业务层处理积压
                //考虑优化业务层和数据库或者增加超时时间
                if (_waitQueueSize >= MaxWaitQueueSize)
                {
                    var ex = new Exception("等待获取连接的线程数过多!");
                    Logger.Error(ex.Message, ex);
                    return null;
                }
                _waitQueueSize += 1;
                try
                {
                    DateTime timeoutAt = DateTime.Now + WaitQueueTimeout;
                    while (true)
                    {
                        //有可用连接
                        if (_avaliableSessions.Count > 0)
                        {
                            //先尝试找到已经打开过的连接
                            for (int i = _avaliableSessions.Count - 1; i >= 0; i--)
                            {
                                if (_avaliableSessions[i].State == SessionState.Open)
                                {
                                    var connection = _avaliableSessions[i];
                                    _avaliableSessions.RemoveAt(i);
                                    return connection;
                                }
                            }

                            //否则去掉最近最少使用的连接,并返回新连接
                            AvaliableSessions[0].Close();
                            AvaliableSessions.RemoveAt(0);
                            return new SyncTcpSession(this);
                        }

                        //无可用连接,新建连接
                        if (PoolSize < MaxConnectionPoolSize)
                        {
                            var connection = new SyncTcpSession(this);
                            PoolSize += 1;
                            return connection;
                        }

                        //不能创建新的连接也没有可用连接,等待连接被回收.
                        var timeRemaining = timeoutAt - DateTime.Now;
                        if (timeRemaining > TimeSpan.Zero)
                        {
                            Monitor.Wait(_poolLock, timeRemaining);
                        }
                        else
                        {
                            //等待超时,说明连接数量不足以应付业务,或者业务层处理积压,考虑优化业务层和数据库或者增加超时时间
                            var ex = new TimeoutException("等待SyncTcpSession已超时.");
                            Logger.Error(ex.Message, ex);
                        }
                    }
                }
                finally
                {
                    _waitQueueSize -= 1;
                }
            }
        }

        /// <summary>
        /// 清空连接池
        /// </summary>
        public void Clear()
        {
            lock (_poolLock)
            {
                foreach (var connection in AvaliableSessions)
                {
                    connection.Close();
                }
                AvaliableSessions.Clear();
                PoolSize = 0;
                Monitor.Pulse(_poolLock);
                Logger.Info("连接池已清空.");
            }
        }
        /// <summary>
        /// 维护连接池的连接数量
        /// </summary>
        public void MaintainPoolSize()
        {
            if (_inMaintainPoolSize)
            {
                return;
            }

            _inMaintainPoolSize = true;
            try
            {
                IList<SyncTcpSession> connectionsToRemove = new List<SyncTcpSession>();
                lock (_poolLock)
                {
                    var now = DateTime.Now;
                    //已改为:移除全部不可用连接,暴露连接风暴风险,但考虑实际业务连接很少闲置,连接风暴风险较小
                    for (int i = AvaliableSessions.Count - 1; i >= 0; i--)
                    {
                        var connection = AvaliableSessions[i];
                        if (now > connection.CreatedAt + MaxConnectionLifeTime
                            || now > connection.LastUsedAt + MaxConnectionIdleTime
                            || connection.IsConnected() == false)
                        {
                            //超过最大生命、闲置时间或未连接则关闭
                            //加入删除集合
                            connectionsToRemove.Add(connection);
                            //从可用连接中移除 
                            AvaliableSessions.RemoveAt(i);
                        }
                    }
                    // }
                }

                //在锁外移除
                if (connectionsToRemove.Any())
                {
                    int i = 0;
                    foreach (var connToRemove in connectionsToRemove)
                    {
                        i++;
                        RemoveConnection(connToRemove);
                    }
                    Logger.InfoFormat("批量移除连接:数量{0}.", i);
                }

                if (PoolSize < MinConnectionPoolSize)
                {
                    ThreadPool.QueueUserWorkItem(EnsureMinConnectionPoolSizeWorkItem,null);
                }
            }
            finally
            {
                _inMaintainPoolSize = false;
            }
        }

        private void EnsureMinConnectionPoolSizeWorkItem(object state)
        {
            if (_inEnsureMinConnectionPoolSizeWorkItem)
            {
                return;
            }

            _inEnsureMinConnectionPoolSizeWorkItem = true;
            try
            {
                while (true)
                {
                    lock (_poolLock)
                    {
                        if (PoolSize >= MinConnectionPoolSize)
                        {
                            return;
                        }
                    }

                    var connection = new SyncTcpSession(this);
                    try
                    {
                        var added = false;
                        lock (_poolLock)
                        {
                            if (PoolSize < MaxConnectionPoolSize)
                            {
                                AvaliableSessions.Add(connection);
                                PoolSize++;
                                added = true;
                                Monitor.Pulse(_poolLock);
                            }
                        }

                        if (!added)
                        {
                            connection.Close();
                        }
                    }
                    catch
                    {
                        Thread.Sleep(TimeSpan.FromSeconds(1));
                    }
                }
            }
            catch
            {
            }
            finally
            {
                _inEnsureMinConnectionPoolSizeWorkItem = false;
            }
        }

        /// <summary>
        /// 回收连接
        /// </summary>
        /// <param name="connection"></param>
        public void ReleaseConnection(SyncTcpSession connection)
        {
            //每次都关闭连接,则退化为短连接
            // RemoveConnection(connection);
            // return;
            if (connection == null)
                return;
            if (connection.SessionPool != this)
            {
                connection.Close();
                Logger.Info("连接不属于此连接池.");
            }

            if (connection.State != SessionState.Open)
            {
                RemoveConnection(connection);
                Logger.Info("移除连接:连接已关闭.");
                return;
            }

            if (DateTime.Now - connection.CreatedAt > MaxConnectionLifeTime)
            {
                RemoveConnection(connection);
                Logger.Info("移除连接:超过最大存活时间.");
                return;
            }

            lock (_poolLock)
            {
                connection.LastUsedAt = DateTime.Now;
                AvaliableSessions.Add(connection);
                Monitor.Pulse(_poolLock);
            }
        }

        /// <summary>
        /// 移除并关闭连接
        /// </summary>
        /// <param name="connection"></param>
        private void RemoveConnection(SyncTcpSession connection)
        {
            lock (_poolLock)
            {
                    AvaliableSessions.Remove(connection); 
                    PoolSize -= 1;
                    Monitor.Pulse(_poolLock);
            }

            connection.Close();
        }

上面的代码中,使用到的同步连接SyncTcpSession,主要用于服务器间的同步通讯,其保证的特性有:

  1. 与连接池关联并受其管理。
  2. 无状态,网络失败则清理连接,超时和重传由客户端处理。
  3. 提供同步通讯能力。
  4. 部署时和业务服务器间网络为内网,所以简化包的解析。
  5. 不使用buffer池,仅使用MemoryStream对象,按需构造对象,此考虑基于代理任务不是很重,以转发为主,缓存为辅,且有多台组成集群,整体内存相对充裕(相比之下,业务服务端采用了buffer池)。

代码如下:

    /// <summary>
    /// 同步会话对象
    /// </summary>
    public class SyncTcpSession
    {
        private readonly object _instanceLock = new object();
        private readonly ILog _logger;

        private readonly IPEndPoint _remoteAddr;
        private readonly SessionPool _sessionPool;
        private Socket _socket;

        private SessionState _state = SessionState.Initial;

        public SessionPool SessionPool { get { return _sessionPool; } }

        public SyncTcpSession(SessionPool pool)
        {
            _logger = pool.Logger;
            _sessionPool = pool;
            _remoteAddr = _sessionPool.RemoteAddr;
            CreatedAt = DateTime.Now;
            _socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)
                          {SendTimeout = 90000, ReceiveTimeout = 180000};
        }

        public DateTime CreatedAt { get; set; }

        public DateTime LastUsedAt { get; set; }

        public SessionState State { get { return _state; } }

        public void Connect()
        {
            _socket.Connect(_remoteAddr);
            LastUsedAt = DateTime.Now;
            _state = SessionState.Open;
        }

        public void Close()
        {
            lock (_instanceLock)
            {
                try
                {
                    if (_socket == null || _state == SessionState.Closed)
                        return;
                    if (_socket.Connected)
                    {
                        try
                        {
                            _socket.Shutdown(SocketShutdown.Both);
                        }
                        finally
                        {
                            _socket.Close();
                            _socket.Dispose();
                            _socket = null;
                        }
                        _state = SessionState.Closed;
                    }
                }
                catch
                {
                }
              //  _logger.Info("session is closed");
            }
        }

        /// <summary>
        /// 获取当前是否连接
        /// </summary>
        /// <returns></returns>
        public bool IsConnected()
        {
            if(_state!=SessionState.Open)
            {
                return false;
            }

            byte[] tmp = new byte[1];
            try
            {
                _socket.Blocking = false;
                _socket.Send(tmp, 0, 0);
                _socket.Blocking = true;
                return true;
            }catch(SocketException ex)
            {
                _logger.Error("[Not Connected]");
                return false;
            }
        }

        public bool Send(byte[] sentBytes)
        {
            lock (_instanceLock)
            {
                LastUsedAt = DateTime.Now;
                if (_state == SessionState.Initial)
                {
                    //如果没有连接成功,返回失败
                    if(!Open())
                    {
                        _state = SessionState.Closed;
                        return false;
                    }
                }

                //如果当前没有连接,返回失败
                if(!IsConnected())
                {
                    _state = SessionState.Closed;
                    return false;
                }

                //此时连接可能已经关闭了
                int allLen = sentBytes.Length;
                try
                {
                    while (allLen > 0)
                    {
                        //发送数据到缓冲区,不保证立即在网络中传递,可能会发送超时
                        int sent = _socket.Send(sentBytes, 0, sentBytes.Length, SocketFlags.None);
                        allLen -= sent;
                    }

                    LastUsedAt = DateTime.Now;
                    return true;
                }
                catch (SocketException ex)
                {
                    //如果出现错误,返回失败
                    _logger.Error(string.Format("[Send Failed] {0}", ex.Message), ex);
                    _state = SessionState.Closed;
                    return false;
                }
            }
        }

        public byte[] Receive(out bool successful)
        {
            successful = false;
            const int headerLen = 8;
            byte[] ret = null;
            bool foundHeader = false;
            LastUsedAt = DateTime.Now;
            if (_socket == null || !_socket.Connected) return null;
            lock (_instanceLock)
            {
                // 部署环境,内网比较稳定,简化包解析
                var buffer = new byte[16*1024];
                int remaining = -1;
                using (var ms = new MemoryStream())
                {
                    try
                    {
                        while (remaining != 0)
                        {
                            int allLen = _socket.Receive(buffer, 0, buffer.Length, SocketFlags.None);
                            if (!foundHeader)
                            {
                                if (allLen >= headerLen)
                                {
                                    foundHeader = true;
                                    int bodyLen = (buffer[4] << 24) + (buffer[5] << 16) + (buffer[6] << 8) +
                                                  buffer[7];
                                    remaining = (int) (headerLen + bodyLen - ms.Length);
                                }
                            }
                            ms.Write(buffer, 0, allLen);
                            if (foundHeader)
                            {
                                remaining -= allLen;
                            }
                        }

                        ret = new byte[ms.Length];
                        ms.Position = 0;
                        ms.Read(ret, 0, ret.Length);

                        LastUsedAt = DateTime.Now;
                        successful = true;
                        return ret;
                    }
                    catch (Exception ex)
                    {
                        successful = false;
                        _state = SessionState.Closed;
                        _logger.Error(string.Format("[Recv Failed] {0}", ex.Message), ex);
                    }
                }
            }

            return ret;
        }

        public bool Open()
        {
            try
            {
                Connect();
                return true;
            }
            catch (Exception ex)
            {
                _state = SessionState.Closed;
                _logger.Error(string.Format("[Open Failed] {0}", ex.Message), ex);
                return false;
            }
        }

        public bool SendRequest(byte[] package)
        {
           return Send(package);
        }

        public byte[] GetBody(byte[] data)
        {
            if (data == null || data.Length < 1)
            {
                return null;
            }

            if (data.Length < 8)
            {
                _logger.Error("接收到的数据包长度不足");
                return null;
            }

            int bodyLen = (data[4] << 24) + (data[5] << 16) + (data[6] << 8) +
                          data[7];
            var body = new byte[bodyLen];
            if (bodyLen + 8 != data.Length)
            {
                _logger.ErrorFormat("包长有误:totalLen:({0}),bodyLen:{1}", data.Length, body.Length);
                return null;
            }
            Buffer.BlockCopy(data, 8, body, 0, bodyLen);
            return body;
        }
    }

数据包解析

为了提供分片发送数据的能力,同时避免粘包问题,必须要进行包解析,完备版的包解析代码提供如下特性:

  1. 可区分收到的字节流是否包含一个完整的包,若是,抛出事件,若不是,继续接收数据。
  2. 可检测包头是否完整。
  3. 可检测包体是否完整。
  4. 网络质量较差时(国外互通国内、非局域网、跨电信、联通)使用此解析器,客户端和代理之间通讯,采用此解析器。

代码如下:

    /// <summary>
    /// 包解析器
    /// </summary>
    public class PackageAnalyzer
    {
        public PackageAnalyzer(int headerLen)
        {
            _headerLen = headerLen;
            _header = new byte[_headerLen];
        }

        /// <summary>
        /// 包头长度
        /// </summary>
        private readonly int _headerLen;

        /// <summary>
        /// 包头缓冲
        /// </summary>
        private readonly byte[] _header;

        /// <summary>
        /// 还差多少字节组成一个完整包
        /// </summary>
        private int _requiredDataLength;

        /// <summary>
        /// 包头协议标识字节数组中已收到的字节数
        /// </summary>
        private int _receivedHeaderLength;

        /// <summary>
        /// 包头获取状态
        /// </summary>
        private Header _headerFlag;

        /// <summary>
        /// 包头状态
        /// </summary>
        private enum Header
        {
            NotFound, Found, PartialFound
        }

        /// <summary>
        /// 包头是否已保存
        /// </summary>
        private bool _headerWritten;

        /// <summary>
        /// 包存储器
        /// </summary>
        public BufferWriter Writer { get; set; }

        /// <summary>
        /// 是否允许变长存储
        /// </summary>
        public bool EnabledVariant { get; set; }

        /// <summary>
        /// 包分析成功时的处理委托
        /// </summary>
        /// <param name="requestInfo"></param>
        public delegate void OnAnalyzeSuccess(BinaryResponseInfo requestInfo);

        /// <summary>
        /// 从字节流分析包头,并获取完整包
        /// </summary>
        /// <param name="data">字节流</param>
        /// <param name="offset">偏移</param>
        /// <param name="total">总字节数</param>
        /// <param name="onAnalyzeSuccessCallback">分包成功时回调</param>
        public void TryProcess(byte[] data, int offset, int total, OnAnalyzeSuccess onAnalyzeSuccessCallback)
        {
            while (total > 0)
            {
                //还没有获取头部
                if (_headerFlag == Header.NotFound)
                {
                    //剩余
                    if (total >= _headerLen)
                    {
                        //获取完整头部
                        _headerFlag = Header.Found;
                        Array.Copy(data, offset, _header, 0, _headerLen);
                        offset += _headerLen;
                        total -= _headerLen;

                        //获取数据长度
                        _requiredDataLength = (_header[4] << 24) + (_header[5] << 16) + (_header[6] << 8) + _header[7];
                        _receivedHeaderLength = 0;
                        //继续处理
                    } //不足
                    else
                    {
                        Array.Copy(data, offset, _header, 0, total);
                        _receivedHeaderLength += total;
                        _headerFlag = Header.PartialFound;
                        break;
                    }
                }
                //已获取头部
                if (_headerFlag == Header.Found)
                {
                    //可以获取完整数据
                    if (total >= _requiredDataLength)
                    {
                        //保存数据
                        //还未写入头部
                        if (!_headerWritten)
                        {
                            Writer = new BufferWriter(System.Text.Encoding.UTF8);
                            _headerWritten = true;
                            Writer.Write(_header, 0, _headerLen);
                        }
                        Writer.Write(data, offset, _requiredDataLength);

                        offset += _requiredDataLength;
                        total -= _requiredDataLength;

                        //重置全部状态
                        _requiredDataLength = 0;
                        _receivedHeaderLength = 0;
                        _headerFlag = Header.NotFound;
                        _headerWritten = false;

                        //获得了完整的包,开始读取消息
                        //////////////////////////////////////////////////////

                        var reader = new BufferReader(System.Text.Encoding.UTF8, Writer) { EnabledVariant = EnabledVariant };

                        BinaryResponseInfo responseInfo;
                        MessageRead(reader, out responseInfo);
                        if (responseInfo != null)
                            if (onAnalyzeSuccessCallback != null)
                                onAnalyzeSuccessCallback(responseInfo);

                        //////////////////////////////////////////////////////

                    }
                    //不能获取完整数据
                    else
                    {
                        //保存数据
                        //还未写入头部
                        if (!_headerWritten)
                        {
                            Writer = new BufferWriter(System.Text.Encoding.UTF8);
                            _headerWritten = true;
                            Writer.Write(_header, 0, _headerLen);
                        }
                        //写入数据
                        Writer.Write(data, offset, total);
                        _requiredDataLength -= total;
                        break;
                    }
                }

                //部分获取头部
                if (_headerFlag == Header.PartialFound)
                {
                    //不能获取完整头部
                    if (total + _receivedHeaderLength < _headerLen)
                    {
                        Array.Copy(data, offset, _header, _receivedHeaderLength, total);
                        _receivedHeaderLength += total;
                        break;
                    }
                    //可以获取完整头部
                    else
                    {
                        _headerFlag = Header.Found;
                        var delta = _headerLen - _receivedHeaderLength;
                        Array.Copy(data, offset, _header, _receivedHeaderLength, delta);

                        total -= delta;
                        offset += delta;

                        //获取数据长度
                        _requiredDataLength = (_header[4] << 24) + (_header[5] << 16) + (_header[6] << 8) + _header[7];
                        _receivedHeaderLength = 0;
                        //继续处理
                    }
                }
            }

        }


        /// <summary>
        /// 从包头获取数据长度
        /// </summary>
        /// <param name="buffer">包头</param>
        /// <param name="offset">偏移</param>
        /// <param name="length">长度</param>
        /// <returns></returns>
        public virtual int GetDataLengthFromHeader(byte[] buffer, int offset, int length)
        {
            //第5、6、7、8个字节标识长度
            return (buffer[offset + 4] << 24) + (buffer[offset + 5] << 16) + (buffer[offset + 6] << 8) +
                   (buffer[offset + 7]);
        }

        /// <summary>
        /// 获取消息实体
        /// </summary>
        /// <param name="reader">数据流读取器</param>
        /// <param name="requestInfo">请求信息</param>
        private void MessageRead(BufferReader reader, out BinaryResponseInfo requestInfo)
        {
            var package = reader.ToBytes();
            //服务标识(调用业务端的哪个服务)
            var serviceKey = System.Text.Encoding.UTF8.GetString(package, 0, 4);
            var bodyLen = (package[4] << 24) + (package[5] << 16) + (package[6] << 8) + (package[7]);

            //  System.Diagnostics.Debug.Assert(bodyLen > 0, "包数据为空");
            var body = new byte[bodyLen];
            Buffer.BlockCopy(package, _headerLen, body, 0, bodyLen);
            requestInfo = new BinaryResponseInfo() { Key = serviceKey, Body = body };
        }
    }

由于历史原因,一些命名不是很恰当,代码结构也不是很好,比如上面这个类,叫做Parser可能更恰当。

结语

希望对初学网络编程的朋友有所帮助。