Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

If i need a file of tflite format,how to convert the stft and istft use conv1d? #30

Open
panhu opened this issue Oct 11, 2022 · 9 comments

Comments

@panhu
Copy link

panhu commented Oct 11, 2022

Hi,i find a method https://github.com/huyanxin/phasen/blob/master/model/conv_stft.py use conv1d and conv1d_transpose instead stft and istft,but it is pytorch.When i replace tensorflow with pytorch,the result is error.Can i know you code about conv1d and conv1d_transpose instead stft and istft? Because later I want to compress it and move it to the chip。
Thank you vary much!

@Le-Xiaohuai-speech
Copy link
Owner

initialize the weights of convolutional layers by the basis function of the FFT.

@panhu
Copy link
Author

panhu commented Oct 20, 2022

Thanks,This is modified code:

import os
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Lambda, Input,Conv1D, Conv2D, BatchNormalization, Conv2DTranspose, Concatenate, LayerNormalization, PReLU
from tensorflow.keras.callbacks import ReduceLROnPlateau, CSVLogger, EarlyStopping, ModelCheckpoint
#from tensorflow.keras.layers import Conv1DTranspose

import soundfile as sf
import librosa
from random import seed
import numpy as np
import tqdm
from scipy.signal import get_window

from modules import DprnnBlock
from utils import reshape, transpose, ParallelModelCheckpoints
from data_loader import *

seed(42)

def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False):
if win_type == 'None' or win_type is None:
window = np.ones(win_len)
else:
window = get_window(win_type, win_len, fftbins=True)**0.5

N = fft_len
fourier_basis = np.fft.rfft(np.eye(N))[:win_len]
real_kernel = np.real(fourier_basis)
imag_kernel = np.imag(fourier_basis)
kernel = np.concatenate([real_kernel, imag_kernel], 1).T

if invers :
    kernel = np.linalg.pinv(kernel).T 

kernel = kernel*window
kernel = kernel[:, None, :]
return tf.convert_to_tensor(kernel,dtype=tf.float32),tf.convert_to_tensor(window[None,:,None],dtype=tf.float32)

#kernel = init_kernels(400, 100, 512, win_type='hanning', invers=False)

class ConvSTFT(tf.keras.layers.Layer):
def init(self,win_len,win_inc,fft_len = None,win_type='hamming',feature_type='real',fix=True):
super(ConvSTFT,self).init()

       if fft_len == None:
          self.fft_len = np.int(2**np.ceil(np.log2(win_len)))
       else:
          self.fft_len = fft_len

       kernel,_ = init_kernels(win_len,win_inc,self.fft_len,win_type)
       self.weight = tf.reshape(kernel,[400,1,402])
       self.feature_type = feature_type
       self.stride = win_inc
       self.win_len = win_len
       self.dim = self.fft_len

   def call(self,inputs):
       
       outputs = tf.nn.conv1d(inputs,self.weight,stride= self.stride,padding='VALID')
       
       if self.feature_type == 'complex':
          return outputs
       else:
          #outputs = tf.reshape(outputs,[1,1,-1])
          dim = self.dim//2 + 1
          real = outputs[:,:,:dim]
          imag = outputs[:,:,dim:] 
       return real,imag

class ConviSTFT(tf.keras.layers.Layer):
def init(self,win_len,win_inc,fft_len=None,win_type='hamming',feature_type='real',fix=True):
super(ConviSTFT,self).init()
if fft_len == None:
self.fft_len = np.int(2**np.ceil(np.log2(win_len)))
else:
self.fft_len = fft_len

      kernel,window = init_kernels(win_len,win_inc,self.fft_len,win_type,invers= True)
      self.weight = tf.Variable(kernel,trainable=False)
      self.weight = tf.reshape(self.weight,[400,1,402])
      self.feature_type = feature_type
      self.win_type = win_type
      self.win_len = win_len
      self.win_inc = win_inc
      self.stride = win_inc
      self.dim = self.fft_len

  def call(self,inputs):
      
      outputs = tf.nn.conv1d_transpose(inputs,filters=self.weight,output_shape=([8,1599,400]),strides=1,padding='VALID')

      #outputs = tf.reshape(outputs,[8,1,-1])

      return outputs
      #t = tf.tile(self.window,[1,1,25597])**2
      #t = to_float(t)
      #t = tf.reshape(t,[1,25597,400])

      #self.enframe = tf.reshape(self.enframe,[400,1,400])

