Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accept variable-length batch prompts for Whisper #1784

Closed
wants to merge 17 commits into from

Conversation

MahmoudAshraf97
Copy link
Contributor

This is a continuation of #1457, the final goal is to enable continuous batching for whisper models which bring large speedups with large batch size

@BBC-Esq
Copy link

BBC-Esq commented Sep 18, 2024

Just got notice of this, cool! However, it makes me wonder...is it really necessary? For example, in the WhisperS2T library it already offers batching. I'm including an in-depth analysis of scripts from that source code for you to see how it's done because I don't think anyone has really analyzed it fully...Hence my question as to whether modifying Ctranslate2 is necessary? Maybe it is...maybe your mod just makes it even faster, but I'd like to understand why it's necessary.

Here's my summary, but I'll leave it to you to analyze the relevant source code files within WhisperS2T yourself because I know you're experienced enough to deduce which scripts to look at. ;-)

BBC SUMMARY OF WHISPERS2T PIPELINE HANDLING A MULTIPLE AUDIO FILES
**Step-by-Step Timeline of How WhisperS2T Processes Multiple Audio Files**

This guide outlines the chronological steps taken by the WhisperS2T library when processing a list of audio files. It explains how the library handles multiple files, segments them, keeps track of which file each segment originates from, and processes them with the `ctranslate2` backend.

---

### **1. Initialization**

**Script**: `model.py`  
**Class**: `WhisperModelCT2`  
**Method**: `__init__`

- **Action**: An instance of `WhisperModelCT2` is created.
- **Details**:
  - Loads the Whisper model using `ctranslate2`.
  - Initializes the tokenizer from the `tokenizer.json` file.
  - Sets up ASR options and generation parameters.
  - Initializes dependents like `self.preprocessor` and `self.data_loader`.
  - **Note**: This step is common regardless of the number of audio files.

---

### **2. Transcription Start with Multiple Audio Files**

**Script**: `model.py`  
**Class**: `WhisperModelCT2`  
**Method**: `transcribe`

- **Action**: The `transcribe` method is called with a list of audio files.
- **Details**:
  - **Parameters**: `audio_files` (list), `lang_codes`, `tasks`, `initial_prompts`, `batch_size`.
  - Uses `fix_batch_param` to ensure that `lang_codes`, `tasks`, and `initial_prompts` are lists matching the length of `audio_files`.
  - Initializes a list of empty lists `responses = [[] for _ in audio_files]` to store results for each file.
  - Sets up a progress bar using `tqdm`, total progress is `len(audio_files) * 100`.

---

### **3. Data Loading and Preprocessing for Multiple Files**

**Script**: `__init__.py`  
**Class**: `WhisperModel`  
**Method**: `transcribe`

- **Action**: Calls the data loader to generate batches that may include segments from multiple audio files.
- **Details**:
  - **Calls**: `self.data_loader(...)`, an instance of `WhisperDataLoader`.
  - **Parameters**: `audio_files`, `lang_codes`, `tasks`, `initial_prompts`, `batch_size`, `use_vad=False`.
  - The data loader will process each audio file and segment it, keeping track of the file ID.

---

### **4. Generating Data Batches from Multiple Audio Files**

**Script**: `data.py`  
**Class**: `WhisperDataLoader`  
**Method**: `__call__`

- **Action**: Invokes `get_data_loader` to process multiple audio files without VAD.
- **Details**:
  - Initializes `segmented_audio_signal` and `pbar_update_len`.
  - Iterates over each audio file and associated parameters:
    - **For each audio file**:
      - Loads the audio signal.
      - Segments the audio into chunks.
      - Adds metadata including `file_id` to keep track of which file the segment comes from.
      - Segments from all files are accumulated into `segmented_audio_signal`.
  - Batches are created from `segmented_audio_signal`, which may include segments from different files.

---

### **5. Loading Audio Files**

**Script**: `audio.py`  
**Function**: `audio_batch_generator`

- **Action**: Generates audio signals from the list of `audio_files`.
- **Details**:
  - Iterates over `audio_files`.
  - If an audio file is a NumPy array, it yields it directly.
  - If it's a file path, it loads the audio using `load_audio`.

---

### **6. Converting Audio Files**

**Script**: `audio.py`  
**Function**: `load_audio`

