2团
Published on 2025-11-04 / 3 Visits
0
0

使用io.Copy简化Golang TCP代理实现的思考

以前尝试给Easegress中添加四层代理功能时,曾面临双向连接错误控制的挑战。
当时的想法是每条TCP连接分别使用两个goroutine进行读写,通过buf实现流量拷贝,并在读写出错时判断应仅关闭读端、写端,还是整个连接。
此外还需处理两端TCP连接因读写能力不匹配而引发的流量控制问题。

1 初版实现

package tcpproxy

import (
    "io"
    "net"
    "runtime/debug"
    "sync"
    "sync/atomic"
    "time"

    "github.com/megaease/easegress/pkg/logger"
    "github.com/megaease/easegress/pkg/util/fasttime"
    "github.com/megaease/easegress/pkg/util/iobufferpool"
    "github.com/megaease/easegress/pkg/util/timerpool"
)

const writeBufSize = 8

var tcpBufferPool = sync.Pool{
    New: func() interface{} {
        buf := make([]byte, iobufferpool.DefaultBufferReadCapacity)
        return buf
    },
}

// Connection wrap tcp connection
type Connection struct {
    closed     uint32
    rawConn    net.Conn
    localAddr  net.Addr
    remoteAddr net.Addr

    readBuffer      []byte
    writeBuffers    net.Buffers
    ioBuffers       []*iobufferpool.StreamBuffer
    writeBufferChan chan *iobufferpool.StreamBuffer

    mu               sync.Mutex
    connStopChan     chan struct{} // use for connection close
    listenerStopChan chan struct{} // use for listener close

    lastReadDeadlineTime  time.Time
    lastWriteDeadlineTime time.Time

    onRead  func(buffer *iobufferpool.StreamBuffer) // execute read filters
    onClose func(event ConnectionEvent)
}

// NewClientConn wrap connection create from client
func NewClientConn(conn net.Conn, listenerStopChan chan struct{}) *Connection {
    return &Connection{
        rawConn:          conn,
        localAddr:        conn.LocalAddr(),
        remoteAddr:       conn.RemoteAddr(),
        listenerStopChan: listenerStopChan,

        mu:              sync.Mutex{},
        connStopChan:    make(chan struct{}),
        writeBufferChan: make(chan *iobufferpool.StreamBuffer, writeBufSize),
    }
}

// SetOnRead set connection read handle
func (c *Connection) SetOnRead(onRead func(buffer *iobufferpool.StreamBuffer)) {
    c.onRead = onRead
}

// SetOnClose set close callback
func (c *Connection) SetOnClose(onclose func(event ConnectionEvent)) {
    c.onClose = onclose
}

// Start running connection read/write loop
func (c *Connection) Start() {
    fnRecover := func() {
        if r := recover(); r != nil {
            logger.Errorf("tcp read/write loop panic: %v\n%s\n", r, string(debug.Stack()))
            c.Close(NoFlush, LocalClose)
        }
    }

    go func() {
        defer fnRecover()
        c.startReadLoop()
    }()

    go func() {
        defer fnRecover()
        c.startWriteLoop()
    }()
}

// Write receive other connection data
func (c *Connection) Write(buf *iobufferpool.StreamBuffer) (err error) {
    defer func() {
        if r := recover(); r != nil {
            logger.Errorf("tcp connection has closed, local addr: %s, remote addr: %s, err: %+v",
                c.localAddr.String(), c.remoteAddr.String(), r)
            err = ErrConnectionHasClosed
        }
    }()

    select {
    case c.writeBufferChan <- buf:
        return
    default:
    }

    // try to send data again in 60 seconds
    t := timerpool.Get(60 * time.Second)
    select {
    case c.writeBufferChan <- buf:
    case <-t.C:
        buf.Release()
        err = ErrWriteBufferChanTimeout
    }
    timerpool.Put(t)
    return
}

