2626 ProgrammingError )
2727
2828from pymysql .connections import TEXT_TYPES , MAX_PACKET_LEN , DEFAULT_CHARSET
29- # from pymysql.connections import dump_packet
30- from pymysql .connections import _scramble
31- from pymysql .connections import _scramble_323
29+ from pymysql .connections import _auth
30+
3231from pymysql .connections import pack_int24
3332
3433from pymysql .connections import MysqlPacket
4443from .utils import _ConnectionContextManager , _ContextManager
4544# from .log import logger
4645
46+
4747DEFAULT_USER = getpass .getuser ()
4848
4949
@@ -54,7 +54,8 @@ def connect(host="localhost", user=None, password="",
5454 client_flag = 0 , cursorclass = Cursor , init_command = None ,
5555 connect_timeout = None , read_default_group = None ,
5656 no_delay = None , autocommit = False , echo = False ,
57- local_infile = False , loop = None , ssl = None , auth_plugin = '' ):
57+ local_infile = False , loop = None , ssl = None , auth_plugin = '' ,
58+ program_name = '' ):
5859 """See connections.Connection.__init__() for information about
5960 defaults."""
6061 coro = _connect (host = host , user = user , password = password , db = db ,
@@ -67,7 +68,7 @@ def connect(host="localhost", user=None, password="",
6768 read_default_group = read_default_group ,
6869 no_delay = no_delay , autocommit = autocommit , echo = echo ,
6970 local_infile = local_infile , loop = loop , ssl = ssl ,
70- auth_plugin = auth_plugin )
71+ auth_plugin = auth_plugin , program_name = program_name )
7172 return _ConnectionContextManager (coro )
7273
7374
@@ -91,7 +92,8 @@ def __init__(self, host="localhost", user=None, password="",
9192 client_flag = 0 , cursorclass = Cursor , init_command = None ,
9293 connect_timeout = None , read_default_group = None ,
9394 no_delay = None , autocommit = False , echo = False ,
94- local_infile = False , loop = None , ssl = None , auth_plugin = '' ):
95+ local_infile = False , loop = None , ssl = None , auth_plugin = '' ,
96+ program_name = '' ):
9597 """
9698 Establish a connection to the MySQL database. Accepts several
9799 arguments:
@@ -125,6 +127,13 @@ def __init__(self, host="localhost", user=None, password="",
125127 (default: False)
126128 :param local_infile: boolean to enable the use of LOAD DATA LOCAL
127129 command. (default: False)
130+ :param ssl: Optional SSL Context to force SSL
131+ :param auth_plugin: String to manually specify the authentication
132+ plugin to use, i.e you will want to use mysql_clear_password
133+ when using IAM authentication with Amazon RDS.
134+ (default: Server Default)
135+ :param program_name: Program name string to provide when
136+ handshaking with MySQL. (default: sys.argv[0])
128137 :param loop: asyncio loop
129138 """
130139 self ._loop = loop or asyncio .get_event_loop ()
@@ -166,6 +175,17 @@ def __init__(self, host="localhost", user=None, password="",
166175 self ._server_auth_plugin = ""
167176 self ._auth_plugin_used = ""
168177
178+ # TODO somehow import version from __init__.py
179+ self ._connect_attrs = {
180+ '_client_name' : 'aiomysql' ,
181+ '_pid' : str (os .getpid ()),
182+ '_client_version' : '0.0.16' ,
183+ }
184+ if program_name :
185+ self ._connect_attrs ["program_name" ] = program_name
186+ elif sys .argv :
187+ self ._connect_attrs ["program_name" ] = sys .argv [0 ]
188+
169189 self ._unix_socket = unix_socket
170190 if charset :
171191 self ._charset = charset
@@ -673,8 +693,10 @@ async def _request_authentication(self):
673693 charset_id = charset_by_name (self .charset ).id
674694 if isinstance (self .user , str ):
675695 _user = self .user .encode (self .encoding )
696+ else :
697+ _user = self .user
676698
677- data_init = struct .pack ('<iIB23s' , self .client_flag , 1 ,
699+ data_init = struct .pack ('<iIB23s' , self .client_flag , MAX_PACKET_LEN ,
678700 charset_id , b'' )
679701
680702 data = data_init + _user + b'\0 '
@@ -687,7 +709,8 @@ async def _request_authentication(self):
687709 auth_plugin = self ._server_auth_plugin
688710
689711 if auth_plugin in ('' , 'mysql_native_password' ):
690- authresp = _scramble (self ._password .encode ('latin1' ), self .salt )
712+ authresp = _auth .scramble_native_password (
713+ self ._password .encode ('latin1' ), self .salt )
691714 elif auth_plugin in ('' , 'mysql_clear_password' ):
692715 authresp = self ._password .encode ('latin1' ) + b'\0 '
693716
@@ -715,6 +738,15 @@ async def _request_authentication(self):
715738
716739 self ._auth_plugin_used = auth_plugin
717740
741+ # Sends the server a few pieces of client info
742+ if self .server_capabilities & CLIENT .CONNECT_ATTRS :
743+ connect_attrs = b''
744+ for k , v in self ._connect_attrs .items ():
745+ k , v = k .encode ('utf8' ), v .encode ('utf8' )
746+ connect_attrs += struct .pack ('B' , len (k )) + k
747+ connect_attrs += struct .pack ('B' , len (v )) + v
748+ data += struct .pack ('B' , len (connect_attrs )) + connect_attrs
749+
718750 self .write_packet (data )
719751 auth_packet = await self ._read_packet ()
720752
@@ -727,27 +759,28 @@ async def _request_authentication(self):
727759 plugin_name = auth_packet .read_string ()
728760 if (self .server_capabilities & CLIENT .PLUGIN_AUTH and
729761 plugin_name is not None ):
730- auth_packet = await self ._process_auth (
731- plugin_name , auth_packet )
762+ await self ._process_auth (plugin_name , auth_packet )
732763 else :
733764 # send legacy handshake
734- data = _scramble_323 (self ._password .encode ('latin1' ),
735- self .salt ) + b'\0 '
765+ data = _auth .scramble_old_password (
766+ self ._password .encode ('latin1' ),
767+ auth_packet .read_all ()) + b'\0 '
736768 self .write_packet (data )
737- auth_packet = await self ._read_packet ()
769+ await self ._read_packet ()
738770
739771 async def _process_auth (self , plugin_name , auth_packet ):
740772 if plugin_name == b"mysql_native_password" :
741773 # https://dev.mysql.com/doc/internals/en/
742774 # secure-password-authentication.html#packet-Authentication::
743775 # Native41
744- data = _scramble (self ._password .encode ('latin1' ),
745- auth_packet .read_all ())
776+ data = _auth .scramble_native_password (
777+ self ._password .encode ('latin1' ),
778+ auth_packet .read_all ())
746779 elif plugin_name == b"mysql_old_password" :
747780 # https://dev.mysql.com/doc/internals/en/
748781 # old-password-authentication.html
749- data = _scramble_323 (self ._password .encode ('latin1' ),
750- auth_packet .read_all ()) + b'\0 '
782+ data = _auth . scramble_old_password (self ._password .encode ('latin1' ),
783+ auth_packet .read_all ()) + b'\0 '
751784 elif plugin_name == b"mysql_clear_password" :
752785 # https://dev.mysql.com/doc/internals/en/
753786 # clear-text-authentication.html
0 commit comments