softether-go/pkg/netcfg/netcfg.go
Git Sagar 51824b830e netcfg: add -connmark flag for DNAT reply routing
When VPN traffic is DNAT'd to local namespaces/VMs, reply packets have
a different source IP (namespace veth) so the policy route's
"from <VPN_IP>" rule doesn't match. CONNMARK marks all connections
arriving on the VPN interface and restores the mark on reply packets,
routing them back through the tunnel via fwmark rule.

New flag: -connmark (requires -policy-route-table)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-06-07 01:06:17 +05:30

270 lines
8.6 KiB
Go

// Package netcfg manages network configuration for the VPN tunnel:
// TAP interface addressing, routing (default, static, policy), DNS, and
// server host routes.
package netcfg
import (
"fmt"
"log"
"net"
"os"
"os/exec"
"strings"
"git.sagar.ch/sagar/softether-go/pkg/dhcp"
)
// Options controls which DHCP-provided network settings to apply.
type Options struct {
DHCP bool
AcceptDefaultGW bool
AcceptStaticRoutes bool
AcceptDNS bool
PolicyRouteTable int
ConnMark bool
}
// ConfigureTAP sets the IP address, routes, and DNS on a TAP interface from a DHCP lease.
// Returns a cleanup function that undoes the changes.
func ConfigureTAP(ifname string, lease *dhcp.Lease, acceptDefaultGW, acceptStaticRoutes, acceptDNS bool) (func(), error) {
noop := func() {}
ones, _ := lease.SubnetMask.Size()
addr := fmt.Sprintf("%s/%d", lease.ClientIP, ones)
log.Printf("tap %s: adding address %s", ifname, addr)
if err := run("ip", "addr", "add", addr, "dev", ifname); err != nil {
return noop, fmt.Errorf("ip addr add: %w", err)
}
if acceptStaticRoutes && len(lease.Routes) > 0 {
for _, r := range lease.Routes {
dest := r.Dest.String()
if dest == "0.0.0.0/0" {
if acceptDefaultGW {
log.Printf("tap %s: adding static default route via %s", ifname, r.Gateway)
if err := run("ip", "route", "add", "default", "via", r.Gateway.String(), "dev", ifname, "metric", "50"); err != nil {
log.Printf("warning: static default route: %v", err)
}
} else {
log.Printf("tap %s: skipping static default route via %s (accept-default-gateway not set)", ifname, r.Gateway)
}
continue
}
log.Printf("tap %s: adding static route %s via %s", ifname, dest, r.Gateway)
if err := run("ip", "route", "add", dest, "via", r.Gateway.String(), "dev", ifname); err != nil {
log.Printf("warning: static route %s via %s: %v", dest, r.Gateway, err)
}
}
}
if acceptDefaultGW && lease.Gateway != nil && !hasDefaultRoute(lease.Routes) {
log.Printf("tap %s: adding default route via %s", ifname, lease.Gateway)
if err := run("ip", "route", "add", "default", "via", lease.Gateway.String(), "dev", ifname, "metric", "50"); err != nil {
log.Printf("warning: default route: %v", err)
}
}
var savedResolv []byte
if acceptDNS && len(lease.DNS) > 0 {
log.Printf("tap %s: setting DNS servers %v", ifname, lease.DNS)
savedResolv = backupResolv()
writeResolv(lease.DNS)
}
cleanup := func() {
log.Printf("tap %s: flushing addresses", ifname)
run("ip", "addr", "flush", "dev", ifname)
if savedResolv != nil {
restoreResolv(savedResolv)
}
}
return cleanup, nil
}
// ReconfigureTAP flushes the current TAP config and applies a new lease.
// Used when DHCP renewal returns a different IP address.
func ReconfigureTAP(ifname string, lease *dhcp.Lease, acceptDefaultGW, acceptStaticRoutes, acceptDNS bool) {
log.Printf("tap %s: reconfiguring for new IP", ifname)
run("ip", "addr", "flush", "dev", ifname)
// Ignore errors — best effort reconfiguration
if _, err := ConfigureTAP(ifname, lease, acceptDefaultGW, acceptStaticRoutes, acceptDNS); err != nil {
log.Printf("warning: reconfigure tap: %v", err)
}
}
// ConfigurePolicyRoute sets up policy routing so packets from the VPN IP are routed
// back through the VPN gateway. Needed when the VPN server forwards ports to the
// client — without it, reply packets use the default route instead of the VPN tunnel.
//
// When connmark is true, also sets up CONNMARK rules so that DNAT'd connections
// (e.g. port forwards to namespaces/VMs) have their reply traffic routed back
// through the tunnel.
func ConfigurePolicyRoute(ifname string, lease *dhcp.Lease, table int, connmark bool) func() {
t := fmt.Sprintf("%d", table)
clientIP := lease.ClientIP.String()
gw := lease.Gateway.String()
// Policy route: packets from VPN IP use VPN gateway
runQuiet("ip", "rule", "del", "table", t)
run("ip", "route", "replace", "default", "via", gw, "dev", ifname, "table", t)
if err := run("ip", "rule", "add", "from", clientIP, "table", t); err != nil {
log.Printf("warning: policy rule: %v", err)
} else {
log.Printf("policy route: from %s via %s dev %s table %s", clientIP, gw, ifname, t)
}
if connmark {
mark := t
// CONNMARK: mark connections arriving on VPN interface, restore mark on replies.
// This ensures DNAT'd traffic (forwarded to namespaces/VMs) returns via the
// tunnel instead of the default route. Without this, reply packets from DNAT
// targets (e.g. namespace veth) have a different source IP than the VPN IP,
// so the "from <VPN_IP>" policy rule doesn't match them.
log.Printf("connmark: adding CONNMARK rules on %s (mark %s)", ifname, mark)
runQuiet("ip", "rule", "del", "fwmark", mark, "table", t)
run("iptables", "-t", "mangle", "-A", "PREROUTING", "-i", ifname, "-j", "CONNMARK", "--set-mark", mark)
run("iptables", "-t", "mangle", "-A", "PREROUTING", "!", "-i", ifname, "-m", "connmark", "--mark", mark, "-j", "CONNMARK", "--restore-mark")
if err := run("ip", "rule", "add", "fwmark", mark, "table", t); err != nil {
log.Printf("warning: fwmark rule: %v", err)
} else {
log.Printf("connmark: fwmark %s → table %s for DNAT reply routing", mark, t)
}
}
return func() {
runQuiet("ip", "rule", "del", "table", t)
runQuiet("ip", "route", "del", "default", "table", t)
if connmark {
mark := t
log.Printf("connmark: removing CONNMARK rules (mark %s)", mark)
runQuiet("ip", "rule", "del", "fwmark", mark, "table", t)
runQuiet("iptables", "-t", "mangle", "-D", "PREROUTING", "-i", ifname, "-j", "CONNMARK", "--set-mark", mark)
runQuiet("iptables", "-t", "mangle", "-D", "PREROUTING", "!", "-i", ifname, "-m", "connmark", "--mark", mark, "-j", "CONNMARK", "--restore-mark")
}
log.Printf("policy route: cleaned up table %s", t)
}
}
// AddServerRoute adds a /32 host route to the VPN server via the current default gateway.
// This prevents routing loops when the VPN's default route is installed.
// Returns a cleanup function that removes the route.
func AddServerRoute(serverIP net.IP) func() {
noop := func() {}
if serverIP == nil {
return noop
}
gw, dev := getDefaultGateway()
if gw == nil {
log.Println("warning: no default gateway found, skipping server route")
return noop
}
route := serverIP.String() + "/32"
args := []string{"route", "add", route, "via", gw.String()}
if dev != "" {
args = append(args, "dev", dev)
}
if err := run("ip", args...); err != nil {
log.Printf("warning: add server route: %v", err)
return noop
}
log.Printf("added route: %s via %s", route, gw)
return func() {
run("ip", "route", "del", route)
log.Printf("removed route: %s", route)
}
}
// ResolveHost resolves a hostname to an IPv4 address. If it's already an IP, returns it.
func ResolveHost(host string) net.IP {
if ip := net.ParseIP(host); ip != nil {
return ip
}
ips, err := net.LookupIP(host)
if err != nil {
log.Printf("warning: could not resolve %s: %v", host, err)
return nil
}
for _, ip := range ips {
if ip.To4() != nil {
return ip.To4()
}
}
log.Printf("warning: no IPv4 address for %s", host)
return nil
}
func hasDefaultRoute(routes []dhcp.Route) bool {
for _, r := range routes {
ones, bits := r.Dest.Mask.Size()
if ones == 0 && bits == 32 {
return true
}
}
return false
}
const resolvPath = "/etc/resolv.conf"
func backupResolv() []byte {
data, err := os.ReadFile(resolvPath)
if err != nil {
log.Printf("warning: backup resolv.conf: %v", err)
return nil
}
return data
}
func writeResolv(servers []net.IP) {
var buf strings.Builder
buf.WriteString("# Generated by softether-go\n")
for _, ip := range servers {
fmt.Fprintf(&buf, "nameserver %s\n", ip)
}
if err := os.WriteFile(resolvPath, []byte(buf.String()), 0644); err != nil {
log.Printf("warning: write resolv.conf: %v", err)
return
}
log.Printf("dns: set nameservers %v", servers)
}
func restoreResolv(saved []byte) {
if err := os.WriteFile(resolvPath, saved, 0644); err != nil {
log.Printf("warning: restore resolv.conf: %v", err)
return
}
log.Println("dns: restored resolv.conf")
}
func getDefaultGateway() (net.IP, string) {
out, err := exec.Command("ip", "route", "show", "default").Output()
if err != nil {
return nil, ""
}
fields := strings.Fields(string(out))
var gw net.IP
var dev string
for i, f := range fields {
if f == "via" && i+1 < len(fields) {
gw = net.ParseIP(fields[i+1])
}
if f == "dev" && i+1 < len(fields) {
dev = fields[i+1]
}
}
return gw, dev
}
func run(name string, args ...string) error {
cmd := exec.Command(name, args...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func runQuiet(name string, args ...string) error {
return exec.Command(name, args...).Run()
}