From 134cb01cda50f02725575808130b05d2d776693f Mon Sep 17 00:00:00 2001
From: Serhiy Storchaka <storchaka@gmail.com>
Date: Sun, 18 Mar 2018 09:55:53 +0200
Subject: [PATCH] bpo-32056: Improve exceptions in aifc, wave and sunau.
 (GH-5951)

---
 Lib/aifc.py                                        |  4 ++
 Lib/sunau.py                                       |  2 +
 Lib/test/test_aifc.py                              | 35 ++++++++++--
 Lib/test/test_sunau.py                             | 37 +++++++++++++
 Lib/test/test_wave.py                              | 62 ++++++++++++++++++++++
 Lib/wave.py                                        | 14 ++++-
 .../2018-03-01-17-49-56.bpo-32056.IlpfgE.rst       |  3 ++
 7 files changed, 150 insertions(+), 7 deletions(-)
 create mode 100644 Misc/NEWS.d/next/Library/2018-03-01-17-49-56.bpo-32056.IlpfgE.rst

diff --git a/Lib/aifc.py b/Lib/aifc.py
index 3d2dc56de198..1916e7ef8e7e 100644
--- a/Lib/aifc.py
+++ b/Lib/aifc.py
@@ -467,6 +467,10 @@ def _read_comm_chunk(self, chunk):
         self._nframes = _read_long(chunk)
         self._sampwidth = (_read_short(chunk) + 7) // 8
         self._framerate = int(_read_float(chunk))
+        if self._sampwidth <= 0:
+            raise Error('bad sample width')
+        if self._nchannels <= 0:
+            raise Error('bad # of channels')
         self._framesize = self._nchannels * self._sampwidth
         if self._aifc:
             #DEBUG: SGI's soundeditor produces a bad size :-(
diff --git a/Lib/sunau.py b/Lib/sunau.py
index dbad3db8392d..129502b0b417 100644
--- a/Lib/sunau.py
+++ b/Lib/sunau.py
@@ -208,6 +208,8 @@ def initfp(self, file):
             raise Error('unknown encoding')
         self._framerate = int(_read_u32(file))
         self._nchannels = int(_read_u32(file))
+        if not self._nchannels:
+            raise Error('bad # of channels')
         self._framesize = self._framesize * self._nchannels
         if self._hdr_size > 24:
             self._info = file.read(self._hdr_size - 24)
diff --git a/Lib/test/test_aifc.py b/Lib/test/test_aifc.py
index 8fd306a36592..ff52f5b6feb8 100644
--- a/Lib/test/test_aifc.py
+++ b/Lib/test/test_aifc.py
@@ -268,7 +268,8 @@ def test_read_no_comm_chunk(self):
 
     def test_read_no_ssnd_chunk(self):
         b = b'FORM' + struct.pack('>L', 4) + b'AIFC'
-        b += b'COMM' + struct.pack('>LhlhhLL', 38, 0, 0, 0, 0, 0, 0)
+        b += b'COMM' + struct.pack('>LhlhhLL', 38, 1, 0, 8,
+                                   0x4000 | 12, 11025<<18, 0)
         b += b'NONE' + struct.pack('B', 14) + b'not compressed' + b'\x00'
         with self.assertRaisesRegex(aifc.Error, 'COMM chunk and/or SSND chunk'
                                                 ' missing'):
@@ -276,13 +277,35 @@ def test_read_no_ssnd_chunk(self):
 
     def test_read_wrong_compression_type(self):
         b = b'FORM' + struct.pack('>L', 4) + b'AIFC'
-        b += b'COMM' + struct.pack('>LhlhhLL', 23, 0, 0, 0, 0, 0, 0)
+        b += b'COMM' + struct.pack('>LhlhhLL', 23, 1, 0, 8,
+                                   0x4000 | 12, 11025<<18, 0)
         b += b'WRNG' + struct.pack('B', 0)
         self.assertRaises(aifc.Error, aifc.open, io.BytesIO(b))
 
