diff --git a/src/inner_test.zip b/src/inner_test.zip new file mode 100644 index 0000000..5740669 Binary files /dev/null and b/src/inner_test.zip differ diff --git a/src/test_ziparchives_inmem.nim b/src/test_ziparchives_inmem.nim new file mode 100644 index 0000000..1744106 --- /dev/null +++ b/src/test_ziparchives_inmem.nim @@ -0,0 +1,21 @@ +import zippy/ziparchives + +proc test_case() = + var archive = open_zip_archive("inner_test.zip") + defer: archive.close() + + for fname in archive.walk_files: + let inner_bytes = archive.extract_file(fname) + + # First test scenario: Extract file-by-file + # var inner_archive = open_zip_archive_bytes(bytes) + # defer: inner_archive.close() + + # for ifname in inner_archive.walk_files: + # let ifbytes = inner_archive.extract_file(ifname) + # writeFile(ifname, ifbytes) + + # Second test scenario: Extract whole inner archives + extractAllBytes(inner_bytes, fname & ".d") + +test_case() \ No newline at end of file diff --git a/src/zippy/ziparchives.nim b/src/zippy/ziparchives.nim index ccc58a9..8b74ee7 100644 --- a/src/zippy/ziparchives.nim +++ b/src/zippy/ziparchives.nim @@ -16,6 +16,9 @@ type ZipArchiveRecordKind = enum FileRecord, DirectoryRecord + ZipArchiveReaderMode = enum + MemfileMode, StringMode + ZipArchiveRecord = object kind: ZipArchiveRecordKind fileHeaderOffset: int @@ -26,7 +29,9 @@ type filePermissions: set[FilePermission] ZipArchiveReader* = ref object + mode: ZipArchiveReaderMode memFile: MemFile + byteString: string records: OrderedTable[string, ZipArchiveRecord] iterator walkFiles*(reader: ZipArchiveReader): string = @@ -36,6 +41,16 @@ iterator walkFiles*(reader: ZipArchiveReader): string = if record.kind == FileRecord: yield record.path +proc getDataPtr(reader: ZipArchiveReader): ptr UncheckedArray[uint8] = + case reader.mode + of MemfileMode: cast[ptr UncheckedArray[uint8]](reader.memFile.mem) + of StringMode: cast[ptr UncheckedArray[uint8]](reader.byteString[0].addr) + +proc getDataLen(reader: ZipArchiveReader): int = + case reader.mode + of MemfileMode: reader.memFile.size + of StringMode: reader.byteString.len + proc extractFile*( reader: ZipArchiveReader, path: string ): string {.raises: [ZippyError].} = @@ -44,7 +59,7 @@ proc extractFile*( raise newException(ZippyError, "No file record found for " & path) let - src = cast[ptr UncheckedArray[uint8]](reader.memFile.mem) + src = reader.getDataPtr() record = try: reader.records[path] @@ -53,7 +68,7 @@ proc extractFile*( var pos = record.fileHeaderOffset - if pos + fileHeaderLen > reader.memFile.size: + if pos + fileHeaderLen > reader.getDataLen(): failArchiveEOF() if read32(src, pos) != fileHeaderSignature: @@ -73,7 +88,7 @@ proc extractFile*( pos += fileHeaderLen + fileNameLen + extraFieldLen - if pos + record.compressedSize > reader.memFile.size: + if pos + record.compressedSize > reader.getDataLen(): failArchiveEOF() case record.kind: @@ -93,7 +108,8 @@ proc extractFile*( raise newException(ZippyError, "Verifying crc32 failed") proc close*(reader: ZipArchiveReader) {.raises: [OSError].} = - reader.memFile.close() + if reader.mode == MemfileMode: + reader.memFile.close() proc parseMsDosDateTime(time, date: uint16): Time = let @@ -155,9 +171,9 @@ proc utf8ify(fileName: string): string = $runes proc findEndOfCentralDirectory(reader: ZipArchiveReader): int = - let src = cast[ptr UncheckedArray[uint8]](reader.memFile.mem) + let src = reader.getDataPtr() - result = reader.memFile.size - 22 # Work backwards in the file starting here + result = reader.getDataLen() - 22 # Work backwards in the file starting here while true: if result < 0: failArchiveEOF() @@ -170,7 +186,7 @@ proc findStartOfCentralDirectory( reader: ZipArchiveReader, start, numRecordEntries: int ): int = - let src = cast[ptr UncheckedArray[uint8]](reader.memFile.mem) + let src = reader.getDataPtr() result = start # Work backwards in the file starting here var numRecordsFound: int @@ -183,17 +199,16 @@ proc findStartOfCentralDirectory( return dec result -proc openZipArchive*( - zipPath: string -): ZipArchiveReader {.raises: [IOError, OSError, ZippyError].} = - result = ZipArchiveReader() - result.memFile = memfiles.open(zipPath) +proc openZipArchiveInternal( + reader: ZipArchiveReader +): ZipArchiveReader {.raises: [IOError, OSError, ZippyError].} = + result = reader try: - let src = cast[ptr UncheckedArray[uint8]](result.memFile.mem) + let src = result.getDataPtr() let eocd = result.findEndOfCentralDirectory() - if eocd + 22 > result.memFile.size: + if eocd + 22 > result.getDataLen(): failArchiveEOF() var zip64 = false @@ -217,7 +232,7 @@ proc openZipArchive*( raise newException(ZippyError, "Unsupported archive, num disks") var pos = zip64EndOfCentralDirectoryStart - if pos + 64 > result.memFile.size: + if pos + 64 > result.getDataLen(): failArchiveEOF() if read32(src, pos) != zip64EndOfCentralDirectorySignature: @@ -268,11 +283,11 @@ proc openZipArchive*( var pos = socdOffset + centralDirectoryStart - if eocd + 22 > result.memFile.size: + if eocd + 22 > result.getDataLen(): failArchiveEOF() for _ in 0 ..< numCentralDirectoryRecords: - if pos + 46 > result.memFile.size: + if pos + 46 > result.getDataLen(): failArchiveEOF() if read32(src, pos) != centralDirectoryFileHeaderSignature: @@ -306,7 +321,7 @@ proc openZipArchive*( pos += 46 - if pos + fileNameLen > result.memFile.size: + if pos + fileNameLen > result.getDataLen(): failArchiveEOF() var fileName = newString(fileNameLen) @@ -321,7 +336,7 @@ proc openZipArchive*( var extraFieldsOffset = pos while extraFieldsOffset < pos + extraFieldLen: - if pos + 4 > result.memFile.size: + if pos + 4 > result.getDataLen(): failArchiveEOF() let @@ -395,10 +410,24 @@ proc openZipArchive*( result.close() raise e -proc extractAll*( - zipPath, dest: string -) {.raises: [IOError, OSError, ZippyError].} = - ## Extracts the files stored in archive to the destination directory. +proc openZipArchive*( + zipPath: string +): ZipArchiveReader {.raises: [IOError, OSError, ZippyError].} = + result = ZipArchiveReader() + result.mode = MemfileMode + result.memFile = memfiles.open(zipPath) + return openZipArchiveInternal(result) + +proc openZipArchiveBytes*( + byteString: string +): ZipArchiveReader {.raises: [IOError, OSError, ZippyError].} = + result = ZipArchiveReader() + result.mode = StringMode + result.byteString = byteString + return openZipArchiveInternal(result) + + +proc checkExtractDestination(dest: string) {.raises: [IOError, OSError, ZippyError].} = ## The path to the destination directory must exist. ## The destination directory itself must not exist (it is not overwitten). if dest == "" or dirExists(dest): @@ -410,9 +439,10 @@ proc extractAll*( if head != "" and not dirExists(head): raise newException(ZippyError, "Path to " & dest & " does not exist") - let - reader = openZipArchive(zipPath) - src = cast[ptr UncheckedArray[uint8]](reader.memFile.mem) +proc extractAllInternal( + reader: ZipArchiveReader, dest: string +) {.raises: [IOError, OSError, ZippyError].} = + let src = reader.getDataPtr() # Verify some things before attempting to write the files for _, record in reader.records: @@ -452,6 +482,22 @@ proc extractAll*( finally: reader.close() +proc extractAll*( + zipPath, dest: string +) {.raises: [IOError, OSError, ZippyError].} = + ## Extracts the files stored in archive to the destination directory. + checkExtractDestination(dest) + let reader = openZipArchive(zipPath) + extractAllInternal(reader, dest) + +proc extractAllBytes*( + zipBytes, dest: string +) {.raises: [IOError, OSError, ZippyError].} = + ## Extracts the files stored in byte-string archive to the destination directory. + checkExtractDestination(dest) + let reader = openZipArchiveBytes(zipBytes) + extractAllInternal(reader, dest) + when (NimMajor, NimMinor, NimPatch) >= (1, 6, 0): # For some reason `sink Table | OrderedTable` does not work, so work around: template createZipArchiveImpl(