Skip to content

Commit

Permalink
Add splitmode and suppressnonspeechtokens parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
pierreguillot committed Aug 27, 2024
1 parent f4b8903 commit 01d4ff2
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 26 deletions.
100 changes: 81 additions & 19 deletions source/wvp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,15 +310,16 @@ void Wvp::Plugin::reset()
{
auto const createContext = [this]() -> struct whisper_context*
{
auto params = whisper_context_default_params();
if(mModelIndex == 0)
{
return whisper_init_from_buffer_with_params(const_cast<void*>(Wvp::model), Wvp::model_size, whisper_context_default_params());
return whisper_init_from_buffer_with_params(const_cast<void*>(Wvp::model), Wvp::model_size, params);
}
auto const models = getModelPaths();
if(mModelIndex <= models.size())
{
auto const path = models.at(mModelIndex - 1).string();
return whisper_init_from_file_with_params(path.c_str(), whisper_context_default_params());
return whisper_init_from_file_with_params(path.c_str(), params);
}
return nullptr;
};
Expand Down Expand Up @@ -359,6 +360,33 @@ Wvp::Plugin::ParameterList Wvp::Plugin::getParameterDescriptors() const
param.quantizeStep = 1.0f;
list.push_back(std::move(param));
}
{
ParameterDescriptor param;
param.identifier = "splitmode";
param.name = "Split Mode";
param.description = "The model splits the text on sentences, words or tokens";
param.unit = "";
param.valueNames = {"Sentences", "Words", "Tokens"};
param.minValue = 0.0f;
param.maxValue = 2.0f;
param.defaultValue = 2.0f;
param.isQuantized = true;
param.quantizeStep = 1.0f;
list.push_back(std::move(param));
}
{
ParameterDescriptor param;
param.identifier = "suppressnonspeechtokens";
param.name = "Suppress Non-Speech Tokens";
param.description = "The model suppresses non-speech tokens";
param.unit = "";
param.minValue = 0.0f;
param.maxValue = 1.0f;
param.defaultValue = 1.0f;
param.isQuantized = true;
param.quantizeStep = 1.0f;
list.push_back(std::move(param));
}
return list;
}

Expand All @@ -369,6 +397,14 @@ void Wvp::Plugin::setParameter(std::string paramid, float newval)
auto const max = static_cast<float>(getModelPaths().size());
mModelIndex = static_cast<size_t>(std::floor(std::clamp(newval, 0.0f, max)));
}
else if(paramid == "splitmode")
{
mSplitMode = static_cast<size_t>(std::floor(std::clamp(newval, 0.0f, 2.0f)));
}
else if(paramid == "suppressnonspeechtokens")
{
mSuppressNonSpeechTokens = newval > 0.5f;
}
else
{
std::cerr << "Invalid parameter : " << paramid << "\n";
Expand All @@ -381,6 +417,14 @@ float Wvp::Plugin::getParameter(std::string paramid) const
{
return static_cast<float>(mModelIndex);
}
if(paramid == "splitmode")
{
return static_cast<float>(mSplitMode);
}
if(paramid == "suppressnonspeechtokens")
{
return mSuppressNonSpeechTokens ? 1.0f : 0.0f;
}
std::cerr << "Invalid parameter : " << paramid << "\n";
return 0.0f;
}
Expand Down Expand Up @@ -408,15 +452,16 @@ Wvp::Plugin::OutputExtraList Wvp::Plugin::getOutputExtraDescriptors(size_t outpu
Wvp::Plugin::FeatureList Wvp::Plugin::getCurrentFeatures(size_t timeOffset)
{
auto params = whisper_full_default_params(whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY);
params.token_timestamps = true;
params.no_context = true;
params.print_progress = false;
params.print_timestamps = false;
params.print_special = false;
params.translate = false;
params.suppress_non_speech_tokens = true;
params.suppress_non_speech_tokens = mSuppressNonSpeechTokens;
params.language = nullptr;
params.split_on_word = false;
params.token_timestamps = mSplitMode >= 1;
params.max_len = mSplitMode == 1;
params.split_on_word = mSplitMode == 1;
static auto const minSize = gModelSampleRate + gModelSampleRate / 10;
if(mBufferPosition < minSize)
{
Expand All @@ -434,22 +479,39 @@ Wvp::Plugin::FeatureList Wvp::Plugin::getCurrentFeatures(size_t timeOffset)
auto const nsegments = whisper_full_n_segments(mHandle.get());
for(int i = 0; i < nsegments; ++i)
{
auto const* text = whisper_full_get_segment_text(mHandle.get(), i);
auto const ntokens = whisper_full_n_tokens(mHandle.get(), i);
for(int j = 0; j < ntokens; ++j)
if(mSplitMode < 2)
{
auto const data = whisper_full_get_token_data(mHandle.get(), i, j);
if(data.id < whisper_token_eot(mHandle.get()))
auto const* text = whisper_full_get_segment_text(mHandle.get(), i);
auto const t0 = whisper_full_get_segment_t0(mHandle.get(), i);
auto const t1 = whisper_full_get_segment_t1(mHandle.get(), i);
Feature feature;
feature.hasTimestamp = true;
auto const time = Vamp::RealTime::fromSeconds(static_cast<double>(t0) / 100.0);
feature.timestamp = time + offset;
feature.hasDuration = true;
feature.duration = Vamp::RealTime::fromSeconds(static_cast<double>(t1) / 100.0) - time;
feature.label = text;
feature.values.push_back(1.0);
fl.push_back(std::move(feature));
}
else
{
auto const ntokens = whisper_full_n_tokens(mHandle.get(), i);
for(int j = 0; j < ntokens; ++j)
{
Feature feature;
feature.hasTimestamp = true;
auto const time = Vamp::RealTime::fromSeconds(static_cast<double>(data.t0) / 100.0);
feature.timestamp = time + offset;
feature.hasDuration = true;
feature.duration = Vamp::RealTime::fromSeconds(static_cast<double>(data.t1) / 100.0) - time;
feature.label = whisper_full_get_token_text(mHandle.get(), i, j);
feature.values.push_back(data.p);
fl.push_back(std::move(feature));
auto const data = whisper_full_get_token_data(mHandle.get(), i, j);
if(!mSuppressNonSpeechTokens || data.id < whisper_token_eot(mHandle.get()))
{
Feature feature;
feature.hasTimestamp = true;
auto const time = Vamp::RealTime::fromSeconds(static_cast<double>(data.t0) / 100.0);
feature.timestamp = time + offset;
feature.hasDuration = true;
feature.duration = Vamp::RealTime::fromSeconds(static_cast<double>(data.t1) / 100.0) - time;
feature.label = whisper_full_get_token_text(mHandle.get(), i, j);
feature.values.push_back(data.p);
fl.push_back(std::move(feature));
}
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions source/wvp.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ namespace Wvp
size_t mAdvancement{0};
size_t mBlockSize{0};
size_t mModelIndex{0};
size_t mSplitMode{2};
bool mSuppressNonSpeechTokens{true};
std::set<size_t> mRanges;
};
} // namespace Wvp
36 changes: 29 additions & 7 deletions test/whisper.ptldoc
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
<layout value="0cbca15363f34e82af09c599fdff328c"/>
<layout value="3a1a8c77cc504338a8cc8452b77cc31f"/>
<timeZoom MiscModelVersion="131075" globalRange_start="0.0" globalRange_end="8.607256235827665"
minimumLength="0.01160997732426304" visibleRange_start="0.005108623311593341"
visibleRange_end="4.926555326685285">
minimumLength="0.01160997732426304" visibleRange_start="0.0"
visibleRange_end="8.607256235827665">
<grid MiscModelVersion="131075" tickReference="0.0" mainTickInterval="0"
tickPowerBase="2.0" tickDivisionFactor="10.0"/>
</timeZoom>
<transport MiscModelVersion="131075" startPlayhead="1.15140832130236" looping="0"
<transport MiscModelVersion="131075" startPlayhead="3.052635596681737" looping="0"
loopRange_start="0.0" loopRange_end="0.0" stopAtLoopEnd="0" autoScroll="1"
gain="1.0" magnetize="0"/>
<groups MiscModelVersion="131075" identifier="0cbca15363f34e82af09c599fdff328c"
Expand All @@ -28,7 +28,7 @@
</zoom>
</groups>
<groups MiscModelVersion="131075" identifier="3a1a8c77cc504338a8cc8452b77cc31f"
name="Group 2" height="170" colour="ff3c0066" expanded="0" referenceid="">
name="Group 2" height="86" colour="ff3c0066" expanded="0" referenceid="">
<layout value="c03605ef7aea48b19ed64beb9c18759d"/>
<zoom MiscModelVersion="131075" globalRange_start="0.0" globalRange_end="1.0"
minimumLength="0.0" visibleRange_start="0.4499999992549419" visibleRange_end="0.5500000603497028">
Expand All @@ -37,7 +37,7 @@
</zoom>
</groups>
<tracks MiscModelVersion="131075" identifier="3c7b46ce4d534244b77669c0ed91a216"
name="Whisper" input="" height="55" font="Nunito Sans; 32.0 Bold"
name="Whisper" input="" height="165" font="Nunito Sans; 32.0 Bold"
showInGroup="1" zoomValueMode="0" zoomLink="1">
<file path="" commit="">
<args/>
Expand All @@ -46,6 +46,8 @@
details="Automatic speech recognition using OpenAI's Whisper model.&#10;Whisper models by OpenAI. Whisper.cpp by Georgi Gerganov. Whisper Vamp Plugin by Pierre Guillot. Copyright 2024 Ircam. All rights reserved.">
<defaultState blockSize="1024" stepSize="0" windowType="3">
<parameters key="model" value="0.0"/>
<parameters key="splitmode" value="2.0"/>
<parameters key="suppressnonspeechtokens" value="1.0"/>
</defaultState>
<parameters>
<value identifier="model" name="Model" description="The model used to generate the tokens"
Expand All @@ -59,6 +61,21 @@
<valueNames value="ggml-base"/>
</value>
</parameters>
<parameters>
<value identifier="splitmode" name="Split Mode" description="The model splits the text on sentences, words or syllables"
unit="" minValue="0.0" maxValue="2.0" defaultValue="2.0" isQuantized="1"
quantizeStep="1.0">
<valueNames value="Sentences"/>
<valueNames value="Words"/>
<valueNames value="Syllables"/>
</value>
</parameters>
<parameters>
<value identifier="suppressnonspeechtokens" name="Suppress Non-Speech Tokens"
description="The model suppresses non-speech tokens" unit=""
minValue="0.0" maxValue="1.0" defaultValue="1.0" isQuantized="1"
quantizeStep="1.0"/>
</parameters>
<output identifier="token" name="Token" description="Tokens generated by text-to-speech transcription"
unit="" hasFixedBinCount="1" binCount="0" hasKnownExtents="0"
minValue="0.0" maxValue="0.0" isQuantized="0" quantizeStep="0.0"
Expand All @@ -75,7 +92,12 @@
</description>
<key identifier="ircamwhisper:whisper" feature="token"/>
<state blockSize="1024" stepSize="0" windowType="3">
<parameters key="model" value="3.0"/>
<parameters key="model" value="5.0"/>
<parameters key="nonspeech" value="1.0"/>
<parameters key="split" value="1.0"/>
<parameters key="splitmode" value="2.0"/>
<parameters key="splitonword" value="1.0"/>
<parameters key="suppressnonspeechtokens" value="0.0"/>
</state>
<colours map="7" background="0" foreground="ff05ffe2" duration="667e7a7d"
text="ff05ffe2" shadow="0"/>
Expand All @@ -93,7 +115,7 @@
</binZoom>
</tracks>
<tracks MiscModelVersion="131075" identifier="24b3f47a03a745c2980a09cb873d5542"
name="Fast Fourier Transform" input="" height="55" font="Nunito Sans; 14.0 Regular"
name="Fast Fourier Transform" input="" height="29" font="Nunito Sans; 14.0 Regular"
showInGroup="1" zoomValueMode="1" zoomLink="1">
<file path="" commit="">
<args/>
Expand Down

0 comments on commit 01d4ff2

Please sign in to comment.