- **Action**: Attempts to open each file as a WAV file.
- **Details**:
  - **Checks**: If the file is a 16kHz mono WAV file.
  - If not, uses `ffmpeg` to convert the audio to 16kHz mono WAV.
    - **Command**:
      ```bash
      ffmpeg -hide_banner -loglevel panic -i "{input_file}" -threads 1 -acodec pcm_s16le -ac 1 -af aresample=resampler=soxr -ar 16000 "{wav_file}" -y
      ```
  - Reads the frames from the WAV file and converts them to a NumPy array.
  - Normalizes the audio data to float32 between -1 and 1.
  - Returns the `audio_signal` NumPy array.

---

### **7. Segmenting Each Audio File**

**Script**: `data.py`  
**Class**: `BasicSegmenter`  
**Method**: `__call__`

- **Action**: Segments each audio signal into fixed-length segments.
- **Details**:
  - For each audio signal:
    - Calculates `audio_duration` from `audio_signal`.
    - Generates `start_ends`, a list of segment start and end times based on `max_seg_len`.
    - Returns `start_ends` and `audio_signal`.

---

### **8. Preparing Segmented Audio with File IDs**

**Script**: `data.py`  
**Class**: `WhisperDataLoader`  
**Method**: `get_segmented_audio_signal`

- **Action**: Prepares segmented audio signals, keeping track of which file each segment comes from.
- **Details**:
  - **Parameters**: `start_ends`, `audio_signal`, `file_id`, `lang`, `task`, `initial_prompt`.
  - Tokenizes the `initial_prompt` if provided.
  - Prepares the `prompt` using `self.tokenizer.sot_sequence`.
  - If `self.merge_chunks` is `True`, merges segments using `stitch_speech_segments`.
    - **Function**: `stitch_speech_segments`
  - For each segment:
    - Extracts the audio samples.
    - Constructs `seg_metadata` including `file_id`, `start_time`, `end_time`, and other metadata.
    - Adds tuples of `(audio, prompt, initial_prompt_tokens, seq_len, seg_metadata)` to `segmented_audio_signal`.

---

### **9. Accumulating Segments from All Files**

**Script**: `data.py`  
**Class**: `WhisperDataLoader`  
**Method**: `get_data_loader`

- **Action**: Accumulates segments from all audio files into a single list.
- **Details**:
  - Segments from each audio file are added to `segmented_audio_signal`.
  - The `file_id` in `seg_metadata` ensures segments can be associated back to the correct audio file.

---

### **10. Batching Data from Multiple Files**

**Script**: `data.py`  
**Class**: `WhisperDataLoader`  
**Method**: `data_collate_fn`

- **Action**: Prepares batches that may include segments from multiple audio files.
- **Details**:
  - Batches are created from `segmented_audio_signal`.
  - Pads or trims audio samples to a fixed length using `pad_or_trim`.
    - **Function**: `pad_or_trim`
  - Converts audio samples to PyTorch tensors and stacks them into `signal_batch`.
  - Prepares `prompt_batch` and `seq_len` tensors.
  - Batches include `seg_metadata` with `file_id` for each segment.

---

### **11. Computing Log-Mel Spectrograms**

**Script**: `audio.py`  
**Class**: `LogMelSpectogram`  
**Method**: `forward`

- **Action**: Computes log-Mel spectrograms for each batch.
- **Details**:
  - Adjusts `seq_len` based on `hop_length`.
  - Computes the Short-Time Fourier Transform (STFT) using `TorchSTFT`.
    - **Class**: `TorchSTFT`
    - **Method**: `forward`
  - Calculates the power spectrum and applies Mel filters.
  - Computes the log-Mel spectrogram.
  - Clips and scales the spectrogram.

---

### **12. Generating Transcriptions for Batches**

**Script**: `model.py`  
**Class**: `WhisperModelCT2`  
**Method**: `generate_segment_batched`

- **Action**: Generates transcriptions using the `ctranslate2` backend for each batch.
- **Details**:
  - Converts `features` (log-Mel spectrograms) to `StorageView`.
  - Calls `self.model.generate` with `features`, `prompts`, and `generate_kwargs`.
  - Decodes token IDs to text using `self.tokenizer.decode_batch`.
  - Prepares the response with transcribed text and metadata.

