Skip to content
Merged
233 changes: 153 additions & 80 deletions comtypes/test/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
STGTY_STORAGE = 1

STATFLAG_DEFAULT = 0
STATFLAG_NONAME = 1

STGC_DEFAULT = 0
STGM_CREATE = 0x00001000
STGM_DIRECT = 0x00000000
Expand All @@ -36,6 +38,7 @@

STG_E_PATHNOTFOUND = -2147287038
STG_E_INVALIDFLAG = -2147286785
STG_E_ACCESSDENIED = -2147287035 # 0x80030005

_ole32 = OleDLL("ole32")

Expand All @@ -48,29 +51,32 @@ def _get_pwcsname(stat: tagSTATSTG) -> WSTRING:
return WSTRING.from_address(ctypes.addressof(stat) + tagSTATSTG.pwcsName.offset)


class Test_IStorage(unittest.TestCase):
RW_EXCLUSIVE = STGM_READWRITE | STGM_SHARE_EXCLUSIVE
RW_EXCLUSIVE_TX = RW_EXCLUSIVE | STGM_TRANSACTED
RW_EXCLUSIVE_CREATE = RW_EXCLUSIVE | STGM_CREATE
CREATE_TESTDOC = STGM_DIRECT | STGM_CREATE | RW_EXCLUSIVE
CREATE_TEMP_TESTDOC = CREATE_TESTDOC | STGM_DELETEONRELEASE
RW_EXCLUSIVE = STGM_READWRITE | STGM_SHARE_EXCLUSIVE
RW_EXCLUSIVE_TX = RW_EXCLUSIVE | STGM_TRANSACTED
RW_EXCLUSIVE_CREATE = RW_EXCLUSIVE | STGM_CREATE
CREATE_TESTDOC = STGM_DIRECT | STGM_CREATE | RW_EXCLUSIVE
CREATE_TEMP_TESTDOC = CREATE_TESTDOC | STGM_DELETEONRELEASE


def _create_docfile(mode: int, name: Optional[str] = None) -> IStorage:
stg = POINTER(IStorage)()
_StgCreateDocfile(name, mode, 0, byref(stg))
return stg # type: ignore


def _create_docfile(self, mode: int, name: Optional[str] = None) -> IStorage:
stg = POINTER(IStorage)()
_StgCreateDocfile(name, mode, 0, byref(stg))
return stg # type: ignore
FIXED_TEST_FILETIME = SystemTimeToFileTime(SYSTEMTIME(wYear=2000, wMonth=1, wDay=1))

FIXED_TEST_FILETIME = SystemTimeToFileTime(SYSTEMTIME(wYear=2000, wMonth=1, wDay=1))

def test_CreateStream(self):
storage = self._create_docfile(mode=self.CREATE_TEMP_TESTDOC)
class Test_CreateStream(unittest.TestCase):
def test_creates_and_writes_to_stream_in_docfile(self):
storage = _create_docfile(mode=CREATE_TEMP_TESTDOC)
# When created with `StgCreateDocfile(NULL, ...)`, `pwcsName` is a
# temporary filename. The file really exists on disk because Windows
# creates an actual temporary file for the compound storage.
stat = storage.Stat(STATFLAG_DEFAULT)
filepath = Path(stat.pwcsName)
self.assertTrue(filepath.exists())
stream = storage.CreateStream("example", self.RW_EXCLUSIVE_CREATE, 0, 0)
stream = storage.CreateStream("example", RW_EXCLUSIVE_CREATE, 0, 0)
test_data = b"Some data"
pv = (c_ubyte * len(test_data)).from_buffer(bytearray(test_data))
stream.RemoteWrite(pv, len(test_data))
Expand All @@ -89,114 +95,157 @@ def test_CreateStream(self):
del stat # `pwcsName` is expected to be freed here.
# `DidAlloc` checks are skipped to avoid using a dangling pointer.

