Skip to content

Commit

Permalink
Merge pull request #1 from bnestor/dev/pastorep/sample-generation-ref…
Browse files Browse the repository at this point in the history
…inement

Dev/pastorep/sample generation refinement
  • Loading branch information
bnestor authored Sep 17, 2024
2 parents 6cb9573 + 01e9f7e commit d722d3d
Show file tree
Hide file tree
Showing 39 changed files with 2,566 additions and 2,353 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ model_type: "FastAI"
model_local_threshold: 0.5
model_global_threshold: 3
model_path: "./model"
model_name: "model.pkl"
model_name: "model.pt"
hls_stream_type: "LiveHLS"
hls_polling_interval: 60
hls_hydrophone_id: "rpi_orcasound_lab"
Expand Down
8 changes: 4 additions & 4 deletions InferenceSystem/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
-f https://download.pytorch.org/whl/torch_stable.html
fastai==1.0.61
librosa==0.8.0
fastai
librosa==0.10
pydub==0.24.1
pandas
numpy
torchaudio==0.6.0
torchaudio
git+https://github.com/fastaudio/[email protected]
ipython
spacy
Expand All @@ -20,7 +20,7 @@ azure-cosmos
azure-cosmosdb-table
azure-storage-blob
ffmpeg-python
numba==0.48
numba
opencv-python
boto3
pytz
Expand Down
4 changes: 2 additions & 2 deletions InferenceSystem/src/LiveInferenceOrchestrator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Live inference orchestrator
# Rename these files
from model.podcast_inference import OrcaDetectionModel
from model.fastai_inference import FastAIModel
from model.fastai_inference import FastAI2Model