---

### **13. Collecting Results with File IDs**

**Script**: `__init__.py`  
**Class**: `WhisperModel`  
**Method**: `transcribe`

- **Action**: Collects and stores transcription results, associating them with the correct audio files.
- **Details**:
  - For each result in the batch:
    - Uses `_seg_metadata['file_id']` to identify which audio file the segment belongs to.
    - Appends the result to `responses[_seg_metadata['file_id']]`.
    - Includes `start_time` and `end_time` from `_seg_metadata` in the response.
  - Updates the progress bar accordingly.
  - After processing all batches, the method returns the `responses` list, where each element corresponds to an audio file and contains all its transcribed segments.

---

### **14. Final Output**

- **Action**: Returns the final transcription results, organized by audio file.
- **Details**:
  - Each entry in the `responses` list corresponds to an input audio file.
  - Within each entry, there is a list of transcribed segments with their associated metadata.
  - The library ensures that segments from different files are correctly associated back to their original files.

---

### **Detailed Processing Flow**

1. **User Input**: The user calls the `transcribe` method of an instance of `WhisperModelCT2`, passing in a list of audio files and optional parameters like language codes and initial prompts.

2. **Parameter Preparation**:
   - The method ensures that `lang_codes`, `tasks`, and `initial_prompts` are lists matching the length of `audio_files`.

3. **Data Loading**:
   - The `WhisperDataLoader` is invoked to prepare data batches.
   - It uses `audio_batch_generator` to yield audio signals from `audio_files`.
   - Each audio file is assigned a unique `file_id`.

4. **Audio Loading**:
   - For each audio file:
     - If it's a NumPy array, it's used directly.
     - If it's a file path, `load_audio` is called to load and preprocess the audio.

5. **Audio Conversion**:
   - In `load_audio`, attempts to read the audio file as a 16kHz mono WAV file.
   - If the format doesn't match, `ffmpeg` is used to convert the audio to 16kHz mono WAV.

6. **Audio Normalization**:
   - The audio frames are converted to a NumPy array of `float32`, normalized between -1 and 1.

7. **Segmenting Audio**:
   - The audio signal is segmented into fixed-length segments using `BasicSegmenter`.
   - Segments are represented as start and end times.

8. **Preparing Segments with Metadata**:
   - Each segment is associated with prompts and metadata, including `file_id`, `start_time`, and `end_time`.
   - Segments from all files are accumulated into `segmented_audio_signal`.

9. **Batch Preparation**:
   - Segments are batched together from `segmented_audio_signal`.
   - Audio samples are padded or trimmed to a fixed length.
   - Batches include `signal_batch`, `prompt_batch`, `seq_len`, and `seg_metadata`.

10. **Feature Extraction**:
    - The preprocessor computes log-Mel spectrograms for each audio batch.
    - This involves STFT computation and applying Mel filters.

11. **Model Inference**:
    - The `ctranslate2` model generates transcriptions from the features and prompts.
    - The model outputs sequences of token IDs.

12. **Decoding Transcriptions**:
    - Token IDs are decoded into text using the tokenizer.
    - Additional metadata like average log probability is calculated.

13. **Collecting Results**:
    - For each result, the `file_id` from `seg_metadata` is used to associate the transcription with the correct audio file in `responses`.
    - Responses are accumulated for each audio file.
    - The progress bar is updated accordingly.

14. **Final Output**:
    - Once all batches are processed, the method returns the final transcription results.
    - Each entry in the `responses` list contains all the transcribed segments for that audio file.

---

### **Key Components and Their Roles**

- **Scripts**:
  - `audio.py`: Handles audio loading, conversion, and feature extraction.
  - `data.py`: Manages data loading, segmentation, batching, and keeps track of file IDs.
  - `__init__.py`: Defines the base `WhisperModel` class.
  - `model.py`: Implements the `WhisperModelCT2` class interfacing with `ctranslate2`.

