Skip to content

Commit e88a7c0

Browse files
committed
fix IP address/network comparison methods
1 parent 92972ae commit e88a7c0

3 files changed

Lines changed: 205 additions & 133 deletions

File tree

Lib/ipaddress.py

Lines changed: 100 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,12 @@ def v6_int_to_packed(address):
150150
raise ValueError("Address negative or too large for IPv6")
151151

152152

153+
def _check_ip_version(a, b):
154+
if a.version != b.version:
155+
# does this need to raise a ValueError?
156+
raise TypeError(f"{a} and {b} are not of the same version")
157+
158+
153159
def _split_optional_netmask(address):
154160
"""Helper to split the netmask and raise AddressValueError if needed"""
155161
addr = str(address).split('/')
@@ -213,7 +219,7 @@ def summarize_address_range(first, last):
213219
214220
Raise:
215221
TypeError:
216-
If the first and last objects are not IP addresses.
222+
If the first or last objects are not IP addresses.
217223
If the first and last objects are not the same version.
218224
ValueError:
219225
If the last object is not greater than the first.
@@ -223,9 +229,7 @@ def summarize_address_range(first, last):
223229
if (not (isinstance(first, _BaseAddress) and
224230
isinstance(last, _BaseAddress))):
225231
raise TypeError('first and last must be IP addresses, not networks')
226-
if first.version != last.version:
227-
raise TypeError("%s and %s are not of the same version" % (
228-
first, last))
232+
_check_ip_version(first, last)
229233
if first > last:
230234
raise ValueError('last IP address must be greater than first')
231235

