Skip to content

Commit

Permalink
FastaiV2 model is saving and loading and working for inference
Browse files Browse the repository at this point in the history
  • Loading branch information
bnestor committed Sep 18, 2024
1 parent a043a96 commit 205b769
Show file tree
Hide file tree
Showing 5 changed files with 262 additions and 171 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ model_type: "FastAI"
model_local_threshold: 0.5
model_global_threshold: 3
model_path: "./model"
model_name: "model.pt"
model_name: "stg2-rn18.pkl"
hls_stream_type: "LiveHLS"
hls_polling_interval: 60
hls_hydrophone_id: "rpi_orcasound_lab"
Expand Down
78 changes: 78 additions & 0 deletions InferenceSystem/src/audiotransform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import torchaudio
from pathlib import Path
from fastai.vision.all import Transform, F
from dataclasses import dataclass


def label_func(f): return f.parent.name



@dataclass
class SpectrogramConfig2:
f_min: float = 0.0 # Minimum frequency to display
f_max: float = 10000.0 # Maximum frequency to display
hop_length: int = 256 # Hop length
n_fft: int = 2560 # Number of samples for Fourier transform
n_mels: int = 256 # Number of Mel bins
pad: int = 0 # Padding
to_db_scale: bool = True # Convert to dB scale
top_db: int = 100 # Top decibel sound
win_length: int = None # Window length
n_mfcc: int = 20 # Number of MFCC features

@dataclass
class AudioConfig2:
standardize: bool = False # Standardization flag
sg_cfg: dataclass = None # Spectrogram configuration
duration: int = 4000 # Duration in samples (e.g., 4000 for 4 seconds)
resample_to: int = 20000 # Resample rate in Hz


class AudioTransform(Transform):
def __init__(self, config, mode='test'):
self.config=config
self.to_db_scale = torchaudio.transforms.AmplitudeToDB(top_db=self.config.sg_cfg.top_db)
self.spectrogrammer = torchaudio.transforms.MelSpectrogram(
sample_rate=self.config.resample_to,
n_fft=self.config.sg_cfg.n_fft,
hop_length=self.config.sg_cfg.hop_length,
n_mels=self.config.sg_cfg.n_mels,
f_min=self.config.sg_cfg.f_min,
f_max=self.config.sg_cfg.f_max
)
self.time_masking = torchaudio.transforms.TimeMasking(time_mask_param=80)
self.freq_masking = torchaudio.transforms.FrequencyMasking(freq_mask_param=80)
self.mode=mode

def encodes(self, fn: Path):
wave, sr = torchaudio.load(fn)
wave = wave.mean(dim=0) # reduce to mono
# resample to
wave = torchaudio.functional.resample(wave, sr, self.config.resample_to)

# pad or truncate to config.duration
max_len = int(self.config.duration/1000 * self.config.resample_to)

# print(wave.shape)
if wave.shape[0] < max_len:
wave = F.pad(wave, (0, max_len - wave.shape[0])) # Pad if shorter than max_len
else:
wave = wave[:max_len] # Truncate if longer than max_len

# print(wave.shape)

# Generate the MelSpectrogram
spec = self.spectrogrammer(wave)

# during training only!
if self.mode=='train':
spec = self.time_masking(self.freq_masking(spec))

# Convert the MelSpectrogram to decibel scale if specified
if self.config.sg_cfg.to_db_scale:
spec = self.to_db_scale(spec)

# print('spec',spec.shape)
spec = spec.unsqueeze(0).expand(3, -1, -1)
return spec
124 changes: 39 additions & 85 deletions InferenceSystem/src/model/fastai_inference.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# from fastai.basic_train import load_learner
from fastai.vision.all import *
from fastai.vision.all import load_learner, load_model, accuracy, models,Learner, DataBlock, CategoryBlock, RandomSplitter, TransformBlock, get_files
import pandas as pd
from pydub import AudioSegment

import sys
from audiotransform import AudioTransform, SpectrogramConfig2, AudioConfig2, label_func
from librosa import get_duration
from pathlib import Path
from numpy import floor
import numpy as np
# from audio.data import AudioConfig, SpectrogramConfig, AudioList
from fastai.data.all import *
import torchaudio
Expand All @@ -16,10 +19,28 @@
import torch
import scipy.io.wavfile

# import decorators








spec_config = SpectrogramConfig2()
config = AudioConfig2(sg_cfg=spec_config)

audio_transform = AudioTransform(config, mode='test')



def _load_model(mPath, mName="stg2-rn18.pkl"):
global audio_transform
if mName.endswith('.pkl'):
tmp = load_learner(os.path.join(mPath, mName))
tmp = load_learner(os.path.join(mPath, mName))
# tmp = learner.load(os.path.join(mPath, mName))
# tmp = load_model(os.path.join(mPath, mName))
elif mName.endswith('.pt'):
import torch
# it is a pytorch model
Expand All @@ -35,10 +56,7 @@ def _load_model(mPath, mName="stg2-rn18.pkl"):
testpath = Path(mPath)


spec_config = SpectrogramConfig2()
config = AudioConfig2(sg_cfg=spec_config)

audio_transform = AudioTransform(config, mode='test')



