Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions Lib/test/test_wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from test.support.os_helper import FakePath, unlink
import io
import os
import re
import struct
import tempfile
import sys
Expand Down Expand Up @@ -323,14 +324,14 @@ def test_read_wrong_number_of_channels(self):
b = b'RIFF' + struct.pack('<L', 36) + b'WAVE'
b += b'fmt ' + struct.pack('<LHHLLHH', 16, 1, 0, 11025, 11025, 1, 8)
b += b'data' + struct.pack('<L', 0)
with self.assertRaisesRegex(wave.Error, 'bad # of channels'):
with self.assertRaisesRegex(wave.Error, 'bad # of channels: 0'):
wave.open(io.BytesIO(b))

def test_read_wrong_sample_width(self):
b = b'RIFF' + struct.pack('<L', 36) + b'WAVE'
b += b'fmt ' + struct.pack('<LHHLLHH', 16, 1, 1, 11025, 11025, 1, 0)
b += b'data' + struct.pack('<L', 0)
with self.assertRaisesRegex(wave.Error, 'bad sample width'):
with self.assertRaisesRegex(wave.Error, 'bad sample width: 0'):
wave.open(io.BytesIO(b))

def test_open_in_write_raises(self):
Expand Down Expand Up @@ -430,6 +431,36 @@ def test_setframerate_rounds(self, arg, expected):
f.setframerate(arg)
self.assertEqual(f.getframerate(), expected)

@support.subTests('nchannels', (0, -1))
def test_setnchannels_error_includes_value(self, nchannels):
with wave.open(io.BytesIO(), 'wb') as f:
with self.assertRaisesRegex(wave.Error,
re.escape(f'bad # of channels: {nchannels!r}')):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For readability purposes, can you put the message outside?

message = re.escape(...)
with wave.open(...):
   with self.assertRaisesRegex(..., err):
        f.setchannels(channels)

f.setnchannels(nchannels)
with self.assertRaises(wave.Error):
f.close()

@support.subTests('sampwidth', (0, 5))
def test_setsampwidth_error_includes_value(self, sampwidth):
with wave.open(io.BytesIO(), 'wb') as f:
f.setnchannels(1)
with self.assertRaisesRegex(wave.Error,
re.escape(f'bad sample width: {sampwidth!r}')):
f.setsampwidth(sampwidth)
with self.assertRaises(wave.Error):
f.close()

@support.subTests('arg', (-1, 0, 0.4))
def test_setframerate_error_includes_value(self, arg):
with wave.open(io.BytesIO(), 'wb') as f:
f.setnchannels(1)
f.setsampwidth(2)
with self.assertRaisesRegex(wave.Error,
re.escape(f'bad frame rate: {arg!r}')):
f.setframerate(arg)
with self.assertRaises(wave.Error):
f.close()

def test_write_odd_data_chunk_pads_and_updates_riff_size(self):
# gh-117716: odd-sized data chunks must be padded with one zero byte.
with io.BytesIO() as output:
Expand Down
21 changes: 11 additions & 10 deletions Lib/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def readframes(self, nframes):

def _read_fmt_chunk(self, chunk):
try:
self._format, self._nchannels, self._framerate, dwAvgBytesPerSec, wBlockAlign = struct.unpack_from('<HHLLH', chunk.read(14))
self._format, nchannels, self._framerate, dwAvgBytesPerSec, wBlockAlign = struct.unpack_from('<HHLLH', chunk.read(14))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't change this. It is a breaking change. Previously the attribute would have been already set after passing unpack_from but now it will be deferred.

except struct.error:
raise EOFError from None
if self._format not in (WAVE_FORMAT_PCM, WAVE_FORMAT_IEEE_FLOAT, WAVE_FORMAT_EXTENSIBLE):
Expand All @@ -415,9 +415,10 @@ def _read_fmt_chunk(self, chunk):
raise Error(subformat_msg)
self._sampwidth = (sampwidth + 7) // 8
if not self._sampwidth:
raise Error('bad sample width')
if not self._nchannels:
raise Error('bad # of channels')
raise Error(f'bad sample width: {sampwidth!r}')
if not nchannels:
raise Error(f'bad # of channels: {nchannels!r}')
self._nchannels = nchannels
self._framesize = self._nchannels * self._sampwidth
self._comptype = 'NONE'
self._compname = 'not compressed'
Expand Down Expand Up @@ -495,7 +496,7 @@ def setnchannels(self, nchannels):
if self._datawritten:
raise Error('cannot change parameters after starting to write')
if nchannels < 1:
raise Error('bad # of channels')
raise Error(f'bad # of channels: {nchannels!r}')
self._nchannels = nchannels

def getnchannels(self):
Expand All @@ -510,7 +511,7 @@ def setsampwidth(self, sampwidth):
if sampwidth not in (4, 8):
raise Error('unsupported sample width for IEEE float format')
elif sampwidth < 1 or sampwidth > 4:
raise Error('bad sample width')
raise Error(f'bad sample width: {sampwidth!r}')
self._sampwidth = sampwidth

def getsampwidth(self):
Expand All @@ -521,10 +522,10 @@ def getsampwidth(self):
def setframerate(self, framerate):
if self._datawritten:
raise Error('cannot change parameters after starting to write')
framerate = int(round(framerate))
if framerate <= 0:
raise Error('bad frame rate')
self._framerate = framerate
rounded = int(round(framerate))
if rounded <= 0:
raise Error(f'bad frame rate: {framerate!r}')
self._framerate = rounded
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
rounded = int(round(framerate))
if rounded <= 0:
raise Error(f'bad frame rate: {framerate!r}')
self._framerate = rounded
rounded_framerate = int(round(framerate))
if rounded_framerate <= 0:
raise Error(f'bad frame rate: {framerate!r}')
self._framerate = rounded_framerate


def getframerate(self):
if not self._framerate:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Error messages in :mod:`wave` for invalid channel count, sample width, and
frame rate now include the offending value.
Loading