func (c *Connection) startReadLoop() {
    defer func() {
        if c.readBuffer != nil {
            tcpBufferPool.Put(c.readBuffer[:iobufferpool.DefaultBufferReadCapacity])
        }
    }()

    for {
        select {
        case <-c.connStopChan:
            return
        case <-c.listenerStopChan:
            logger.Debugf("connection close due to listener stopped, local addr: %s, remote addr: %s",
                c.localAddr.String(), c.remoteAddr.String())
            c.Close(NoFlush, LocalClose)
            return
        default:
        }

        n, err := c.doReadIO()
        if n > 0 {
            c.onRead(iobufferpool.NewStreamBuffer(c.readBuffer[:n]))
        }

        if err == nil {
            continue
        }

        if te, ok := err.(net.Error); ok && te.Timeout() {
            select {
            case <-c.connStopChan:
                logger.Debugf("connection has closed, exit read loop, local addr: %s, remote addr: %s",
                    c.localAddr.String(), c.remoteAddr.String())
                return
            default:
            }
            continue // ignore timeout error, continue read data
        }

        if err == io.EOF {
            logger.Debugf("remote close connection, local addr: %s, remote addr: %s, err: %s",
                c.localAddr.String(), c.remoteAddr.String(), err.Error())
            c.Close(NoFlush, RemoteClose)
        } else {
            logger.Errorf("error on read, local addr: %s, remote addr: %s, err: %s",
                c.localAddr.String(), c.remoteAddr.String(), err.Error())
            c.Close(NoFlush, OnReadErrClose)
        }
        return
    }
}

func (c *Connection) startWriteLoop() {
    var err error
    for {
        select {
        case <-c.connStopChan:
            logger.Debugf("connection exit write loop, local addr: %s, remote addr: %s",
                c.localAddr.String(), c.remoteAddr.String())
            return
        case buf, ok := <-c.writeBufferChan:
            if !ok {
                return
            }
            c.appendBuffer(buf)
        NoMoreData:
            // Keep reading until writeBufferChan is empty
            // writeBufferChan may be full when writeLoop call doWrite
            for i := 0; i < writeBufSize-1; i++ {
                select {
                case buf, ok := <-c.writeBufferChan:
                    if !ok {
                        return
                    }
                    c.appendBuffer(buf)
                default:
                    break NoMoreData
                }
            }
        }

        if _, err = c.doWrite(); err == nil {
            continue
        }

        if te, ok := err.(net.Error); ok && te.Timeout() {
            select {
            case <-c.connStopChan:
                logger.Debugf("connection has closed, exit write loop, local addr: %s, remote addr: %s",
                    c.localAddr.String(), c.remoteAddr.String())
                return
            default:
            }

            c.Close(NoFlush, OnWriteTimeout)
            return
        }

        if err == iobufferpool.ErrEOF {
            logger.Debugf("finish write with eof, local addr: %s, remote addr: %s",
                c.localAddr.String(), c.remoteAddr.String())
            c.Close(NoFlush, LocalClose)
        } else {
            // remote call CloseRead, so just exit write loop, wait read loop exit
            logger.Errorf("error on write, local addr: %s, remote addr: %s, err: %+v",
                c.localAddr.String(), c.remoteAddr.String(), err)
        }
        return
    }
}

func (c *Connection) appendBuffer(buf *iobufferpool.StreamBuffer) {
    if buf == nil {
        return
    }
    c.ioBuffers = append(c.ioBuffers, buf)
    c.writeBuffers = append(c.writeBuffers, buf.Bytes())
}

// Close connection close function
func (c *Connection) Close(ccType CloseType, event ConnectionEvent) {
    defer func() {
        if r := recover(); r != nil {
            logger.Errorf("connection close panic, err: %+v\n%s", r, string(debug.Stack()))
        }
    }()

    if ccType == FlushWrite {
        _ = c.Write(iobufferpool.NewEOFStreamBuffer())
        return
    }

    if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
        // connection has already closed, so there is no need to execute below code
        return
    }

    // close tcp rawConn read first
    logger.Debugf("enter connection close func(%s), local addr: %s, remote addr: %s",
        event, c.localAddr.String(), c.remoteAddr.String())

    close(c.connStopChan)
    _ = c.rawConn.SetDeadline(time.Now()) // notify break read/write loop

    c.onClose(event)
    _ = c.rawConn.Close()
}

