diff --git a/src/dns_cache.rs b/src/dns_cache.rs index c06ac34..921df96 100644 --- a/src/dns_cache.rs +++ b/src/dns_cache.rs @@ -11,9 +11,31 @@ use crate::{ }; use std::{ collections::{HashMap, HashSet}, + ops::BitOr, time::SystemTime, }; +/// Bitflags-style type for filtering by IP version. +#[derive(Clone, Copy)] +pub(crate) struct IpType(u8); + +impl IpType { + pub const V4: Self = Self(0b01); + pub const V6: Self = Self(0b10); + pub const BOTH: Self = Self(0b11); + + fn contains(self, other: Self) -> bool { + self.0 & other.0 == other.0 + } +} + +impl BitOr for IpType { + type Output = Self; + fn bitor(self, rhs: Self) -> Self { + Self(self.0 | rhs.0) + } +} + /// Associate a DnsRecord with the interface it was received on. pub(crate) struct DnsRecordIntf { pub(crate) record: DnsRecordBox, @@ -673,23 +695,34 @@ impl DnsCache { .collect() } - pub(crate) fn remove_addrs_on_disabled_intf(&mut self, disabled_if_index: u32) { + /// Removes cached address records on a disabled interface, filtered by IP version. + /// Use `IpType::V4` for A records only, `IpType::V6` for AAAA only, + /// or `IpType::V4 | IpType::V6` for both. + pub(crate) fn remove_addrs_on_disabled_intf( + &mut self, + disabled_if_index: u32, + ip_type: IpType, + ) { for (host, records) in self.addr.iter_mut() { records.retain(|record| { let Some(dns_addr) = record.record.any().downcast_ref::() else { return false; // invalid address record. }; - // Remove the record if it is on this interface. + // Remove the record if it is on this interface and matches the IP version filter. if dns_addr.interface_id.index == disabled_if_index { - debug!( - "removing ADDR on disabled intf: {:?} host {host}", - dns_addr.interface_id.name - ); - false - } else { - true + let rr_type = dns_addr.record.entry.ty; + let version_matches = (rr_type == RRType::A && ip_type.contains(IpType::V4)) + || (rr_type == RRType::AAAA && ip_type.contains(IpType::V6)); + if version_matches { + debug!( + "removing ADDR on disabled intf: {:?} host {host}", + dns_addr.interface_id.name + ); + return false; + } } + true }); } } diff --git a/src/service_daemon.rs b/src/service_daemon.rs index 4ab6992..39e1927 100644 --- a/src/service_daemon.rs +++ b/src/service_daemon.rs @@ -31,7 +31,7 @@ #[cfg(feature = "logging")] use crate::log::{debug, error, trace}; use crate::{ - dns_cache::{current_time_millis, DnsCache}, + dns_cache::{current_time_millis, DnsCache, IpType}, dns_parser::{ ip_address_rr_type, DnsAddress, DnsEntryExt, DnsIncoming, DnsOutgoing, DnsPointer, DnsRecordBox, DnsRecordExt, DnsSrv, DnsTxt, InterfaceId, RRType, ScopedIp, @@ -793,6 +793,8 @@ pub enum IfKind { Name(String), /// By an IPv4 or IPv6 address. + /// This is used to look up the interface. The semantics is to identify an interface of + /// IPv4 or IPv6, not a specific address on the interface. Addr(IpAddr), /// 127.0.0.1 (or anything in 127.0.0.0/8), enabled by default. @@ -802,6 +804,12 @@ pub enum IfKind { /// ::1/128, enabled by default. LoopbackV6, + + /// By interface index, IPv4 only. + IndexV4(u32), + + /// By interface index, IPv6 only. + IndexV6(u32), } impl IfKind { @@ -815,6 +823,8 @@ impl IfKind { Self::Addr(addr) => addr == &intf.ip(), Self::LoopbackV4 => intf.is_loopback() && intf.ip().is_ipv4(), Self::LoopbackV6 => intf.is_loopback() && intf.ip().is_ipv6(), + Self::IndexV4(idx) => intf.index == Some(*idx) && intf.ip().is_ipv4(), + Self::IndexV6(idx) => intf.index == Some(*idx) && intf.ip().is_ipv6(), } } } @@ -1431,26 +1441,30 @@ impl Zeroconf { fn enable_interface(&mut self, kinds: Vec) { debug!("enable_interface: {:?}", kinds); + let interfaces = my_ip_interfaces_inner(true, self.include_apple_p2p); + for if_kind in kinds { self.if_selections.push(IfSelection { - if_kind, + if_kind: resolve_addr_to_index(if_kind, &interfaces), selected: true, }); } - self.apply_intf_selections(my_ip_interfaces_inner(true, self.include_apple_p2p)); + self.apply_intf_selections(interfaces); } fn disable_interface(&mut self, kinds: Vec) { debug!("disable_interface: {:?}", kinds); + let interfaces = my_ip_interfaces_inner(true, self.include_apple_p2p); + for if_kind in kinds { self.if_selections.push(IfSelection { - if_kind, + if_kind: resolve_addr_to_index(if_kind, &interfaces), selected: false, }); } - self.apply_intf_selections(my_ip_interfaces_inner(true, self.include_apple_p2p)); + self.apply_intf_selections(interfaces); } fn set_multicast_loop_v4(&mut self, on: bool) { @@ -1704,7 +1718,7 @@ impl Zeroconf { /// If no more addresses on the interface, remove the interface as well. fn del_interface_addr(&mut self, intf: &Interface) { let if_index = intf.index.unwrap_or(0); - trace!( + debug!( "del_interface_addr: {} ({if_index}) addr {}", intf.name, intf.ip() @@ -1726,6 +1740,8 @@ impl Zeroconf { if let Some(sock) = self.ipv4_sock.as_mut() { if let Err(e) = sock.pktinfo.leave_multicast_v4(&GROUP_ADDR_V4, &ipv4) { debug!("leave multicast group for addr {ipv4}: {e}"); + } else { + debug!("leave multicast for {ipv4}"); } } } @@ -1749,7 +1765,22 @@ impl Zeroconf { debug!("del_interface_addr: removing interface {}", intf.name); self.my_intfs.remove(&if_index); self.dns_registry_map.remove(&if_index); - self.cache.remove_addrs_on_disabled_intf(if_index); + self.cache + .remove_addrs_on_disabled_intf(if_index, IpType::BOTH); + } else { + // Interface still has addresses of the other IP version. + // Remove cached address records for the disabled IP version + // only if no more addresses of that version remain. + let is_v4 = intf.addr.ip().is_ipv4(); + let version_gone = if is_v4 { + my_intf.next_ifaddr_v4().is_none() + } else { + my_intf.next_ifaddr_v6().is_none() + }; + if version_gone { + let ip_type = if is_v4 { IpType::V4 } else { IpType::V6 }; + self.cache.remove_addrs_on_disabled_intf(if_index, ip_type); + } } } @@ -2322,6 +2353,22 @@ impl Zeroconf { return true; // We still return true to indicate that we read something. }; + // Drop packets for an IP version that has been disabled on this interface. + // This is needed because some times the socket layer may still receive packets + // for an IP version even after we left the multicast group for that IP version. + // We want to drop such packets to avoid unnecessary processing. + let is_ipv4 = event_key == IPV4_SOCK_EVENT_KEY; + if (is_ipv4 && my_intf.next_ifaddr_v4().is_none()) + || (!is_ipv4 && my_intf.next_ifaddr_v6().is_none()) + { + debug!( + "handle_read: dropping {} packet on intf {} (disabled)", + if is_ipv4 { "IPv4" } else { "IPv6" }, + my_intf.name + ); + return true; + } + buf.truncate(sz); // reduce potential processing errors match DnsIncoming::new(buf, my_intf.into()) { @@ -4664,6 +4711,21 @@ fn handle_expired_probes( waiting_services } +/// Resolves `IfKind::Addr(ip)` to `IndexV4(if_index)` or `IndexV6(if_index)`. +fn resolve_addr_to_index(if_kind: IfKind, interfaces: &[Interface]) -> IfKind { + if let IfKind::Addr(addr) = &if_kind { + if let Some(intf) = interfaces.iter().find(|intf| &intf.ip() == addr) { + let if_index = intf.index.unwrap_or(0); + return if addr.is_ipv4() { + IfKind::IndexV4(if_index) + } else { + IfKind::IndexV6(if_index) + }; + } + } + if_kind +} + #[cfg(test)] mod tests { use super::{ diff --git a/src/service_info.rs b/src/service_info.rs index 22261b0..72137fe 100644 --- a/src/service_info.rs +++ b/src/service_info.rs @@ -497,6 +497,8 @@ impl ServiceInfo { IfKind::Addr(a) => *a == addr, IfKind::LoopbackV4 => matches!(addr, IpAddr::V4(ipv4) if ipv4.is_loopback()), IfKind::LoopbackV6 => matches!(addr, IpAddr::V6(ipv6) if ipv6.is_loopback()), + IfKind::IndexV4(idx) => intf.index == Some(*idx) && addr.is_ipv4(), + IfKind::IndexV6(idx) => intf.index == Some(*idx) && addr.is_ipv6(), IfKind::All => true, }); diff --git a/tests/mdns_test.rs b/tests/mdns_test.rs index d618f5f..5440f8c 100644 --- a/tests/mdns_test.rs +++ b/tests/mdns_test.rs @@ -755,11 +755,11 @@ fn test_disable_interface_cache() { .duration_since(SystemTime::UNIX_EPOCH) .unwrap(); let instance_name = now.as_micros().to_string(); - let service_ip_addr = my_ip_interfaces() + let ipv4_list: Vec<_> = my_ip_interfaces() .iter() - .find(|iface| iface.ip().is_ipv4()) .map(|iface| iface.ip()) - .unwrap(); + .filter(|ip| ip.is_ipv4() && !ip.is_loopback()) + .collect(); let host_name = "disabled_intf_host.local."; let port = 5201; @@ -767,7 +767,7 @@ fn test_disable_interface_cache() { ty_domain, &instance_name, host_name, - service_ip_addr, + &ipv4_list[..], port, None, ) @@ -783,8 +783,8 @@ fn test_disable_interface_cache() { sleep(Duration::from_secs(1)); // Disable the interface for the client. - println!("Disabling interface with IP: {service_ip_addr}"); - client.disable_interface(service_ip_addr).unwrap(); + println!("Disabling interface with IP: {:?}", ipv4_list); + client.disable_interface(ipv4_list).unwrap(); // Browse for the service. let handle = client.browse(ty_domain).unwrap();