- **Classes and Functions**:
  - `WhisperModelCT2`: Main class that orchestrates the transcription process.
  - `WhisperDataLoader`: Loads and batches audio data from multiple files, keeping track of file IDs.
  - `BasicSegmenter`: Segments audio into fixed lengths.
  - `load_audio`: Loads and preprocesses audio files.
  - `audio_batch_generator`: Yields audio signals from files or arrays.
  - `pad_or_trim`: Pads or trims audio arrays to a fixed length.
  - `LogMelSpectogram`: Computes log-Mel spectrograms.
  - `TorchSTFT`: Performs the Short-Time Fourier Transform.
  - `generate_segment_batched`: Generates transcriptions using `ctranslate2`.
  - `stitch_speech_segments`: Merges segments when needed.
  - `fix_batch_param`: Ensures parameters are correctly formatted as lists.

---

### **Additional Notes**

- **Tracking Segments**: The library assigns a unique `file_id` to each audio file. This `file_id` is included in the `seg_metadata` for each segment, ensuring that when segments are processed in batches, the results can be correctly associated back to their original files.

- **Batch Processing**: Batches may include segments from different audio files. The library efficiently processes these batches while maintaining the association between segments and their source files.

- **Response Structure**: The final `responses` list is structured so that each element corresponds to an input audio file. Within each element, there is a list of transcription results for that file's segments, each including start and end times.

- **Parallelism and Efficiency**: By batching segments from multiple files, the library can utilize computational resources more efficiently, especially when processing small audio files.

---

By following this step-by-step timeline, you can understand how the WhisperS2T library processes multiple audio files, segments them, keeps track of which file each segment originates from, and ultimately processes them with the `ctranslate2` backend to generate transcriptions. The key components work together to ensure that the transcriptions are accurate and correctly associated with their respective audio files.

---

**References to Additional Source Code**

In analyzing how the library handles multiple audio files, we relied on the provided source code snippets. If there are other relevant files or functions (e.g., `speech_segmenter.py`, `hf_utils.py`, or additional methods in `tokenizer.py`) that influence the behavior, having access to them would provide a more comprehensive understanding. However, based on the provided code, we've outlined the processing flow as accurately as possible.

@BBC-Esq
Copy link

BBC-Esq commented Sep 18, 2024

Forgot to attach and outline of how WhisperS2T performs batching nonetheless even if only a single audio file is selected:

BBC OUTLINE OF WHISPERS2T PIPELINE FOR SINGLE AUDIO FILE
**Step-by-Step Timeline of Audio File Processing in WhisperS2T**

1. **Transcription Initiation**:
   - **Script**: `model.py`
   - **Class**: `WhisperModelCT2`
   - **Method**: `transcribe`
   - **Action**: The user initiates transcription by calling the `transcribe` method of the `WhisperModelCT2` class, passing in the audio file(s) to be transcribed.

2. **Data Loading Preparation**:
   - **Script**: `init.py`
   - **Class**: `WhisperModel`
   - **Method**: `transcribe` (inherited by `WhisperModelCT2`)
   - **Action**: The `transcribe` method invokes the data loader to prepare the audio data for processing. It creates an instance of `WhisperDataLoader`.

3. **Data Loader Invocation**:
   - **Script**: `data.py`
   - **Class**: `WhisperDataLoader`
   - **Method**: `__call__`
   - **Action**: The `WhisperDataLoader`'s `__call__` method is invoked with the audio files and related parameters. Since voice activity detection (VAD) is not used here, it calls the `get_data_loader` method.

4. **Audio Batch Generation**:
   - **Script**: `audio.py`
   - **Function**: `audio_batch_generator`
   - **Action**: The `audio_batch_generator` function iterates over the audio files, yielding each one for processing. It determines whether each item is a file path or a numpy array.

5. **Audio Loading and Format Checking**:
   - **Script**: `audio.py`
   - **Function**: `load_audio`
   - **Action**: For each audio file, the `load_audio` function attempts to open the file using `wave.open` to check if it is a 16kHz mono WAV file.

6. **Audio Conversion with FFmpeg (if necessary)**:
   - **Script**: `audio.py`
   - **Action**:
     - If the audio file is not in the required format (16kHz, mono, PCM_s16le), `ffmpeg` is used to convert it.
     - **Command Used**:
       ```
       ffmpeg -hide_banner -loglevel panic -i "{input_file}" -threads 1 -acodec pcm_s16le -ac 1 -af aresample=resampler={RESAMPLING_ENGINE} -ar 16000 "{wav_file}" -y
       ```
     - **Variables**:
       - `input_file`: Path to the original audio file.
       - `RESAMPLING_ENGINE`: Resampling engine, either 'soxr' or 'swr'.
       - `wav_file`: Temporary output file path.
     - **Action**: Converts the audio to a 16kHz mono WAV file with 16-bit PCM encoding.

