Skip to content

Commit 878438d

Browse files
authored
Avoid creating tasks if the name is already in the cache (#27)
1 parent 1bc7763 commit 878438d

File tree

2 files changed

+91
-20
lines changed

2 files changed

+91
-20
lines changed

src/aiohttp_asyncmdnsresolver/_impl.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import socket
77
import sys
88
from ipaddress import IPv4Address, IPv6Address
9-
from typing import Any
9+
from typing import TYPE_CHECKING, Any, Union
1010

1111
from aiohttp.resolver import AsyncResolver, ResolveResult
1212
from zeroconf import (
@@ -19,6 +19,8 @@
1919

2020
DEFAULT_TIMEOUT = 5.0
2121

22+
ResolverType = Union[AddressResolver, AddressResolverIPv4, AddressResolverIPv6]
23+
2224
_FAMILY_TO_RESOLVER_CLASS: dict[
2325
socket.AddressFamily,
2426
type[AddressResolver] | type[AddressResolverIPv4] | type[AddressResolverIPv6],
@@ -69,26 +71,31 @@ def __init__(
6971
self._aiozc_owner = async_zeroconf is None
7072
self._aiozc = async_zeroconf or AsyncZeroconf()
7173

74+
def _make_resolver(self, host: str, family: socket.AddressFamily) -> ResolverType:
75+
"""Create an mDNS resolver."""
76+
resolver_class = _FAMILY_TO_RESOLVER_CLASS[family]
77+
return resolver_class(host if host[-1] == "." else f"{host}.")
78+
79+
def _addresses_from_info_or_raise(
80+
self, info: ResolverType, port: int, family: socket.AddressFamily
81+
) -> list[ResolveResult]:
82+
"""Get addresses from info or raise OSError."""
83+
ip_version = _FAMILY_TO_IP_VERSION[family]
84+
if addresses := info.ip_addresses_by_version(ip_version):
85+
if TYPE_CHECKING:
86+
assert info.server is not None
87+
return [
88+
_to_resolve_result(info.server, port, address) for address in addresses
89+
]
90+
raise OSError(None, "MDNS lookup failed")
91+
7292
async def _resolve_mdns(
73-
self, host: str, port: int, family: socket.AddressFamily
93+
self, info: ResolverType, port: int, family: socket.AddressFamily
7494
) -> list[ResolveResult]:
7595
"""Resolve a host name to an IP address using mDNS."""
76-
resolver_class = _FAMILY_TO_RESOLVER_CLASS[family]
77-
ip_version: IPVersion = _FAMILY_TO_IP_VERSION[family]
78-
if host[-1] != ".":
79-
host += "."
80-
info = resolver_class(host)
81-
if (
82-
info.load_from_cache(self._aiozc.zeroconf)
83-
or (
84-
self._mdns_timeout
85-
and await info.async_request(
86-
self._aiozc.zeroconf, self._mdns_timeout * 1000
87-
)
88-
)
89-
) and (addresses := info.ip_addresses_by_version(ip_version)):
90-
return [_to_resolve_result(host, port, address) for address in addresses]
91-
raise OSError(None, "MDNS lookup failed")
96+
if self._mdns_timeout:
97+
await info.async_request(self._aiozc.zeroconf, self._mdns_timeout * 1000)
98+
return self._addresses_from_info_or_raise(info, port, family)
9299

93100
async def close(self) -> None:
94101
"""Close the resolver."""
@@ -107,7 +114,10 @@ async def resolve(
107114
"""Resolve a host name to an IP address."""
108115
if not host.endswith(".local") and not host.endswith(".local."):
109116
return await super().resolve(host, port, family)
110-
return await self._resolve_mdns(host, port, family)
117+
info = self._make_resolver(host, family)
118+
if info.load_from_cache(self._aiozc.zeroconf):
119+
return self._addresses_from_info_or_raise(info, port, family)
120+
return await self._resolve_mdns(info, port, family)
111121

112122

113123
class AsyncDualMDNSResolver(_AsyncMDNSResolverBase):
@@ -128,7 +138,10 @@ async def resolve(
128138
"""Resolve a host name to an IP address."""
129139
if not host.endswith(".local") and not host.endswith(".local."):
130140
return await super().resolve(host, port, family)
131-
resolve_via_mdns = self._resolve_mdns(host, port, family)
141+
info = self._make_resolver(host, family)
142+
if info.load_from_cache(self._aiozc.zeroconf):
143+
return self._addresses_from_info_or_raise(info, port, family)
144+
resolve_via_mdns = self._resolve_mdns(info, port, family)
132145
resolve_via_dns = super().resolve(host, port, family)
133146
loop = asyncio.get_running_loop()
134147
if sys.version_info >= (3, 12):

tests/test_impl.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,31 @@ async def _take_a_while_to_resolve(*args: Any, **kwargs: Any) -> bool:
410410
assert result["host"] == "127.0.0.2"
411411

412412

413+
@pytest.mark.asyncio
414+
async def test_async_dual_mdns_resolver_from_cache(
415+
dual_resolver: AsyncMDNSResolver,
416+
) -> None:
417+
"""Test AsyncDualMDNSResolver can resolve from cache."""
418+
with (
419+
patch(
420+
"aiohttp_asyncmdnsresolver._impl.AsyncResolver.resolve",
421+
side_effect=OSError,
422+
),
423+
patch.object(IPv4HostResolver, "load_from_cache", return_value=True),
424+
patch.object(
425+
IPv4HostResolver,
426+
"ip_addresses_by_version",
427+
return_value=[IPv4Address("127.0.0.2")],
428+
),
429+
):
430+
results = await dual_resolver.resolve("localhost.local.")
431+
assert results is not None
432+
assert len(results) == 1
433+
result = results[0]
434+
assert result["hostname"] == "localhost.local."
435+
assert result["host"] == "127.0.0.2"
436+
437+
413438
@pytest.mark.asyncio
414439
async def test_different_results_async_dual_mdns_resolver(
415440
dual_resolver: AsyncMDNSResolver,
@@ -443,6 +468,39 @@ async def test_different_results_async_dual_mdns_resolver(
443468
assert result["host"] == "127.0.0.1"
444469

445470

471+
@pytest.mark.asyncio
472+
async def test_different_results_async_dual_mdns_resolver_zero_timeout(
473+
dual_resolver: AsyncMDNSResolver,
474+
) -> None:
475+
"""Test AsyncDualMDNSResolver resolves using mDNS and DNS.
476+
477+
Test when both resolvers return different results with zero timeout
478+
for mDNS.
479+
"""
480+
dual_resolver._mdns_timeout = 0
481+
with (
482+
patch(
483+
"aiohttp_asyncmdnsresolver._impl.AsyncResolver.resolve",
484+
return_value=[
485+
ResolveResult(hostname="localhost.local.", host="127.0.0.1", port=0) # type: ignore[typeddict-item]
486+
],
487+
),
488+
patch.object(IPv4HostResolver, "load_from_cache", return_value=False),
489+
patch.object(IPv4HostResolver, "async_request", return_value=True),
490+
patch.object(
491+
IPv4HostResolver,
492+
"ip_addresses_by_version",
493+
return_value=[],
494+
),
495+
):
496+
results = await dual_resolver.resolve("localhost.local.")
497+
assert results is not None
498+
assert len(results) == 1
499+
result = results[0]
500+
assert result["hostname"] == "localhost.local."
501+
assert result["host"] == "127.0.0.1"
502+
503+
446504
@pytest.mark.asyncio
447505
async def test_failed_mdns_async_dual_mdns_resolver(
448506
dual_resolver: AsyncMDNSResolver,

0 commit comments

Comments
 (0)