# TODO: Auto-generated methods based on type info are remote-side and hard
# to call from the client.
# If a proper invocation method or workaround is found, testing
# becomes possible.
# See: https://github.com/enthought/comtypes/issues/607
# def test_RemoteOpenStream(self):
# pass

def test_CreateStorage(self):
parent = self._create_docfile(mode=self.CREATE_TEMP_TESTDOC)
child = parent.CreateStorage("child", self.RW_EXCLUSIVE_TX, 0, 0)

# TODO: Auto-generated methods based on type info are remote-side and hard
# to call from the client.
# If a proper invocation method or workaround is found, testing
# becomes possible.
# See: https://github.com/enthought/comtypes/issues/607
# class Test_RemoteOpenStream(unittest.TestCase):
# def test_RemoteOpenStream(self):
# pass


class Test_CreateStorage(unittest.TestCase):
def test_creates_child_storage_in_parent(self):
parent = _create_docfile(mode=CREATE_TEMP_TESTDOC)
child = parent.CreateStorage("child", RW_EXCLUSIVE_TX, 0, 0)
self.assertEqual("child", child.Stat(STATFLAG_DEFAULT).pwcsName)

def test_OpenStorage(self):
parent = self._create_docfile(mode=self.CREATE_TEMP_TESTDOC)

class Test_OpenStorage(unittest.TestCase):
def test_opens_existing_child_storage(self):
parent = _create_docfile(mode=CREATE_TEMP_TESTDOC)
with self.assertRaises(COMError) as cm:
parent.OpenStorage("child", None, self.RW_EXCLUSIVE_TX, None, 0)
parent.OpenStorage("child", None, RW_EXCLUSIVE_TX, None, 0)
self.assertEqual(cm.exception.hresult, STG_E_PATHNOTFOUND)
parent.CreateStorage("child", self.RW_EXCLUSIVE_TX, 0, 0)
child = parent.OpenStorage("child", None, self.RW_EXCLUSIVE_TX, None, 0)
parent.CreateStorage("child", RW_EXCLUSIVE_TX, 0, 0)
child = parent.OpenStorage("child", None, RW_EXCLUSIVE_TX, None, 0)
self.assertEqual("child", child.Stat(STATFLAG_DEFAULT).pwcsName)

def test_RemoteCopyTo(self):
src_stg = self._create_docfile(mode=self.CREATE_TEMP_TESTDOC)
src_stg.CreateStorage("child", self.RW_EXCLUSIVE_TX, 0, 0)
dst_stg = self._create_docfile(mode=self.CREATE_TEMP_TESTDOC)

class Test_RemoteCopyTo(unittest.TestCase):
def test_copies_storage_content_to_destination(self):
src_stg = _create_docfile(mode=CREATE_TEMP_TESTDOC)
src_stg.CreateStorage("child", RW_EXCLUSIVE_TX, 0, 0)
dst_stg = _create_docfile(mode=CREATE_TEMP_TESTDOC)
src_stg.RemoteCopyTo(0, None, None, dst_stg)
src_stg.Commit(STGC_DEFAULT)
del src_stg
opened_stg = dst_stg.OpenStorage("child", None, self.RW_EXCLUSIVE_TX, None, 0)
opened_stg = dst_stg.OpenStorage("child", None, RW_EXCLUSIVE_TX, None, 0)
self.assertEqual("child", opened_stg.Stat(STATFLAG_DEFAULT).pwcsName)

def test_MoveElementTo(self):
src_stg = self._create_docfile(mode=self.CREATE_TEMP_TESTDOC)
src_stg.CreateStorage("foo", self.RW_EXCLUSIVE_TX, 0, 0)
dst_stg = self._create_docfile(mode=self.CREATE_TEMP_TESTDOC)

class Test_MoveElementTo(unittest.TestCase):
def test_moves_element_to_new_location_and_renames(self):
src_stg = _create_docfile(mode=CREATE_TEMP_TESTDOC)
src_stg.CreateStorage("foo", RW_EXCLUSIVE_TX, 0, 0)
dst_stg = _create_docfile(mode=CREATE_TEMP_TESTDOC)
src_stg.MoveElementTo("foo", dst_stg, "bar", STGMOVE_MOVE)
opened_stg = dst_stg.OpenStorage("bar", None, self.RW_EXCLUSIVE_TX, None, 0)
opened_stg = dst_stg.OpenStorage("bar", None, RW_EXCLUSIVE_TX, None, 0)
self.assertEqual("bar", opened_stg.Stat(STATFLAG_DEFAULT).pwcsName)
with self.assertRaises(COMError) as cm:
src_stg.OpenStorage("foo", None, self.RW_EXCLUSIVE_TX, None, 0)
src_stg.OpenStorage("foo", None, RW_EXCLUSIVE_TX, None, 0)
self.assertEqual(cm.exception.hresult, STG_E_PATHNOTFOUND)

def test_Revert(self):
storage = self._create_docfile(mode=self.CREATE_TEMP_TESTDOC)
foo = storage.CreateStorage("foo", self.RW_EXCLUSIVE_TX, 0, 0)
foo.CreateStorage("bar", self.RW_EXCLUSIVE_TX, 0, 0)
bar = foo.OpenStorage("bar", None, self.RW_EXCLUSIVE_TX, None, 0)

class Test_Revert(unittest.TestCase):
def test_reverts_pending_changes_to_storage(self):
storage = _create_docfile(mode=CREATE_TEMP_TESTDOC)
foo = storage.CreateStorage("foo", RW_EXCLUSIVE_TX, 0, 0)
foo.CreateStorage("bar", RW_EXCLUSIVE_TX, 0, 0)
bar = foo.OpenStorage("bar", None, RW_EXCLUSIVE_TX, None, 0)
self.assertEqual("bar", bar.Stat(STATFLAG_DEFAULT).pwcsName)
foo.Revert()
with self.assertRaises(COMError) as cm:
foo.OpenStorage("bar", None, self.RW_EXCLUSIVE_TX, None, 0)
foo.OpenStorage("bar", None, RW_EXCLUSIVE_TX, None, 0)
self.assertEqual(cm.exception.hresult, STG_E_PATHNOTFOUND)

# TODO: Auto-generated methods based on type info are remote-side and hard
# to call from the client.
# If a proper invocation method or workaround is found, testing
# becomes possible.
# See: https://github.com/enthought/comtypes/issues/607
# def test_RemoteEnumElements(self):
# pass

def test_DestroyElement(self):
storage = self._create_docfile(mode=self.CREATE_TEMP_TESTDOC)
storage.CreateStorage("example", self.RW_EXCLUSIVE_TX, 0, 0)

# TODO: Auto-generated methods based on type info are remote-side and hard
# to call from the client.
# If a proper invocation method or workaround is found, testing
# becomes possible.
# See: https://github.com/enthought/comtypes/issues/607
# class Test_RemoteEnumElements(unittest.TestCase):
# def test_RemoteEnumElements(self):
# pass


class Test_DestroyElement(unittest.TestCase):
def test_destroys_existing_element_in_storage(self):
storage = _create_docfile(mode=CREATE_TEMP_TESTDOC)
storage.CreateStorage("example", RW_EXCLUSIVE_TX, 0, 0)
storage.DestroyElement("example")
with self.assertRaises(COMError) as cm:
storage.OpenStorage("example", None, self.RW_EXCLUSIVE_TX, None, 0)
storage.OpenStorage("example", None, RW_EXCLUSIVE_TX, None, 0)
self.assertEqual(cm.exception.hresult, STG_E_PATHNOTFOUND)

def test_RenameElement(self):
storage = self._create_docfile(mode=self.CREATE_TEMP_TESTDOC)
storage.CreateStorage("example", self.RW_EXCLUSIVE_TX, 0, 0)
def test_fails_to_destroy_non_existent_element(self):
storage = _create_docfile(mode=CREATE_TEMP_TESTDOC)
with self.assertRaises(COMError) as cm:
storage.DestroyElement("non_existent")
self.assertEqual(cm.exception.hresult, STG_E_PATHNOTFOUND)


class Test_RenameElement(unittest.TestCase):
def test_renames_element_in_storage(self):
storage = _create_docfile(mode=CREATE_TEMP_TESTDOC)
storage.CreateStorage("example", RW_EXCLUSIVE_TX, 0, 0)
storage.RenameElement("example", "sample")
sample = storage.OpenStorage("sample", None, self.RW_EXCLUSIVE_TX, None, 0)
sample = storage.OpenStorage("sample", None, RW_EXCLUSIVE_TX, None, 0)
self.assertEqual("sample", sample.Stat(STATFLAG_DEFAULT).pwcsName)
with self.assertRaises(COMError) as cm:
storage.OpenStorage("example", None, self.RW_EXCLUSIVE_TX, None, 0)
storage.OpenStorage("example", None, RW_EXCLUSIVE_TX, None, 0)
self.assertEqual(cm.exception.hresult, STG_E_PATHNOTFOUND)

def test_SetElementTimes(self):
storage = self._create_docfile(mode=self.CREATE_TEMP_TESTDOC)
def test_fails_if_destination_exists(self):
storage = _create_docfile(mode=CREATE_TEMP_TESTDOC)
storage.CreateStorage("foo", RW_EXCLUSIVE_TX, 0, 0)
storage.CreateStorage("bar", RW_EXCLUSIVE_TX, 0, 0)
# Rename "foo" to "bar" (which already exists)
with self.assertRaises(COMError) as cm:
storage.RenameElement("foo", "bar")
self.assertEqual(cm.exception.hresult, STG_E_ACCESSDENIED)

def test_fails_if_takes_same_name(self):
storage = _create_docfile(mode=CREATE_TEMP_TESTDOC)
storage.CreateStorage("foo", RW_EXCLUSIVE_TX, 0, 0)
# Rename "foo" to "foo" (same name)
with self.assertRaises(COMError) as cm:
storage.RenameElement("foo", "foo")
self.assertEqual(cm.exception.hresult, STG_E_ACCESSDENIED)


class Test_SetElementTimes(unittest.TestCase):
def test_sets_modification_time_for_element(self):
storage = _create_docfile(mode=CREATE_TEMP_TESTDOC)
sub_name = "SubStorageElement"
orig_stat = storage.CreateStorage(sub_name, self.CREATE_TESTDOC, 0, 0).Stat(
orig_stat = storage.CreateStorage(sub_name, CREATE_TESTDOC, 0, 0).Stat(
STATFLAG_DEFAULT
)
storage.SetElementTimes(
sub_name,
None, # pctime (creation time)
None, # patime (access time)
self.FIXED_TEST_FILETIME, # pmtime (modification time)
FIXED_TEST_FILETIME, # pmtime (modification time)
)
storage.Commit(STGC_DEFAULT)
modified_stat = storage.OpenStorage(
sub_name, None, self.RW_EXCLUSIVE_TX, None, 0
sub_name, None, RW_EXCLUSIVE_TX, None, 0
).Stat(STATFLAG_DEFAULT)
self.assertEqual(CompareFileTime(orig_stat.ctime, modified_stat.ctime), 0)
self.assertEqual(CompareFileTime(orig_stat.atime, modified_stat.atime), 0)
self.assertNotEqual(CompareFileTime(orig_stat.mtime, modified_stat.mtime), 0)
self.assertEqual(
CompareFileTime(self.FIXED_TEST_FILETIME, modified_stat.mtime), 0
)
self.assertEqual(CompareFileTime(FIXED_TEST_FILETIME, modified_stat.mtime), 0)
with self.assertRaises(COMError) as cm:
storage.SetElementTimes("NonExistent", None, None, self.FIXED_TEST_FILETIME)
storage.SetElementTimes("NonExistent", None, None, FIXED_TEST_FILETIME)
self.assertEqual(cm.exception.hresult, STG_E_PATHNOTFOUND)

def test_SetClass(self):
storage = self._create_docfile(mode=self.CREATE_TEMP_TESTDOC)

class Test_SetClass(unittest.TestCase):
def test_sets_clsid(self):
storage = _create_docfile(mode=CREATE_TEMP_TESTDOC)
# Initial value is CLSID_NULL.
self.assertEqual(storage.Stat(STATFLAG_DEFAULT).clsid, comtypes.GUID())
new_clsid = comtypes.GUID.create_new()
Expand All @@ -206,16 +255,32 @@ def test_SetClass(self):
storage.SetClass(comtypes.GUID())
self.assertEqual(storage.Stat(STATFLAG_DEFAULT).clsid, comtypes.GUID())

def test_Stat(self):

class Test_SetStateBits(unittest.TestCase):
def test_sets_and_updates_storage_state_bits(self):
storage = _create_docfile(mode=CREATE_TEMP_TESTDOC)
# Initial state bits should be 0
self.assertEqual(storage.Stat(STATFLAG_DEFAULT).grfStateBits, 0)
# 1. Set all bits
bits1, mask1 = 0xABCD1234, 0xFFFFFFFF
storage.SetStateBits(bits1, mask1)
self.assertEqual(storage.Stat(STATFLAG_DEFAULT).grfStateBits, bits1)
# 2. Partial update using mask (only lower 16 bits)
bits2, mask2 = 0x00005678, 0x0000FFFF
storage.SetStateBits(bits2, mask2)
# Expected: 0xABCD (original upper) + 0x5678 (new lower) = 0xABCD5678
self.assertEqual(storage.Stat(STATFLAG_DEFAULT).grfStateBits, 0xABCD5678)


class Test_Stat(unittest.TestCase):
def test_returns_correct_stat_information_for_docfile(self):
with tempfile.TemporaryDirectory() as t:
tmpdir = Path(t)
tmpfile = tmpdir / "test_docfile.cfs"
self.assertFalse(tmpfile.exists())
# When created with `StgCreateDocfile(filepath_string, ...)`, the
# compound file is created at that location.
storage = self._create_docfile(
name=str(tmpfile), mode=self.CREATE_TEMP_TESTDOC
)
storage = _create_docfile(name=str(tmpfile), mode=CREATE_TEMP_TESTDOC)
self.assertTrue(tmpfile.exists())
with self.assertRaises(COMError) as cm:
storage.Stat(0xFFFFFFFF) # Invalid flag
Expand Down Expand Up @@ -244,7 +309,7 @@ def test_Stat(self):
# greater than 0 bytes.
self.assertGreaterEqual(stat.cbSize, 0)
# `grfMode` should reflect the access mode flags from creation.
self.assertEqual(stat.grfMode, self.RW_EXCLUSIVE | STGM_DIRECT)
self.assertEqual(stat.grfMode, RW_EXCLUSIVE | STGM_DIRECT)
self.assertEqual(stat.grfLocksSupported, 0)
self.assertEqual(stat.clsid, comtypes.GUID()) # CLSID_NULL for new creation.
self.assertEqual(stat.grfStateBits, 0)
Expand All @@ -254,3 +319,11 @@ def test_Stat(self):
self.assertEqual(malloc.DidAlloc(name_ptr), 1)
del stat # `pwcsName` is expected to be freed here.
# `DidAlloc` checks are skipped to avoid using a dangling pointer.

def test_stat_returns_none_for_pwcsname_with_noname_flag(self):
storage = _create_docfile(mode=CREATE_TEMP_TESTDOC)
# Using `STATFLAG_NONAME` should return `None` for `pwcsName`.
stat = storage.Stat(STATFLAG_NONAME)
self.assertIsNone(stat.pwcsName)
# Verify other fields are still present.
self.assertEqual(stat.type, STGTY_STORAGE)