class MK_M(tf.keras.layers.Layer):
def init(self,**kwargs):
super(MK_M,self).init(**kwargs)

  def call(self,inputs):
        [noisy_real,noisy_imag,mask] = inputs
        noisy_real = noisy_real[:,:,:,0]
        noisy_imag = noisy_imag[:,:,:,0]            

        mask_real = mask[:,:,:,0]
        mask_imag = mask[:,:,:,1]

        enh_real = noisy_real*mask_real - noisy_imag*mask_imag
        enh_imag = noisy_real * mask_imag + noisy_imag*mask_real

        return [enh_real,enh_imag]

class Overlap_addLayer(tf.keras.layers.Layer):
def init(self,**kwargs):
super(Overlap_addLayer,self).init(**kwargs)

  def call(self,inputs):
     return tf.signal.overlap_and_add(inputs,200)

class DPCRN_model():
'''
Class to create and train the DPCRN model
'''

def __init__(self, batch_size = 1,
                   length_in_s = 5,
                   fs = 16000,
                   norm = 'iLN',
                   numUnits = 128,
                   numDP = 2,
                   block_len = 400,
                   block_shift = 200,
                   max_epochs = 200,
                   lr = 1e-3):

    # defining default cost function
    self.cost_function = self.snr_cost
    self.model = None
    # defining default parameters
    self.fs = fs
    self.length_in_s = length_in_s
    self.batch_size = batch_size
    # number of the hidden layer size in the LSTM
    self.numUnits = numUnits
    # number of the DPRNN modules
    self.numDP = numDP
    # frame length and hop length in STFT
    self.block_len = block_len
    self.block_shift = block_shift
    self.lr = lr
    self.max_epochs = max_epochs
    # window for STFT: sine win
    win = np.sin(np.arange(.5,self.block_len-.5+1)/self.block_len*np.pi)
    #print(win)
    self.win = tf.constant(win,dtype = 'float32')
    
    self.L = (16000*length_in_s-self.block_len)//self.block_shift + 1
    
    self.multi_gpu = False
    # iLN for instant Layer norm and BN for Batch norm
    self.input_norm = norm
    
@staticmethod
def snr_cost(s_estimate, s_true):
    '''
    Static Method defining the cost function. 
    The negative signal to noise ratio is calculated here. The loss is 
    always calculated over the last dimension. 
    '''
    # calculating the SNR
    snr = tf.reduce_mean(tf.math.square(s_true), axis=-1, keepdims=True) / \
        (tf.reduce_mean(tf.math.square(s_true-s_estimate), axis=-1, keepdims=True) + 1e-8)
    # using some more lines, because TF has no log10
    num = tf.math.log(snr + 1e-8) 
    denom = tf.math.log(tf.constant(10, dtype=num.dtype))
    loss = -10*(num / (denom))

    return loss

@staticmethod
def sisnr_cost(s_hat, s):
    '''
    Static Method defining the cost function. 
    The negative signal to noise ratio is calculated here. The loss is 
    always calculated over the last dimension. 
    '''
    def norm(x):
        return tf.reduce_sum(x**2, axis=-1, keepdims=True)

    s_target = tf.reduce_sum(
        s_hat * s, axis=-1, keepdims=True) * s / norm(s)
    upp = norm(s_target)
    low = norm(s_hat - s_target)

    return -10 * tf.math.log(upp /low) / tf.math.log(10.0)  

def spectrum_loss(self,y_true):
    '''
    spectrum MSE loss 
    '''
    enh_real = self.enh_real
    enh_imag = self.enh_imag
    enh_mag = tf.sqrt(enh_real**2 + enh_imag**2 + 1e-8)
    
    true_real,true_imag = self.stftLayer(y_true, mode='real_imag')
    true_mag = tf.sqrt(true_real**2 + true_imag**2 + 1e-8)
    
    loss_real = tf.reduce_mean((enh_real - true_real)**2,)
    loss_imag = tf.reduce_mean((enh_imag - true_imag)**2,)
    loss_mag = tf.reduce_mean((enh_mag - true_mag)**2,) 
    
    return loss_real + loss_imag + loss_mag