func (c *Connection) doReadIO() (bufLen int, err error) {
    if c.readBuffer == nil {
        c.readBuffer = tcpBufferPool.Get().([]byte)
    }

    // add read deadline setting optimization?
    // https://github.com/golang/go/issues/15133
    curr := fasttime.Now().Add(15 * time.Second)
    // there is no need to set readDeadline in too short time duration
    if diff := curr.Sub(c.lastReadDeadlineTime).Milliseconds(); diff > 0 {
        _ = c.rawConn.SetReadDeadline(curr)
        c.lastReadDeadlineTime = curr
    }
    return c.rawConn.(io.Reader).Read(c.readBuffer)
}

func (c *Connection) doWrite() (int64, error) {
    curr := fasttime.Now().Add(15 * time.Second)
    // there is no need to set writeDeadline in too short time duration
    if diff := curr.Sub(c.lastWriteDeadlineTime).Milliseconds(); diff > 0 {
        _ = c.rawConn.SetWriteDeadline(curr)
        c.lastWriteDeadlineTime = curr
    }
    return c.doWriteIO()
}

func (c *Connection) writeBufLen() (bufLen int) {
    for _, buf := range c.writeBuffers {
        bufLen += len(buf)
    }
    return
}

func (c *Connection) doWriteIO() (bytesSent int64, err error) {
    buffers := c.writeBuffers
    bytesSent, err = buffers.WriteTo(c.rawConn)
    if err != nil {
        return bytesSent, err
    }

    for i, buf := range c.ioBuffers {
        c.ioBuffers[i] = nil
        c.writeBuffers[i] = nil
        if buf.EOF() {
            err = iobufferpool.ErrEOF
        }
        buf.Release()
    }
    c.ioBuffers = c.ioBuffers[:0]
    c.writeBuffers = c.writeBuffers[:0]
    return
}

// ServerConnection wrap tcp connection to backend server
type ServerConnection struct {
    Connection
    connectTimeout time.Duration
}

// NewServerConn construct tcp server connection
func NewServerConn(connectTimeout uint32, serverAddr net.Addr, listenerStopChan chan struct{}) *ServerConnection {
    conn := &ServerConnection{
        Connection: Connection{
            remoteAddr: serverAddr,

            writeBufferChan: make(chan *iobufferpool.StreamBuffer, writeBufSize),

            mu:               sync.Mutex{},
            connStopChan:     make(chan struct{}),
            listenerStopChan: listenerStopChan,
        },
        connectTimeout: time.Duration(connectTimeout) * time.Millisecond,
    }
    return conn
}

// Connect create backend server tcp connection
func (u *ServerConnection) Connect() bool {
    addr := u.remoteAddr
    if addr == nil {
        logger.Errorf("cannot connect because the server has been closed, server addr: %s", addr.String())
        return false
    }

    timeout := u.connectTimeout
    if timeout == 0 {
        timeout = 10 * time.Second
    }

    var err error
    u.rawConn, err = net.DialTimeout("tcp", addr.String(), timeout)
    if err != nil {
        if err == io.EOF {
            logger.Errorf("cannot connect because the server has been closed, server addr: %s", addr.String())
        } else if te, ok := err.(net.Error); ok && te.Timeout() {
            logger.Errorf("connect to server timeout, server addr: %s", addr.String())
        } else {
            logger.Errorf("connect to server failed, server addr: %s, err: %s", addr.String(), err.Error())
        }
        return false
    }

    u.localAddr = u.rawConn.LocalAddr()
    _ = u.rawConn.(*net.TCPConn).SetNoDelay(true)
    _ = u.rawConn.(*net.TCPConn).SetKeepAlive(true)
    u.Start()
    return true
}

以上是当初的实现,此段代码通过两个goroutine分别处理读写操作,并在读写过程中进行错误处理和流量控制。

    readBuffer      []byte
    writeBuffers    net.Buffers
    ioBuffers       []*iobufferpool.StreamBuffer
    writeBufferChan chan *iobufferpool.StreamBuffer

代码中还通过以上字段去控制读写缓冲区和写缓冲通道,虽然现在已经不清楚具体设计细节,但可以看出写的是相当不好,引入了太多复杂度,关键自己也没有很好的控制手段。

