4141# from aiomysql.utils import _convert_to_str
4242from .cursors import Cursor
4343from .utils import _ConnectionContextManager , _ContextManager
44- # from .log import logger
44+ from .log import logger
4545
4646
4747DEFAULT_USER = getpass .getuser ()
@@ -55,7 +55,7 @@ def connect(host="localhost", user=None, password="",
5555 connect_timeout = None , read_default_group = None ,
5656 no_delay = None , autocommit = False , echo = False ,
5757 local_infile = False , loop = None , ssl = None , auth_plugin = '' ,
58- program_name = '' ):
58+ program_name = '' , server_public_key = None ):
5959 """See connections.Connection.__init__() for information about
6060 defaults."""
6161 coro = _connect (host = host , user = user , password = password , db = db ,
@@ -93,7 +93,7 @@ def __init__(self, host="localhost", user=None, password="",
9393 connect_timeout = None , read_default_group = None ,
9494 no_delay = None , autocommit = False , echo = False ,
9595 local_infile = False , loop = None , ssl = None , auth_plugin = '' ,
96- program_name = '' ):
96+ program_name = '' , server_public_key = None ):
9797 """
9898 Establish a connection to the MySQL database. Accepts several
9999 arguments:
@@ -134,6 +134,8 @@ def __init__(self, host="localhost", user=None, password="",
134134 (default: Server Default)
135135 :param program_name: Program name string to provide when
136136 handshaking with MySQL. (default: sys.argv[0])
137+ :param server_public_key: SHA256 authentication plugin public
138+ key value.
137139 :param loop: asyncio loop
138140 """
139141 self ._loop = loop or asyncio .get_event_loop ()
@@ -174,6 +176,8 @@ def __init__(self, host="localhost", user=None, password="",
174176 self ._client_auth_plugin = auth_plugin
175177 self ._server_auth_plugin = ""
176178 self ._auth_plugin_used = ""
179+ self .server_public_key = server_public_key
180+ self .salt = None
177181
178182 # TODO somehow import version from __init__.py
179183 self ._connect_attrs = {
@@ -711,6 +715,20 @@ async def _request_authentication(self):
711715 if auth_plugin in ('' , 'mysql_native_password' ):
712716 authresp = _auth .scramble_native_password (
713717 self ._password .encode ('latin1' ), self .salt )
718+ elif auth_plugin == 'caching_sha2_password' :
719+ if self ._password :
720+ authresp = _auth .scramble_caching_sha2 (
721+ self ._password .encode ('latin1' ), self .salt
722+ )
723+ # Else: empty password
724+ elif auth_plugin == 'sha256_password' :
725+ if self ._ssl_context and self .server_capabilities & CLIENT .SSL :
726+ authresp = self ._password .encode ('latin1' ) + b'\0 '
727+ elif self ._password :
728+ authresp = b'\1 ' # request public key
729+ else :
730+ authresp = b'\0 ' # empty password
731+
714732 elif auth_plugin in ('' , 'mysql_clear_password' ):
715733 authresp = self ._password .encode ('latin1' ) + b'\0 '
716734
@@ -767,9 +785,21 @@ async def _request_authentication(self):
767785 auth_packet .read_all ()) + b'\0 '
768786 self .write_packet (data )
769787 await self ._read_packet ()
788+ elif auth_packet .is_extra_auth_data ():
789+ if auth_plugin == "caching_sha2_password" :
790+ await self .caching_sha2_password_auth (auth_packet )
791+ elif auth_plugin == "sha256_password" :
792+ await self .sha256_password_auth (auth_packet )
793+ else :
794+ raise OperationalError ("Received extra packet "
795+ "for auth method %r" , auth_plugin )
770796
771797 async def _process_auth (self , plugin_name , auth_packet ):
772- if plugin_name == b"mysql_native_password" :
798+ if plugin_name == b"caching_sha2_password" :
799+ return self .caching_sha2_password_auth (auth_packet )
800+ elif plugin_name == b"sha256_password" :
801+ return self .sha256_password_auth (auth_packet )
802+ elif plugin_name == b"mysql_native_password" :
773803 # https://dev.mysql.com/doc/internals/en/
774804 # secure-password-authentication.html#packet-Authentication::
775805 # Native41
@@ -798,6 +828,125 @@ async def _process_auth(self, plugin_name, auth_packet):
798828
799829 return pkt
800830
831+ async def caching_sha2_password_auth (self , pkt ):
832+ # No password fast path
833+ if not self ._password :
834+ self .write_packet (b'' )
835+ pkt = await self ._read_packet ()
836+ pkt .check_error ()
837+ return pkt
838+
839+ if pkt .is_auth_switch_request ():
840+ # Try from fast auth
841+ logger .debug ("caching sha2: Trying fast path" )
842+ self .salt = pkt .read_all ()
843+ scrambled = _auth .scramble_caching_sha2 (
844+ self ._password .encode ('latin1' ), self .salt
845+ )
846+
847+ self .write_packet (scrambled )
848+ pkt = await self ._read_packet ()
849+ pkt .check_error ()
850+
851+ # else: fast auth is tried in initial handshake
852+
853+ if not pkt .is_extra_auth_data ():
854+ raise OperationalError (
855+ "caching sha2: Unknown packet "
856+ "for fast auth: {0}" .format (pkt ._data [:1 ])
857+ )
858+
859+ # magic numbers:
860+ # 2 - request public key
861+ # 3 - fast auth succeeded
862+ # 4 - need full auth
863+
864+ pkt .advance (1 )
865+ n = pkt .read_uint8 ()
866+
867+ if n == 3 :
868+ logger .debug ("caching sha2: succeeded by fast path." )
869+ pkt = await self ._read_packet ()
870+ pkt .check_error () # pkt must be OK packet
871+ return pkt
872+
873+ if n != 4 :
874+ raise OperationalError ("caching sha2: Unknown "
875+ "result for fast auth: {0}" .format (n ))
876+
877+ logger .debug ("caching sha2: Trying full auth..." )
878+
879+ if self ._ssl_context :
880+ logger .debug ("caching sha2: Sending plain "
881+ "password via secure connection" )
882+ self .write_packet (self ._password .encode ('latin1' ) + b'\0 ' )
883+ pkt = await self ._read_packet ()
884+ pkt .check_error ()
885+ return pkt
886+
887+ if not self .server_public_key :
888+ self .write_packet (b'\x02 ' )
889+ pkt = await self ._read_packet () # Request public key
890+ pkt .check_error ()
891+
892+ if not pkt .is_extra_auth_data ():
893+ raise OperationalError (
894+ "caching sha2: Unknown packet "
895+ "for public key: {0}" .format (pkt ._data [:1 ])
896+ )
897+
898+ self .server_public_key = pkt ._data [1 :]
899+ logger .debug (self .server_public_key .decode ('ascii' ))
900+
901+ data = _auth .sha2_rsa_encrypt (
902+ self ._password .encode ('latin1' ), self .salt ,
903+ self .server_public_key
904+ )
905+ self .write_packet (data )
906+ pkt = await self ._read_packet ()
907+ pkt .check_error ()
908+
909+ async def sha256_password_auth (self , pkt ):
910+ if self ._ssl_context :
911+ logger .debug ("sha256: Sending plain password" )
912+ data = self ._password .encode ('latin1' ) + b'\0 '
913+ self .write_packet (data )
914+ pkt = await self ._read_packet ()
915+ pkt .check_error ()
916+ return pkt
917+
918+ if pkt .is_auth_switch_request ():
919+ self .salt = pkt .read_all ()
920+ if not self .server_public_key and self ._password :
921+ # Request server public key
922+ logger .debug ("sha256: Requesting server public key" )
923+ self .write_packet (b'\1 ' )
924+ pkt = await self ._read_packet ()
925+ pkt .check_error ()
926+
927+ if pkt .is_extra_auth_data ():
928+ self .server_public_key = pkt ._data [1 :]
929+ logger .debug (
930+ "Received public key:\n " ,
931+ self .server_public_key .decode ('ascii' )
932+ )
933+
934+ if self ._password :
935+ if not self .server_public_key :
936+ raise OperationalError ("Couldn't receive server's public key" )
937+
938+ data = _auth .sha2_rsa_encrypt (
939+ self ._password .encode ('latin1' ), self .salt ,
940+ self .server_public_key
941+ )
942+ else :
943+ data = b''
944+
945+ self .write_packet (data )
946+ pkt = await self ._read_packet ()
947+ pkt .check_error ()
948+ return pkt
949+
801950 # _mysql support
802951 def thread_id (self ):
803952 return self .server_thread_id [0 ]
0 commit comments