from orca_hls_utils.DateRangeHLSStream import DateRangeHLSStream
from orca_hls_utils.HLSStream import HLSStream
Expand Down Expand Up @@ -118,7 +118,7 @@ def populate_metadata_json(
whalecall_classification_model = OrcaDetectionModel(model_path, threshold=model_local_threshold, min_num_positive_calls_threshold=model_global_threshold)
elif model_type == "FastAI":
model_name = config_params["model_name"]
whalecall_classification_model = FastAIModel(model_path=model_path, model_name=model_name, threshold=model_local_threshold, min_num_positive_calls_threshold=model_global_threshold)
whalecall_classification_model = FastAI2Model(model_path=model_path, model_name=model_name, threshold=model_local_threshold, min_num_positive_calls_threshold=model_global_threshold)
else:
raise ValueError("model_type should be one of AudioSet / FastAIModel")

Expand Down
191 changes: 172 additions & 19 deletions InferenceSystem/src/model/fastai_inference.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,73 @@
from fastai.basic_train import load_learner
# from fastai.basic_train import load_learner
from fastai.vision.all import *
import pandas as pd
from pydub import AudioSegment
from librosa import get_duration
from pathlib import Path
from numpy import floor
from audio.data import AudioConfig, SpectrogramConfig, AudioList
# from audio.data import AudioConfig, SpectrogramConfig, AudioList
from fastai.data.all import *
import torchaudio
import os
import shutil
import tempfile
from dataclasses import dataclass

import torch
import scipy.io.wavfile

def load_model(mPath, mName="stg2-rn18.pkl"):
return load_learner(mPath, mName)

def _load_model(mPath, mName="stg2-rn18.pkl"):
if mName.endswith('.pkl'):
tmp = load_learner(os.path.join(mPath, mName))
elif mName.endswith('.pt'):
import torch
# it is a pytorch model
checkpoint = torch.load(os.path.join(mPath, mName), map_location=torch.device('cpu'))

# create a dummy dataloader
metrics = [accuracy]
# loss_func = CrossEntropyLossFlat()

# write a dummy wave file
wave =np.zeros((20000))
scipy.io.wavfile.write(os.path.join(mPath,'dummy.wav'), 20000, wave)
testpath = Path(mPath)


spec_config = SpectrogramConfig2()
config = AudioConfig2(sg_cfg=spec_config)

audio_transform = AudioTransform(config, mode='test')


get_wav_files = lambda x: get_files(x, extensions=['.wav',])
# Define your DataBlock
audio_block = DataBlock(
blocks=(TransformBlock, CategoryBlock),
get_items=get_wav_files,
get_x=audio_transform,
get_y=label_func,
splitter=RandomSplitter(),
item_tfms=[],
batch_tfms=[]
)

dls = audio_block.dataloaders(testpath, bs=1)


tmp = Learner(dls,models.resnet18(), metrics = metrics)


# load the model
for n, p in checkpoint.items():
print(n, p.shape)
tmp.model.load_state_dict(checkpoint, strict=False)
else:raise NotImplementedError("Only .pkl and .pt models are supported")

return tmp



def get_wave_file(wav_file):
'''
Expand Down Expand Up @@ -46,16 +101,91 @@ def extract_segments(audioPath, sampleDict, destnPath, suffix):
end_time, output_file_path)


class FastAIModel():
@dataclass
class SpectrogramConfig2:
f_min: float = 0.0 # Minimum frequency to display
f_max: float = 10000.0 # Maximum frequency to display
hop_length: int = 256 # Hop length
n_fft: int = 2560 # Number of samples for Fourier transform
n_mels: int = 256 # Number of Mel bins
pad: int = 0 # Padding
to_db_scale: bool = True # Convert to dB scale
top_db: int = 100 # Top decibel sound
win_length: int = None # Window length
n_mfcc: int = 20 # Number of MFCC features

@dataclass
class AudioConfig2:
standardize: bool = False # Standardization flag
sg_cfg: dataclass = None # Spectrogram configuration
duration: int = 4000 # Duration in samples (e.g., 4000 for 4 seconds)
resample_to: int = 20000 # Resample rate in Hz


class AudioTransform(Transform):
def __init__(self, config, mode='test'):
self.config=config
self.to_db_scale = torchaudio.transforms.AmplitudeToDB(top_db=self.config.sg_cfg.top_db)
self.spectrogrammer = torchaudio.transforms.MelSpectrogram(
sample_rate=self.config.resample_to,
n_fft=self.config.sg_cfg.n_fft,
hop_length=self.config.sg_cfg.hop_length,
n_mels=self.config.sg_cfg.n_mels,
f_min=self.config.sg_cfg.f_min,
f_max=self.config.sg_cfg.f_max
)
self.time_masking = torchaudio.transforms.TimeMasking(time_mask_param=80)
self.freq_masking = torchaudio.transforms.FrequencyMasking(freq_mask_param=80)
self.mode=mode

def encodes(self, fn: Path):
wave, sr = torchaudio.load(fn)
wave = wave.mean(dim=0) # reduce to mono
# resample to
wave = torchaudio.functional.resample(wave, sr, self.config.resample_to)

# pad or truncate to config.duration
max_len = int(self.config.duration/1000 * self.config.resample_to)

# print(wave.shape)
if wave.shape[0] < max_len:
wave = F.pad(wave, (0, max_len - wave.shape[0])) # Pad if shorter than max_len
else:
wave = wave[:max_len] # Truncate if longer than max_len

# print(wave.shape)

# Generate the MelSpectrogram
spec = self.spectrogrammer(wave)

# during training only!
if self.mode=='train':
spec = self.time_masking(self.freq_masking(spec))

# Convert the MelSpectrogram to decibel scale if specified
if self.config.sg_cfg.to_db_scale:
spec = self.to_db_scale(spec)

# print('spec',spec.shape)
spec = spec.unsqueeze(0).expand(3, -1, -1)
return spec


# Integrate with fastai's DataBlock (customized for your use case)
def label_func(f):
return f.parent.name

class FastAI2Model():
def __init__(self, model_path, model_name="stg2-rn18.pkl", threshold=0.5, min_num_positive_calls_threshold=3):
self.model = load_model(model_path, model_name)
self.model = _load_model(model_path, model_name)
self.threshold = threshold
self.min_num_positive_calls_threshold = min_num_positive_calls_threshold

def predict(self, wav_file_path):
'''
Function which generates local predictions using wavefile
'''
device = "cuda" if torch.cuda.is_available() else "cpu"

# Creates local directory to save 2 second clops
# local_dir = "./fastai_dir/"
Expand Down Expand Up @@ -90,8 +220,7 @@ def predict(self, wav_file_path):
)

# Definining Audio config needed to create on the fly mel spectograms
config = AudioConfig(standardize=False,
sg_cfg=SpectrogramConfig(
spec_cfg = SpectrogramConfig2(
f_min=0.0, # Minimum frequency to Display
f_max=10000, # Maximum Frequency to Display
hop_length=256,
Expand All @@ -102,23 +231,46 @@ def predict(self, wav_file_path):
top_db=100, # Top decible sound
win_length=None,
n_mfcc=20)

# Definining Audio config needed to create on the fly mel spectograms
config = AudioConfig2(standardize=False,
sg_cfg=spec_cfg,
duration=4000, # 4 sec padding or snip
resample_to = 20000, # Every sample at 20000 frequency
)
config.duration = 4000 # 4 sec padding or snip
config.resample_to = 20000 # Every sample at 20000 frequency
config.downmix=True
# config.downmix=True

# Creating a Audio DataLoader
test_data_folder = Path(local_dir)
tfms = None
test = AudioList.from_folder(
test_data_folder, config=config).split_none().label_empty()
testdb = test.transform(tfms).databunch(bs=32)

# Creating a Audio DataLoader
audio_transform = AudioTransform(config, mode='test')


# Define your DataBlock
audio_block = DataBlock(
blocks=(TransformBlock, CategoryBlock),
get_items=get_files,
get_x=audio_transform,
get_y=label_func,
splitter=RandomSplitter(),
item_tfms=[],
batch_tfms=[]
)

# Create DataLoaders
test_dls = audio_block.dataloaders(test_data_folder, bs=1)

# Scoring each 2 sec clip
predictions = []
pathList = list(pd.Series(test_data_folder.ls()).astype('str'))
for item in testdb.x:
predictions.append(self.model.predict(item)[2][1])
for pathname, item in zip(test_dls.items, [item[0] for item in test_dls.train_ds]):
print(item.shape)
# with torch.amp.autocast(device_type=device, enabled=False):
print(self.model.predict(item))
predictions.append(self.model.predict(item)[2][1].cpu().data.numpy().tolist())
pathList.append(str(pathname))
# for item in testdb.x:
# predictions.append(self.model.predict(item)[2][1])

# clean folder
shutil.rmtree(local_dir)
Expand Down Expand Up @@ -157,7 +309,8 @@ def predict(self, wav_file_path):
'duration_s': 1.0,
'confidence': [prediction.confidence[prediction.shape[0]-1]]
})
submission = submission.append(lastLine, ignore_index=True)
# submission = submission.append(lastLine, ignore_index=True)
submission = pd.concat((submission, lastLine), ignore_index=True)
submission = submission[['wav_filename', 'start_time_s', 'duration_s', 'confidence']]

# initialize output JSON
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading

0 comments on commit d722d3d

Please sign in to comment.