From f80602dba51bb5d0abf7eea01bfc1487c9fd2025 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 9 May 2023 09:35:47 -0500 Subject: [PATCH] Add WhisperParams for mutable configuration Adds WhisperParams, which have a setproperty! method for changing params Adds create_ctx and create_params helpers Adds a transcribe method that takes a ctx and wparams --- src/Whisper.jl | 34 ++++++++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/src/Whisper.jl b/src/Whisper.jl index 9183b7f..9f64c4f 100644 --- a/src/Whisper.jl +++ b/src/Whisper.jl @@ -14,6 +14,26 @@ function __init__() register_datadeps() end +function create_ctx(model) + whisper_init_from_file(DataDeps.resolve("whisper-ggml-$model/ggml-$model.bin", "__FILE__")) +end + +struct WhisperParams + ref::Base.RefValue{whisper_full_params} +end +Base.unsafe_convert(::Type{Ptr{whisper_full_params}}, wparams::WhisperParams) = + Base.unsafe_convert(Ptr{whisper_full_params}, getfield(wparams, :ref)) +Base.convert(::Type{whisper_full_params}, wparams::WhisperParams) = + getfield(wparams, :ref)[] +Base.getproperty(wparams::WhisperParams, field::Symbol) = + getproperty(Base.unsafe_convert(Ptr{whisper_full_params}, wparams), field) +Base.setproperty!(wparams::WhisperParams, field::Symbol, value) = + setproperty!(Base.unsafe_convert(Ptr{whisper_full_params}, wparams), field, value) +function create_params(flag=LibWhisper.WHISPER_SAMPLING_GREEDY) + raw_params = whisper_full_default_params(flag) + return WhisperParams(Ref(raw_params)) +end + """ transcribe(model, data) -> String @@ -24,9 +44,17 @@ automatically downloaded from HuggingFace on first use. - `data`: `Vector{Float32}` containing 16kHz sampled audio """ function transcribe(model, data) - ctx = whisper_init_from_file(DataDeps.resolve("whisper-ggml-$model/ggml-$model.bin", "__FILE__") ) - wparams = whisper_full_default_params(LibWhisper.WHISPER_SAMPLING_GREEDY) + ctx = create_ctx(model) + wparams = create_params(flag) + + ret = transcribe(ctx, wparams, data) + whisper_free(ctx) + + return ret +end + +function transcribe(ctx, wparams, data) ret = whisper_full_parallel(ctx, wparams, data, length(data), 1) if ret != 0 @@ -44,8 +72,6 @@ function transcribe(model, data) @debug "Time for inference: ", t0-t1 end - whisper_free(ctx) - return result end