Skip to content

Commit

Permalink
Add WhisperParams for mutable configuration
Browse files Browse the repository at this point in the history
Adds WhisperParams, which have a setproperty! method for changing params
Adds create_ctx and create_params helpers
Adds a transcribe method that takes a ctx and wparams
  • Loading branch information
jpsamaroo committed May 9, 2023
1 parent aa2a1e9 commit f80602d
Showing 1 changed file with 30 additions and 4 deletions.
34 changes: 30 additions & 4 deletions src/Whisper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,26 @@ function __init__()
register_datadeps()
end

function create_ctx(model)
whisper_init_from_file(DataDeps.resolve("whisper-ggml-$model/ggml-$model.bin", "__FILE__"))
end

struct WhisperParams
ref::Base.RefValue{whisper_full_params}
end
Base.unsafe_convert(::Type{Ptr{whisper_full_params}}, wparams::WhisperParams) =
Base.unsafe_convert(Ptr{whisper_full_params}, getfield(wparams, :ref))
Base.convert(::Type{whisper_full_params}, wparams::WhisperParams) =
getfield(wparams, :ref)[]
Base.getproperty(wparams::WhisperParams, field::Symbol) =
getproperty(Base.unsafe_convert(Ptr{whisper_full_params}, wparams), field)
Base.setproperty!(wparams::WhisperParams, field::Symbol, value) =
setproperty!(Base.unsafe_convert(Ptr{whisper_full_params}, wparams), field, value)
function create_params(flag=LibWhisper.WHISPER_SAMPLING_GREEDY)
raw_params = whisper_full_default_params(flag)
return WhisperParams(Ref(raw_params))
end

"""
transcribe(model, data) -> String
Expand All @@ -24,9 +44,17 @@ automatically downloaded from HuggingFace on first use.
- `data`: `Vector{Float32}` containing 16kHz sampled audio
"""
function transcribe(model, data)
ctx = whisper_init_from_file(DataDeps.resolve("whisper-ggml-$model/ggml-$model.bin", "__FILE__") )
wparams = whisper_full_default_params(LibWhisper.WHISPER_SAMPLING_GREEDY)
ctx = create_ctx(model)
wparams = create_params(flag)

ret = transcribe(ctx, wparams, data)

whisper_free(ctx)

return ret
end

function transcribe(ctx, wparams, data)
ret = whisper_full_parallel(ctx, wparams, data, length(data), 1)

if ret != 0
Expand All @@ -44,8 +72,6 @@ function transcribe(model, data)
@debug "Time for inference: ", t0-t1
end

whisper_free(ctx)

return result
end

Expand Down

0 comments on commit f80602d

Please sign in to comment.