7. **Reading Audio Frames and Conversion to Numpy Array**:
   - **Script**: `audio.py`
   - **Function**: `load_audio`
   - **Action**:
     - The frames from the WAV file are read using `wave.readframes`.
     - The byte data is converted to a numpy array of type `float32` with values normalized between -1.0 and 1.0 using:
       ```
       audio_signal = np.frombuffer(x, np.int16).flatten().astype(np.float32) / 32768.0
       ```

8. **Basic Audio Segmentation**:
   - **Script**: `data.py`
   - **Class**: `BasicSegmenter`
   - **Method**: `__call__`
   - **Action**:
     - The `BasicSegmenter` divides the audio signal into segments of a maximum length (`max_speech_len`, default 29.0 seconds).
     - It returns a list of start and end times for each segment.

9. **Segmented Audio Signal Preparation**:
   - **Script**: `data.py`
   - **Method**: `get_segmented_audio_signal`
   - **Action**:
     - For each segment, the corresponding portion of the audio signal is extracted.
     - Prompts are prepared using the tokenizer's `sot_sequence` method, including language and task tokens.
     - A list of tuples is created containing:
       - The audio segment (numpy array).
       - The prompt sequence (list of token IDs).
       - Any initial prompt tokens.
       - The sequence length (number of samples).
       - Segment metadata (file ID, start time, end time).

10. **Batch Creation for Data Loader**:
    - **Script**: `data.py`
    - **Method**: `get_data_loader`
    - **Action**:
      - Segmented audio signals are accumulated.
      - Batches are created based on the specified `batch_size`.
      - The method yields batches for further processing.

11. **Data Collation and Padding/Trimming**:
    - **Script**: `data.py`
    - **Method**: `data_collate_fn`
    - **Action**:
      - The `data_collate_fn` method is called for each batch.
      - **Script**: `audio.py`
      - **Function**: `pad_or_trim`
      - **Action**:
        - Each audio segment is padded or trimmed to a consistent length (`N_SAMPLES`, default 1500000 samples or 93.75 seconds at 16kHz).
        - Padding is done using zeros if the audio is shorter than `N_SAMPLES`.
        - Trimming is done by selecting the first `N_SAMPLES` samples if longer.

12. **Conversion to Torch Tensors**:
    - **Script**: `data.py`
    - **Method**: `data_collate_fn`
    - **Action**:
      - The padded or trimmed audio signals are converted to PyTorch tensors.
      - The tensors are stacked to form a batch tensor for processing.
      - Prompts are prepared and padded as needed.

13. **Feature Extraction via Log-Mel Spectrogram**:
    - **Script**: `audio.py`
    - **Class**: `LogMelSpectogram`
    - **Method**: `forward`
    - **Action**:
      - The batch of audio tensors is passed to the `LogMelSpectogram` preprocessor.
      - **STFT Computation**:
        - **Class**: `TorchSTFT`
        - **Method**: `forward`
        - **Action**: Computes the Short-Time Fourier Transform (STFT) of the audio signals.
      - **Mel Filter Application**:
        - Mel filter banks are applied to the STFT results to obtain mel spectrograms.
      - **Logarithmic Scaling**:
        - The mel spectrograms are converted to log scale using `torch.log10`.
      - **Clipping and Scaling**:
        - The log-mel spectrograms are clipped to prevent extreme values and scaled appropriately.

14. **Transcription with ctranslate2 Backend**:
    - **Script**: `model.py`
    - **Class**: `WhisperModelCT2`
    - **Method**: `generate_segment_batched`
    - **Action**:
      - The preprocessed features (log-mel spectrograms) and prompts are passed to the `generate_segment_batched` method.
      - **Method**: `self.model.generate`
      - **Action**:
        - The ctranslate2 model (`self.model`) generates transcriptions based on the input features and prompts.
        - It returns the transcription results, including text and optional scores.

