Skip to content

Commit a7894dd

Browse files
committed
Add context to some DNS utils; export a couple functions
Newly exported functions are marked as experimental since I may refactor or unexport their API again.
1 parent b24a7ba commit a7894dd

File tree

3 files changed

+56
-39
lines changed

3 files changed

+56
-39
lines changed

dnsutil.go

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package certmagic
22

33
import (
4+
"context"
45
"errors"
56
"fmt"
67
"net"
@@ -18,21 +19,24 @@ import (
1819
//
1920
// It has been modified.
2021

21-
// findZoneByFQDN determines the zone apex for the given fqdn by recursing
22-
// up the domain labels until the nameserver returns a SOA record in the
23-
// answer section. The logger must be non-nil.
24-
func findZoneByFQDN(logger *zap.Logger, fqdn string, nameservers []string) (string, error) {
22+
// FindZoneByFQDN determines the zone apex for the given fully-qualified
23+
// domain name (FQDN) by recursing up the domain labels until the nameserver
24+
// returns a SOA record in the answer section. The logger must be non-nil.
25+
//
26+
// EXPERIMENTAL: This API was previously unexported, and may be changed or
27+
// unexported again in the future. Do not rely on it at this time.
28+
func FindZoneByFQDN(ctx context.Context, logger *zap.Logger, fqdn string, nameservers []string) (string, error) {
2529
if !strings.HasSuffix(fqdn, ".") {
2630
fqdn += "."
2731
}
28-
soa, err := lookupSoaByFqdn(logger, fqdn, nameservers)
32+
soa, err := lookupSoaByFqdn(ctx, logger, fqdn, nameservers)
2933
if err != nil {
3034
return "", err
3135
}
3236
return soa.zone, nil
3337
}
3438

35-
func lookupSoaByFqdn(logger *zap.Logger, fqdn string, nameservers []string) (*soaCacheEntry, error) {
39+
func lookupSoaByFqdn(ctx context.Context, logger *zap.Logger, fqdn string, nameservers []string) (*soaCacheEntry, error) {
3640
logger = logger.Named("soa_lookup")
3741

3842
if !strings.HasSuffix(fqdn, ".") {
@@ -42,13 +46,17 @@ func lookupSoaByFqdn(logger *zap.Logger, fqdn string, nameservers []string) (*so
4246
fqdnSOACacheMu.Lock()
4347
defer fqdnSOACacheMu.Unlock()
4448

49+
if err := ctx.Err(); err != nil {
50+
return nil, err
51+
}
52+
4553
// prefer cached version if fresh
4654
if ent := fqdnSOACache[fqdn]; ent != nil && !ent.isExpired() {
4755
logger.Debug("using cached SOA result", zap.String("entry", ent.zone))
4856
return ent, nil
4957
}
5058

51-
ent, err := fetchSoaByFqdn(logger, fqdn, nameservers)
59+
ent, err := fetchSoaByFqdn(ctx, logger, fqdn, nameservers)
5260
if err != nil {
5361
return nil, err
5462
}
@@ -66,15 +74,19 @@ func lookupSoaByFqdn(logger *zap.Logger, fqdn string, nameservers []string) (*so
6674
return ent, nil
6775
}
6876

69-
func fetchSoaByFqdn(logger *zap.Logger, fqdn string, nameservers []string) (*soaCacheEntry, error) {
77+
func fetchSoaByFqdn(ctx context.Context, logger *zap.Logger, fqdn string, nameservers []string) (*soaCacheEntry, error) {
7078
var err error
7179
var in *dns.Msg
7280

7381
labelIndexes := dns.Split(fqdn)
7482
for _, index := range labelIndexes {
83+
if err := ctx.Err(); err != nil {
84+
return nil, err
85+
}
86+
7587
domain := fqdn[index:]
7688

77-
in, err = dnsQuery(domain, dns.TypeSOA, nameservers, true)
89+
in, err = dnsQuery(ctx, domain, dns.TypeSOA, nameservers, true)
7890
if err != nil {
7991
continue
8092
}
@@ -122,12 +134,12 @@ func dnsMsgContainsCNAME(msg *dns.Msg) bool {
122134
return false
123135
}
124136

125-
func dnsQuery(fqdn string, rtype uint16, nameservers []string, recursive bool) (*dns.Msg, error) {
137+
func dnsQuery(ctx context.Context, fqdn string, rtype uint16, nameservers []string, recursive bool) (*dns.Msg, error) {
126138
m := createDNSMsg(fqdn, rtype, recursive)
127139
var in *dns.Msg
128140
var err error
129141
for _, ns := range nameservers {
130-
in, err = sendDNSQuery(m, ns)
142+
in, err = sendDNSQuery(ctx, m, ns)
131143
if err == nil && len(in.Answer) > 0 {
132144
break
133145
}
@@ -147,16 +159,16 @@ func createDNSMsg(fqdn string, rtype uint16, recursive bool) *dns.Msg {
147159
return m
148160
}
149161

150-
func sendDNSQuery(m *dns.Msg, ns string) (*dns.Msg, error) {
162+
func sendDNSQuery(ctx context.Context, m *dns.Msg, ns string) (*dns.Msg, error) {
151163
udp := &dns.Client{Net: "udp", Timeout: dnsTimeout}
152-
in, _, err := udp.Exchange(m, ns)
164+
in, _, err := udp.ExchangeContext(ctx, m, ns)
153165
// two kinds of errors we can handle by retrying with TCP:
154166
// truncation and timeout; see https://github.com/caddyserver/caddy/issues/3639
155167
truncated := in != nil && in.Truncated
156168
timeoutErr := err != nil && strings.Contains(err.Error(), "timeout")
157169
if truncated || timeoutErr {
158170
tcp := &dns.Client{Net: "tcp", Timeout: dnsTimeout}
159-
in, _, err = tcp.Exchange(m, ns)
171+
in, _, err = tcp.ExchangeContext(ctx, m, ns)
160172
}
161173
return in, err
162174
}
@@ -205,7 +217,8 @@ func systemOrDefaultNameservers(path string, defaults []string) []string {
205217
return config.Servers
206218
}
207219

208-
// populateNameserverPorts ensures that all nameservers have a port number.
220+
// populateNameserverPorts ensures that all nameservers have a port number
221+
// If not, the the default DNS server port of 53 will be appended.
209222
func populateNameserverPorts(servers []string) {
210223
for i := range servers {
211224
_, port, _ := net.SplitHostPort(servers[i])
@@ -216,7 +229,7 @@ func populateNameserverPorts(servers []string) {
216229
}
217230

218231
// checkDNSPropagation checks if the expected record has been propagated to all authoritative nameservers.
219-
func checkDNSPropagation(logger *zap.Logger, fqdn string, recType uint16, expectedValue string, checkAuthoritativeServers bool, resolvers []string) (bool, error) {
232+
func checkDNSPropagation(ctx context.Context, logger *zap.Logger, fqdn string, recType uint16, expectedValue string, checkAuthoritativeServers bool, resolvers []string) (bool, error) {
220233
logger = logger.Named("propagation")
221234

222235
if !strings.HasSuffix(fqdn, ".") {
@@ -227,7 +240,7 @@ func checkDNSPropagation(logger *zap.Logger, fqdn string, recType uint16, expect
227240
// dereference (follow) a CNAME record if we are targeting a CNAME record
228241
// itself
229242
if recType != dns.TypeCNAME {
230-
r, err := dnsQuery(fqdn, recType, resolvers, true)
243+
r, err := dnsQuery(ctx, fqdn, recType, resolvers, true)
231244
if err != nil {
232245
return false, fmt.Errorf("CNAME dns query: %v", err)
233246
}
@@ -237,7 +250,7 @@ func checkDNSPropagation(logger *zap.Logger, fqdn string, recType uint16, expect
237250
}
238251

239252
if checkAuthoritativeServers {
240-
authoritativeServers, err := lookupNameservers(logger, fqdn, resolvers)
253+
authoritativeServers, err := lookupNameservers(ctx, logger, fqdn, resolvers)
241254
if err != nil {
242255
return false, fmt.Errorf("looking up authoritative nameservers: %v", err)
243256
}
@@ -246,13 +259,13 @@ func checkDNSPropagation(logger *zap.Logger, fqdn string, recType uint16, expect
246259
}
247260
logger.Debug("checking authoritative nameservers", zap.Strings("resolvers", resolvers))
248261

249-
return checkAuthoritativeNss(fqdn, recType, expectedValue, resolvers)
262+
return checkAuthoritativeNss(ctx, fqdn, recType, expectedValue, resolvers)
250263
}
251264

252265
// checkAuthoritativeNss queries each of the given nameservers for the expected record.
253-
func checkAuthoritativeNss(fqdn string, recType uint16, expectedValue string, nameservers []string) (bool, error) {
266+
func checkAuthoritativeNss(ctx context.Context, fqdn string, recType uint16, expectedValue string, nameservers []string) (bool, error) {
254267
for _, ns := range nameservers {
255-
r, err := dnsQuery(fqdn, recType, []string{ns}, true)
268+
r, err := dnsQuery(ctx, fqdn, recType, []string{ns}, true)
256269
if err != nil {
257270
return false, fmt.Errorf("querying authoritative nameservers: %v", err)
258271
}
@@ -293,15 +306,15 @@ func checkAuthoritativeNss(fqdn string, recType uint16, expectedValue string, na
293306
}
294307

295308
// lookupNameservers returns the authoritative nameservers for the given fqdn.
296-
func lookupNameservers(logger *zap.Logger, fqdn string, resolvers []string) ([]string, error) {
309+
func lookupNameservers(ctx context.Context, logger *zap.Logger, fqdn string, resolvers []string) ([]string, error) {
297310
var authoritativeNss []string
298311

299-
zone, err := findZoneByFQDN(logger, fqdn, resolvers)
312+
zone, err := FindZoneByFQDN(ctx, logger, fqdn, resolvers)
300313
if err != nil {
301314
return nil, fmt.Errorf("could not determine the zone for '%s': %w", fqdn, err)
302315
}
303316

304-
r, err := dnsQuery(zone, dns.TypeNS, resolvers, true)
317+
r, err := dnsQuery(ctx, zone, dns.TypeNS, resolvers, true)
305318
if err != nil {
306319
return nil, fmt.Errorf("querying NS resolver for zone '%s' recursively: %v", zone, err)
307320
}
@@ -330,11 +343,14 @@ func updateDomainWithCName(r *dns.Msg, fqdn string) string {
330343
return fqdn
331344
}
332345

333-
// recursiveNameservers are used to pre-check DNS propagation. It
346+
// RecursiveNameservers are used to pre-check DNS propagation. It
334347
// picks user-configured nameservers (custom) OR the defaults
335348
// obtained from resolv.conf and defaultNameservers if none is
336349
// configured and ensures that all server addresses have a port value.
337-
func recursiveNameservers(custom []string) []string {
350+
//
351+
// EXPERIMENTAL: This API was previously unexported, and may be
352+
// be unexported again in the future. Do not rely on it at this time.
353+
func RecursiveNameservers(custom []string) []string {
338354
var servers []string
339355
if len(custom) == 0 {
340356
servers = systemOrDefaultNameservers(defaultResolvConf, defaultNameservers)

dnsutil_test.go

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package certmagic
77
// It has been modified.
88

99
import (
10+
"context"
1011
"net"
1112
"reflect"
1213
"runtime"
@@ -34,7 +35,7 @@ func TestLookupNameserversOK(t *testing.T) {
3435
t.Run(test.fqdn, func(t *testing.T) {
3536
t.Parallel()
3637

37-
nss, err := lookupNameservers(zap.NewNop(), test.fqdn, recursiveNameservers(nil))
38+
nss, err := lookupNameservers(context.Background(), zap.NewNop(), test.fqdn, RecursiveNameservers(nil))
3839
if err != nil {
3940
t.Errorf("Expected no error, got: %v", err)
4041
}
@@ -68,7 +69,7 @@ func TestLookupNameserversErr(t *testing.T) {
6869
t.Run(test.desc, func(t *testing.T) {
6970
t.Parallel()
7071

71-
_, err := lookupNameservers(zap.NewNop(), test.fqdn, nil)
72+
_, err := lookupNameservers(context.Background(), zap.NewNop(), test.fqdn, nil)
7273
if err == nil {
7374
t.Errorf("expected error, got none")
7475
}
@@ -93,28 +94,28 @@ var findXByFqdnTestCases = []struct {
9394
fqdn: "scholar.google.com.",
9495
zone: "google.com.",
9596
primaryNs: "ns1.google.com.",
96-
nameservers: recursiveNameservers(nil),
97+
nameservers: RecursiveNameservers(nil),
9798
},
9899
{
99100
desc: "domain is a non-existent subdomain",
100101
fqdn: "foo.google.com.",
101102
zone: "google.com.",
102103
primaryNs: "ns1.google.com.",
103-
nameservers: recursiveNameservers(nil),
104+
nameservers: RecursiveNameservers(nil),
104105
},
105106
{
106107
desc: "domain is a eTLD",
107108
fqdn: "example.com.ac.",
108109
zone: "ac.",
109110
primaryNs: "a0.nic.ac.",
110-
nameservers: recursiveNameservers(nil),
111+
nameservers: RecursiveNameservers(nil),
111112
},
112113
{
113114
desc: "domain is a cross-zone CNAME",
114115
fqdn: "cross-zone-example.assets.sh.",
115116
zone: "assets.sh.",
116117
primaryNs: "gina.ns.cloudflare.com.",
117-
nameservers: recursiveNameservers(nil),
118+
nameservers: RecursiveNameservers(nil),
118119
},
119120
{
120121
desc: "NXDOMAIN",
@@ -160,7 +161,7 @@ func TestFindZoneByFqdn(t *testing.T) {
160161
}
161162
clearFqdnCache()
162163

163-
zone, err := findZoneByFQDN(zap.NewNop(), test.fqdn, test.nameservers)
164+
zone, err := FindZoneByFQDN(context.Background(), zap.NewNop(), test.fqdn, test.nameservers)
164165
if test.expectedError != "" {
165166
if err == nil {
166167
t.Errorf("test %d: expected error, got none", i)
@@ -219,7 +220,7 @@ func TestRecursiveNameserversAddsPort(t *testing.T) {
219220
}
220221
custom := []string{"127.0.0.1", "ns1.google.com:43"}
221222
expectations := []want{{port: "53"}, {port: "43"}}
222-
results := recursiveNameservers(custom)
223+
results := RecursiveNameservers(custom)
223224

224225
if !reflect.DeepEqual(custom, []string{"127.0.0.1", "ns1.google.com:43"}) {
225226
t.Errorf("Expected custom nameservers to be unmodified. got %v", custom)
@@ -247,12 +248,12 @@ func TestRecursiveNameserversAddsPort(t *testing.T) {
247248
}
248249

249250
func TestRecursiveNameserversDefaults(t *testing.T) {
250-
results := recursiveNameservers(nil)
251+
results := RecursiveNameservers(nil)
251252
if len(results) < 1 {
252253
t.Errorf("%v Expected at least 1 records as default when nil custom", results)
253254
}
254255

255-
results = recursiveNameservers([]string{})
256+
results = RecursiveNameservers([]string{})
256257
if len(results) < 1 {
257258
t.Errorf("%v Expected at least 1 records as default when empty custom", results)
258259
}

solvers.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ type DNSManager struct {
379379
func (m *DNSManager) createRecord(ctx context.Context, dnsName, recordType, recordValue string) (zoneRecord, error) {
380380
logger := m.logger()
381381

382-
zone, err := findZoneByFQDN(logger, dnsName, recursiveNameservers(m.Resolvers))
382+
zone, err := FindZoneByFQDN(ctx, logger, dnsName, RecursiveNameservers(m.Resolvers))
383383
if err != nil {
384384
return zoneRecord{}, fmt.Errorf("could not determine zone for domain %q: %v", dnsName, err)
385385
}
@@ -439,7 +439,7 @@ func (m *DNSManager) wait(ctx context.Context, zrec zoneRecord) error {
439439

440440
// how we'll do the checks
441441
checkAuthoritativeServers := len(m.Resolvers) == 0
442-
resolvers := recursiveNameservers(m.Resolvers)
442+
resolvers := RecursiveNameservers(m.Resolvers)
443443

444444
recType := dns.TypeTXT
445445
if zrec.record.Type == "CNAME" {
@@ -464,7 +464,7 @@ func (m *DNSManager) wait(ctx context.Context, zrec zoneRecord) error {
464464
zap.Strings("resolvers", resolvers))
465465

466466
var ready bool
467-
ready, err = checkDNSPropagation(logger, absName, recType, zrec.record.Value, checkAuthoritativeServers, resolvers)
467+
ready, err = checkDNSPropagation(ctx, logger, absName, recType, zrec.record.Value, checkAuthoritativeServers, resolvers)
468468
if err != nil {
469469
return fmt.Errorf("checking DNS propagation of %q (relative=%s zone=%s resolvers=%v): %w", absName, zrec.record.Name, zrec.zone, resolvers, err)
470470
}

0 commit comments

Comments
 (0)