softether-go/pkg/client/tunnel.go
Git Sagar b3f4c5f42b tunnel: add write mutex for concurrent safety
WriteFrames and keepalive both write multi-part messages to the TLS
connection. Without synchronization, their writes could interleave
and corrupt the framing. Add writeMu to serialize all tunnel writes.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-06-06 22:53:06 +05:30

197 lines
5.8 KiB
Go

package client
import (
"encoding/binary"
"fmt"
"io"
"math/rand"
"sync"
"time"
)
// TCP block framing constants.
// After the HTTP handshake completes, SoftEther switches to raw TCP framing:
//
// Send: uint32(numBlocks) + [uint32(blockSize) + blockData]...
// Recv: same format
// Keepalive: uint32(0xFFFFFFFF) + uint32(randSize) + randData
//
// See: https://github.com/SoftEtherVPN/SoftEtherVPN/blob/v5.02.5187/src/Cedar/Connection.c#L1654 (TcpSockRecv)
// See: https://github.com/SoftEtherVPN/SoftEtherVPN/blob/v5.02.5187/src/Cedar/Connection.c#L1761 (TcpSockSend)
const (
keepAliveMagic uint32 = 0xFFFFFFFF // Magic value indicating a keepalive packet
maxKeepaliveSize uint32 = 512 // Max random data in keepalive
keepAliveInterval = 3 * time.Second
)
// Tunnel handles bidirectional TCP block framing for Ethernet frames over a
// SoftEther VPN session. Each "block" is one Ethernet frame.
type Tunnel struct {
sess *Session
writeMu sync.Mutex
stopCh chan struct{}
stopped sync.Once
}
// NewTunnel creates a tunnel from an established session.
// Call StartKeepalive() before reading/writing frames.
func NewTunnel(sess *Session) *Tunnel {
return &Tunnel{
sess: sess,
stopCh: make(chan struct{}),
}
}
// Close stops the keepalive goroutine and closes the underlying connection.
func (t *Tunnel) Close() error {
t.stopped.Do(func() { close(t.stopCh) })
return t.sess.Conn.Close()
}
// ReadFrames reads a batch of Ethernet frames from the server.
// Returns nil (no error) for keepalive packets. Blocks until data arrives.
func (t *Tunnel) ReadFrames() ([][]byte, error) {
var numBlocks uint32
if err := binary.Read(t.sess.Conn, binary.BigEndian, &numBlocks); err != nil {
return nil, fmt.Errorf("read num blocks: %w", err)
}
// Keepalive: server sends 0xFFFFFFFF + uint32(size) + random data
if numBlocks == keepAliveMagic {
var size uint32
if err := binary.Read(t.sess.Conn, binary.BigEndian, &size); err != nil {
return nil, fmt.Errorf("read keepalive size: %w", err)
}
if _, err := io.CopyN(io.Discard, t.sess.Conn, int64(size)); err != nil {
return nil, fmt.Errorf("discard keepalive: %w", err)
}
return nil, nil
}
frames := make([][]byte, 0, numBlocks)
for i := uint32(0); i < numBlocks; i++ {
var size uint32
if err := binary.Read(t.sess.Conn, binary.BigEndian, &size); err != nil {
return nil, fmt.Errorf("read block size: %w", err)
}
buf := make([]byte, size)
if _, err := io.ReadFull(t.sess.Conn, buf); err != nil {
return nil, fmt.Errorf("read block data: %w", err)
}
frames = append(frames, buf)
}
return frames, nil
}
// WriteFrames sends a batch of Ethernet frames to the server.
// Safe for concurrent use from multiple goroutines.
func (t *Tunnel) WriteFrames(frames [][]byte) error {
t.writeMu.Lock()
defer t.writeMu.Unlock()
if err := binary.Write(t.sess.Conn, binary.BigEndian, uint32(len(frames))); err != nil {
return fmt.Errorf("write num blocks: %w", err)
}
for _, frame := range frames {
if err := binary.Write(t.sess.Conn, binary.BigEndian, uint32(len(frame))); err != nil {
return fmt.Errorf("write block size: %w", err)
}
if _, err := t.sess.Conn.Write(frame); err != nil {
return fmt.Errorf("write block data: %w", err)
}
}
return nil
}
// FrameHandler is called for each Ethernet frame received from the server.
// Returning a non-nil error stops the bridge.
type FrameHandler func(frame []byte) error
// Bridge runs bidirectional frame forwarding between the tunnel and a TAP device.
// tapRead reads one Ethernet frame into buf, returning (n, err).
// tapWrite writes one Ethernet frame.
// onFrame is called for each server frame before writing to TAP (e.g. for DHCP).
// Blocks until an error occurs on either direction.
func (t *Tunnel) Bridge(tapRead func(buf []byte) (int, error), tapWrite func(buf []byte) (int, error), onFrame FrameHandler) error {
errCh := make(chan error, 2)
// Server → TAP
go func() {
for {
frames, err := t.ReadFrames()
if err != nil {
errCh <- fmt.Errorf("read from server: %w", err)
return
}
for _, frame := range frames {
if onFrame != nil {
if err := onFrame(frame); err != nil {
errCh <- err
return
}
}
if _, err := tapWrite(frame); err != nil {
errCh <- fmt.Errorf("write to tap: %w", err)
return
}
}
}
}()
// TAP → Server
// Note: tapRead blocks on the TAP fd which doesn't support deadlines.
// On disconnect, this goroutine survives until the next TAP frame arrives,
// at which point WriteFrames fails on the closed connection and it exits.
go func() {
buf := make([]byte, 1600)
for {
n, err := tapRead(buf)
if err != nil {
errCh <- fmt.Errorf("read from tap: %w", err)
return
}
if err := t.WriteFrames([][]byte{buf[:n]}); err != nil {
errCh <- fmt.Errorf("write to server: %w", err)
return
}
}
}()
return <-errCh
}
// StartKeepalive sends periodic keepalive packets to prevent the server from
// timing out the connection. Must be called after the session enters tunnel mode.
// See: https://github.com/SoftEtherVPN/SoftEtherVPN/blob/v5.02.5187/src/Cedar/Connection.c#L1779
func (t *Tunnel) StartKeepalive() {
go func() {
ticker := time.NewTicker(keepAliveInterval)
defer ticker.Stop()
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
randBuf := make([]byte, maxKeepaliveSize)
for {
select {
case <-t.stopCh:
return
case <-ticker.C:
size := uint32(rng.Intn(int(maxKeepaliveSize))) + 1
rng.Read(randBuf[:size])
t.writeMu.Lock()
err1 := binary.Write(t.sess.Conn, binary.BigEndian, keepAliveMagic)
var err2, err3 error
if err1 == nil {
err2 = binary.Write(t.sess.Conn, binary.BigEndian, size)
}
if err2 == nil {
_, err3 = t.sess.Conn.Write(randBuf[:size])
}
t.writeMu.Unlock()
if err1 != nil || err2 != nil || err3 != nil {
return
}
}
}
}()
}