22
33from __future__ import annotations
44
5+ import asyncio
56import socket
7+ import sys
68from ipaddress import IPv4Address , IPv6Address
79from typing import Any
810
@@ -51,7 +53,7 @@ def _to_resolve_result(
5153 )
5254
5355
54- class AsyncMDNSResolver (AsyncResolver ):
56+ class _AsyncMDNSResolverBase (AsyncResolver ):
5557 """Use the `aiodns`/`zeroconf` packages to make asynchronous DNS lookups."""
5658
5759 def __init__ (
@@ -67,14 +69,6 @@ def __init__(
6769 self ._aiozc_owner = async_zeroconf is None
6870 self ._aiozc = async_zeroconf or AsyncZeroconf ()
6971
70- async def resolve (
71- self , host : str , port : int = 0 , family : socket .AddressFamily = socket .AF_INET
72- ) -> list [ResolveResult ]:
73- """Resolve a host name to an IP address."""
74- if host .endswith (".local" ) or host .endswith (".local." ):
75- return await self ._resolve_mdns (host , port , family )
76- return await super ().resolve (host , port , family )
77-
7872 async def _resolve_mdns (
7973 self , host : str , port : int , family : socket .AddressFamily
8074 ) -> list [ResolveResult ]:
@@ -102,3 +96,89 @@ async def close(self) -> None:
10296 await self ._aiozc .async_close ()
10397 await super ().close ()
10498 self ._aiozc = None # type: ignore[assignment] # break ref cycles early
99+
100+
101+ class AsyncMDNSResolver (_AsyncMDNSResolverBase ):
102+ """Use the `aiodns`/`zeroconf` packages to make asynchronous DNS lookups."""
103+
104+ async def resolve (
105+ self , host : str , port : int = 0 , family : socket .AddressFamily = socket .AF_INET
106+ ) -> list [ResolveResult ]:
107+ """Resolve a host name to an IP address."""
108+ if not host .endswith (".local" ) and not host .endswith (".local." ):
109+ return await super ().resolve (host , port , family )
110+ return await self ._resolve_mdns (host , port , family )
111+
112+
113+ class AsyncDualMDNSResolver (_AsyncMDNSResolverBase ):
114+ """Use the `aiodns`/`zeroconf` packages to make asynchronous DNS lookups.
115+
116+ This resolver is a variant of `AsyncMDNSResolver` that resolves .local
117+ names with both mDNS and regular DNS.
118+
119+ - The first successful result from either resolver is returned.
120+ - If both resolvers fail, an exception is raised.
121+ - If both resolvers return results at the same time, the results are
122+ combined and duplicates are removed.
123+ """
124+
125+ async def resolve (
126+ self , host : str , port : int = 0 , family : socket .AddressFamily = socket .AF_INET
127+ ) -> list [ResolveResult ]:
128+ """Resolve a host name to an IP address."""
129+ if not host .endswith (".local" ) and not host .endswith (".local." ):
130+ return await super ().resolve (host , port , family )
131+ resolve_via_mdns = self ._resolve_mdns (host , port , family )
132+ resolve_via_dns = super ().resolve (host , port , family )
133+ loop = asyncio .get_running_loop ()
134+ if sys .version_info >= (3 , 12 ):
135+ mdns_task = asyncio .Task (resolve_via_mdns , loop = loop , eager_start = True )
136+ dns_task = asyncio .Task (resolve_via_dns , loop = loop , eager_start = True )
137+ else :
138+ mdns_task = loop .create_task (resolve_via_mdns )
139+ dns_task = loop .create_task (resolve_via_dns )
140+ await asyncio .wait ((mdns_task , dns_task ), return_when = asyncio .FIRST_COMPLETED )
141+ if mdns_task .done () and mdns_task .exception ():
142+ await asyncio .wait ((dns_task ,), return_when = asyncio .ALL_COMPLETED )
143+ elif dns_task .done () and dns_task .exception ():
144+ await asyncio .wait ((mdns_task ,), return_when = asyncio .ALL_COMPLETED )
145+ resolve_results : list [ResolveResult ] = []
146+ exceptions : list [BaseException ] = []
147+ seen_results : set [tuple [str , int , str ]] = set ()
148+ for task in (mdns_task , dns_task ):
149+ if task .done ():
150+ if exc := task .exception ():
151+ exceptions .append (exc )
152+ else :
153+ # If we have multiple results, we need to remove duplicates
154+ # and combine the results. We put the mDNS results first
155+ # to prioritize them.
156+ for result in task .result ():
157+ result_key = (
158+ result ["hostname" ],
159+ result ["port" ],
160+ result ["host" ],
161+ )
162+ if result_key not in seen_results :
163+ seen_results .add (result_key )
164+ resolve_results .append (result )
165+ else :
166+ task .cancel ()
167+ try :
168+ await task # clear log traceback
169+ except asyncio .CancelledError :
170+ if (
171+ sys .version_info >= (3 , 11 )
172+ and (current_task := asyncio .current_task ())
173+ and current_task .cancelling ()
174+ ):
175+ raise
176+
177+ if resolve_results :
178+ return resolve_results
179+
180+ exception_strings = ", " .join (
181+ exc .strerror or str (exc ) if isinstance (exc , OSError ) else str (exc )
182+ for exc in exceptions
183+ )
184+ raise OSError (None , exception_strings )
0 commit comments