softether-go/pkg/client/tunnel.go
Git Sagar 47e06b525c tunnel: replace write mutex with channel-based single writer
All writes (frames, keepalive, DHCP renewal) are queued to a buffered
channel and drained by a single writer goroutine. Eliminates mutex
contention on the data path entirely.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-06-06 23:16:27 +05:30

221 lines
6.3 KiB
Go

package client
import (
"encoding/binary"
"fmt"
"io"
"math/rand"
"sync"
"time"
)
// TCP block framing constants.
// After the HTTP handshake completes, SoftEther switches to raw TCP framing:
//
// Send: uint32(numBlocks) + [uint32(blockSize) + blockData]...
// Recv: same format
// Keepalive: uint32(0xFFFFFFFF) + uint32(randSize) + randData
//
// See: https://github.com/SoftEtherVPN/SoftEtherVPN/blob/v5.02.5187/src/Cedar/Connection.c#L1654 (TcpSockRecv)
// See: https://github.com/SoftEtherVPN/SoftEtherVPN/blob/v5.02.5187/src/Cedar/Connection.c#L1761 (TcpSockSend)
const (
keepAliveMagic uint32 = 0xFFFFFFFF // Magic value indicating a keepalive packet
maxKeepaliveSize uint32 = 512 // Max random data in keepalive
keepAliveInterval = 3 * time.Second
writeChanSize = 128 // Buffered write queue depth
)
// Tunnel handles bidirectional TCP block framing for Ethernet frames over a
// SoftEther VPN session. Each "block" is one Ethernet frame.
type Tunnel struct {
sess *Session
writeCh chan []byte // serialized messages queued for the single writer
writeErr error // last write error
stopCh chan struct{}
stopped sync.Once
}
// NewTunnel creates a tunnel from an established session.
// Call StartKeepalive() before reading/writing frames.
func NewTunnel(sess *Session) *Tunnel {
t := &Tunnel{
sess: sess,
writeCh: make(chan []byte, writeChanSize),
stopCh: make(chan struct{}),
}
go t.writeLoop()
return t
}
// writeLoop is the single goroutine that writes to the connection.
// All writes are serialized through writeCh — no mutex needed.
func (t *Tunnel) writeLoop() {
for buf := range t.writeCh {
if _, err := t.sess.Conn.Write(buf); err != nil {
t.writeErr = err
return
}
}
}
// Close stops the keepalive goroutine and closes the underlying connection.
func (t *Tunnel) Close() error {
t.stopped.Do(func() {
close(t.stopCh)
close(t.writeCh)
})
return t.sess.Conn.Close()
}
// ReadFrames reads a batch of Ethernet frames from the server.
// Returns nil (no error) for keepalive packets. Blocks until data arrives.
func (t *Tunnel) ReadFrames() ([][]byte, error) {
var numBlocks uint32
if err := binary.Read(t.sess.Conn, binary.BigEndian, &numBlocks); err != nil {
return nil, fmt.Errorf("read num blocks: %w", err)
}
// Keepalive: server sends 0xFFFFFFFF + uint32(size) + random data
if numBlocks == keepAliveMagic {
var size uint32
if err := binary.Read(t.sess.Conn, binary.BigEndian, &size); err != nil {
return nil, fmt.Errorf("read keepalive size: %w", err)
}
if _, err := io.CopyN(io.Discard, t.sess.Conn, int64(size)); err != nil {
return nil, fmt.Errorf("discard keepalive: %w", err)
}
return nil, nil
}
frames := make([][]byte, 0, numBlocks)
for i := uint32(0); i < numBlocks; i++ {
var size uint32
if err := binary.Read(t.sess.Conn, binary.BigEndian, &size); err != nil {
return nil, fmt.Errorf("read block size: %w", err)
}
buf := make([]byte, size)
if _, err := io.ReadFull(t.sess.Conn, buf); err != nil {
return nil, fmt.Errorf("read block data: %w", err)
}
frames = append(frames, buf)
}
return frames, nil
}
// WriteFrames sends a batch of Ethernet frames to the server.
// Safe for concurrent use — messages are queued for the single writer goroutine.
func (t *Tunnel) WriteFrames(frames [][]byte) error {
// Calculate total size: 4 (numBlocks) + per frame: 4 (size) + len(data)
total := 4
for _, f := range frames {
total += 4 + len(f)
}
buf := make([]byte, total)
off := 0
binary.BigEndian.PutUint32(buf[off:], uint32(len(frames)))
off += 4
for _, f := range frames {
binary.BigEndian.PutUint32(buf[off:], uint32(len(f)))
off += 4
copy(buf[off:], f)
off += len(f)
}
select {
case t.writeCh <- buf:
return t.writeErr
case <-t.stopCh:
return fmt.Errorf("tunnel closed")
}
}
// FrameHandler is called for each Ethernet frame received from the server.
// Returning a non-nil error stops the bridge.
type FrameHandler func(frame []byte) error
// Bridge runs bidirectional frame forwarding between the tunnel and a TAP device.
// tapRead reads one Ethernet frame into buf, returning (n, err).
// tapWrite writes one Ethernet frame.
// onFrame is called for each server frame before writing to TAP (e.g. for DHCP).
// Blocks until an error occurs on either direction.
func (t *Tunnel) Bridge(tapRead func(buf []byte) (int, error), tapWrite func(buf []byte) (int, error), onFrame FrameHandler) error {
errCh := make(chan error, 2)
// Server → TAP
go func() {
for {
frames, err := t.ReadFrames()
if err != nil {
errCh <- fmt.Errorf("read from server: %w", err)
return
}
for _, frame := range frames {
if onFrame != nil {
if err := onFrame(frame); err != nil {
errCh <- err
return
}
}
if _, err := tapWrite(frame); err != nil {
errCh <- fmt.Errorf("write to tap: %w", err)
return
}
}
}
}()
// TAP → Server
// Note: tapRead blocks on the TAP fd which doesn't support deadlines.
// On disconnect, this goroutine survives until the next TAP frame arrives,
// at which point WriteFrames fails on the closed connection and it exits.
go func() {
buf := make([]byte, 1600)
for {
n, err := tapRead(buf)
if err != nil {
errCh <- fmt.Errorf("read from tap: %w", err)
return
}
if err := t.WriteFrames([][]byte{buf[:n]}); err != nil {
errCh <- fmt.Errorf("write to server: %w", err)
return
}
}
}()
return <-errCh
}
// StartKeepalive sends periodic keepalive packets to prevent the server from
// timing out the connection. Must be called after the session enters tunnel mode.
// See: https://github.com/SoftEtherVPN/SoftEtherVPN/blob/v5.02.5187/src/Cedar/Connection.c#L1779
func (t *Tunnel) StartKeepalive() {
go func() {
ticker := time.NewTicker(keepAliveInterval)
defer ticker.Stop()
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
randBuf := make([]byte, maxKeepaliveSize)
for {
select {
case <-t.stopCh:
return
case <-ticker.C:
size := uint32(rng.Intn(int(maxKeepaliveSize))) + 1
rng.Read(randBuf[:size])
// Assemble keepalive into one buffer
buf := make([]byte, 8+size)
binary.BigEndian.PutUint32(buf[0:4], keepAliveMagic)
binary.BigEndian.PutUint32(buf[4:8], size)
copy(buf[8:], randBuf[:size])
select {
case t.writeCh <- buf:
case <-t.stopCh:
return
}
}
}
}()
}