+    def test_read_wrong_number_of_channels(self):
+        for nchannels in 0, -1:
+            b = b'FORM' + struct.pack('>L', 4) + b'AIFC'
+            b += b'COMM' + struct.pack('>LhlhhLL', 38, nchannels, 0, 8,
+                                       0x4000 | 12, 11025<<18, 0)
+            b += b'NONE' + struct.pack('B', 14) + b'not compressed' + b'\x00'
+            b += b'SSND' + struct.pack('>L', 8) + b'\x00' * 8
+            with self.assertRaisesRegex(aifc.Error, 'bad # of channels'):
+                aifc.open(io.BytesIO(b))
+
+    def test_read_wrong_sample_width(self):
+        for sampwidth in 0, -1:
+            b = b'FORM' + struct.pack('>L', 4) + b'AIFC'
+            b += b'COMM' + struct.pack('>LhlhhLL', 38, 1, 0, sampwidth,
+                                       0x4000 | 12, 11025<<18, 0)
+            b += b'NONE' + struct.pack('B', 14) + b'not compressed' + b'\x00'
+            b += b'SSND' + struct.pack('>L', 8) + b'\x00' * 8
+            with self.assertRaisesRegex(aifc.Error, 'bad sample width'):
+                aifc.open(io.BytesIO(b))
+
     def test_read_wrong_marks(self):
         b = b'FORM' + struct.pack('>L', 4) + b'AIFF'
-        b += b'COMM' + struct.pack('>LhlhhLL', 18, 0, 0, 0, 0, 0, 0)
+        b += b'COMM' + struct.pack('>LhlhhLL', 18, 1, 0, 8,
+                                   0x4000 | 12, 11025<<18, 0)
         b += b'SSND' + struct.pack('>L', 8) + b'\x00' * 8
         b += b'MARK' + struct.pack('>LhB', 3, 1, 1)
         with self.assertWarns(UserWarning) as cm:
@@ -293,7 +316,8 @@ def test_read_wrong_marks(self):
 
     def test_read_comm_kludge_compname_even(self):
         b = b'FORM' + struct.pack('>L', 4) + b'AIFC'
-        b += b'COMM' + struct.pack('>LhlhhLL', 18, 0, 0, 0, 0, 0, 0)
+        b += b'COMM' + struct.pack('>LhlhhLL', 18, 1, 0, 8,
+                                   0x4000 | 12, 11025<<18, 0)
         b += b'NONE' + struct.pack('B', 4) + b'even' + b'\x00'
         b += b'SSND' + struct.pack('>L', 8) + b'\x00' * 8
         with self.assertWarns(UserWarning) as cm:
@@ -303,7 +327,8 @@ def test_read_comm_kludge_compname_even(self):
 
     def test_read_comm_kludge_compname_odd(self):
         b = b'FORM' + struct.pack('>L', 4) + b'AIFC'
-        b += b'COMM' + struct.pack('>LhlhhLL', 18, 0, 0, 0, 0, 0, 0)
+        b += b'COMM' + struct.pack('>LhlhhLL', 18, 1, 0, 8,
+                                   0x4000 | 12, 11025<<18, 0)
         b += b'NONE' + struct.pack('B', 3) + b'odd'
         b += b'SSND' + struct.pack('>L', 8) + b'\x00' * 8
         with self.assertWarns(UserWarning) as cm:
diff --git a/Lib/test/test_sunau.py b/Lib/test/test_sunau.py
index 966224b1df5a..470a1007b4d4 100644
--- a/Lib/test/test_sunau.py
+++ b/Lib/test/test_sunau.py
@@ -1,6 +1,8 @@
 import unittest
 from test import audiotests
 from audioop import byteswap
+import io
+import struct
 import sys
 import sunau
 