15. **Result Aggregation**:
    - **Script**: `model.py`
    - **Method**: `generate_segment_batched`
    - **Action**:
      - Transcription results are collected and formatted.
      - If word-level timestamps are enabled (not default), additional alignment is performed (not detailed here).
    - **Script**: `init.py`
    - **Method**: `transcribe`
    - **Action**:
      - Results from each batch are aggregated.
      - The final transcriptions for all audio files are compiled.

16. **Transcription Completion**:
    - **Script**: `init.py`
    - **Method**: `transcribe`
    - **Action**:
      - The transcriptions are returned to the user as the output of the `transcribe` method.

**Summary of Key Processing Steps:**

- **Audio Conversion to Mono and Resampling**:
  - **Function**: `load_audio` in `audio.py`.
  - **Action**: Uses `ffmpeg` to convert audio files to 16kHz mono WAV format if they are not already in that format.

- **Conversion to Numpy Array**:
  - **Function**: `load_audio` in `audio.py`.
  - **Action**: Reads audio frames and converts them to a numpy array of type `float32` with normalized values between -1.0 and 1.0.

- **Segmentation and Padding/Trimming**:
  - **Method**: `basic_segmenter` in `data.py`.
  - **Function**: `pad_or_trim` in `audio.py`.
  - **Action**:
    - Segments the audio into manageable chunks.
    - Pads or trims audio segments to a consistent length for batch processing.

- **Feature Extraction (Conversion to Log-Mel Spectrograms)**:
  - **Class**: `LogMelSpectogram` in `audio.py`.
  - **Action**: Converts audio signals into log-mel spectrogram features suitable for input to the transcription model.

- **Processing by ctranslate2 Backend**:
  - **Class**: `WhisperModelCT2` in `model.py`.
  - **Method**: `generate_segment_batched`
  - **Action**: Uses the ctranslate2 backend to generate transcriptions from the processed audio features.

This step-by-step timeline outlines how the WhisperS2T library processes an audio file, detailing the conversion to mono, resampling, conversion to an array, and subsequent processing leading up to transcription by the ctranslate2 backend.

@MahmoudAshraf97
Copy link
Contributor Author

no speedup gains will be noticed unless continuous batching is implemented which is different from regular batching, in regular batching the speedups will eventually plateau because the longest sequence often runs alone until completion which means in large batch sizes the actual batch size will be 1 because all other sequences have already finished, this is an analysis of how the generation efficiency drops greatly with larger batch sizes even if the GPU is not yet saturated

efficiency = generation loops needed / actual generation loops (calculated by the longest seq in the output)
# time is for decoding only, 75 * 30s segments

# efficiency:  1.0 / batch_size=1
# efficiency:  0.87 / batch_size=2
# efficiency:  0.80 / batch_size=4
# efficiency:  0.67 / batch_size=8
# efficiency:  0.56 / batch_size=16
# efficiency:  0.45 / batch_size=32

achieving 100% efficiency is already possible with transformers backend regardless of the batch size
also check #1333
AFAIK, there is no public implementation of continuous batching for whisper in any backend so far

@BBC-Esq
Copy link

BBC-Esq commented Sep 18, 2024

When I run WhisperS2T it has no problem fully saturating the CUDA cores. I'll look more into the distinction you're drawing, but my initial impression is that as long as you construct the pipeline eloquently before the data is sent to ctranslate2, the distinction you're drawing is not on point. You've reviewed the whispers2t pipeline I think you said, and see how it aggregates chunks of an audio file for a batch, but beyond that, even aggregates chunks from different audio files into a batch...all the while keeping track of all timestamp information, the files from which the various chunks came from etc.? Sanity test for me perhaps. :-)

@BBC-Esq
Copy link

BBC-Esq commented Sep 18, 2024

no speedup gains will be noticed unless continuous batching is implemented which is different from regular batching, in regular batching the speedups will eventually plateau because the longest sequence often runs alone until completion which means in large batch sizes the actual batch size will be 1 because all other sequences have already finished, this is an analysis of how the generation efficiency drops greatly with larger batch sizes even if the GPU is not yet saturated

efficiency = generation loops needed / actual generation loops (calculated by the longest seq in the output)
# time is for decoding only, 75 * 30s segments