def spectrum_loss_phasen(self, s_hat,s,gamma = 0.3):
    
    true_real,true_imag = self.stftLayer(s, mode='real_imag')
    hat_real,hat_imag = self.stftLayer(s_hat, mode='real_imag')

    true_mag = tf.sqrt(true_real**2 + true_imag**2+1e-9)
    hat_mag = tf.sqrt(hat_real**2 + hat_imag**2+1e-9)

    true_real_cprs = (true_real / true_mag )*true_mag**gamma
    true_imag_cprs = (true_imag / true_mag )*true_mag**gamma
    hat_real_cprs = (hat_real / hat_mag )* hat_mag**gamma
    hat_imag_cprs = (hat_imag / hat_mag )* hat_mag**gamma

    loss_mag = tf.reduce_mean((hat_mag**gamma - true_mag**gamma)**2,)         
    loss_real = tf.reduce_mean((hat_real_cprs - true_real_cprs)**2,)
    loss_imag = tf.reduce_mean((hat_imag_cprs - true_imag_cprs)**2,)

    return 0.7 * loss_mag + 0.3 * ( loss_imag + loss_real ) 

def lossWrapper(self):
    '''
    A wrapper function which returns the loss function. This is done to
    to enable additional arguments to the loss function if necessary.
    '''
    def lossFunction(y_true,y_pred):
        # calculating loss and squeezing single dimensions away
        loss = tf.squeeze(self.cost_function(y_pred,y_true)) 
        mag_loss = tf.math.log(self.spectrum_loss(y_true) + 1e-8)
        # calculate mean over batches
        loss = tf.reduce_mean(loss)
        return loss + mag_loss 
    
    return lossFunction

'''
In the following some helper layers are defined.
'''  
def seg2frame(self, x):
    '''
    split signal x to frames
    '''
    frames = tf.signal.frame(x, self.block_len, self.block_shift)
    if self.win is not None:
        frames = self.win*frames
    return frames

def stftLayer(self, x, mode ='mag_pha'):
    '''
    Method for an STFT helper layer used with a Lambda layer
    mode: 'mag_pha'   return magnitude and phase spectrogram
          'real_imag' return real and imaginary parts
    '''
    # creating frames from the continuous waveform
    frames = tf.signal.frame(x, self.block_len, self.block_shift)
    
    if self.win is not None:
        frames = self.win*frames
    # calculating the fft over the time frames. rfft returns NFFT/2+1 bins.
    #print('win.............',type(frames))
    #frames = frames.numpy()
    #print('win.............',type(frames))
    stft_dat = tf.signal.rfft(frames)
    stft_dat = tf.convert_to_tensor(stft_dat)
    # calculating magnitude and phase from the complex signal
    output_list = []
    if mode == 'mag_pha':
        mag = tf.math.abs(stft_dat)
        phase = tf.math.angle(stft_dat)
        output_list = [mag, phase]
    elif mode == 'real_imag':
        real = tf.math.real(stft_dat)
        imag = tf.math.imag(stft_dat)
        output_list = [real, imag]            
    # returning magnitude and phase as list
    return output_list

def fftLayer(self, x):
    '''
    Method for an fft helper layer used with a Lambda layer.
    The layer calculates the rFFT on the last dimension and returns
    the magnitude and phase of the STFT.
    '''
    # calculating the fft over the time frames. rfft returns NFFT/2+1 bins.
    stft_dat = tf.signal.rfft(x)
    # calculating magnitude and phase from the complex signal
    mag = tf.abs(stft_dat)
    phase = tf.math.angle(stft_dat)
    # returning magnitude and phase as list
    return [mag, phase]


def ifftLayer(self, x,mode = 'mag_pha'):
    '''
    Method for an inverse FFT layer used with an Lambda layer. This layer
    calculates time domain frames from magnitude and phase information. 
    As input x a list with [mag,phase] is required.
    '''
    if mode == 'mag_pha':
    # calculating the complex representation
        s1_stft = (tf.cast(x[0], tf.complex64) * 
                    tf.exp( (1j * tf.cast(x[1], tf.complex64))))
    elif mode == 'real_imag':
        s1_stft = tf.cast(x[0], tf.complex64) + 1j * tf.cast(x[1], tf.complex64)
    # returning the time domain frames
    return tf.signal.irfft(s1_stft)  

def overlapAddLayer(self, x):
    '''
    Method for an overlap and add helper layer used with a Lambda layer.
    This layer reconstructs the waveform from a framed signal.
    '''
    # calculating and returning the reconstructed waveform
    '''
    if self.move_dc:
        x = x - tf.expand_dims(tf.reduce_mean(x,axis = -1),2)
    '''
    return tf.signal.overlap_and_add(x, self.block_shift)              
 
