Skip to content

Commit

Permalink
refactor IO implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
JonasIsensee committed Oct 6, 2024
1 parent 9726fdd commit f010181
Show file tree
Hide file tree
Showing 7 changed files with 232 additions and 368 deletions.
161 changes: 78 additions & 83 deletions src/JLD2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,13 @@ using FileIO: load, save
export load, save
using Requires: @require
using PrecompileTools: @setup_workload, @compile_workload
export jldopen, @load, @save, save_object, load_object, jldsave

export jldopen, @load, @save, save_object, load_object, printtoc
export jldsave

# abstract types and other definitions that need to be defined early on
include("types.jl")




include("macros_utils.jl")
include("mmapio.jl")
include("bufferedio.jl")
Expand Down Expand Up @@ -117,9 +115,7 @@ mutable struct JLDFile{T<:IO}
JLDWriteSession(), Dict{String,Any}(), IdDict(), IdDict(), Dict{RelOffset,WeakRef}(),
DATA_START, Dict{RelOffset,GlobalHeap}(),
GlobalHeap(0, 0, 0, Int64[]), Dict{RelOffset,Group{JLDFile{T}}}(), UNDEFINED_ADDRESS)
if !(io isa ReadOnlyBuffer)
finalizer(jld_finalizer, f)
end
finalizer(jld_finalizer, f)
f
end
end
Expand Down Expand Up @@ -163,9 +159,6 @@ end
FallbackType(::Type{MmapIO}) = IOStream
FallbackType(::Type{IOStream}) = nothing

# The delimiter is excluded by default
read_bytestring(io::Union{IOStream, IOBuffer}) = String(readuntil(io, 0x00))

