@@ -291,82 +291,89 @@ def test_join(self):
291291 bytes_join (b'' , NULL )
292292
293293
294- class PyBytesWriterTest (unittest .TestCase ):
294+ class BytesWriterTest (unittest .TestCase ):
295295 SMALL_BUFFER = 256 # bytes
296+ result_type = bytes
296297
297298 def create_writer (self , alloc = 0 , string = b'' ):
298- return _testcapi .PyBytesWriter (alloc , string )
299+ return _testcapi .PyBytesWriter (alloc , string , 0 )
299300
300301 def test_create (self ):
301302 # Test PyBytesWriter_Create()
302303 writer = self .create_writer ()
303304 self .assertEqual (writer .get_size (), 0 )
304305 self .assertEqual (writer .get_allocated (), self .SMALL_BUFFER )
305- self .assertEqual (writer .finish (), b'' )
306+ self .assertEqual (writer .finish (), self . result_type ( b'' ) )
306307
307308 writer = self .create_writer (3 , b'abc' )
308309 self .assertEqual (writer .get_size (), 3 )
309310 self .assertEqual (writer .get_allocated (), self .SMALL_BUFFER )
310- self .assertEqual (writer .finish (), b'abc' )
311+ self .assertEqual (writer .finish (), self . result_type ( b'abc' ) )
311312
312313 writer = self .create_writer (10 , b'abc' )
313314 self .assertEqual (writer .get_size (), 10 )
314315 self .assertEqual (writer .get_allocated (), self .SMALL_BUFFER )
315- self .assertEqual (writer .finish_with_size (3 ), b'abc' )
316+ self .assertEqual (writer .finish_with_size (3 ), self . result_type ( b'abc' ) )
316317
317318 def test_write_bytes (self ):
318319 # Test PyBytesWriter_WriteBytes()
319320 writer = self .create_writer ()
320321 writer .write_bytes (b'Hello World!' , - 1 )
321- self .assertEqual (writer .finish (), b'Hello World!' )
322+ self .assertEqual (writer .finish (), self . result_type ( b'Hello World!' ) )
322323
323324 writer = self .create_writer ()
324325 writer .write_bytes (b'Hello ' , - 1 )
325326 writer .write_bytes (b'World! <truncated>' , 6 )
326- self .assertEqual (writer .finish (), b'Hello World!' )
327+ self .assertEqual (writer .finish (), self . result_type ( b'Hello World!' ) )
327328
328329 def test_resize (self ):
329330 # Test PyBytesWriter_Resize()
330331 writer = self .create_writer ()
331332 writer .resize (len (b'number=123456' ), b'number=123456' )
332333 writer .resize (len (b'number=123456' ), b'' )
333334 self .assertEqual (writer .get_size (), len (b'number=123456' ))
334- self .assertEqual (writer .finish (), b'number=123456' )
335+ self .assertEqual (writer .finish (), self . result_type ( b'number=123456' ) )
335336
336337 writer = self .create_writer ()
337338 writer .resize (0 , b'' )
338339 writer .resize (len (b'number=123456' ), b'number=123456' )
339- self .assertEqual (writer .finish (), b'number=123456' )
340+ self .assertEqual (writer .finish (), self . result_type ( b'number=123456' ) )
340341
341342 writer = self .create_writer ()
342343 writer .resize (len (b'number=' ), b'number=' )
343344 writer .resize (len (b'number=123456' ), b'123456' )
344- self .assertEqual (writer .finish (), b'number=123456' )
345+ self .assertEqual (writer .finish (), self . result_type ( b'number=123456' ) )
345346
346347 writer = self .create_writer ()
347348 writer .resize (len (b'number=' ), b'number=' )
348349 writer .resize (len (b'number=' ), b'' )
349350 writer .resize (len (b'number=123456' ), b'123456' )
350- self .assertEqual (writer .finish (), b'number=123456' )
351+ self .assertEqual (writer .finish (), self . result_type ( b'number=123456' ) )
351352
352353 writer = self .create_writer ()
353354 writer .resize (len (b'number' ), b'number' )
354355 writer .resize (len (b'number=' ), b'=' )
355356 writer .resize (len (b'number=123' ), b'123' )
356357 writer .resize (len (b'number=123456' ), b'456' )
357- self .assertEqual (writer .finish (), b'number=123456' )
358+ self .assertEqual (writer .finish (), self . result_type ( b'number=123456' ) )
358359
359360 def test_format_i (self ):
360361 # Test PyBytesWriter_Format()
361362 writer = self .create_writer ()
362363 writer .format_i (b'x=%i' , 123456 )
363- self .assertEqual (writer .finish (), b'x=123456' )
364+ self .assertEqual (writer .finish (), self . result_type ( b'x=123456' ) )
364365
365366 writer = self .create_writer ()
366367 writer .format_i (b'x=%i, ' , 123 )
367368 writer .format_i (b'y=%i' , 456 )
368- self .assertEqual (writer .finish (), b'x=123, y=456' )
369+ self .assertEqual (writer .finish (), self . result_type ( b'x=123, y=456' ) )
369370
370371
372+ class ByteArrayWriterTest (BytesWriterTest ):
373+ result_type = bytearray
374+
375+ def create_writer (self , alloc = 0 , string = b'' ):
376+ return _testcapi .PyBytesWriter (alloc , string , 1 )
377+
371378if __name__ == "__main__" :
372379 unittest .main ()
0 commit comments