Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 42 additions & 9 deletions src/dns_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,31 @@ use crate::{
};
use std::{
collections::{HashMap, HashSet},
ops::BitOr,
time::SystemTime,
};

/// Bitflags-style type for filtering by IP version.
#[derive(Clone, Copy)]
pub(crate) struct IpType(u8);

impl IpType {
pub const V4: Self = Self(0b01);
pub const V6: Self = Self(0b10);
pub const BOTH: Self = Self(0b11);

fn contains(self, other: Self) -> bool {
self.0 & other.0 == other.0
}
}

impl BitOr for IpType {
type Output = Self;
fn bitor(self, rhs: Self) -> Self {
Self(self.0 | rhs.0)
}
}

/// Associate a DnsRecord with the interface it was received on.
pub(crate) struct DnsRecordIntf {
pub(crate) record: DnsRecordBox,
Expand Down Expand Up @@ -673,23 +695,34 @@ impl DnsCache {
.collect()
}

pub(crate) fn remove_addrs_on_disabled_intf(&mut self, disabled_if_index: u32) {
/// Removes cached address records on a disabled interface, filtered by IP version.
/// Use `IpType::V4` for A records only, `IpType::V6` for AAAA only,
/// or `IpType::V4 | IpType::V6` for both.
pub(crate) fn remove_addrs_on_disabled_intf(
&mut self,
disabled_if_index: u32,
ip_type: IpType,
) {
for (host, records) in self.addr.iter_mut() {
records.retain(|record| {
let Some(dns_addr) = record.record.any().downcast_ref::<DnsAddress>() else {
return false; // invalid address record.
};

// Remove the record if it is on this interface.
// Remove the record if it is on this interface and matches the IP version filter.
if dns_addr.interface_id.index == disabled_if_index {
debug!(
"removing ADDR on disabled intf: {:?} host {host}",
dns_addr.interface_id.name
);
false
} else {
true
let rr_type = dns_addr.record.entry.ty;
let version_matches = (rr_type == RRType::A && ip_type.contains(IpType::V4))
|| (rr_type == RRType::AAAA && ip_type.contains(IpType::V6));
if version_matches {
debug!(
"removing ADDR on disabled intf: {:?} host {host}",
dns_addr.interface_id.name
);
return false;
}
}
true
});
}
}
Expand Down
76 changes: 69 additions & 7 deletions src/service_daemon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
#[cfg(feature = "logging")]
use crate::log::{debug, error, trace};
use crate::{
dns_cache::{current_time_millis, DnsCache},
dns_cache::{current_time_millis, DnsCache, IpType},
dns_parser::{
ip_address_rr_type, DnsAddress, DnsEntryExt, DnsIncoming, DnsOutgoing, DnsPointer,
DnsRecordBox, DnsRecordExt, DnsSrv, DnsTxt, InterfaceId, RRType, ScopedIp,
Expand Down Expand Up @@ -793,6 +793,8 @@ pub enum IfKind {
Name(String),

/// By an IPv4 or IPv6 address.
/// This is used to look up the interface. The semantics is to identify an interface of
/// IPv4 or IPv6, not a specific address on the interface.
Addr(IpAddr),

/// 127.0.0.1 (or anything in 127.0.0.0/8), enabled by default.
Expand All @@ -802,6 +804,12 @@ pub enum IfKind {

/// ::1/128, enabled by default.
LoopbackV6,

/// By interface index, IPv4 only.
IndexV4(u32),

/// By interface index, IPv6 only.
IndexV6(u32),
}

impl IfKind {
Expand All @@ -815,6 +823,8 @@ impl IfKind {
Self::Addr(addr) => addr == &intf.ip(),
Self::LoopbackV4 => intf.is_loopback() && intf.ip().is_ipv4(),
Self::LoopbackV6 => intf.is_loopback() && intf.ip().is_ipv6(),
Self::IndexV4(idx) => intf.index == Some(*idx) && intf.ip().is_ipv4(),
Self::IndexV6(idx) => intf.index == Some(*idx) && intf.ip().is_ipv6(),
}
}
}
Expand Down Expand Up @@ -1431,26 +1441,30 @@ impl Zeroconf {

fn enable_interface(&mut self, kinds: Vec<IfKind>) {
debug!("enable_interface: {:?}", kinds);
let interfaces = my_ip_interfaces_inner(true, self.include_apple_p2p);

for if_kind in kinds {
self.if_selections.push(IfSelection {
if_kind,
if_kind: resolve_addr_to_index(if_kind, &interfaces),
selected: true,
});
}

self.apply_intf_selections(my_ip_interfaces_inner(true, self.include_apple_p2p));
self.apply_intf_selections(interfaces);
}

fn disable_interface(&mut self, kinds: Vec<IfKind>) {
debug!("disable_interface: {:?}", kinds);
let interfaces = my_ip_interfaces_inner(true, self.include_apple_p2p);

for if_kind in kinds {
self.if_selections.push(IfSelection {
if_kind,
if_kind: resolve_addr_to_index(if_kind, &interfaces),
selected: false,
});
}

self.apply_intf_selections(my_ip_interfaces_inner(true, self.include_apple_p2p));
self.apply_intf_selections(interfaces);
}

fn set_multicast_loop_v4(&mut self, on: bool) {
Expand Down Expand Up @@ -1704,7 +1718,7 @@ impl Zeroconf {
/// If no more addresses on the interface, remove the interface as well.
fn del_interface_addr(&mut self, intf: &Interface) {
let if_index = intf.index.unwrap_or(0);
trace!(
debug!(
"del_interface_addr: {} ({if_index}) addr {}",
intf.name,
intf.ip()
Expand All @@ -1726,6 +1740,8 @@ impl Zeroconf {
if let Some(sock) = self.ipv4_sock.as_mut() {
if let Err(e) = sock.pktinfo.leave_multicast_v4(&GROUP_ADDR_V4, &ipv4) {
debug!("leave multicast group for addr {ipv4}: {e}");
} else {
debug!("leave multicast for {ipv4}");
}
}
}
Expand All @@ -1749,7 +1765,22 @@ impl Zeroconf {
debug!("del_interface_addr: removing interface {}", intf.name);
self.my_intfs.remove(&if_index);
self.dns_registry_map.remove(&if_index);
self.cache.remove_addrs_on_disabled_intf(if_index);
self.cache
.remove_addrs_on_disabled_intf(if_index, IpType::BOTH);
} else {
// Interface still has addresses of the other IP version.
// Remove cached address records for the disabled IP version
// only if no more addresses of that version remain.
let is_v4 = intf.addr.ip().is_ipv4();
let version_gone = if is_v4 {
my_intf.next_ifaddr_v4().is_none()
} else {
my_intf.next_ifaddr_v6().is_none()
};
if version_gone {
let ip_type = if is_v4 { IpType::V4 } else { IpType::V6 };
self.cache.remove_addrs_on_disabled_intf(if_index, ip_type);
}
}
}

Expand Down Expand Up @@ -2322,6 +2353,22 @@ impl Zeroconf {
return true; // We still return true to indicate that we read something.
};

// Drop packets for an IP version that has been disabled on this interface.
// This is needed because some times the socket layer may still receive packets
// for an IP version even after we left the multicast group for that IP version.
// We want to drop such packets to avoid unnecessary processing.
let is_ipv4 = event_key == IPV4_SOCK_EVENT_KEY;
if (is_ipv4 && my_intf.next_ifaddr_v4().is_none())
|| (!is_ipv4 && my_intf.next_ifaddr_v6().is_none())
{
debug!(
"handle_read: dropping {} packet on intf {} (disabled)",
if is_ipv4 { "IPv4" } else { "IPv6" },
my_intf.name
);
return true;
}

buf.truncate(sz); // reduce potential processing errors

match DnsIncoming::new(buf, my_intf.into()) {
Expand Down Expand Up @@ -4664,6 +4711,21 @@ fn handle_expired_probes(
waiting_services
}

/// Resolves `IfKind::Addr(ip)` to `IndexV4(if_index)` or `IndexV6(if_index)`.
fn resolve_addr_to_index(if_kind: IfKind, interfaces: &[Interface]) -> IfKind {
if let IfKind::Addr(addr) = &if_kind {
if let Some(intf) = interfaces.iter().find(|intf| &intf.ip() == addr) {
let if_index = intf.index.unwrap_or(0);
return if addr.is_ipv4() {
IfKind::IndexV4(if_index)
} else {
IfKind::IndexV6(if_index)
};
}
}
if_kind
}

#[cfg(test)]
mod tests {
use super::{
Expand Down
2 changes: 2 additions & 0 deletions src/service_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,8 @@ impl ServiceInfo {
IfKind::Addr(a) => *a == addr,
IfKind::LoopbackV4 => matches!(addr, IpAddr::V4(ipv4) if ipv4.is_loopback()),
IfKind::LoopbackV6 => matches!(addr, IpAddr::V6(ipv6) if ipv6.is_loopback()),
IfKind::IndexV4(idx) => intf.index == Some(*idx) && addr.is_ipv4(),
IfKind::IndexV6(idx) => intf.index == Some(*idx) && addr.is_ipv6(),
IfKind::All => true,
});

Expand Down
12 changes: 6 additions & 6 deletions tests/mdns_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -755,19 +755,19 @@ fn test_disable_interface_cache() {
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap();
let instance_name = now.as_micros().to_string();
let service_ip_addr = my_ip_interfaces()
let ipv4_list: Vec<_> = my_ip_interfaces()
.iter()
.find(|iface| iface.ip().is_ipv4())
.map(|iface| iface.ip())
.unwrap();
.filter(|ip| ip.is_ipv4() && !ip.is_loopback())
.collect();

let host_name = "disabled_intf_host.local.";
let port = 5201;
let my_service = ServiceInfo::new(
ty_domain,
&instance_name,
host_name,
service_ip_addr,
&ipv4_list[..],
port,
None,
)
Expand All @@ -783,8 +783,8 @@ fn test_disable_interface_cache() {
sleep(Duration::from_secs(1));

// Disable the interface for the client.
println!("Disabling interface with IP: {service_ip_addr}");
client.disable_interface(service_ip_addr).unwrap();
println!("Disabling interface with IP: {:?}", ipv4_list);
client.disable_interface(ipv4_list).unwrap();

// Browse for the service.
let handle = client.browse(ty_domain).unwrap();
Expand Down