refactor: extract session/netcfg/tunnel, add mac/dhcp/policy-route flags

- Split cmd/softether-go into main.go (flags, reconnect loop) and
  session.go (session lifecycle, DHCP orchestration)
- Extract network config to pkg/netcfg (TAP config, routing, DNS, policy routes)
- Move frame bridging to pkg/client/tunnel.go as Bridge() method
- Add -mac, -dhcp, -policy-route-table CLI flags
- Add SetMAC() to pkg/tap for deterministic DHCP assignments
- Update all docs to reflect new structure and flags

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Git Sagar 2026-06-06 16:43:12 +05:30
parent 846ed96ff4
commit 17c1063e1f
10 changed files with 495 additions and 332 deletions

View file

@ -98,6 +98,62 @@ func (t *Tunnel) WriteFrames(frames [][]byte) error {
return nil
}
// FrameHandler is called for each Ethernet frame received from the server.
// Returning a non-nil error stops the bridge.
type FrameHandler func(frame []byte) error
// Bridge runs bidirectional frame forwarding between the tunnel and a TAP device.
// tapRead reads one Ethernet frame into buf, returning (n, err).
// tapWrite writes one Ethernet frame.
// onFrame is called for each server frame before writing to TAP (e.g. for DHCP).
// Blocks until an error occurs on either direction.
func (t *Tunnel) Bridge(tapRead func(buf []byte) (int, error), tapWrite func(buf []byte) (int, error), onFrame FrameHandler) error {
errCh := make(chan error, 2)
// Server → TAP
go func() {
for {
frames, err := t.ReadFrames()
if err != nil {
errCh <- fmt.Errorf("read from server: %w", err)
return
}
for _, frame := range frames {
if onFrame != nil {
if err := onFrame(frame); err != nil {
errCh <- err
return
}
}
if _, err := tapWrite(frame); err != nil {
errCh <- fmt.Errorf("write to tap: %w", err)
return
}
}
}
}()
// TAP → Server
go func() {
buf := make([]byte, 1600)
for {
n, err := tapRead(buf)
if err != nil {
errCh <- fmt.Errorf("read from tap: %w", err)
return
}
frame := make([]byte, n)
copy(frame, buf[:n])
if err := t.WriteFrames([][]byte{frame}); err != nil {
errCh <- fmt.Errorf("write to server: %w", err)
return
}
}
}()
return <-errCh
}
// StartKeepalive sends periodic keepalive packets to prevent the server from
// timing out the connection. Must be called after the session enters tunnel mode.
// See: https://github.com/SoftEtherVPN/SoftEtherVPN/blob/v5.02.5187/src/Cedar/Connection.c#L1779

216
pkg/netcfg/netcfg.go Normal file
View file

