diff --git a/src/Whisper.jl b/src/Whisper.jl index 9183b7f..9f64c4f 100644 --- a/src/Whisper.jl +++ b/src/Whisper.jl @@ -14,6 +14,26 @@ function __init__() register_datadeps() end +function create_ctx(model) + whisper_init_from_file(DataDeps.resolve("whisper-ggml-$model/ggml-$model.bin", "__FILE__")) +end + +struct WhisperParams + ref::Base.RefValue{whisper_full_params} +end +Base.unsafe_convert(::Type{Ptr{whisper_full_params}}, wparams::WhisperParams) = + Base.unsafe_convert(Ptr{whisper_full_params}, getfield(wparams, :ref)) +Base.convert(::Type{whisper_full_params}, wparams::WhisperParams) = + getfield(wparams, :ref)[] +Base.getproperty(wparams::WhisperParams, field::Symbol) = + getproperty(Base.unsafe_convert(Ptr{whisper_full_params}, wparams), field) +Base.setproperty!(wparams::WhisperParams, field::Symbol, value) = + setproperty!(Base.unsafe_convert(Ptr{whisper_full_params}, wparams), field, value) +function create_params(flag=LibWhisper.WHISPER_SAMPLING_GREEDY) + raw_params = whisper_full_default_params(flag) + return WhisperParams(Ref(raw_params)) +end + """ transcribe(model, data) -> String @@ -24,9 +44,17 @@ automatically downloaded from HuggingFace on first use. - `data`: `Vector{Float32}` containing 16kHz sampled audio """ function transcribe(model, data) - ctx = whisper_init_from_file(DataDeps.resolve("whisper-ggml-$model/ggml-$model.bin", "__FILE__") ) - wparams = whisper_full_default_params(LibWhisper.WHISPER_SAMPLING_GREEDY) + ctx = create_ctx(model) + wparams = create_params(flag) + + ret = transcribe(ctx, wparams, data) + whisper_free(ctx) + + return ret +end + +function transcribe(ctx, wparams, data) ret = whisper_full_parallel(ctx, wparams, data, length(data), 1) if ret != 0 @@ -44,8 +72,6 @@ function transcribe(model, data) @debug "Time for inference: ", t0-t1 end - whisper_free(ctx) - return result end