const OPEN_FILES = Dict{String,WeakRef}()
const OPEN_FILES_LOCK = ReentrantLock()
function jldopen(fname::AbstractString, wr::Bool, create::Bool, truncate::Bool, iotype::T=DEFAULT_IOTYPE;
Expand All @@ -178,7 +171,6 @@ function jldopen(fname::AbstractString, wr::Bool, create::Bool, truncate::Bool,
) where T<:Union{Type{IOStream},Type{MmapIO}}
mmaparrays && @warn "mmaparrays keyword is currently ignored" maxlog=1
verify_compressor(compress)
exists = ispath(fname)

# Can only open multiple in parallel if mode is "r"
if parallel_read && (wr, create, truncate) != (false, false, false)
Expand All @@ -188,37 +180,34 @@ function jldopen(fname::AbstractString, wr::Bool, create::Bool, truncate::Bool,
lock(OPEN_FILES_LOCK)

f = try
exists = ispath(fname)
if exists
rname = realpath(fname)
# catch existing file system entities that are not regular files
!isfile(rname) && throw(ArgumentError("not a regular file: $fname"))

f = get(OPEN_FILES, rname, (;value=nothing)).value
# If in serial, return existing handle. In parallel always generate a new handle
if haskey(OPEN_FILES, rname)
ref = OPEN_FILES[rname]
f = ref.value
if !isnothing(f)
if parallel_read
f.writable && throw(ArgumentError("Tried to open file in a parallel context but it is open in write-mode elsewhere in a serial context."))
else
if truncate
throw(ArgumentError("attempted to truncate a file that was already open"))
elseif !isa(f, JLDFile{iotype})
throw(ArgumentError("attempted to open file with $iotype backend, but already open with a different backend"))
elseif f.writable != wr
current = wr ? "read/write" : "read-only"
previous = f.writable ? "read/write" : "read-only"
throw(ArgumentError("attempted to open file $(current), but file was already open $(previous)"))
elseif f.compress != compress
throw(ArgumentError("attempted to open file with compress=$(compress), but file was already open with compress=$(f.compress)"))
elseif f.mmaparrays != mmaparrays
throw(ArgumentError("attempted to open file with mmaparrays=$(mmaparrays), but file was already open with mmaparrays=$(f.mmaparrays)"))
end

f = f::JLDFile{iotype}
f.n_times_opened += 1
return f
if !isnothing(f)
if parallel_read
f.writable && throw(ArgumentError("Tried to open file in a parallel context but it is open in write-mode elsewhere in a serial context."))
else
if truncate
throw(ArgumentError("attempted to truncate a file that was already open"))
elseif !isa(f, JLDFile{iotype})
throw(ArgumentError("attempted to open file with $iotype backend, but already open with a different backend"))

Check warning on line 198 in src/JLD2.jl

View check run for this annotation

Codecov / codecov/patch

src/JLD2.jl#L198

Added line #L198 was not covered by tests
elseif f.writable != wr
current = wr ? "read/write" : "read-only"
previous = f.writable ? "read/write" : "read-only"
throw(ArgumentError("attempted to open file $(current), but file was already open $(previous)"))
elseif f.compress != compress
throw(ArgumentError("attempted to open file with compress=$(compress), but file was already open with compress=$(f.compress)"))
elseif f.mmaparrays != mmaparrays
throw(ArgumentError("attempted to open file with mmaparrays=$(mmaparrays), but file was already open with mmaparrays=$(f.mmaparrays)"))
end
f = f::JLDFile{iotype}
f.n_times_opened += 1
return f
end
end
end
Expand All @@ -228,71 +217,51 @@ function jldopen(fname::AbstractString, wr::Bool, create::Bool, truncate::Bool,
rname = realpath(fname)
f = JLDFile(io, rname, wr, created, plain, compress, mmaparrays)

if !parallel_read
OPEN_FILES[rname] = WeakRef(f)
end
!parallel_read && (OPEN_FILES[rname] = WeakRef(f))

f
catch e
rethrow(e)
finally
unlock(OPEN_FILES_LOCK)
end
if f.written
f.base_address = 512
if f isa JLDFile{MmapIO}
f.root_group = Group{JLDFile{MmapIO}}(f)
f.types_group = Group{JLDFile{MmapIO}}(f)
elseif f isa JLDFile{IOStream}
f.root_group = Group{JLDFile{IOStream}}(f)
f.types_group = Group{JLDFile{IOStream}}(f)
end
else
try
load_file_metadata!(f)
catch e
close(f)
throw(e)
end
try
initialize_fileobject!(f)
catch e
close(f)
throw(e)
end
merge!(f.typemap, typemap)
return f
end

function load_file_metadata!(f)
function initialize_fileobject!(f::JLDFile)
if f.written
f.base_address = 512
f.root_group = Group{typeof(f)}(f)
f.types_group = Group{typeof(f)}(f)
return
end
superblock = find_superblock(f)
f.end_of_data = superblock.end_of_file_address
f.base_address = superblock.base_address
f.root_group_offset = superblock.root_group_object_header_address
if superblock.version >= 2
verify_file_header(f)
else
@warn "This file was not written with JLD2. Some things may not work."
if f.writable
close(f)
throw(UnsupportedVersionException("This file can not be edited by JLD2. Please open in read-only mode."))
end
elseif f.writable
close(f)
throw(UnsupportedVersionException("This file can not be edited by JLD2. Please open in read-only mode."))

Check warning on line 251 in src/JLD2.jl

View check run for this annotation

Codecov / codecov/patch

src/JLD2.jl#L250-L251

Added lines #L250 - L251 were not covered by tests
end
#try
f.root_group = load_group(f, f.root_group_offset)

if haskey(f.root_group.written_links, "_types")
types_group_offset = f.root_group.written_links["_types"]::RelOffset
f.types_group = f.loaded_groups[types_group_offset] = load_group(f, types_group_offset)
i = 0
for (offset::RelOffset) in values(f.types_group.written_links)
f.datatype_locations[offset] = CommittedDatatype(offset, i += 1)
end
resize!(f.datatypes, length(f.datatype_locations))
else
f.types_group = Group{typeof(f)}(f)
end
# catch e
# show(e)
# f.types_group = Group{typeof(f)}(f)
f.root_group = load_group(f, f.root_group_offset)

# end
nothing
types_offset = get(f.root_group.written_links, "_types", UNDEFINED_ADDRESS)
if types_offset != UNDEFINED_ADDRESS
f.types_group = f.loaded_groups[types_offset] = load_group(f, types_offset)
for (i, offset::RelOffset) in enumerate(values(f.types_group.written_links))
f.datatype_locations[offset] = CommittedDatatype(offset, i)
end
resize!(f.datatypes, length(f.datatype_locations))
else
f.types_group = Group{typeof(f)}(f)
end
end

"""
Expand Down Expand Up @@ -320,6 +289,32 @@ function jldopen(fname::Union{AbstractString, IO}, mode::AbstractString="r"; iot
end
end


function jldopen(io::IO, writable::Bool, create::Bool, truncate::Bool;
plain::Bool=false,
compress=false,
typemap::Dict{String}=Dict{String,Any}(),
)
verify_compressor(compress)
# figure out what kind of io object this is
# for now assume it is
!io.readable && throw("IO object is not readable")
if io.seekable && writable && iswritable(io)
# Here could have a more lightweight wrapper
# that just ensures API is defined
created = truncate
io = RWBuffer(io)
f = JLDFile(io, "RWBuffer", writable, created, plain, compress, false)
elseif (false == writable == create == truncate)
# Were trying to read, so let's hope `io` implements `read` and bytesavailable
io = ReadOnlyBuffer(io)
f = JLDFile(io, "ReadOnlyBuffer", false, false, plain, compress, false)
end
initialize_fileobject!(f)
merge!(f.typemap, typemap)
return f
end

"""
load_datatypes(f::JLDFile)
Expand Down Expand Up @@ -490,8 +485,8 @@ include("data/custom_serialization.jl")
include("data/writing_datatypes.jl")
include("data/reconstructing_datatypes.jl")

include("general_io.jl")
include("dataio.jl")
include("general_io.jl")
include("loadsave.jl")
include("backwards_compatibility.jl")
include("inlineunion.jl")
Expand Down
56 changes: 11 additions & 45 deletions src/bufferedio.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,17 @@ function Base.seek(io::BufferedWriter, offset::Integer)
buffer_offset = offset - io.file_position
buffer_offset < 0 && throw(ArgumentError("cannot seek before start of buffer"))
ensureroom(io, buffer_offset - bufferpos(io))
io.curptr += buffer_offset - bufferpos(io)
io.curptr = pointer(io.buffer) + buffer_offset

Check warning on line 39 in src/bufferedio.jl

View check run for this annotation

Codecov / codecov/patch

src/bufferedio.jl#L35-L39

Added lines #L35 - L39 were not covered by tests
end

function finish!(io::BufferedWriter)
f = io.f
bufferpos(io) == length(io.buffer) ||
throw(InternalError("BufferedWriter: buffer not written to end; position is $(bufferpos(io)) but length is $(length(io.buffer))"))
seek(f, io.file_position)
jlwrite(f, io.buffer)
seek(io.f, io.file_position)
jlwrite(io.f, io.buffer)
nothing
end

jlwrite(io::BufferedWriter, x::Union{UInt8,Int8}) = _write(io, x)

mutable struct BufferedReader{io} <: MemoryBackedIO
f::io
buffer::Vector{UInt8}
Expand All @@ -61,6 +58,7 @@ function BufferedReader(io)
buf = Vector{UInt8}()
BufferedReader(io, buf, Int64(position(io)), Ptr{Nothing}(pointer(buf)))
end

Base.show(io::IO, ::BufferedReader) = print(io, "BufferedReader")

function readmore!(io::BufferedReader, n::Integer)
Expand All @@ -77,27 +75,6 @@ end
ensureroom(io::BufferedReader, n::Integer) =
(bufferpos(io) + n >= length(io.buffer)) && readmore!(io, n)

function _read(io::BufferedReader, T::DataType)
n = jlsizeof(T)
ensureroom(io, n)
v = jlunsafe_load(Ptr{T}(io.curptr))
io.curptr += n
v
end
jlread(io::BufferedReader, T::Type{UInt8}) = _read(io, T)
jlread(io::BufferedReader, T::Type{Int8}) = _read(io, T)
jlread(io::BufferedReader, T::PlainType) = _read(io, T)

function jlread(io::BufferedReader, ::Type{T}, n::Int) where T
m = jlsizeof(T) * n
ensureroom(io, m)
arr = Vector{T}(undef, n)
unsafe_copyto!(pointer(arr), Ptr{T}(io.curptr), n)
io.curptr += m
arr
end
jlread(io::BufferedReader, ::Type{T}, n::Integer) where {T} =
jlread(io, T, Int(n))

Base.position(io::BufferedReader) = io.file_position + bufferpos(io)

Expand All @@ -108,18 +85,13 @@ Get the current position in the buffer.
"""
bufferpos(io::Union{BufferedReader, BufferedWriter}) = Int(io.curptr - pointer(io.buffer))

function adjust_position!(io::BufferedReader, position::Integer)
position < 0 && throw(ArgumentError("cannot seek before start of buffer"))
if position > length(io.buffer)
readmore!(io, position - length(io.buffer))
end
io.curptr = pointer(io.buffer, position+1)
position
function Base.seek(io::BufferedReader, offset::Integer)
pos = offset - io.file_position
pos < 0 && throw(ArgumentError("cannot seek before start of buffer"))
ensureroom(io, offset - position(io))
io.curptr = pointer(io.buffer) + pos
nothing
end

Base.seek(io::BufferedReader, offset::Integer) = adjust_position!(io, offset - io.file_position)
Base.skip(io::BufferedReader, offset::Integer) = adjust_position!(io, bufferpos(io) + offset)

finish!(io::BufferedReader) = seek(io.f, io.file_position + bufferpos(io))

function truncate_and_close(io::IOStream, endpos::Integer)
Expand All @@ -129,15 +101,9 @@ end

Base.close(::BufferedReader) = nothing

Check warning on line 102 in src/bufferedio.jl

View check run for this annotation

Codecov / codecov/patch

src/bufferedio.jl#L102

Added line #L102 was not covered by tests


# We sometimes need to compute checksums. We do this by first calling begin_checksum when
# starting to handle whatever needs checksumming, and calling end_checksum afterwards. Note
# that we never compute nested checksums, but we may compute multiple checksums
# simultaneously.

begin_checksum_read(io::IO) = BufferedReader(io)

function begin_checksum_write(io::IOStream, sz::Integer)
function begin_checksum_write(io::IO, sz::Integer)
BufferedWriter(io, sz)
end
function end_checksum(io::Union{BufferedReader,BufferedWriter})
Expand Down
Loading

0 comments on commit f010181

Please sign in to comment.