@ -0,0 +1,216 @@
// Package netcfg manages network configuration for the VPN tunnel:
// TAP interface addressing, routing (default, static, policy), DNS, and
// server host routes.
package netcfg
import (
"fmt"
"log"
"net"
"os"
"os/exec"
"strings"
"git.sagar.ch/sagar/softether-go/pkg/dhcp"
)
// Options controls which DHCP-provided network settings to apply.
type Options struct {
DHCP bool
AcceptDefaultGW bool
AcceptStaticRoutes bool
AcceptDNS bool
PolicyRouteTable int
}
// ConfigureTAP sets the IP address, routes, and DNS on a TAP interface from a DHCP lease.
// Returns a cleanup function that undoes the changes.
func ConfigureTAP(ifname string, lease *dhcp.Lease, acceptDefaultGW, acceptStaticRoutes, acceptDNS bool) (func(), error) {
noop := func() {}
ones, _ := lease.SubnetMask.Size()
addr := fmt.Sprintf("%s/%d", lease.ClientIP, ones)
if err := run("ip", "addr", "add", addr, "dev", ifname); err != nil {
return noop, fmt.Errorf("ip addr add: %w", err)
}
if acceptStaticRoutes && len(lease.Routes) > 0 {
for _, r := range lease.Routes {
dest := r.Dest.String()
if dest == "0.0.0.0/0" {
if acceptDefaultGW {
if err := run("ip", "route", "add", "default", "via", r.Gateway.String(), "dev", ifname, "metric", "50"); err != nil {
log.Printf("warning: static default route: %v", err)
}
}
continue
}
if err := run("ip", "route", "add", dest, "via", r.Gateway.String(), "dev", ifname); err != nil {
log.Printf("warning: static route %s via %s: %v", dest, r.Gateway, err)
}
}
}
if acceptDefaultGW && lease.Gateway != nil && !hasDefaultRoute(lease.Routes) {
if err := run("ip", "route", "add", "default", "via", lease.Gateway.String(), "dev", ifname, "metric", "50"); err != nil {
log.Printf("warning: default route: %v", err)
}
}
var savedResolv []byte
if acceptDNS && len(lease.DNS) > 0 {
savedResolv = backupResolv()
writeResolv(lease.DNS)
}
cleanup := func() {
run("ip", "addr", "flush", "dev", ifname)
if savedResolv != nil {
restoreResolv(savedResolv)
}
}
return cleanup, nil
}
// ConfigurePolicyRoute sets up policy routing so packets from the VPN IP are routed
// back through the VPN gateway. Needed when the VPN server forwards ports to the
// client — without it, reply packets use the default route instead of the VPN tunnel.
func ConfigurePolicyRoute(ifname string, lease *dhcp.Lease, table int) func() {
t := fmt.Sprintf("%d", table)
clientIP := lease.ClientIP.String()
gw := lease.Gateway.String()
run("ip", "rule", "del", "table", t)
run("ip", "route", "replace", "default", "via", gw, "dev", ifname, "table", t)
if err := run("ip", "rule", "add", "from", clientIP, "table", t); err != nil {
log.Printf("warning: policy rule: %v", err)
} else {
log.Printf("policy route: from %s via %s dev %s table %s", clientIP, gw, ifname, t)
}
return func() {
run("ip", "rule", "del", "table", t)
run("ip", "route", "del", "default", "table", t)
log.Printf("policy route: cleaned up table %s", t)
}
}
// AddServerRoute adds a /32 host route to the VPN server via the current default gateway.
// This prevents routing loops when the VPN's default route is installed.
// Returns a cleanup function that removes the route.
func AddServerRoute(serverIP net.IP) func() {
noop := func() {}
if serverIP == nil {
return noop
}
gw, dev := getDefaultGateway()
if gw == nil {
log.Println("warning: no default gateway found, skipping server route")
return noop
}
route := serverIP.String() + "/32"
args := []string{"route", "add", route, "via", gw.String()}
if dev != "" {
args = append(args, "dev", dev)
}
if err := run("ip", args...); err != nil {
log.Printf("warning: add server route: %v", err)
return noop
}
log.Printf("added route: %s via %s", route, gw)
return func() {
run("ip", "route", "del", route)
log.Printf("removed route: %s", route)
}
}
// ResolveHost resolves a hostname to an IPv4 address. If it's already an IP, returns it.
func ResolveHost(host string) net.IP {
if ip := net.ParseIP(host); ip != nil {
return ip
}
ips, err := net.LookupIP(host)
if err != nil {
log.Printf("warning: could not resolve %s: %v", host, err)
return nil
}
for _, ip := range ips {
if ip.To4() != nil {
return ip.To4()
}
}
log.Printf("warning: no IPv4 address for %s", host)
return nil
}
func hasDefaultRoute(routes []dhcp.Route) bool {
for _, r := range routes {
ones, bits := r.Dest.Mask.Size()
if ones == 0 && bits == 32 {
return true
}
}
return false
}
const resolvPath = "/etc/resolv.conf"
func backupResolv() []byte {
data, err := os.ReadFile(resolvPath)
if err != nil {
log.Printf("warning: backup resolv.conf: %v", err)
return nil
}
return data
}
func writeResolv(servers []net.IP) {
var buf strings.Builder
buf.WriteString("# Generated by softether-go\n")
for _, ip := range servers {
fmt.Fprintf(&buf, "nameserver %s\n", ip)
}
if err := os.WriteFile(resolvPath, []byte(buf.String()), 0644); err != nil {
log.Printf("warning: write resolv.conf: %v", err)
return
}
log.Printf("dns: set nameservers %v", servers)
}
func restoreResolv(saved []byte) {
if err := os.WriteFile(resolvPath, saved, 0644); err != nil {
log.Printf("warning: restore resolv.conf: %v", err)
return
}
log.Println("dns: restored resolv.conf")
}
func getDefaultGateway() (net.IP, string) {
out, err := exec.Command("ip", "route", "show", "default").Output()
if err != nil {
return nil, ""
}
fields := strings.Fields(string(out))
var gw net.IP
var dev string
for i, f := range fields {
if f == "via" && i+1 < len(fields) {
gw = net.ParseIP(fields[i+1])
}
if f == "dev" && i+1 < len(fields) {
dev = fields[i+1]
}
}
return gw, dev
}
func run(name string, args ...string) error {
cmd := exec.Command(name, args...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}

View file

@ -106,6 +106,27 @@ func (d *Device) MAC() (net.HardwareAddr, error) {
return mac, nil
}
// SetMAC sets the hardware (MAC) address of the TAP interface.
func (d *Device) SetMAC(mac net.HardwareAddr) error {
sock, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, 0)
if err != nil {
return fmt.Errorf("socket: %w", err)
}
defer unix.Close(sock)
var ifr [40]byte
copy(ifr[:ifnameSize], d.Name)
ifr[ifnameSize] = 1 // sa_family = ARPHRD_ETHER
ifr[ifnameSize+1] = 0
copy(ifr[ifnameSize+2:ifnameSize+8], mac)
_, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(sock), unix.SIOCSIFHWADDR, uintptr(unsafe.Pointer(&ifr[0])))
if errno != 0 {
return fmt.Errorf("SIOCSIFHWADDR: %w", errno)
}
return nil
}
// SetUp brings the TAP interface up (equivalent to `ip link set <name> up`).
func (d *Device) SetUp() error {
sock, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, 0)