def mk_mask(self,x):
    '''
    Method for complex ratio mask and add helper layer used with a Lambda layer.
    '''
    [noisy_real,noisy_imag,mask] = x
    noisy_real = noisy_real[:,:,:,0]
    noisy_imag = noisy_imag[:,:,:,0]
    
    mask_real = mask[:,:,:,0]
    mask_imag = mask[:,:,:,1]
    
    enh_real = noisy_real * mask_real - noisy_imag * mask_imag
    enh_imag = noisy_real * mask_imag + noisy_imag * mask_real
    
    return [enh_real,enh_imag]
    
def build_DPCRN_model(self, name = 'model0'):

    # input layer for time signal
    time_dat = Input(batch_shape=(8, 320000))
    # calculate STFT
    
    time_dat_1 = tf.reshape(time_dat,[8,320000,1])
    real,imag = ConvSTFT(400,200,400,win_type='hanning',feature_type='real')(time_dat_1)
    print(real.shape)
   
    real = tf.reshape(real,[8,-1,201,1])
    imag = tf.reshape(imag,[8,-1,201,1])

    input_complex_spec = Concatenate(axis = -1)([real,imag])
    '''encoder'''
    #print(input_complex_spec.shape)
    if self.input_norm == 'iLN':    
        input_complex_spec = LayerNormalization(axis = [-1,-2], name = 'input_norm')(input_complex_spec)
    elif self.input_norm == 'BN':    
        input_complex_spec =BatchNormalization(name = 'input_norm')(input_complex_spec)
    
    # causal padding [1,0],[0,2]
    input_complex_spec = tf.pad(input_complex_spec,[[0,0],[1,0],[0,2],[0,0]])
    conv_1 = Conv2D(32, (2,5),(1,2),name = name+'_conv_1',padding = "VALID")(input_complex_spec)
    bn_1 = BatchNormalization(name = name+'_bn_1')(conv_1)
    out_1 = PReLU(shared_axes=[1,2])(bn_1)
    # causal padding [1,0],[0,1]
    out_1_1 = tf.pad(out_1,[[0,0],[1,0],[0,1],[0,0]])
    conv_2 = Conv2D(32, (2,3),(1,2),name = name+'_conv_2',padding = "VALID")(out_1_1)
    bn_2 = BatchNormalization(name = name+'_bn_2')(conv_2)
    out_2 = PReLU(shared_axes=[1,2])(bn_2)
    # causal padding [1,0],[1,1]
    out_2_1 = tf.pad(out_2,[[0,0],[1,0],[1,1],[0,0]])
    conv_3 = Conv2D(32, (2,3),(1,1),name = name+'_conv_3',padding = "VALID")(out_2_1)
    bn_3 = BatchNormalization(name = name+'_bn_3')(conv_3)
    out_3 = PReLU(shared_axes=[1,2])(bn_3)
    # causal padding [1,0],[1,1]
    out_3_1 = tf.pad(out_3,[[0,0],[1,0],[1,1],[0,0]])
    conv_4 = Conv2D(64, (2,3),(1,1),name = name+'_conv_4',padding = "VALID")(out_3_1)
    bn_4 = BatchNormalization(name = name+'_bn_4')(conv_4)
    out_4 = PReLU(shared_axes=[1,2])(bn_4)
    # causal padding [1,0],[1,1]
    out_4_1 = tf.pad(out_4,[[0,0],[1,0],[1,1],[0,0]])
    conv_5 = Conv2D(128, (2,3),(1,1),name = name+'_conv_5',padding = "VALID")(out_4_1)
    bn_5 = BatchNormalization(name = name +'_bn_5')(conv_5)
    out_5 = PReLU(shared_axes=[1,2])(bn_5)
    
    dp_in = out_5

    print(dp_in.shape)
    for i in range(self.numDP):
        
        dp_in = DprnnBlock(numUnits = self.numUnits, batch_size = self.batch_size, L = -1,width = 50,channel = 128, causal=True)(dp_in)#self.DPRNN_kernal(dp_in,str(i),last_dp = 0)
   
    dp_out = dp_in
    
    '''decoder'''
    skipcon_1 = Concatenate(axis = -1)([out_5,dp_out])

    deconv_1 = Conv2DTranspose(64,(2,3),(1,1),name = name+'_dconv_1',padding = 'same')(skipcon_1)
    dbn_1 = BatchNormalization(name = name+'_dbn_1')(deconv_1)
    dout_1 = PReLU(shared_axes=[1,2])(dbn_1)

    skipcon_2 = Concatenate(axis = -1)([out_4,dout_1])
    
    deconv_2 = Conv2DTranspose(32,(2,3),(1,1),name = name+'_dconv_2',padding = 'same')(skipcon_2)
    dbn_2 = BatchNormalization(name = name+'_dbn_2')(deconv_2)
    dout_2 = PReLU(shared_axes=[1,2])(dbn_2)
    
    skipcon_3 = Concatenate(axis = -1)([out_3,dout_2])
    
    deconv_3 = Conv2DTranspose(32,(2,3),(1,1),name = name+'_dconv_3',padding = 'same')(skipcon_3)
    dbn_3 = BatchNormalization(name = name+'_dbn_3')(deconv_3)
    dout_3 = PReLU(shared_axes=[1,2])(dbn_3)
    
    skipcon_4 = Concatenate(axis = -1)([out_2,dout_3])

    deconv_4 = Conv2DTranspose(32,(2,3),(1,2),name = name+'_dconv_4',padding = 'same')(skipcon_4)
    dbn_4 = BatchNormalization(name = name+'_dbn_4')(deconv_4)
    dout_4 = PReLU(shared_axes=[1,2])(dbn_4)
    
    skipcon_5 = Concatenate(axis = -1)([out_1,dout_4])
    
    deconv_5 = Conv2DTranspose(2,(2,5),(1,2),name = name+'_dconv_5',padding = 'valid')(skipcon_5)
    
    '''no activation'''        
    deconv_5 = deconv_5[:,:-1,:-2]

    #output_mask = Activation('tanh')(dbn_5)
    output_mask = deconv_5

    #enh_spec = Lambda(self.mk_mask)([real,imag,output_mask])
    enh_spec = MK_M(name='mask')([real,imag,output_mask])

    self.enh_real, self.enh_imag = enh_spec[0],enh_spec[1]

    #enh_frame = Lambda(self.ifftLayer,arguments = {'mode':'real_imag'})(enh_spec)
    #enh_frame = ifft_Layer(name='ifft_layer')(enh_spec)

    s1_stft = tf.cast(enh_spec[0], tf.complex64) + 1j * tf.cast(enh_spec[1], tf.complex64)
    s1_stft = tf.to_float(s1_stft)
    #enh_frame = tf.nn.conv1d_transpose(enh_spec,filters=kernel,output_shape=(8,1599,400),strides=100,padding="VALID")
    enh_frame = ConviSTFT(400,100,400,win_type='hanning',feature_type='complex')(s1_stft)
    enh_frame = tf.reshape(enh_frame,[8,1599,400])
   
    enh_frame = enh_frame * self.win

    enh_time = Overlap_addLayer(name='overlayer')(enh_frame)        

    self.model = Model(time_dat,enh_time)
    self.model.summary()

    return self.model

