diff --git a/test_wakeonlan.py b/test_wakeonlan.py index 5ade761..d158b5b 100644 --- a/test_wakeonlan.py +++ b/test_wakeonlan.py @@ -7,7 +7,7 @@ import unittest from unittest import mock -from wakeonlan import create_magic_packet, main, send_magic_packet +from wakeonlan import create_magic_packet, create_socket, main, send_magic_packet class TestCreateMagicPacket(unittest.TestCase): @@ -253,6 +253,84 @@ def test_invalid_secureon(self) -> None: create_magic_packet('01:23:45:67:89:ab/invalid') +class TestCreateSocket(unittest.TestCase): + """ + Test :func:`create_socket`. + + """ + + def test_ipv4_broadcast(self) -> None: + """ + Test if IPv4 works. + + """ + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as server: + server.bind(('', 1234)) + with create_socket(port=1234) as client: + client.send(b'Hello server!') + data, addr = server.recvfrom(1024) + self.assertEqual(data, b'Hello server!') + self.assertEqual(addr[0], socket.gethostbyname(socket.gethostname())) + + def test_ipv6_broadcast(self) -> None: + """ + Test if IPv6 works. + + """ + with socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) as server: + server.bind(('', 1234)) + with create_socket(port=1234) as client: + client.send(b'Hello server!') + data, addr = server.recvfrom(1024) + self.assertEqual(data, b'Hello server!') + self.assertEqual( + addr[0], f'::ffff:{socket.gethostbyname(socket.gethostname())}' + ) + + def test_interface(self) -> None: + """ + Test if IPv4 works. + + """ + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as server: + server.bind(('', 1234)) + with create_socket(interface='127.0.0.1', port=1234) as client: + client.send(b'Hello server!') + data, addr = server.recvfrom(1024) + self.assertEqual(data, b'Hello server!') + self.assertEqual(addr[0], '127.0.0.1') + + def test_explicit_ipv4(self) -> None: + """ + Test if explicit IPv4 works. + + """ + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as server: + server.bind(('', 1234)) + with create_socket( + ip_address='localhost', port=1234, address_family=socket.AF_INET + ) as client: + client.send(b'Hello server!') + data, addr = server.recvfrom(1024) + self.assertEqual(data, b'Hello server!') + self.assertEqual(addr[0], '127.0.0.1') + + def test_explicit_ipv6(self) -> None: + """ + Test if explicit IPv4 works. + + """ + with socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) as server: + server.bind(('', 1234)) + with create_socket( + ip_address='localhost', port=1234, address_family=socket.AF_INET6 + ) as client: + client.send(b'Hello server!') + data, addr = server.recvfrom(1024) + self.assertEqual(data, b'Hello server!') + self.assertEqual(addr[0], '::1') + + class TestSendMagicPacket(unittest.TestCase): """ Test :ref:`send_magic_packet`. @@ -550,14 +628,14 @@ def test_main(self, send_magic_packet: mock.Mock) -> None: ip_address='host.example', port=1337, interface=None, - address_family=None, + address_family=socket.AF_UNSPEC, ), mock.call( '00:11:22:33:44:55', ip_address='host.example', port=1337, interface='192.168.0.2', - address_family=None, + address_family=socket.AF_UNSPEC, ), mock.call( '00:11:22:33:44:55', diff --git a/wakeonlan/__init__.py b/wakeonlan/__init__.py index 5918455..5155e60 100755 --- a/wakeonlan/__init__.py +++ b/wakeonlan/__init__.py @@ -5,7 +5,6 @@ """ import argparse -import ipaddress import socket @@ -52,12 +51,61 @@ def create_magic_packet(macaddress: str) -> bytes: return bytes.fromhex('F' * 12 + macaddress * 16 + secureon) +def create_socket( + *, + ip_address: str = BROADCAST_IP, + port: int = DEFAULT_PORT, + interface: str | None = None, + address_family: socket.AddressFamily = socket.AF_UNSPEC, +) -> socket.socket: + """ + Create a socket that’s suitable for sending magic packets. + + Args: + ip_address: The hostname to connect to. + port: The port to connect to. + interface: The IP address of the network adapter to use. + address_family: The address family to send the magic packet to. + Use this to force the use of IPv4 or IPv6. The default is + to auto detect. + + Returns: + A socket you can use for sending magic packets. + + """ + # This is based on the example for a connection that supports both IPv4 + # and IPv6 in https://docs.python.org/3/library/socket.html#example + # This also matches the getaddrinfo man page, which states applications + # should try using the addresses in order. + # https://man7.org/linux/man-pages/man3/getaddrinfo.3.html + address_infos = socket.getaddrinfo( + ip_address, port, address_family, socket.SOCK_DGRAM + ) + sock: socket.socket | None = None + for index, (family, type, proto, canonname, addr) in enumerate(address_infos, 1): + try: # pragma: nocover + sock = socket.socket(family, type, proto) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) + if interface: + sock.bind((interface, 0)) + sock.connect(addr) + break + except OSError: # pragma: nocover + if sock: + sock.close() + sock = None + if index == len(address_infos): + raise + assert sock, 'sock should be defined at this point' + return sock + + def send_magic_packet( *macs: str, ip_address: str = BROADCAST_IP, port: int = DEFAULT_PORT, interface: str | None = None, - address_family: socket.AddressFamily | None = None, + address_family: socket.AddressFamily = socket.AF_UNSPEC, ) -> None: """ Wake up computers having any of the given mac addresses. @@ -81,27 +129,16 @@ def send_magic_packet( """ packets = [create_magic_packet(mac) for mac in macs] - if address_family is None: - address_family = ( - socket.AF_INET6 if _is_ipv6_address(ip_address) else socket.AF_INET - ) - - with socket.socket(address_family, socket.SOCK_DGRAM) as sock: - if interface is not None: - sock.bind((interface, 0)) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) - sock.connect((ip_address, port)) + with create_socket( + ip_address=ip_address, + port=port, + interface=interface, + address_family=address_family, + ) as sock: for packet in packets: sock.send(packet) -def _is_ipv6_address(ip_address: str) -> bool: - try: - return isinstance(ipaddress.ip_address(ip_address), ipaddress.IPv6Address) - except ValueError: - return False - - def main(argv: list[str] | None = None) -> None: """ Run wake on lan as a CLI application. @@ -147,7 +184,7 @@ def main(argv: list[str] | None = None) -> None: ip_address=args.ip, port=args.port, interface=args.interface, - address_family=socket.AF_INET6 if args.ipv6 else None, + address_family=socket.AF_INET6 if args.ipv6 else socket.AF_UNSPEC, )