2 使用io.Copy

当初陈皓老师曾建议使用io.Copy来简化读写逻辑,因其已内置高效的流量拷贝和边界情况处理。
但当时我已按自行方案完成初版,既不舍推翻,也因工作紧张未能重构。
如今回看,采用io.Copy确实是更优选择,可借助Go标准库的成熟实现降低复杂度。

此外,UDP代理因其无连接特性带来较多实现复杂性,若从一开始就暂缓支持,或许更为合理。

2.1 简化后的TCP代理示例

package main

import (
    "io"
    "log"
    "net"
    "sync"
)

// TCPProxy 简单的TCP代理实现
type TCPProxy struct {
    ListenAddr string
    TargetAddr string
}

func (p *TCPProxy) Start() error {
    tcpAddr, err := net.ResolveTCPAddr("tcp", p.ListenAddr)
    if err != nil {
        return err
    }

    listener, err := net.ListenTCP("tcp", tcpAddr)
    if err != nil {
        return err
    }
    defer listener.Close()

    log.Printf("TCP代理启动,监听地址: %s,目标地址: %s", p.ListenAddr, p.TargetAddr)

    for {
        clientConn, err := listener.AcceptTCP()
        if err != nil {
            log.Printf("接受连接失败: %v", err)
            continue
        }

        go p.handleConnection(clientConn)
    }
}

func (p *TCPProxy) handleConnection(clientConn *net.TCPConn) {
    defer clientConn.Close()

    tcpAddr, _ := net.ResolveTCPAddr("tcp", p.TargetAddr)
    serverConn, _ := net.DialTCP("tcp", nil, tcpAddr)
    defer serverConn.Close()

    var wg sync.WaitGroup
    wg.Add(2)

    go func() {
        defer wg.Done()
        defer serverConn.Close()

        _, _ = io.Copy(serverConn, clientConn)
    }()

    go func() {
        defer wg.Done()
        defer clientConn.Close()

        _, _ = io.Copy(clientConn, serverConn)
    }()

    wg.Wait()
}

func main() {
    proxy := &TCPProxy{
        ListenAddr: ":8080",
        TargetAddr: "localhost:80", // 示例:代理到本地HTTP服务
    }

    if err := proxy.Start(); err != nil {
        log.Fatalf("代理启动失败: %v", err)
    }
}

2.2 添加缓冲区

简化的代码中,每次调用io.Copy均会使用默认缓冲区大小(32KB)进行数据传输。
此时,可能会想到使用io.CopyBuffer来指定更合适的缓冲区大小,以优化性能,代码示例如下:

func (p *TCPProxy) handleConnection(clientConn *net.TCPConn) {
    defer clientConn.Close()

    tcpAddr, _ := net.ResolveTCPAddr("tcp", p.TargetAddr)
    serverConn, _ := net.DialTCP("tcp", nil, tcpAddr)
    defer serverConn.Close()

    var wg sync.WaitGroup
    wg.Add(2)

    // 从客户端读取数据并写入服务器
    go func() {
        defer wg.Done()
        defer serverConn.CloseWrite()

        bufPtr := bufferPool.Get().(*[]byte)
        defer bufferPool.Put(bufPtr)

        _, _ = io.CopyBuffer(serverConn, clientConn, *bufPtr)
    }()

    // 从服务器读取数据并写入客户端
    go func() {
        defer wg.Done()
        defer clientConn.CloseWrite()

        bufPtr := bufferPool.Get().(*[]byte)
        defer bufferPool.Put(bufPtr)

        _, _ = io.CopyBuffer(clientConn, serverConn, *bufPtr)
    }()

    wg.Wait()
}

但是这样真的能提升性能吗?以下是io.CopyBuffer的实现代码:

// CopyBuffer is identical to Copy except that it stages through the
// provided buffer (if one is required) rather than allocating a
// temporary one. If buf is nil, one is allocated; otherwise if it has
// zero length, CopyBuffer panics.
//
// If either src implements [WriterTo] or dst implements [ReaderFrom],
// buf will not be used to perform the copy.
func CopyBuffer(dst Writer, src Reader, buf []byte) (written int64, err error) {
    if buf != nil && len(buf) == 0 {
        panic("empty buffer in CopyBuffer")
    }
    return copyBuffer(dst, src, buf)
}

