@@ -870,27 +870,39 @@ def __init__(self, host, port=None, timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
870870 def set_tunnel (self , host , port = None , headers = None ):
871871 """Set up host and port for HTTP CONNECT tunnelling.
872872
873- In a connection that uses HTTP CONNECT tunneling , the host passed to the
874- constructor is used as a proxy server that relays all communication to
875- the endpoint passed to `set_tunnel`. This done by sending an HTTP
873+ In a connection that uses HTTP CONNECT tunnelling , the host passed to
874+ the constructor is used as a proxy server that relays all communication
875+ to the endpoint passed to `set_tunnel`. This done by sending an HTTP
876876 CONNECT request to the proxy server when the connection is established.
877877
878878 This method must be called before the HTTP connection has been
879879 established.
880880
881881 The headers argument should be a mapping of extra HTTP headers to send
882882 with the CONNECT request.
883+
884+ As HTTP/1.1 is used for HTTP CONNECT tunnelling request, as per the RFC
885+ (https://tools.ietf.org/html/rfc7231#section-4.3.6), a HTTP Host:
886+ header must be provided, matching the authority-form of the request
887+ target provided as the destination for the CONNECT request. If a
888+ HTTP Host: header is not provided via the headers argument, one
889+ is generated and transmitted automatically.
883890 """
884891
885892 if self .sock :
886893 raise RuntimeError ("Can't set up tunnel for established connection" )
887894
888895 self ._tunnel_host , self ._tunnel_port = self ._get_hostport (host , port )
889896 if headers :
890- self ._tunnel_headers = headers
897+ self ._tunnel_headers = headers . copy ()
891898 else :
892899 self ._tunnel_headers .clear ()
893900
901+ if not any (header .lower () == "host" for header in self ._tunnel_headers ):
902+ encoded_host = self ._tunnel_host .encode ("idna" ).decode ("ascii" )
903+ self ._tunnel_headers ["Host" ] = "%s:%d" % (
904+ encoded_host , self ._tunnel_port )
905+
894906 def _get_hostport (self , host , port ):
895907 if port is None :
896908 i = host .rfind (':' )
@@ -915,8 +927,9 @@ def set_debuglevel(self, level):
915927 self .debuglevel = level
916928
917929 def _tunnel (self ):
918- connect = b"CONNECT %s:%d HTTP/1.0\r \n " % (
919- self ._tunnel_host .encode ("ascii" ), self ._tunnel_port )
930+ connect = b"CONNECT %s:%d %s\r \n " % (
931+ self ._tunnel_host .encode ("idna" ), self ._tunnel_port ,
932+ self ._http_vsn_str .encode ("ascii" ))
920933 headers = [connect ]
921934 for header , value in self ._tunnel_headers .items ():
922935 headers .append (f"{ header } : { value } \r \n " .encode ("latin-1" ))
0 commit comments