import pythoncom import win32com.server.util import win32com.test.util import unittest from pywin32_testutil import str2bytes class Persists: _public_methods_ = [ 'GetClassID', 'IsDirty', 'Load', 'Save', 'GetSizeMax', 'InitNew' ] _com_interfaces_ = [ pythoncom.IID_IPersistStreamInit ] def __init__(self): self.data = str2bytes("abcdefg") self.dirty = 1 def GetClassID(self): return pythoncom.IID_NULL def IsDirty(self): return self.dirty def Load(self, stream): self.data = stream.Read(26) def Save(self, stream, clearDirty): stream.Write(self.data) if clearDirty: self.dirty = 0 def GetSizeMax(self): return 1024 def InitNew(self): pass class Stream: _public_methods_ = [ 'Read', 'Write', 'Seek' ] _com_interfaces_ = [ pythoncom.IID_IStream ] def __init__(self, data): self.data = data self.index = 0 def Read(self, amount): result = self.data[self.index : self.index + amount] self.index = self.index + amount return result def Write(self, data): self.data = data self.index = 0 return len(data) def Seek(self, dist, origin): if origin==pythoncom.STREAM_SEEK_SET: self.index = dist elif origin==pythoncom.STREAM_SEEK_CUR: self.index = self.index + dist elif origin==pythoncom.STREAM_SEEK_END: self.index = len(self.data)+dist else: raise ValueError('Unknown Seek type: ' +str(origin)) if self.index < 0: self.index = 0 else: self.index = min(self.index, len(self.data)) return self.index class BadStream(Stream): """ PyGStream::Read could formerly overflow buffer if the python implementation returned more data than requested. """ def Read(self, amount): return str2bytes('x')*(amount+1) class StreamTest(win32com.test.util.TestCase): def _readWrite(self, data, write_stream, read_stream = None): if read_stream is None: read_stream = write_stream write_stream.Write(data) read_stream.Seek(0, pythoncom.STREAM_SEEK_SET) got = read_stream.Read(len(data)) self.assertEqual(data, got) read_stream.Seek(1, pythoncom.STREAM_SEEK_SET) got = read_stream.Read(len(data)-2) self.assertEqual(data[1:-1], got) def testit(self): mydata = str2bytes('abcdefghijklmnopqrstuvwxyz') # First test the objects just as Python objects... s = Stream(mydata) p = Persists() p.Load(s) p.Save(s, 0) self.assertEqual(s.data, mydata) # Wrap the Python objects as COM objects, and make the calls as if # they were non-Python COM objects. s2 = win32com.server.util.wrap(s, pythoncom.IID_IStream) p2 = win32com.server.util.wrap(p, pythoncom.IID_IPersistStreamInit) self._readWrite(mydata, s, s) self._readWrite(mydata, s, s2) self._readWrite(mydata, s2, s) self._readWrite(mydata, s2, s2) self._readWrite(str2bytes("string with\0a NULL"), s2, s2) # reset the stream s.Write(mydata) p2.Load(s2) p2.Save(s2, 0) self.assertEqual(s.data, mydata) def testseek(self): s = Stream(str2bytes('yo')) s = win32com.server.util.wrap(s, pythoncom.IID_IStream) # we used to die in py3k passing a value > 32bits s.Seek(0x100000000, pythoncom.STREAM_SEEK_SET) def testerrors(self): # setup a test logger to capture tracebacks etc. records, old_log = win32com.test.util.setup_test_logger() ## check for buffer overflow in Read method badstream = BadStream('Check for buffer overflow') badstream2 = win32com.server.util.wrap(badstream, pythoncom.IID_IStream) self.assertRaises(pythoncom.com_error, badstream2.Read, 10) win32com.test.util.restore_test_logger(old_log) # expecting 2 pythoncom errors to have been raised by the gateways. self.assertEqual(len(records), 2) self.failUnless(records[0].msg.startswith('pythoncom error')) self.failUnless(records[1].msg.startswith('pythoncom error')) if __name__=='__main__': unittest.main()