@@ -121,5 +123,40 @@ class SunauMiscTests(audiotests.AudioMiscTests, unittest.TestCase):
     module = sunau
 
 
+class SunauLowLevelTest(unittest.TestCase):
+
+    def test_read_bad_magic_number(self):
+        b = b'SPA'
+        with self.assertRaises(EOFError):
+            sunau.open(io.BytesIO(b))
+        b = b'SPAM'
+        with self.assertRaisesRegex(sunau.Error, 'bad magic number'):
+            sunau.open(io.BytesIO(b))
+
+    def test_read_too_small_header(self):
+        b = struct.pack('>LLLLL', sunau.AUDIO_FILE_MAGIC, 20, 0,
+                        sunau.AUDIO_FILE_ENCODING_LINEAR_8, 11025)
+        with self.assertRaisesRegex(sunau.Error, 'header size too small'):
+            sunau.open(io.BytesIO(b))
+
+    def test_read_too_large_header(self):
+        b = struct.pack('>LLLLLL', sunau.AUDIO_FILE_MAGIC, 124, 0,
+                        sunau.AUDIO_FILE_ENCODING_LINEAR_8, 11025, 1)
+        b += b'\0' * 100
+        with self.assertRaisesRegex(sunau.Error, 'header size ridiculously large'):
+            sunau.open(io.BytesIO(b))
+
+    def test_read_wrong_encoding(self):
+        b = struct.pack('>LLLLLL', sunau.AUDIO_FILE_MAGIC, 24, 0, 0, 11025, 1)
+        with self.assertRaisesRegex(sunau.Error, r'encoding not \(yet\) supported'):
+            sunau.open(io.BytesIO(b))
+
+    def test_read_wrong_number_of_channels(self):
+        b = struct.pack('>LLLLLL', sunau.AUDIO_FILE_MAGIC, 24, 0,
+                        sunau.AUDIO_FILE_ENCODING_LINEAR_8, 11025, 0)
+        with self.assertRaisesRegex(sunau.Error, 'bad # of channels'):
+            sunau.open(io.BytesIO(b))
+
+
 if __name__ == "__main__":
     unittest.main()
diff --git a/Lib/test/test_wave.py b/Lib/test/test_wave.py
index c5d2e02450ef..8a42f8e47105 100644
--- a/Lib/test/test_wave.py
+++ b/Lib/test/test_wave.py
@@ -2,6 +2,8 @@
 from test import audiotests
 from test import support
 from audioop import byteswap
+import io
+import struct
 import sys
 import wave
 
@@ -111,5 +113,65 @@ def test__all__(self):
         support.check__all__(self, wave, blacklist=blacklist)
 
 