// copyBuffer is the actual implementation of Copy and CopyBuffer.
// if buf is nil, one is allocated.
func copyBuffer(dst Writer, src Reader, buf []byte) (written int64, err error) {
    // If the reader has a WriteTo method, use it to do the copy.
    // Avoids an allocation and a copy.
    if wt, ok := src.(WriterTo); ok {
        return wt.WriteTo(dst)
    }
    // Similarly, if the writer has a ReadFrom method, use it to do the copy.
    if rf, ok := dst.(ReaderFrom); ok {
        return rf.ReadFrom(src)
    }
    if buf == nil {
        size := 32 * 1024
        if l, ok := src.(*LimitedReader); ok && int64(size) > l.N {
            if l.N < 1 {
                size = 1
            } else {
                size = int(l.N)
            }
        }
        buf = make([]byte, size)
    }
    for {
        nr, er := src.Read(buf)
        if nr > 0 {
            nw, ew := dst.Write(buf[0:nr])
            if nw < 0 || nr < nw {
                nw = 0
                if ew == nil {
                    ew = errInvalidWrite
                }
            }
            written += int64(nw)
            if ew != nil {
                err = ew
                break
            }
            if nr != nw {
                err = ErrShortWrite
                break
            }
        }
        if er != nil {
            if er != EOF {
                err = er
            }
            break
        }
    }
    return written, err
}

从代码可以看出,io.CopyBuffer在源或目标实现了WriterToReaderFrom接口时,并不会使用提供的缓冲区。
net.TCPConn类型正是实现了这些接口,具体代码详见:

// ReadFrom implements the [io.ReaderFrom] ReadFrom method.
func (c *TCPConn) ReadFrom(r io.Reader) (int64, error) {
    if !c.ok() {
        return 0, syscall.EINVAL
    }
    n, err := c.readFrom(r)
    if err != nil && err != io.EOF {
        err = &OpError{Op: "readfrom", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
    }
    return n, err
}

因此,在这种情况下,使用io.CopyBuffer并不会带来性能提升,反而是凭空添加了代码复杂度。

2.3 ReadFrom和WriteTo接口

net.TCPConn实现了io.ReaderFromio.WriterTo接口,这使得io.Copy能够直接利用这些高效的读写方法进行数据传输。
以下是ReadFromWriteTo的实现代码片段:

func (c *TCPConn) readFrom(r io.Reader) (int64, error) {
    if n, err, handled := spliceFrom(c.fd, r); handled {
        return n, err
    }
    if n, err, handled := sendFile(c.fd, r); handled {
        return n, err
    }
    return genericReadFrom(c, r)
}

func (c *TCPConn) writeTo(w io.Writer) (int64, error) {
    if n, err, handled := spliceTo(w, c.fd); handled {
        return n, err
    }
    return genericWriteTo(c, w)
}

查阅源码注释,可以发现spliceFrom、sendFile以及spliceTo等函数均利用了操作系统底层的高效数据传输机制,如Linux的splicesendfile系统调用,从而实现零拷贝数据传输,极大提升了性能。

3 遗憾

在利用io.Copy简化数据流拷贝的同时,其对于长连接的空闲超时(Idle Timeout)管理能力确实存在不足。
*net.TCPConn提供的SetDeadline方法所设定的是连接级别的绝对超时时间,这对于HTTP短连接较为适用,但在需要维持连接长时间存活的四层代理场景中则缺乏灵活性。
一个可行的优化思路是,封装支持超时控制的Reader/Writer接口,在每次读写操作前动态更新超时时间,从而实现真正的空闲超时机制。
然而,这种方案需要自行实现完整的读写循环,无法直接复用io.Copy等系统级API,这在一定程度上又回到了最初的复杂度问题,需要在简洁性与控制粒度之间进行权衡。

还是不折腾了,寄希望于Go未来在net包中提供更灵活的超时控制机制吧,正如Golang的.ioCopy引入了ReadFrom和WriteTo接口一样。


Comment