2020
2121# Valid content types (ordered from newest, to most obsolete)
2222versions = {
23- "aes128gcm" : {"padding " : 2 },
24- "aesgcm" : {"padding " : 2 },
25- "aesgcm128" : {"padding " : 1 }
23+ "aes128gcm" : {"pad " : 2 },
24+ "aesgcm" : {"pad " : 2 },
25+ "aesgcm128" : {"pad " : 1 },
2626}
2727
2828
@@ -34,8 +34,8 @@ def __init__(self, message):
3434# TODO: turn this into a class so that we don't grow/stomp keys.
3535
3636
37- def derive_key (mode , salt = None , key = None , dh = None , keyid = None ,
38- auth_secret = None , version = "aesgcm" , ** kwargs ):
37+ def derive_key (mode , version , salt = None , key = None , dh = None , auth_secret = None ,
38+ keyid = None , keymap = None , keylabels = None ):
3939 """Derive the encryption key
4040
4141 :param mode: operational mode (encrypt or decrypt)
@@ -48,6 +48,10 @@ def derive_key(mode, salt=None, key=None, dh=None, keyid=None,
4848 :type dh: str
4949 :param keyid: key identifier label
5050 :type keyid: str
51+ :param keymap: map of keyids to keys
52+ :type keymap: map
53+ :param keylabels: map of keyids to labels
54+ :type keylabels: map
5155 :param auth_secret: authorization secret
5256 :type auth_secret: str
5357 :param version: Content Type identifier
@@ -61,69 +65,47 @@ def derive_key(mode, salt=None, key=None, dh=None, keyid=None,
6165 def build_info (base , info_context ):
6266 return b"Content-Encoding: " + base + b"\0 " + info_context
6367
64- def derive_dh (mode , keyid , dh , version = "aesgcm" ):
65-
68+ def derive_dh (mode , version , dh , keyid , keymap , keylabels ):
6669 def length_prefix (key ):
6770 return struct .pack ("!H" , len (key )) + key
6871
6972 if keyid is None :
7073 raise ECEException (u"'keyid' is not specified with 'dh'" )
71- if keyid not in keys :
74+ if keyid not in keymap :
7275 raise ECEException (u"'keyid' doesn't identify a key: " + keyid )
7376 if mode == "encrypt" :
74- sender_pub_key = key or keys [keyid ].get_pubkey ()
77+ sender_pub_key = key or keymap [keyid ].get_pubkey ()
7578 receiver_pub_key = dh
7679 elif mode == "decrypt" :
7780 sender_pub_key = dh
78- receiver_pub_key = key or keys [keyid ].get_pubkey ()
81+ receiver_pub_key = key or keymap [keyid ].get_pubkey ()
7982 else :
8083 raise ECEException (u"unknown 'mode' specified: " + mode )
8184 if version == "aes128gcm" :
8285 context = b"WebPush: info\x00 " + receiver_pub_key + sender_pub_key
8386 else :
84- label = labels .get (keyid , 'P-256' ).encode ('utf-8' )
87+ label = keylabels .get (keyid , 'P-256' ).encode ('utf-8' )
8588 context = (label + b"\0 " + length_prefix (receiver_pub_key ) +
8689 length_prefix (sender_pub_key ))
8790
88- return keys [keyid ].get_ecdh_key (dh ), context
89-
90- # handle the older, ill formatted args.
91- pad_size = kwargs .get ('padSize' , 2 )
92- auth_secret = kwargs .get ('authSecret' , auth_secret )
93- secret = key
94-
95- # handle old cases where version is explicitly None.
96- if not version :
97- if pad_size == 1 :
98- version = "aesgcm128"
99- else :
100- version = "aesgcm"
91+ return keymap [keyid ].get_ecdh_key (dh ), context
10192
10293 if version not in versions :
103- raise ECEException (u"invalid version specified " )
94+ raise ECEException (u"Invalid version" )
10495 if salt is None or len (salt ) != 16 :
10596 raise ECEException (u"'salt' must be a 16 octet value" )
10697 if dh is not None :
107- (secret , context ) = derive_dh (mode = mode , keyid = keyid , dh = dh ,
108- version = version )
109- elif keyid in keys :
110- if isinstance (keys [keyid ], ecc .ECC ):
111- secret = keys [keyid ].get_privkey ()
112- else :
113- secret = keys [keyid ]
98+ (secret , context ) = derive_dh (mode = mode , version = version , dh = dh ,
99+ keyid = keyid , keymap = keymap ,
100+ keylabels = keylabels )
101+ elif keyid in keymap :
102+ secret = keymap [keyid ]
103+ else :
104+ secret = key
105+
114106 if secret is None :
115107 raise ECEException (u"unable to determine the secret" )
116108
117- if auth_secret is not None :
118- hkdf_auth = HKDF (
119- algorithm = hashes .SHA256 (),
120- length = 32 ,
121- salt = auth_secret ,
122- info = build_info (b"auth" , b"" ),
123- backend = default_backend ()
124- )
125- secret = hkdf_auth .derive (secret )
126-
127109 if version == "aesgcm" :
128110 keyinfo = build_info (b"aesgcm" , context )
129111 nonceinfo = build_info (b"nonce" , context )
@@ -134,6 +116,20 @@ def length_prefix(key):
134116 keyinfo = b"Content-Encoding: aes128gcm\x00 "
135117 nonceinfo = b"Content-Encoding: nonce\x00 "
136118
119+ if auth_secret is not None :
120+ if version == "aes128gcm" :
121+ info = context
122+ else :
123+ info = build_info (b'auth' , b'' )
124+ hkdf_auth = HKDF (
125+ algorithm = hashes .SHA256 (),
126+ length = 32 ,
127+ salt = auth_secret ,
128+ info = info ,
129+ backend = default_backend ()
130+ )
131+ secret = hkdf_auth .derive (secret )
132+
137133 hkdf_key = HKDF (
138134 algorithm = hashes .SHA256 (),
139135 length = 16 ,
@@ -161,8 +157,8 @@ def iv(base, counter):
161157 return base [:4 ] + struct .pack ("!Q" , counter ^ mask )
162158
163159
164- def decrypt (content , salt , key = None , keyid = None , dh = None , rs = 4096 ,
165- auth_secret = None , version = "aesgcm" , ** kwargs ):
160+ def decrypt (content , salt , key = None , keyid = None , keymap = None , keylabels = None ,
161+ dh = None , rs = 4096 , auth_secret = None , version = "aesgcm" , ** kwargs ):
166162 """
167163 Decrypt a data block
168164
@@ -218,12 +214,17 @@ def decrypt_record(key, nonce, counter, content):
218214 data = data [pad_size + pad :]
219215 return data
220216
217+ if version not in versions :
218+ raise ECEException (u"Invalid version" )
219+
221220 # handle old, malformed args
222- pad_size = kwargs .get ('padSize' , 2 )
221+ pad_size = kwargs .get ('padSize' , versions [ version ][ 'pad' ] )
223222 auth_secret = kwargs .get ('authSecret' , auth_secret )
223+ if keymap is None :
224+ keymap = keys
225+ if keylabels is None :
226+ keylabels = labels
224227
225- if version not in versions :
226- raise ECEException (u"Invalid version" )
227228 if version == "aes128gcm" :
228229 try :
229230 content_header = parse_content_header (content )
@@ -232,14 +233,12 @@ def decrypt_record(key, nonce, counter, content):
232233 ex .message )
233234 salt = content_header ['salt' ]
234235 keyid = content_header ['key_id' ] or '' if keyid is None else keyid
235- pad_size = 2
236236 content = content_header ['content' ]
237237
238- (key_ , nonce_ ) = derive_key (mode = "decrypt" , salt = salt ,
239- key = key , keyid = keyid , dh = dh ,
240- auth_secret = auth_secret ,
241- padSize = pad_size ,
242- version = version )
238+ (key_ , nonce_ ) = derive_key (mode = "decrypt" , version = version ,
239+ salt = salt , key = key ,
240+ dh = dh , auth_secret = auth_secret ,
241+ keyid = keyid , keymap = keymap , keylabels = keylabels )
243242 if rs <= pad_size :
244243 raise ECEException (u"Record size too small" )
245244 rs += 16 # account for tags
@@ -257,8 +256,8 @@ def decrypt_record(key, nonce, counter, content):
257256 return result
258257
259258
260- def encrypt (content , salt = None , key = None , keyid = None , dh = None , rs = 4096 ,
261- auth_secret = None , pad_size = 2 , version = "aesgcm" , ** kwargs ):
259+ def encrypt (content , salt = None , key = None , keyid = None , keymap = None , keylabels = None ,
260+ dh = None , rs = 4096 , auth_secret = None , version = "aesgcm" , ** kwargs ):
262261 """
263262 Encrypt a data block
264263
@@ -288,7 +287,7 @@ def encrypt_record(key, nonce, counter, buf):
288287 modes .GCM (iv (nonce , counter )),
289288 backend = default_backend ()
290289 ).encryptor ()
291- data = encryptor .update (b"\0 \0 " + buf )
290+ data = encryptor .update (( b"\0 " * pad_size ) + buf )
292291 data += encryptor .finalize ()
293292 data += encryptor .tag
294293 return data
@@ -324,26 +323,33 @@ def compose_aes128gcm(salt, content, rs=4096, key_id=""):
324323 header += key_id .encode ('utf-8' )
325324 return header + content
326325
326+ if version not in versions :
327+ raise ECEException (u"Invalid version" )
328+
327329 # handle the older, ill formatted args.
328- pad_size = kwargs .get ('padSize' , pad_size )
330+ pad_size = kwargs .get ('padSize' , versions [ version ][ 'pad' ] )
329331 auth_secret = kwargs .get ('authSecret' , auth_secret )
332+ if keymap is None :
333+ keymap = keys
334+ if keylabels is None :
335+ keylabels = labels
330336 if salt is None :
331337 salt = os .urandom (16 )
332338 version = "aes128gcm"
333339
334- (key_ , nonce_ ) = derive_key (mode = "encrypt" , salt = salt ,
335- key = key , keyid = keyid , dh = dh ,
336- auth_secret = auth_secret , padSize = pad_size ,
337- version = version )
340+ (key_ , nonce_ ) = derive_key (mode = "encrypt" , version = version ,
341+ salt = salt , key = key ,
342+ dh = dh , auth_secret = auth_secret ,
343+ keyid = keyid , keymap = keymap , keylabels = keylabels )
338344 if rs <= pad_size :
339345 raise ECEException (u"Record size too small" )
340346 rs -= pad_size # account for padding
341-
347+
342348 result = b""
343349 counter = 0
344350
345- # the extra padSize on the loop ensures that we produce a padding only
346- # record if the data length is an exact multiple of rs-padSize
351+ # the extra pad_size on the loop ensures that we produce a padding only
352+ # record if the data length is an exact multiple of rs-pad_size
347353 for i in list (range (0 , len (content ) + pad_size , rs )):
348354 result += encrypt_record (key_ , nonce_ , counter , content [i :i + rs ])
349355 counter += 1
0 commit comments