get_wav_files = lambda x: get_files(x, extensions=['.wav',])
Expand All @@ -47,7 +65,7 @@ def _load_model(mPath, mName="stg2-rn18.pkl"):
blocks=(TransformBlock, CategoryBlock),
get_items=get_wav_files,
get_x=audio_transform,
get_y=label_func,
get_y=lambda x: random.choice(['negative', 'positive']),
splitter=RandomSplitter(),
item_tfms=[],
batch_tfms=[]
Expand All @@ -62,7 +80,10 @@ def _load_model(mPath, mName="stg2-rn18.pkl"):
# load the model
for n, p in checkpoint.items():
print(n, p.shape)
tmp.model.load_state_dict(checkpoint, strict=False)
for n, p in tmp.model.state_dict().items():
print(n, p.shape)
tmp.model.load_state_dict(checkpoint, strict=True)
raise NotImplementedError("pytorch loading is currently throwing an indexing error")
else:raise NotImplementedError("Only .pkl and .pt models are supported")

return tmp
Expand Down Expand Up @@ -101,80 +122,10 @@ def extract_segments(audioPath, sampleDict, destnPath, suffix):
end_time, output_file_path)


@dataclass
class SpectrogramConfig2:
f_min: float = 0.0 # Minimum frequency to display
f_max: float = 10000.0 # Maximum frequency to display
hop_length: int = 256 # Hop length
n_fft: int = 2560 # Number of samples for Fourier transform
n_mels: int = 256 # Number of Mel bins
pad: int = 0 # Padding
to_db_scale: bool = True # Convert to dB scale
top_db: int = 100 # Top decibel sound
win_length: int = None # Window length
n_mfcc: int = 20 # Number of MFCC features

@dataclass
class AudioConfig2:
standardize: bool = False # Standardization flag
sg_cfg: dataclass = None # Spectrogram configuration
duration: int = 4000 # Duration in samples (e.g., 4000 for 4 seconds)
resample_to: int = 20000 # Resample rate in Hz


class AudioTransform(Transform):
def __init__(self, config, mode='test'):
self.config=config
self.to_db_scale = torchaudio.transforms.AmplitudeToDB(top_db=self.config.sg_cfg.top_db)
self.spectrogrammer = torchaudio.transforms.MelSpectrogram(
sample_rate=self.config.resample_to,
n_fft=self.config.sg_cfg.n_fft,
hop_length=self.config.sg_cfg.hop_length,
n_mels=self.config.sg_cfg.n_mels,
f_min=self.config.sg_cfg.f_min,
f_max=self.config.sg_cfg.f_max
)
self.time_masking = torchaudio.transforms.TimeMasking(time_mask_param=80)
self.freq_masking = torchaudio.transforms.FrequencyMasking(freq_mask_param=80)
self.mode=mode

def encodes(self, fn: Path):
wave, sr = torchaudio.load(fn)
wave = wave.mean(dim=0) # reduce to mono
# resample to
wave = torchaudio.functional.resample(wave, sr, self.config.resample_to)

# pad or truncate to config.duration
max_len = int(self.config.duration/1000 * self.config.resample_to)

# print(wave.shape)
if wave.shape[0] < max_len:
wave = F.pad(wave, (0, max_len - wave.shape[0])) # Pad if shorter than max_len
else:
wave = wave[:max_len] # Truncate if longer than max_len

# print(wave.shape)

# Generate the MelSpectrogram
spec = self.spectrogrammer(wave)

# during training only!
if self.mode=='train':
spec = self.time_masking(self.freq_masking(spec))

# Convert the MelSpectrogram to decibel scale if specified
if self.config.sg_cfg.to_db_scale:
spec = self.to_db_scale(spec)

# print('spec',spec.shape)
spec = spec.unsqueeze(0).expand(3, -1, -1)
return spec


# Integrate with fastai's DataBlock (customized for your use case)
def label_func(f):
return f.parent.name

class FastAI2Model():
def __init__(self, model_path, model_name="stg2-rn18.pkl", threshold=0.5, min_num_positive_calls_threshold=3):
self.model = _load_model(model_path, model_name)
Expand All @@ -197,13 +148,15 @@ def predict(self, wav_file_path):
os.makedirs(local_dir)

# infer clip length
max_length = get_duration(filename=wav_file_path)
# max_length = get_duration(filename=wav_file_path)
max_length = get_duration(path=wav_file_path)

print(os.path.basename(wav_file_path))
print("Length of Audio Clip:{0}".format(max_length))
#max_length = 60
# Generating 2 sec proposal with 1 sec hop length
twoSecList = []
for i in range(int(floor(max_length)-1)):
for i in range(int(np.floor(max_length)-1)):
twoSecList.append([i, i+2])

# Creating a proposal dictionary
Expand All @@ -218,6 +171,8 @@ def predict(self, wav_file_path):
local_dir,
""
)

print(local_dir)

# Definining Audio config needed to create on the fly mel spectograms
spec_cfg = SpectrogramConfig2(
Expand Down Expand Up @@ -258,15 +213,14 @@ def predict(self, wav_file_path):
)

# Create DataLoaders
test_dls = audio_block.dataloaders(test_data_folder, bs=1)
test_dls = audio_block.dataloaders(test_data_folder, bs=2)

# Scoring each 2 sec clip
predictions = []
pathList = list(pd.Series(test_data_folder.ls()).astype('str'))
# pathList = list(pd.Series(test_data_folder.ls()).astype('str'))
pathList= []
for pathname, item in zip(test_dls.items, [item[0] for item in test_dls.train_ds]):
print(item.shape)
# with torch.amp.autocast(device_type=device, enabled=False):
print(self.model.predict(item))
predictions.append(self.model.predict(item)[2][1].cpu().data.numpy().tolist())
pathList.append(str(pathname))
# for item in testdb.x:
Expand Down
Loading

0 comments on commit 205b769

Please sign in to comment.