Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use a fine-tuned segmentation model for diarization? #840

Open
Arche151 opened this issue Jul 17, 2024 · 6 comments
Open

How to use a fine-tuned segmentation model for diarization? #840

Arche151 opened this issue Jul 17, 2024 · 6 comments

Comments

@Arche151
Copy link

I have a WhisperX Python script for transcribing meetings, but the speaker diarization for German is really bad, unfortunately.

After some research I came across the fine-tuned German segmentation model diarizers-community/speaker-segmentation-fine-tuned-callhome-deu but I haven't figured out how to get WhisperX to use it.

Here's my Python script:

import os
import sys
import torch  
import whisperx
import ffmpeg

# Hardcoded Hugging Face token
HF_TOKEN = 'xyz'

def convert_to_wav(audio_path):
    output_path = os.path.splitext(audio_path)[0] + ".wav"
    ffmpeg.input(audio_path).output(output_path).run(quiet=True, overwrite_output=True)
    return output_path

def transcribe_audio(audio_path, num_speakers=None):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    batch_size = 16
    compute_type = "float16" if torch.cuda.is_available() else "int8"

    # Load WhisperX model
    print(f"Loading WhisperX model on {device}...")
    model = whisperx.load_model("flozi00/whisper-large-v3-german-ct2", device, compute_type=compute_type)

    # Load and transcribe audio
    print("Loading and transcribing audio...")
    audio = whisperx.load_audio(audio_path)
    result = model.transcribe(audio, batch_size=batch_size)

    # Load alignment model
    print("Loading alignment model...")
    model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
    result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)

    # Load diarization model and assign speaker labels
    print("Loading diarization model...")
    diarize_model = whisperx.DiarizationPipeline(use_auth_token=HF_TOKEN, device=device)
    diarize_segments = diarize_model(audio, min_speakers=num_speakers, max_speakers=num_speakers)
    result = whisperx.assign_word_speakers(diarize_segments, result)

    # Add speaker label if missing
    for segment in result["segments"]:
        if 'speaker' not in segment:
            segment['speaker'] = 'Unknown'

    return result["segments"]

def save_transcription(transcription, audio_path, speaker_mapping):
    output_path = os.path.splitext(audio_path)[0] + ".txt"
    with open(output_path, 'w') as f:
        for segment in transcription:
            speaker = speaker_mapping.get(segment['speaker'], segment['speaker'])
            f.write(f"{speaker}: {segment['text']}\n")
    print(f"Transcription saved to {output_path}")

def display_first_10_lines(transcription):
    for i, segment in enumerate(transcription[:10]):
        print(f"{segment['speaker']}: {segment['text']}")
    print()

def get_speaker_names(unique_speakers):
    speaker_mapping = {}
    for speaker in unique_speakers:
        name = input(f"Enter the name for {speaker}: ")
        speaker_mapping[speaker] = name
    return speaker_mapping

def main():
    audio_path = input("Enter the filename of the audio to be transcribed: ").strip().strip("'")
    if not os.path.isfile(audio_path):
        print(f"Error: The file '{audio_path}' does not exist.")
        sys.exit(1)

    num_speakers = int(input("Enter the number of speakers: "))

    if not audio_path.endswith(".wav"):
        print("Converting audio to WAV format...")
        audio_path = convert_to_wav(audio_path)

    print("Transcribing audio...")
    transcription = transcribe_audio(audio_path, num_speakers)

    print("\nFirst 10 lines of the transcription:")
    display_first_10_lines(transcription)

    unique_speakers = sorted(set(segment['speaker'] for segment in transcription))
    speaker_mapping = get_speaker_names(unique_speakers)

    save_transcription(transcription, audio_path, speaker_mapping)

if __name__ == "__main__":
    main()

I'd greatly appreciate any help!

@Dream-gamer
Copy link

Dream-gamer commented Aug 13, 2024

This worked for me:

from diarizers import SegmentationModel
diarize_model = whisperx.DiarizationPipeline(use_auth_token="Your_hftoken",
                                            device = device)

segmentation_model = SegmentationModel().from_pretrained('diarizers-community/speaker-segmentation-fine-tuned-callhome-deu')

fine_tuned_model = segmentation_model.to_pyannote_model()

diarize_model.model._segmentation.model = fine_tuned_model.to(device)
diarize_segments= diarize_model(audio, num_speakers=2)

@Arche151
Copy link
Author

@Dream-gamer

I will try that out, thanks! :) Heads up, you accidentally shared your HF_Token.

Do you maybe also happen to know, how I can use speaker embeddings, that I extracted via reference speakers, in the diarization pipeline? I asked about that in the Pyannote Github, but didn't get a response unfortunately. pyannote/pyannote-audio#1750

@Dream-gamer
Copy link

Hey thanks for the heads up lol. Do let me know if you face any error in the above code cause I just tried this today and it worked.
As for your follow-up question, I haven't really used speaker embeddings, but yeah I am currently working on improving my diarized transcriptions, so will play around will that as well and get back to you if I found anything,

@slaesh
Copy link

slaesh commented Aug 30, 2024

Thanks @Dream-gamer , works just fine using the tuned model loading!

@Arche151 having the same problem, but using the fine-tuned model doesnt increase the accuracy of diarization that much.. how is it going on your side? any further steps taken to improve it? =)

@Dream-gamer
Copy link

Dream-gamer commented Sep 3, 2024

Glad to know. Yeah, it didn't noticably increase the accuracy. I have been using LLMs like gemini to parse the generated transcript and get corrected transcript. You can use the prompt like "Here is the speaker separated transcript. Some of the words in the transcript are in the wrong speaker labels. Correct them and give corrected transcript:" This has given me much better results for hindi-english transcript.
Also, if this didn't improve performace, you can use gemini for diarization itself in google's ai studio: https://aistudio.google.com/app/prompts/new_chat?pli=1
Just upload the audio file, give some context in the prompt, for.ex number of speakers, language in the audio and ask it to generate speaker separated transcript.

@slaesh
Copy link

slaesh commented Sep 3, 2024

thanks! :)
yeah I thought also about doing similar stuff for the transcription itself. I am doing stuff offline, so gemini is not a thing ;D
the transcription itself is not that big issue for me, its more the diarization itself.
if there are questions and some1 else answers quickly with just 1 or 2 words.. it most of the time not really seperated and both the question and answer belong to the same speaker. weirdly sometimes even the question will count to the answer-er :D

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants