From 9593becc2ace71e2fb0cb05adf59831a815ce795 Mon Sep 17 00:00:00 2001 From: Alexey Dolotov Date: Wed, 20 May 2026 17:55:06 +0000 Subject: [PATCH] internal/cli: consolidate duplicated SNI-DNS check MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `doctor`'s checkSecretHost and the proxy-startup warnSNIMismatch each carried their own copy of the same logic: resolve the secret hostname, determine the server's public IPv4/IPv6 (config first, getIP fallback), and compare the two sets. Extract that data-gathering into runSNICheck (internal/cli/sni_check.go), returning an sniCheckResult. The success decision stays with each caller because the rules genuinely differ — `doctor` reports OK when any family matches, while the startup warning requires every detected family to match — so only the gathering is shared, not the verdict. No behavior change: both callers produce byte-identical output and the same return values as before. --- internal/cli/doctor.go | 52 +++++++++++++------------- internal/cli/run_proxy.go | 51 ++++++++----------------- internal/cli/sni_check.go | 78 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 120 insertions(+), 61 deletions(-) create mode 100644 internal/cli/sni_check.go diff --git a/internal/cli/doctor.go b/internal/cli/doctor.go index 3d7e3d4cd..e566f7e38 100644 --- a/internal/cli/doctor.go +++ b/internal/cli/doctor.go @@ -361,26 +361,17 @@ func (d *Doctor) checkFrontingDomain(ntw mtglib.Network) bool { } func (d *Doctor) checkSecretHost(resolver *net.Resolver, ntw mtglib.Network) bool { - addresses, err := resolver.LookupIPAddr(context.Background(), d.conf.Secret.Host) - if err != nil { + res := runSNICheck(context.Background(), resolver, d.conf, ntw) + + if res.ResolveErr != nil { tplError.Execute(os.Stdout, map[string]any{ //nolint: errcheck "description": fmt.Sprintf("cannot resolve DNS name of %s", d.conf.Secret.Host), - "error": err, + "error": res.ResolveErr, }) return false } - ourIP4 := d.conf.PublicIPv4.Get(nil) - if ourIP4 == nil { - ourIP4 = getIP(ntw, "tcp4") - } - - ourIP6 := d.conf.PublicIPv6.Get(nil) - if ourIP6 == nil { - ourIP6 = getIP(ntw, "tcp6") - } - - if ourIP4 == nil && ourIP6 == nil { + if !res.PublicIPKnown() { tplError.Execute(os.Stdout, map[string]any{ //nolint: errcheck "description": "cannot detect public IP address", "error": errors.New("cannot detect automatically and public-ipv4/public-ipv6 are not set in config"), @@ -388,25 +379,34 @@ func (d *Doctor) checkSecretHost(resolver *net.Resolver, ntw mtglib.Network) boo return false } - strAddresses := []string{} - for _, value := range addresses { - if (ourIP4 != nil && value.IP.String() == ourIP4.String()) || - (ourIP6 != nil && value.IP.String() == ourIP6.String()) { - tplODNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck - "ip": value.IP, - "hostname": d.conf.Secret.Host, - }) - return true + if res.IPv4Match || res.IPv6Match { + var matched net.IP + + for _, ip := range res.Resolved { + if (res.OurIPv4 != nil && ip.String() == res.OurIPv4.String()) || + (res.OurIPv6 != nil && ip.String() == res.OurIPv6.String()) { + matched = ip + break + } } - strAddresses = append(strAddresses, `"`+value.IP.String()+`"`) + tplODNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck + "ip": matched, + "hostname": d.conf.Secret.Host, + }) + return true + } + + strAddresses := make([]string, 0, len(res.Resolved)) + for _, ip := range res.Resolved { + strAddresses = append(strAddresses, `"`+ip.String()+`"`) } tplEDNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck "hostname": d.conf.Secret.Host, "resolved": strings.Join(strAddresses, ", "), - "ip4": ourIP4, - "ip6": ourIP6, + "ip4": res.OurIPv4, + "ip6": res.OurIPv6, }) return false diff --git a/internal/cli/run_proxy.go b/internal/cli/run_proxy.go index 11520847d..704666381 100644 --- a/internal/cli/run_proxy.go +++ b/internal/cli/run_proxy.go @@ -215,72 +215,53 @@ func warnSNIMismatch(conf *config.Config, ntw mtglib.Network, log mtglib.Logger) return } - addresses, err := net.DefaultResolver.LookupIPAddr(context.Background(), host) - if err != nil { + res := runSNICheck(context.Background(), net.DefaultResolver, conf, ntw) + + if res.ResolveErr != nil { log.BindStr("hostname", host). - WarningError("SNI-DNS check: cannot resolve secret hostname", err) + WarningError("SNI-DNS check: cannot resolve secret hostname", res.ResolveErr) return } - ourIP4 := conf.PublicIPv4.Get(nil) - if ourIP4 == nil { - ourIP4 = getIP(ntw, "tcp4") - } - - ourIP6 := conf.PublicIPv6.Get(nil) - if ourIP6 == nil { - ourIP6 = getIP(ntw, "tcp6") - } - - if ourIP4 == nil && ourIP6 == nil { + if !res.PublicIPKnown() { log.Warning("SNI-DNS check: cannot detect public IP address; set public-ipv4/public-ipv6 in config or run 'mtg doctor'") return } - v4Match := ourIP4 == nil - v6Match := ourIP6 == nil - - for _, addr := range addresses { - if ourIP4 != nil && addr.IP.String() == ourIP4.String() { - v4Match = true - } - - if ourIP6 != nil && addr.IP.String() == ourIP6.String() { - v6Match = true - } - } + v4Match := res.OurIPv4 == nil || res.IPv4Match + v6Match := res.OurIPv6 == nil || res.IPv6Match if v4Match && v6Match { return } - resolved := make([]string, 0, len(addresses)) - for _, addr := range addresses { - resolved = append(resolved, addr.IP.String()) + resolved := make([]string, 0, len(res.Resolved)) + for _, ip := range res.Resolved { + resolved = append(resolved, ip.String()) } our := "" - if ourIP4 != nil { - our = ourIP4.String() + if res.OurIPv4 != nil { + our = res.OurIPv4.String() } - if ourIP6 != nil { + if res.OurIPv6 != nil { if our != "" { our += "/" } - our += ourIP6.String() + our += res.OurIPv6.String() } entry := log.BindStr("hostname", host). BindStr("resolved", strings.Join(resolved, ", ")). BindStr("public_ip", our) - if ourIP4 != nil { + if res.OurIPv4 != nil { entry = entry.BindStr("ipv4_match", fmt.Sprintf("%t", v4Match)) } - if ourIP6 != nil { + if res.OurIPv6 != nil { entry = entry.BindStr("ipv6_match", fmt.Sprintf("%t", v6Match)) } diff --git a/internal/cli/sni_check.go b/internal/cli/sni_check.go new file mode 100644 index 000000000..d4dbade9b --- /dev/null +++ b/internal/cli/sni_check.go @@ -0,0 +1,78 @@ +package cli + +import ( + "context" + "net" + + "github.com/9seconds/mtg/v2/internal/config" + "github.com/9seconds/mtg/v2/mtglib" +) + +// sniCheckResult holds the data gathered while comparing the secret +// hostname's DNS records against this server's public IP addresses. +// +// IPv4Match / IPv6Match report whether a resolved record actually equals the +// corresponding public IP. They are false when that family's public IP could +// not be determined — there is nothing to compare against. Callers decide +// what counts as a clean result from these fields: `mtg doctor` and the +// startup warning apply different rules. +type sniCheckResult struct { + Resolved []net.IP + OurIPv4 net.IP + OurIPv6 net.IP + IPv4Match bool + IPv6Match bool + ResolveErr error +} + +// PublicIPKnown reports whether at least one public IP family was detected. +func (r sniCheckResult) PublicIPKnown() bool { + return r.OurIPv4 != nil || r.OurIPv6 != nil +} + +// runSNICheck resolves conf.Secret.Host and compares the records with this +// server's public IPv4 and IPv6. Public IPs come from config first and fall +// back to on-the-fly detection via ntw. It gathers data only — it does not +// decide success; see sniCheckResult. +func runSNICheck( + ctx context.Context, + resolver *net.Resolver, + conf *config.Config, + ntw mtglib.Network, +) sniCheckResult { + res := sniCheckResult{} + + addrs, err := resolver.LookupIPAddr(ctx, conf.Secret.Host) + if err != nil { + res.ResolveErr = err + + return res + } + + res.Resolved = make([]net.IP, 0, len(addrs)) + for _, a := range addrs { + res.Resolved = append(res.Resolved, a.IP) + } + + res.OurIPv4 = conf.PublicIPv4.Get(nil) + if res.OurIPv4 == nil { + res.OurIPv4 = getIP(ntw, "tcp4") + } + + res.OurIPv6 = conf.PublicIPv6.Get(nil) + if res.OurIPv6 == nil { + res.OurIPv6 = getIP(ntw, "tcp6") + } + + for _, ip := range res.Resolved { + if res.OurIPv4 != nil && ip.String() == res.OurIPv4.String() { + res.IPv4Match = true + } + + if res.OurIPv6 != nil && ip.String() == res.OurIPv6.String() { + res.IPv6Match = true + } + } + + return res +}