以前尝试给Easegress中添加四层代理功能时,曾面临双向连接错误控制的挑战。
当时的想法是每条TCP连接分别使用两个goroutine进行读写,通过buf实现流量拷贝,并在读写出错时判断应仅关闭读端、写端,还是整个连接。
此外还需处理两端TCP连接因读写能力不匹配而引发的流量控制问题。
1 初版实现
package tcpproxy
import (
"io"
"net"
"runtime/debug"
"sync"
"sync/atomic"
"time"
"github.com/megaease/easegress/pkg/logger"
"github.com/megaease/easegress/pkg/util/fasttime"
"github.com/megaease/easegress/pkg/util/iobufferpool"
"github.com/megaease/easegress/pkg/util/timerpool"
)
const writeBufSize = 8
var tcpBufferPool = sync.Pool{
New: func() interface{} {
buf := make([]byte, iobufferpool.DefaultBufferReadCapacity)
return buf
},
}
// Connection wrap tcp connection
type Connection struct {
closed uint32
rawConn net.Conn
localAddr net.Addr
remoteAddr net.Addr
readBuffer []byte
writeBuffers net.Buffers
ioBuffers []*iobufferpool.StreamBuffer
writeBufferChan chan *iobufferpool.StreamBuffer
mu sync.Mutex
connStopChan chan struct{} // use for connection close
listenerStopChan chan struct{} // use for listener close
lastReadDeadlineTime time.Time
lastWriteDeadlineTime time.Time
onRead func(buffer *iobufferpool.StreamBuffer) // execute read filters
onClose func(event ConnectionEvent)
}
// NewClientConn wrap connection create from client
func NewClientConn(conn net.Conn, listenerStopChan chan struct{}) *Connection {
return &Connection{
rawConn: conn,
localAddr: conn.LocalAddr(),
remoteAddr: conn.RemoteAddr(),
listenerStopChan: listenerStopChan,
mu: sync.Mutex{},
connStopChan: make(chan struct{}),
writeBufferChan: make(chan *iobufferpool.StreamBuffer, writeBufSize),
}
}
// SetOnRead set connection read handle
func (c *Connection) SetOnRead(onRead func(buffer *iobufferpool.StreamBuffer)) {
c.onRead = onRead
}
// SetOnClose set close callback
func (c *Connection) SetOnClose(onclose func(event ConnectionEvent)) {
c.onClose = onclose
}
// Start running connection read/write loop
func (c *Connection) Start() {
fnRecover := func() {
if r := recover(); r != nil {
logger.Errorf("tcp read/write loop panic: %v\n%s\n", r, string(debug.Stack()))
c.Close(NoFlush, LocalClose)
}
}
go func() {
defer fnRecover()
c.startReadLoop()
}()
go func() {
defer fnRecover()
c.startWriteLoop()
}()
}
// Write receive other connection data
func (c *Connection) Write(buf *iobufferpool.StreamBuffer) (err error) {
defer func() {
if r := recover(); r != nil {
logger.Errorf("tcp connection has closed, local addr: %s, remote addr: %s, err: %+v",
c.localAddr.String(), c.remoteAddr.String(), r)
err = ErrConnectionHasClosed
}
}()
select {
case c.writeBufferChan <- buf:
return
default:
}
// try to send data again in 60 seconds
t := timerpool.Get(60 * time.Second)
select {
case c.writeBufferChan <- buf:
case <-t.C:
buf.Release()
err = ErrWriteBufferChanTimeout
}
timerpool.Put(t)
return
}
func (c *Connection) startReadLoop() {
defer func() {
if c.readBuffer != nil {
tcpBufferPool.Put(c.readBuffer[:iobufferpool.DefaultBufferReadCapacity])
}
}()
for {
select {
case <-c.connStopChan:
return
case <-c.listenerStopChan:
logger.Debugf("connection close due to listener stopped, local addr: %s, remote addr: %s",
c.localAddr.String(), c.remoteAddr.String())
c.Close(NoFlush, LocalClose)
return
default:
}
n, err := c.doReadIO()
if n > 0 {
c.onRead(iobufferpool.NewStreamBuffer(c.readBuffer[:n]))
}
if err == nil {
continue
}
if te, ok := err.(net.Error); ok && te.Timeout() {
select {
case <-c.connStopChan:
logger.Debugf("connection has closed, exit read loop, local addr: %s, remote addr: %s",
c.localAddr.String(), c.remoteAddr.String())
return
default:
}
continue // ignore timeout error, continue read data
}
if err == io.EOF {
logger.Debugf("remote close connection, local addr: %s, remote addr: %s, err: %s",
c.localAddr.String(), c.remoteAddr.String(), err.Error())
c.Close(NoFlush, RemoteClose)
} else {
logger.Errorf("error on read, local addr: %s, remote addr: %s, err: %s",
c.localAddr.String(), c.remoteAddr.String(), err.Error())
c.Close(NoFlush, OnReadErrClose)
}
return
}
}
func (c *Connection) startWriteLoop() {
var err error
for {
select {
case <-c.connStopChan:
logger.Debugf("connection exit write loop, local addr: %s, remote addr: %s",
c.localAddr.String(), c.remoteAddr.String())
return
case buf, ok := <-c.writeBufferChan:
if !ok {
return
}
c.appendBuffer(buf)
NoMoreData:
// Keep reading until writeBufferChan is empty
// writeBufferChan may be full when writeLoop call doWrite
for i := 0; i < writeBufSize-1; i++ {
select {
case buf, ok := <-c.writeBufferChan:
if !ok {
return
}
c.appendBuffer(buf)
default:
break NoMoreData
}
}
}
if _, err = c.doWrite(); err == nil {
continue
}
if te, ok := err.(net.Error); ok && te.Timeout() {
select {
case <-c.connStopChan:
logger.Debugf("connection has closed, exit write loop, local addr: %s, remote addr: %s",
c.localAddr.String(), c.remoteAddr.String())
return
default:
}
c.Close(NoFlush, OnWriteTimeout)
return
}
if err == iobufferpool.ErrEOF {
logger.Debugf("finish write with eof, local addr: %s, remote addr: %s",
c.localAddr.String(), c.remoteAddr.String())
c.Close(NoFlush, LocalClose)
} else {
// remote call CloseRead, so just exit write loop, wait read loop exit
logger.Errorf("error on write, local addr: %s, remote addr: %s, err: %+v",
c.localAddr.String(), c.remoteAddr.String(), err)
}
return
}
}
func (c *Connection) appendBuffer(buf *iobufferpool.StreamBuffer) {
if buf == nil {
return
}
c.ioBuffers = append(c.ioBuffers, buf)
c.writeBuffers = append(c.writeBuffers, buf.Bytes())
}
// Close connection close function
func (c *Connection) Close(ccType CloseType, event ConnectionEvent) {
defer func() {
if r := recover(); r != nil {
logger.Errorf("connection close panic, err: %+v\n%s", r, string(debug.Stack()))
}
}()
if ccType == FlushWrite {
_ = c.Write(iobufferpool.NewEOFStreamBuffer())
return
}
if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
// connection has already closed, so there is no need to execute below code
return
}
// close tcp rawConn read first
logger.Debugf("enter connection close func(%s), local addr: %s, remote addr: %s",
event, c.localAddr.String(), c.remoteAddr.String())
close(c.connStopChan)
_ = c.rawConn.SetDeadline(time.Now()) // notify break read/write loop
c.onClose(event)
_ = c.rawConn.Close()
}
func (c *Connection) doReadIO() (bufLen int, err error) {
if c.readBuffer == nil {
c.readBuffer = tcpBufferPool.Get().([]byte)
}
// add read deadline setting optimization?
// https://github.com/golang/go/issues/15133
curr := fasttime.Now().Add(15 * time.Second)
// there is no need to set readDeadline in too short time duration
if diff := curr.Sub(c.lastReadDeadlineTime).Milliseconds(); diff > 0 {
_ = c.rawConn.SetReadDeadline(curr)
c.lastReadDeadlineTime = curr
}
return c.rawConn.(io.Reader).Read(c.readBuffer)
}
func (c *Connection) doWrite() (int64, error) {
curr := fasttime.Now().Add(15 * time.Second)
// there is no need to set writeDeadline in too short time duration
if diff := curr.Sub(c.lastWriteDeadlineTime).Milliseconds(); diff > 0 {
_ = c.rawConn.SetWriteDeadline(curr)
c.lastWriteDeadlineTime = curr
}
return c.doWriteIO()
}
func (c *Connection) writeBufLen() (bufLen int) {
for _, buf := range c.writeBuffers {
bufLen += len(buf)
}
return
}
func (c *Connection) doWriteIO() (bytesSent int64, err error) {
buffers := c.writeBuffers
bytesSent, err = buffers.WriteTo(c.rawConn)
if err != nil {
return bytesSent, err
}
for i, buf := range c.ioBuffers {
c.ioBuffers[i] = nil
c.writeBuffers[i] = nil
if buf.EOF() {
err = iobufferpool.ErrEOF
}
buf.Release()
}
c.ioBuffers = c.ioBuffers[:0]
c.writeBuffers = c.writeBuffers[:0]
return
}
// ServerConnection wrap tcp connection to backend server
type ServerConnection struct {
Connection
connectTimeout time.Duration
}
// NewServerConn construct tcp server connection
func NewServerConn(connectTimeout uint32, serverAddr net.Addr, listenerStopChan chan struct{}) *ServerConnection {
conn := &ServerConnection{
Connection: Connection{
remoteAddr: serverAddr,
writeBufferChan: make(chan *iobufferpool.StreamBuffer, writeBufSize),
mu: sync.Mutex{},
connStopChan: make(chan struct{}),
listenerStopChan: listenerStopChan,
},
connectTimeout: time.Duration(connectTimeout) * time.Millisecond,
}
return conn
}
// Connect create backend server tcp connection
func (u *ServerConnection) Connect() bool {
addr := u.remoteAddr
if addr == nil {
logger.Errorf("cannot connect because the server has been closed, server addr: %s", addr.String())
return false
}
timeout := u.connectTimeout
if timeout == 0 {
timeout = 10 * time.Second
}
var err error
u.rawConn, err = net.DialTimeout("tcp", addr.String(), timeout)
if err != nil {
if err == io.EOF {
logger.Errorf("cannot connect because the server has been closed, server addr: %s", addr.String())
} else if te, ok := err.(net.Error); ok && te.Timeout() {
logger.Errorf("connect to server timeout, server addr: %s", addr.String())
} else {
logger.Errorf("connect to server failed, server addr: %s, err: %s", addr.String(), err.Error())
}
return false
}
u.localAddr = u.rawConn.LocalAddr()
_ = u.rawConn.(*net.TCPConn).SetNoDelay(true)
_ = u.rawConn.(*net.TCPConn).SetKeepAlive(true)
u.Start()
return true
}
以上是当初的实现,此段代码通过两个goroutine分别处理读写操作,并在读写过程中进行错误处理和流量控制。
readBuffer []byte
writeBuffers net.Buffers
ioBuffers []*iobufferpool.StreamBuffer
writeBufferChan chan *iobufferpool.StreamBuffer
代码中还通过以上字段去控制读写缓冲区和写缓冲通道,虽然现在已经不清楚具体设计细节,但可以看出写的是相当不好,引入了太多复杂度,关键自己也没有很好的控制手段。
2 使用io.Copy
当初陈皓老师曾建议使用io.Copy来简化读写逻辑,因其已内置高效的流量拷贝和边界情况处理。
但当时我已按自行方案完成初版,既不舍推翻,也因工作紧张未能重构。
如今回看,采用io.Copy确实是更优选择,可借助Go标准库的成熟实现降低复杂度。
此外,UDP代理因其无连接特性带来较多实现复杂性,若从一开始就暂缓支持,或许更为合理。
2.1 简化后的TCP代理示例
package main
import (
"io"
"log"
"net"
"sync"
)
// TCPProxy 简单的TCP代理实现
type TCPProxy struct {
ListenAddr string
TargetAddr string
}
func (p *TCPProxy) Start() error {
tcpAddr, err := net.ResolveTCPAddr("tcp", p.ListenAddr)
if err != nil {
return err
}
listener, err := net.ListenTCP("tcp", tcpAddr)
if err != nil {
return err
}
defer listener.Close()
log.Printf("TCP代理启动,监听地址: %s,目标地址: %s", p.ListenAddr, p.TargetAddr)
for {
clientConn, err := listener.AcceptTCP()
if err != nil {
log.Printf("接受连接失败: %v", err)
continue
}
go p.handleConnection(clientConn)
}
}
func (p *TCPProxy) handleConnection(clientConn *net.TCPConn) {
defer clientConn.Close()
tcpAddr, _ := net.ResolveTCPAddr("tcp", p.TargetAddr)
serverConn, _ := net.DialTCP("tcp", nil, tcpAddr)
defer serverConn.Close()
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
defer serverConn.Close()
_, _ = io.Copy(serverConn, clientConn)
}()
go func() {
defer wg.Done()
defer clientConn.Close()
_, _ = io.Copy(clientConn, serverConn)
}()
wg.Wait()
}
func main() {
proxy := &TCPProxy{
ListenAddr: ":8080",
TargetAddr: "localhost:80", // 示例:代理到本地HTTP服务
}
if err := proxy.Start(); err != nil {
log.Fatalf("代理启动失败: %v", err)
}
}
2.2 添加缓冲区
简化的代码中,每次调用io.Copy均会使用默认缓冲区大小(32KB)进行数据传输。
此时,可能会想到使用io.CopyBuffer来指定更合适的缓冲区大小,以优化性能,代码示例如下:
func (p *TCPProxy) handleConnection(clientConn *net.TCPConn) {
defer clientConn.Close()
tcpAddr, _ := net.ResolveTCPAddr("tcp", p.TargetAddr)
serverConn, _ := net.DialTCP("tcp", nil, tcpAddr)
defer serverConn.Close()
var wg sync.WaitGroup
wg.Add(2)
// 从客户端读取数据并写入服务器
go func() {
defer wg.Done()
defer serverConn.CloseWrite()
bufPtr := bufferPool.Get().(*[]byte)
defer bufferPool.Put(bufPtr)
_, _ = io.CopyBuffer(serverConn, clientConn, *bufPtr)
}()
// 从服务器读取数据并写入客户端
go func() {
defer wg.Done()
defer clientConn.CloseWrite()
bufPtr := bufferPool.Get().(*[]byte)
defer bufferPool.Put(bufPtr)
_, _ = io.CopyBuffer(clientConn, serverConn, *bufPtr)
}()
wg.Wait()
}
但是这样真的能提升性能吗?以下是io.CopyBuffer的实现代码:
// CopyBuffer is identical to Copy except that it stages through the
// provided buffer (if one is required) rather than allocating a
// temporary one. If buf is nil, one is allocated; otherwise if it has
// zero length, CopyBuffer panics.
//
// If either src implements [WriterTo] or dst implements [ReaderFrom],
// buf will not be used to perform the copy.
func CopyBuffer(dst Writer, src Reader, buf []byte) (written int64, err error) {
if buf != nil && len(buf) == 0 {
panic("empty buffer in CopyBuffer")
}
return copyBuffer(dst, src, buf)
}
// copyBuffer is the actual implementation of Copy and CopyBuffer.
// if buf is nil, one is allocated.
func copyBuffer(dst Writer, src Reader, buf []byte) (written int64, err error) {
// If the reader has a WriteTo method, use it to do the copy.
// Avoids an allocation and a copy.
if wt, ok := src.(WriterTo); ok {
return wt.WriteTo(dst)
}
// Similarly, if the writer has a ReadFrom method, use it to do the copy.
if rf, ok := dst.(ReaderFrom); ok {
return rf.ReadFrom(src)
}
if buf == nil {
size := 32 * 1024
if l, ok := src.(*LimitedReader); ok && int64(size) > l.N {
if l.N < 1 {
size = 1
} else {
size = int(l.N)
}
}
buf = make([]byte, size)
}
for {
nr, er := src.Read(buf)
if nr > 0 {
nw, ew := dst.Write(buf[0:nr])
if nw < 0 || nr < nw {
nw = 0
if ew == nil {
ew = errInvalidWrite
}
}
written += int64(nw)
if ew != nil {
err = ew
break
}
if nr != nw {
err = ErrShortWrite
break
}
}
if er != nil {
if er != EOF {
err = er
}
break
}
}
return written, err
}
从代码可以看出,io.CopyBuffer在源或目标实现了WriterTo或ReaderFrom接口时,并不会使用提供的缓冲区。
而net.TCPConn类型正是实现了这些接口,具体代码详见:
// ReadFrom implements the [io.ReaderFrom] ReadFrom method.
func (c *TCPConn) ReadFrom(r io.Reader) (int64, error) {
if !c.ok() {
return 0, syscall.EINVAL
}
n, err := c.readFrom(r)
if err != nil && err != io.EOF {
err = &OpError{Op: "readfrom", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
}
return n, err
}
因此,在这种情况下,使用io.CopyBuffer并不会带来性能提升,反而是凭空添加了代码复杂度。
2.3 ReadFrom和WriteTo接口
net.TCPConn实现了io.ReaderFrom和io.WriterTo接口,这使得io.Copy能够直接利用这些高效的读写方法进行数据传输。
以下是ReadFrom和WriteTo的实现代码片段:
func (c *TCPConn) readFrom(r io.Reader) (int64, error) {
if n, err, handled := spliceFrom(c.fd, r); handled {
return n, err
}
if n, err, handled := sendFile(c.fd, r); handled {
return n, err
}
return genericReadFrom(c, r)
}
func (c *TCPConn) writeTo(w io.Writer) (int64, error) {
if n, err, handled := spliceTo(w, c.fd); handled {
return n, err
}
return genericWriteTo(c, w)
}
查阅源码注释,可以发现spliceFrom、sendFile以及spliceTo等函数均利用了操作系统底层的高效数据传输机制,如Linux的splice和sendfile系统调用,从而实现零拷贝数据传输,极大提升了性能。
3 遗憾
在利用io.Copy简化数据流拷贝的同时,其对于长连接的空闲超时(Idle Timeout)管理能力确实存在不足。
*net.TCPConn提供的SetDeadline方法所设定的是连接级别的绝对超时时间,这对于HTTP短连接较为适用,但在需要维持连接长时间存活的四层代理场景中则缺乏灵活性。
一个可行的优化思路是,封装支持超时控制的Reader/Writer接口,在每次读写操作前动态更新超时时间,从而实现真正的空闲超时机制。
然而,这种方案需要自行实现完整的读写循环,无法直接复用io.Copy等系统级API,这在一定程度上又回到了最初的复杂度问题,需要在简洁性与控制粒度之间进行权衡。
还是不折腾了,寄希望于Go未来在net包中提供更灵活的超时控制机制吧,正如Golang的.ioCopy引入了ReadFrom和WriteTo接口一样。