# efficiency:  1.0 / batch_size=1
# efficiency:  0.87 / batch_size=2
# efficiency:  0.80 / batch_size=4
# efficiency:  0.67 / batch_size=8
# efficiency:  0.56 / batch_size=16
# efficiency:  0.45 / batch_size=32

achieving 100% efficiency is already possible with transformers backend regardless of the batch size also check #1333 AFAIK, there is no public implementation of continuous batching for whisper in any backend so far

Regarding your specific test results...can you link in the private repo I created the actual audio files and the script you used...and the overall processing time? I'd like to do a comparison if you don't mind.

@MahmoudAshraf97
Copy link
Contributor Author

The CUDA cores will be saturated ofcourse with one tiny problem, it's outputting garbage that will be discarded anyways
Continuous vs naive batching is a pretty common topic with a lot of resources, check this for a thourough explanation
https://www.anyscale.com/blog/continuous-batching-llm-inference

@BBC-Esq
Copy link

BBC-Esq commented Sep 18, 2024

The CUDA cores will be saturated ofcourse with one tiny problem, it's outputting garbage that will be discarded anyways Continuous vs naive batching is a pretty common topic with a lot of resources, check this for a thourough explanation https://www.anyscale.com/blog/continuous-batching-llm-inference

Would you mind sharing the audio files you tested like I mentioned?

@MahmoudAshraf97
Copy link
Contributor Author

I tested with this
https://youtu.be/s0XopkGcN9U?si=TDHuCQWy9J93n5T6
Although any audio file should produce similar results

@MahmoudAshraf97
Copy link
Contributor Author

75 segments not files, 34min / 30s ~= 75 segments

@BBC-Esq
Copy link

BBC-Esq commented Sep 19, 2024

@MahmoudAshraf97 Can you provide the script you used for your benchmark please?

@MahmoudAshraf97
Copy link
Contributor Author

This is a good starting point to reproduce using any inference engine:
features is the mel-transformed audio, the encoding and decoding is split in different stages to correctly measure the decoding performance

%%timeit
batch_size = 4
encoder_batch_size = 4

total_cycles = 0
total_tokens = 0
encoder_outputs = []
for i in range(0, len(features), encoder_batch_size):
    encoder_outputs.extend(
        model.encoder.get_audio_features(
            features[i : i + encoder_batch_size]
        ).unbind()
    )
    torch.cuda.empty_cache()
filtered_outputs = []
for i in range(0,len(encoder_outputs),batch_size):
    inputs = list(encoder_outputs[i:i+batch_size])
    decoder_input_ids = prompt_id.repeat(len(inputs), 1)
    outputs = model.decoder.generate(decoder_input_ids, inputs, model.eot_id, max_new_tokens=124, num_beams=1)
    
    for output in outputs:
        filtered_outputs.append([token for token in output[0][4:] if token != model.eot_id])
    total_cycles += max([len(output) for output in filtered_outputs])
    total_tokens += sum([len(output) for output in filtered_outputs])
    torch.cuda.empty_cache()
    
print("effeciency: ",total_tokens / batch_size/total_cycles)

@MahmoudAshraf97
Copy link
Contributor Author

Mathematically, both are identical

@BBC-Esq
Copy link

BBC-Esq commented Sep 19, 2024

This is a good starting point to reproduce using any inference engine: features is the mel-transformed audio, the encoding and decoding is split in different stages to correctly measure the decoding performance

%%timeit
batch_size = 4
encoder_batch_size = 4

total_cycles = 0
total_tokens = 0
encoder_outputs = []
for i in range(0, len(features), encoder_batch_size):
    encoder_outputs.extend(
        model.encoder.get_audio_features(
            features[i : i + encoder_batch_size]
        ).unbind()
    )
    torch.cuda.empty_cache()
filtered_outputs = []
for i in range(0,len(encoder_outputs),batch_size):
    inputs = list(encoder_outputs[i:i+batch_size])
    decoder_input_ids = prompt_id.repeat(len(inputs), 1)
    outputs = model.decoder.generate(decoder_input_ids, inputs, model.eot_id, max_new_tokens=124, num_beams=1)
    
    for output in outputs:
        filtered_outputs.append([token for token in output[0][4:] if token != model.eot_id])
    total_cycles += max([len(output) for output in filtered_outputs])
    total_tokens += sum([len(output) for output in filtered_outputs])
    torch.cuda.empty_cache()
    