+class WaveLowLevelTest(unittest.TestCase):
+
+    def test_read_no_chunks(self):
+        b = b'SPAM'
+        with self.assertRaises(EOFError):
+            wave.open(io.BytesIO(b))
+
+    def test_read_no_riff_chunk(self):
+        b = b'SPAM' + struct.pack('<L', 0)
+        with self.assertRaisesRegex(wave.Error,
+                                    'file does not start with RIFF id'):
+            wave.open(io.BytesIO(b))
+
+    def test_read_not_wave(self):
+        b = b'RIFF' + struct.pack('<L', 4) + b'SPAM'
+        with self.assertRaisesRegex(wave.Error,
+                                    'not a WAVE file'):
+            wave.open(io.BytesIO(b))
+
+    def test_read_no_fmt_no_data_chunk(self):
+        b = b'RIFF' + struct.pack('<L', 4) + b'WAVE'
+        with self.assertRaisesRegex(wave.Error,
+                                    'fmt chunk and/or data chunk missing'):
+            wave.open(io.BytesIO(b))
+
+    def test_read_no_data_chunk(self):
+        b = b'RIFF' + struct.pack('<L', 28) + b'WAVE'
+        b += b'fmt ' + struct.pack('<LHHLLHH', 16, 1, 1, 11025, 11025, 1, 8)
+        with self.assertRaisesRegex(wave.Error,
+                                    'fmt chunk and/or data chunk missing'):
+            wave.open(io.BytesIO(b))
+
+    def test_read_no_fmt_chunk(self):
+        b = b'RIFF' + struct.pack('<L', 12) + b'WAVE'
+        b += b'data' + struct.pack('<L', 0)
+        with self.assertRaisesRegex(wave.Error, 'data chunk before fmt chunk'):
+            wave.open(io.BytesIO(b))
+
+    def test_read_wrong_form(self):
+        b = b'RIFF' + struct.pack('<L', 36) + b'WAVE'
+        b += b'fmt ' + struct.pack('<LHHLLHH', 16, 2, 1, 11025, 11025, 1, 1)
+        b += b'data' + struct.pack('<L', 0)
+        with self.assertRaisesRegex(wave.Error, 'unknown format: 2'):
+            wave.open(io.BytesIO(b))
+
+    def test_read_wrong_number_of_channels(self):
+        b = b'RIFF' + struct.pack('<L', 36) + b'WAVE'
+        b += b'fmt ' + struct.pack('<LHHLLHH', 16, 1, 0, 11025, 11025, 1, 8)
+        b += b'data' + struct.pack('<L', 0)
+        with self.assertRaisesRegex(wave.Error, 'bad # of channels'):
+            wave.open(io.BytesIO(b))
+
+    def test_read_wrong_sample_width(self):
+        b = b'RIFF' + struct.pack('<L', 36) + b'WAVE'
+        b += b'fmt ' + struct.pack('<LHHLLHH', 16, 1, 1, 11025, 11025, 1, 0)
+        b += b'data' + struct.pack('<L', 0)
+        with self.assertRaisesRegex(wave.Error, 'bad sample width'):
+            wave.open(io.BytesIO(b))
+
+
 if __name__ == '__main__':
     unittest.main()
diff --git a/Lib/wave.py b/Lib/wave.py
index cf94d5af72b4..f155879a9a76 100644
--- a/Lib/wave.py
+++ b/Lib/wave.py
@@ -253,12 +253,22 @@ def readframes(self, nframes):
     #
 
     def _read_fmt_chunk(self, chunk):
-        wFormatTag, self._nchannels, self._framerate, dwAvgBytesPerSec, wBlockAlign = struct.unpack_from('<HHLLH', chunk.read(14))
+        try:
+            wFormatTag, self._nchannels, self._framerate, dwAvgBytesPerSec, wBlockAlign = struct.unpack_from('<HHLLH', chunk.read(14))
+        except struct.error:
+            raise EOFError from None
         if wFormatTag == WAVE_FORMAT_PCM:
-            sampwidth = struct.unpack_from('<H', chunk.read(2))[0]
+            try:
+                sampwidth = struct.unpack_from('<H', chunk.read(2))[0]
+            except struct.error:
+                raise EOFError from None
             self._sampwidth = (sampwidth + 7) // 8
+            if not self._sampwidth:
+                raise Error('bad sample width')
         else:
             raise Error('unknown format: %r' % (wFormatTag,))
+        if not self._nchannels:
+            raise Error('bad # of channels')
         self._framesize = self._nchannels * self._sampwidth
         self._comptype = 'NONE'
         self._compname = 'not compressed'
diff --git a/Misc/NEWS.d/next/Library/2018-03-01-17-49-56.bpo-32056.IlpfgE.rst b/Misc/NEWS.d/next/Library/2018-03-01-17-49-56.bpo-32056.IlpfgE.rst
new file mode 100644
index 000000000000..421aa3767794
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2018-03-01-17-49-56.bpo-32056.IlpfgE.rst
@@ -0,0 +1,3 @@
+Improved exceptions raised for invalid number of channels and sample width
+when read an audio file in modules :mod:`aifc`, :mod:`wave` and
+:mod:`sunau`.