@@ -316,40 +320,39 @@ def collapse_addresses(addresses):
316320
TypeError: If passed a list of mixed version objects.
317321
318322
"""
319-
addrs = []
320323
ips = []
321324
nets = []
322325

323-
# split IP addresses and networks
326+
# split IP addresses/interfaces and networks
324327
for ip in addresses:
325328
if isinstance(ip, _BaseAddress):
326-
if ips and ips[-1].version != ip.version:
327-
raise TypeError("%s and %s are not of the same version" % (
328-
ip, ips[-1]))
329-
ips.append(ip)
330-
elif ip._prefixlen == ip.max_prefixlen:
331-
if ips and ips[-1].version != ip.version:
332-
raise TypeError("%s and %s are not of the same version" % (
333-
ip, ips[-1]))
334-
try:
335-
ips.append(ip.ip)
336-
except AttributeError:
337-
ips.append(ip.network_address)
329+
if ips:
330+
_check_ip_version(ips[-1], ip)
331+
if hasattr(ip, "ip") and isinstance(ip.ip, _BaseAddress):
332+
ips.append(ip.ip) # interface IP address
333+
else:
334+
ips.append(ip)
335+
elif isinstance(ip, _BaseNetwork):
336+
if ip.prefixlen == ip.max_prefixlen:
337+
if ips:
338+
_check_ip_version(ips[-1], ip)
339+
ips.append(ip.network_address) # network address
340+
else:
341+
if nets:
342+
_check_ip_version(nets[-1], ip)
343+
nets.append(ip)
338344
else:
339-
if nets and nets[-1].version != ip.version:
340-
raise TypeError("%s and %s are not of the same version" % (
341-
ip, nets[-1]))
342-
nets.append(ip)
345+
raise TypeError(f"{ip} is not an IP object")
343346

344347
# sort and dedup
345348
ips = sorted(set(ips))
346-
347349
# find consecutive address ranges in the sorted sequence and summarize them
350+
nets_from_range = []
348351
if ips:
349352
for first, last in _find_address_range(ips):
350-
addrs.extend(summarize_address_range(first, last))
353+
nets_from_range.extend(summarize_address_range(first, last))
351354

352-
return _collapse_addresses_internal(addrs + nets)
355+
return _collapse_addresses_internal(nets_from_range + nets)
353356

354357

355358
def get_mixed_type_key(obj):
@@ -550,6 +553,13 @@ def __reduce__(self):
550553
return self.__class__, (str(self),)
551554

552555

556+
def _base_address_eq(a, b):
557+
return a._ip == b._ip and a.version == b.version
558+
559+
def _base_address_lt(a, b):
560+
_check_ip_version(a, b)
561+
return a._ip < b._ip
562+
553563
_address_fmt_re = None
554564

555565
@functools.total_ordering
@@ -567,21 +577,15 @@ def __int__(self):
567577
return self._ip
568578

569579
def __eq__(self, other):
570-
try:
571-
return (self._ip == other._ip
572-
and self.version == other.version)
573-
except AttributeError:
580+
if not isinstance(other, _BaseAddress):
574581
return NotImplemented
582+
return self._ip == other._ip and self.version == other.version
575583

576584
def __lt__(self, other):
577585
if not isinstance(other, _BaseAddress):
578586
return NotImplemented
579-
if self.version != other.version:
580-
raise TypeError('%s and %s are not of the same version' % (
581-
self, other))
582-
if self._ip != other._ip:
583-
return self._ip < other._ip
584-
return False
587+
_check_ip_version(self, other)
588+
return self._ip < other._ip
585589

586590
# Shorthand for Integer addition and subtraction. This is not
587591
# meant to ever support addition/subtraction of addresses.
@@ -708,40 +712,39 @@ def __getitem__(self, n):
708712
def __lt__(self, other):
709713
if not isinstance(other, _BaseNetwork):
710714
return NotImplemented
711-
if self.version != other.version:
712-
raise TypeError('%s and %s are not of the same version' % (
713-
self, other))
715+
_check_ip_version(self, other)
714716
if self.network_address != other.network_address:
715717
return self.network_address < other.network_address
716718
if self.netmask != other.netmask:
717719
return self.netmask < other.netmask
718720
return False
719721

720722
def __eq__(self, other):
721-
try:
722-
return (self.version == other.version and
723-
self.network_address == other.network_address and
724-
int(self.netmask) == int(other.netmask))
725-
except AttributeError:
723+
if not isinstance(other, _BaseNetwork):
726724
return NotImplemented
725+
return (self.version == other.version
726+
and self.network_address == other.network_address
727+
and int(self.netmask._ip) == int(other.netmask))
727728

728729
def __hash__(self):
729730
return hash((int(self.network_address), int(self.netmask)))
730731

731732
def __contains__(self, other):
732-
# always false if one is v4 and the other is v6.
733-
if self.version != other.version:
734-
return False
735-
# dealing with another network.
736733
if isinstance(other, _BaseNetwork):
734+
# should __contains__ actually implement subnet_of()
735+
# and supernet_of() instead?
737736
return False
738-
# dealing with another address
739-
else:
740-
# address
741-
return other._ip & self.netmask._ip == self.network_address._ip
737+
if isinstance(other, _BaseAddress):
738+
return (
739+
self.version == other.version
740+
and (other._ip & self.netmask._ip) == self.network_address._ip
741+
)
742+
return NotImplemented
742743

743744
def overlaps(self, other):
744745
"""Tell if self is partly contained in other."""
746+
if not isinstance(other, _BaseNetwork):
747+
raise TypeError("%s is not a network object" % other)
745748
return self.network_address in other or (
746749
self.broadcast_address in other or (
747750
other.network_address in self or (
@@ -821,13 +824,9 @@ def address_exclude(self, other):
821824
ValueError: If other is not completely contained by self.
822825
823826
"""
824-
if not self.version == other.version:
825-
raise TypeError("%s and %s are not of the same version" % (
826-
self, other))
827-
828827
if not isinstance(other, _BaseNetwork):
829828
raise TypeError("%s is not a network object" % other)
830-
829+
_check_ip_version(self, other)
831830
if not other.subnet_of(self):
832831
raise ValueError('%s not contained in %s' % (other, self))
833832
if other == self:
@@ -870,7 +869,7 @@ def compare_networks(self, other):
870869
'HostA._ip < HostB._ip'
871870
872871
Args:
873-
other: An IP object.
872+
other: An IP network object.
874873
875874
Returns:
876875
If the IP versions of self and other are the same, returns:
@@ -892,10 +891,9 @@ def compare_networks(self, other):
892891
TypeError if the IP versions are different.
893892
894893
"""
895-
# does this need to raise a ValueError?
896-
if self.version != other.version:
897-
raise TypeError('%s and %s are not of the same type' % (
898-
self, other))
894+
if not isinstance(other, _BaseNetwork):
895+
raise TypeError("%s is not a network object" % other)
896+
_check_ip_version(self, other)
899897
# self.version == other.version below here:
900898
if self.network_address < other.network_address:
901899
return -1
@@ -1026,15 +1024,13 @@ def is_multicast(self):
10261024

10271025
@staticmethod
10281026
def _is_subnet_of(a, b):
1029-
try:
1030-
# Always false if one is v4 and the other is v6.
1031-
if a.version != b.version:
1032-
raise TypeError(f"{a} and {b} are not of the same version")
1033-
return (b.network_address <= a.network_address and
1034-
b.broadcast_address >= a.broadcast_address)
1035-
except AttributeError:
1036-
raise TypeError(f"Unable to test subnet containment "
1037-
f"between {a} and {b}")
1027+
if not isinstance(a, _BaseNetwork):
1028+
raise TypeError(f"{a} is not a network object")
1029+
if not isinstance(b, _BaseNetwork):
1030+
raise TypeError(f"{b} is not a network object")
1031+
_check_ip_version(a, b)
1032+
return (b.network_address <= a.network_address and
1033+
b.broadcast_address >= a.broadcast_address)
10381034

10391035
def subnet_of(self, other):
10401036
"""Return True if this network is a subnet of other."""
@@ -1429,28 +1425,27 @@ def __str__(self):
14291425
self._prefixlen)
14301426

14311427
def __eq__(self, other):
1428+
if not isinstance(other, IPv4Interface):
1429+
if isinstance(other, IPv4Address):
1430+
# avoid falling back to IPv4Address.__eq__(other, self)
1431+
return False
1432+
return NotImplemented
1433+
# An interface with an associated network is NOT the
1434+
# same as an unassociated address. That's why the hash
1435+
# takes the extra info into account.
14321436
address_equal = IPv4Address.__eq__(self, other)
1433-
if address_equal is NotImplemented or not address_equal:
1434-
return address_equal
1435-
try:
1436-
return self.network == other.network
1437-
except AttributeError:
1438-
# An interface with an associated network is NOT the
1439-
# same as an unassociated address. That's why the hash
1440-
# takes the extra info into account.
1441-
return False
1437+
return address_equal and self.network == other.network
14421438

14431439
def __lt__(self, other):
1440+
# We *do* allow addresses and interfaces to be sorted. The
1441+
# unassociated address is considered less than all interfaces.
14441442
address_less = IPv4Address.__lt__(self, other)
1445-
if address_less is NotImplemented:
1446-
return NotImplemented
1447-
try:
1448-
return (self.network < other.network or
1449-
self.network == other.network and address_less)
1450-
except AttributeError:
1451-
# We *do* allow addresses and interfaces to be sorted. The
1452-
# unassociated address is considered less than all interfaces.
1453-
return False
1443+
if isinstance(other, IPv4Interface):
1444+
assert address_less is not NotImplemented
1445+
# compare interfaces by their network first
1446+
return (self.network < other.network
1447+
or (self.network == other.network and address_less))
1448+
return address_less
14541449

14551450
def __hash__(self):
14561451
return hash((self._ip, self._prefixlen, int(self.network.network_address)))
@@ -2219,28 +2214,27 @@ def __str__(self):
22192214
self._prefixlen)
22202215

22212216
def __eq__(self, other):
2217+
if not isinstance(other, IPv6Interface):
2218+
if isinstance(other, IPv6Address):
2219+
# avoid falling back to IPv6Address.__eq__(other, self)
2220+
return False
2221+
return NotImplemented
2222+
# An interface with an associated network is NOT the
2223+
# same as an unassociated address. That's why the hash
2224+
# takes the extra info into account.
22222225
address_equal = IPv6Address.__eq__(self, other)
2223-
if address_equal is NotImplemented or not address_equal:
2224-
return address_equal
2225-
try:
2226-
return self.network == other.network
2227-
except AttributeError:
2228-
# An interface with an associated network is NOT the
2229-
# same as an unassociated address. That's why the hash
2230-
# takes the extra info into account.
2231-
return False
2226+
return address_equal and self.network == other.network
22322227

22332228
def __lt__(self, other):
2229+
# We *do* allow addresses and interfaces to be sorted. The
2230+
# unassociated address is considered less than all interfaces.
22342231
address_less = IPv6Address.__lt__(self, other)
2235-
if address_less is NotImplemented:
2236-
return address_less
2237-
try:
2238-
return (self.network < other.network or
2239-
self.network == other.network and address_less)
2240-
except AttributeError:
2241-
# We *do* allow addresses and interfaces to be sorted. The
2242-
# unassociated address is considered less than all interfaces.
2243-
return False
2232+
if isinstance(other, IPv6Interface):
2233+
assert address_less is not NotImplemented
2234+
# compare interfaces by their network first
2235+
return (self.network < other.network
2236+
or (self.network == other.network and address_less))
2237+
return address_less
22442238

22452239
def __hash__(self):
22462240
return hash((self._ip, self._prefixlen, int(self.network.network_address)))

0 commit comments

Comments
 (0)