print("effeciency: ",total_tokens / batch_size/total_cycles)

I don't do "good starting points". When I ask for a script I expect a script, and out of respect, I do the same when someone I'm supposedly collaborating with asks me for a SCRIPT.

@MahmoudAshraf97
Copy link
Contributor Author

My dear friend we are on the same boat here,

  1. I haven't given up on benchmarking whisper, but FW is not in a stable state right now with this much of unmerged commits and fixes, that's why I put a hold on it and will resume it once the organization moves forward with the pending PRs
  2. The full code is part of private work with NDA signed, that's why I cannot share more than this which was also shared here, aside from that this PR has nothing to do with the other work we've been doing together so I'd politely ask to keep comments here solely related to what this PR is trying to do, anything else can be communicated in private via email or through an issue in your private repo, also probably best to delete all comments on this PR and I'll do the same to prevent out-of-topic distraction for maintainers

@BBC-Esq
Copy link

BBC-Esq commented Sep 19, 2024

I deleted all comments that were unduly inflammatory but left the ones that, while a little inflammatory, still directly pertain to this pull request. Next time please lead with the "NDA" reason. Also, you didn't clarify if this prevents you from sending the tensorrt code on the private repo like you promised. This is all voluntary so...tell me if you will NOT do it rather than say you will, but don't do it. Thanks.

@freddierice
Copy link

How can I help? I would like to use this in my pipeline.

@MahmoudAshraf97
Copy link
Contributor Author

How can I help? I would like to use this in my pipeline.

Even after this PR is accepted, we will need to find a way to stop the generation once a single sequence finishes rather that wait for all of them to finish, this is easily done in transformers using a custom stopping criteria, but I have no idea how to do it in CT2

@freddierice
Copy link

I tested the code and it works for my usecase as is! I'm curious what the speedup is between this and full IFB, but don't have the time to dedicate to that big of a code change, unfortunately. Thank you for this, its a huge win 🙇.

@Jiltseb
Copy link

Jiltseb commented Oct 22, 2024

no speedup gains will be noticed unless continuous batching is implemented which is different from regular batching, in regular batching the speedups will eventually plateau because the longest sequence often runs alone until completion which means in large batch sizes the actual batch size will be 1 because all other sequences have already finished, this is an analysis of how the generation efficiency drops greatly with larger batch sizes even if the GPU is not yet saturated

efficiency = generation loops needed / actual generation loops (calculated by the longest seq in the output)
# time is for decoding only, 75 * 30s segments

# efficiency:  1.0 / batch_size=1
# efficiency:  0.87 / batch_size=2
# efficiency:  0.80 / batch_size=4
# efficiency:  0.67 / batch_size=8
# efficiency:  0.56 / batch_size=16
# efficiency:  0.45 / batch_size=32

achieving 100% efficiency is already possible with transformers backend regardless of the batch size also check #1333 AFAIK, there is no public implementation of continuous batching for whisper in any backend so far

Hi @MahmoudAshraf97 Cool! How different it is from vllm-whisper with continuous batching?

@MahmoudAshraf97
Copy link
Contributor Author

Hi Jilt, TRT-LLM just released continuous batching support for whisper last week, I'm still improving my implementation using transformers backend to have better KV cache utilization and will release it publicly as soon as it's ready.
It should not be different from a higher level POV, but if the vLLM implementation matured it's going to be faster, although there's lack of benchmarks to compare both, I'll try to compare them when I can

@Jiltseb
Copy link

Jiltseb commented Oct 22, 2024

Hi Jilt, TRT-LLM just released continuous batching support for whisper last week, I'm still improving my implementation using transformers backend to have better KV cache utilization and will release it publicly as soon as it's ready. It should not be different from a higher level POV, but if the vLLM implementation matured it's going to be faster, although there's lack of benchmarks to compare both, I'll try to compare them when I can

I have tried vllm-whisper with CB but the results seem to suffer because they used 3sec audio+padding instead of relying on semantics to split the audio.
vllm-project/vllm#5964 (comment)

@MahmoudAshraf97 MahmoudAshraf97 deleted the branch OpenNMT:master November 21, 2024 15:56
@MahmoudAshraf97 MahmoudAshraf97 deleted the master branch November 21, 2024 15:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants