Skip to content

Commit c779b46

Browse files
authored
Add remove method to ZipFile (#111)
* Add remove method to ZipFile Refer to: python/cpython#103033 * Make use of `ZipFileWithRemove`
1 parent a157c27 commit c779b46

5 files changed

Lines changed: 127 additions & 81 deletions

File tree

.github/workflows/ruff.yml

Lines changed: 0 additions & 8 deletions
This file was deleted.

darkseid/archivers/zip.py

Lines changed: 33 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
from __future__ import annotations
22

33
import logging
4-
import shutil
5-
import tempfile
64
import zipfile
75
from typing import TYPE_CHECKING
86

7+
from darkseid.zipfile_remove import ZipFileWithRemove
8+
99
if TYPE_CHECKING:
1010
from pathlib import Path
11-
from typing import cast
1211

1312
import rarfile
1413

@@ -71,8 +70,20 @@ def remove_file(self: ZipArchiver, archive_file: str) -> bool:
7170
Returns:
7271
bool: True if the file was successfully removed, False otherwise.
7372
"""
74-
75-
return self._rebuild([archive_file])
73+
try:
74+
with ZipFileWithRemove(self.path, "a") as zf:
75+
zf.remove(archive_file)
76+
except KeyError:
77+
return False
78+
except (zipfile.BadZipfile, OSError):
79+
logger.exception(
80+
"Error writing zip archive %s :: %s",
81+
self.path,
82+
archive_file,
83+
)
84+
return False
85+
else:
86+
return True
7687

7788
def remove_files(self: ZipArchiver, filename_lst: list[str]) -> bool:
7889
"""
@@ -84,8 +95,20 @@ def remove_files(self: ZipArchiver, filename_lst: list[str]) -> bool:
8495
Returns:
8596
bool: True if all files were successfully removed, False otherwise.
8697
"""
87-
88-
return self._rebuild(filename_lst)
98+
files = set(self.get_filename_list())
99+
if filenames_to_remove := [filename for filename in filename_lst if filename in files]:
100+
try:
101+
with ZipFileWithRemove(self.path, "a") as zf:
102+
for filename in filenames_to_remove:
103+
zf.remove(filename)
104+
except (zipfile.BadZipfile, OSError):
105+
logger.exception(
106+
"Error writing zip archive %s :: %s",
107+
self.path,
108+
filename,
109+
)
110+
return False
111+
return True
89112

90113
def write_file(self: ZipArchiver, archive_file: str, data: str) -> bool:
91114
"""
@@ -98,22 +121,10 @@ def write_file(self: ZipArchiver, archive_file: str, data: str) -> bool:
98121
Returns:
99122
bool: True if the write operation was successful, False otherwise.
100123
"""
101-
102-
# At the moment, no other option but to rebuild the whole
103-
# zip archive w/o the indicated file. Very sucky, but maybe
104-
# another solution can be found
105-
files = self.get_filename_list()
106-
if archive_file in files:
107-
self._rebuild([archive_file])
108-
109124
try:
110-
# now just add the archive file as a new one
111-
with zipfile.ZipFile(
112-
self.path,
113-
mode="a",
114-
allowZip64=True,
115-
compression=zipfile.ZIP_DEFLATED,
116-
) as zf:
125+
with ZipFileWithRemove(self.path, "a") as zf:
126+
if archive_file in set(zf.namelist()):
127+
zf.remove(archive_file)
117128
zf.writestr(archive_file, data)
118129
except (zipfile.BadZipfile, OSError):
119130
logger.exception(
@@ -143,39 +154,6 @@ def get_filename_list(self: ZipArchiver) -> list[str]:
143154
logger.exception("Error listing files in zip archive: %s", self.path)
144155
return []
145156

146-
def _rebuild(self: ZipArchiver, exclude_list: list[str]) -> bool:
147-
"""
148-
Rebuilds the ZIP archive excluding specified files.
149-
150-
Args:
151-
exclude_list (list[str]): The list of files to exclude from the rebuild.
152-
153-
Returns:
154-
bool: True if the rebuild was successful, False otherwise.
155-
"""
156-
157-
try:
158-
with zipfile.ZipFile(
159-
tempfile.NamedTemporaryFile(dir=self.path.parent, delete=False),
160-
"w",
161-
allowZip64=True,
162-
) as zout:
163-
with zipfile.ZipFile(self.path, mode="r") as zin:
164-
for item in zin.infolist():
165-
buffer = zin.read(item.filename)
166-
if item.filename not in exclude_list:
167-
zout.writestr(item, buffer)
168-
169-
# replace with the new file
170-
self.path.unlink(missing_ok=True)
171-
zout.close() # Required on Windows
172-
shutil.move(cast(str, zout.filename), self.path)
173-
except (zipfile.BadZipfile, OSError):
174-
logger.exception("Error rebuilding zip file: %s", self.path)
175-
return False
176-
else:
177-
return True
178-
179157
def copy_from_archive(self: ZipArchiver, other_archive: Archiver) -> bool:
180158
"""
181159
Copies files from another archive to the ZIP archive.
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""ZipFile with Remove method."""
2+
3+
# From https://github.com/python/cpython/pull/103033
4+
# Not linted to compare against above PR
5+
import contextlib
6+
from zipfile import ZipFile, ZipInfo
7+
8+
9+
class ZipFileWithRemove(ZipFile):
10+
"""ZipFile with Remove method."""
11+
12+
def remove(self, zinfo_or_arcname):
13+
"""Remove a member from the archive."""
14+
if self.mode not in ("w", "x", "a"):
15+
raise ValueError("remove() requires mode 'w', 'x', or 'a'")
16+
if not self.fp:
17+
raise ValueError("Attempt to write to ZIP archive that was already closed")
18+
if self._writing:
19+
raise ValueError("Can't write to ZIP archive while an open writing handle exists")
20+
21+
# Make sure we have an existing info object
22+
if isinstance(zinfo_or_arcname, ZipInfo):
23+
zinfo = zinfo_or_arcname
24+
# make sure zinfo exists
25+
if zinfo not in self.filelist:
26+
raise KeyError("There is no item %r in the archive" % zinfo_or_arcname)
27+
else:
28+
# get the info object
29+
zinfo = self.getinfo(zinfo_or_arcname)
30+
31+
return self._remove_members({zinfo})
32+
33+
def _remove_members(self, members, *, remove_physical=True, chunk_size=2**20):
34+
"""Remove members in a zip file.
35+
All members (as zinfo) should exist in the zip; otherwise the zip file
36+
will erroneously end in an inconsistent state.
37+
"""
38+
fp = self.fp
39+
entry_offset = 0
40+
member_seen = False
41+
42+
# get a sorted filelist by header offset, in case the dir order
43+
# doesn't match the actual entry order
44+
filelist = sorted(self.filelist, key=lambda x: x.header_offset)
45+
for i in range(len(filelist)):
46+
info = filelist[i]
47+
is_member = info in members
48+
49+
if not (member_seen or is_member):
50+
continue
51+
52+
# get the total size of the entry
53+
try:
54+
offset = filelist[i + 1].header_offset
55+
except IndexError:
56+
offset = self.start_dir
57+
entry_size = offset - info.header_offset
58+
59+
if is_member:
60+
member_seen = True
61+
entry_offset += entry_size
62+
63+
# update caches
64+
self.filelist.remove(info)
65+
with contextlib.suppress(KeyError):
66+
del self.NameToInfo[info.filename]
67+
continue
68+
69+
# update the header and move entry data to the new position
70+
if remove_physical:
71+
old_header_offset = info.header_offset
72+
info.header_offset -= entry_offset
73+
read_size = 0
74+
while read_size < entry_size:
75+
fp.seek(old_header_offset + read_size)
76+
data = fp.read(min(entry_size - read_size, chunk_size))
77+
fp.seek(info.header_offset + read_size)
78+
fp.write(data)
79+
fp.flush()
80+
read_size += len(data)
81+
82+
# Avoid missing entry if entries have a duplicated name.
83+
# Reverse the order as NameToInfo normally stores the last added one.
84+
for info in reversed(self.filelist):
85+
self.NameToInfo.setdefault(info.filename, info)
86+
87+
# update state
88+
if remove_physical:
89+
self.start_dir -= entry_offset
90+
self._didModify = True
91+
92+
# seek to the start of the central dir
93+
fp.seek(self.start_dir)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ testpaths = "tests"
135135
exclude = "*~,.git/*,.mypy_cache/*,.pytest_cache/*,.venv*,__pycache__/*,cache/*,dist/*,node_modules/*,test-results/*,typings/*"
136136

137137
[tool.ruff]
138-
extend-exclude = ["typings"]
138+
extend-exclude = ["node_modules", "darkseid/zipfile_remove"]
139139
target-version = "py310"
140140
line-length = 100
141141

tests/test_archiver_zip.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -96,23 +96,6 @@ def test_get_filename_list(zip_archiver, archive_file, data):
9696
assert archive_file in filenames
9797

9898

99-
@pytest.mark.parametrize(
100-
("exclude_list", "data"), [(["test.txt"], "Hello, World!")], ids=["exclude_file"]
101-
)
102-
def test_rebuild(zip_archiver, exclude_list, data):
103-
# Arrange
104-
zip_archiver.write_file("test.txt", data)
105-
zip_archiver.write_file("keep.txt", "Keep this file")
106-
107-
# Act
108-
result = zip_archiver._rebuild(exclude_list) # noqa: SLF001
109-
110-
# Assert
111-
assert result is True
112-
assert "test.txt" not in zip_archiver.get_filename_list()
113-
assert "keep.txt" in zip_archiver.get_filename_list()
114-
115-
11699
@pytest.mark.parametrize(
117100
("other_archive_files", "data"), [(["test.txt"], ["Hello, World!"])], ids=["simple_copy"]
118101
)

0 commit comments

Comments
 (0)