But,i got a error:

ValueError: Depth of output (402) is not a multiple of the number of groups (400) for 'Adam/gradients/convi_stft/conv1d_transpose_grad/Conv2D' (op: 'Conv2D') with input shapes: [8,1,1599,400], [1,400,1,402].

Thanks!

@Le-Xiaohuai-speech
Copy link
Owner

it looks like the output dimensions of the iSTFT do not match the groups number

@panhu
Copy link
Author

panhu commented Oct 20, 2022

Yes,but i changed the size of kernel(filter) and stride is invalid。

@panhu
Copy link
Author

panhu commented Oct 20, 2022

Can you help me verify the code of tf.nn.conv1d_transpose(ConviSTFT)?
Thanks!

@Le-Xiaohuai-speech
Copy link
Owner

Ok,I'll get back to you later. Please post the code of ConviSTFT agian? you can send the .py file to [email protected]

@panhu
Copy link
Author

panhu commented Oct 20, 2022

OK

@panhu
Copy link
Author

panhu commented Oct 24, 2022

Hi:
When i "load_model" got a new error,this is my code:

modelparh = r"dpcrn_4.h5"
model = tf.keras.models.load_model(modelparh,custom_objects={"DprnnBlock":DprnnBlock,"ConvSTFT":ConvSTFT,"MK_M":MK_M,"ConviSTFT":ConviSTFT,
"Overlap_addLayer":Overlap_addLayer})

The error is :

ValueError: Unknown loss function:lossFunction

@panhu
Copy link
Author

panhu commented Oct 24, 2022

When i use:
model = tf.keras.models.load_model(modelparh,custom_objects={"DprnnBlock":DprnnBlock,"ConvSTFT":ConvSTFT,"MK_M":MK_M,"ConviSTFT":ConviSTFT,
"Overlap_addLayer":Overlap_addLayer,"lossFunction":DPCRN_model.lossWrapper})

The error is:

TypeError: lossWrapper() takes 1 positional argument but 2 were given

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants