Skip to content

Instantly share code, notes, and snippets.

@mthrok
Forked from eldrin/my_melspec.py
Last active September 18, 2022 11:35
Show Gist options
  • Save mthrok/844f86a6855e1414a35deeb94274383a to your computer and use it in GitHub Desktop.
Save mthrok/844f86a6855e1414a35deeb94274383a to your computer and use it in GitHub Desktop.
Quick digging in what makes the mel-spectrum discrepancy between torch audio and librosa
import math
from typing import Callable, Optional
from warnings import warn
import librosa
import torch
from torch import Tensor
from torchaudio import functional as F
from torchaudio import transforms as T
from torchaudio.compliance import kaldi
class MyMelScale(torch.nn.Module):
r"""Turn a normal STFT into a mel frequency STFT, using a conversion
matrix. This uses triangular filter banks.
User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)).
Args:
n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
f_min (float, optional): Minimum frequency. (Default: ``0.``)
f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
n_stft (int, optional): Number of bins in STFT. Calculated from first input
if None is given. See ``n_fft`` in :class:`Spectrogram`. (Default: ``None``)
"""
__constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max']
def __init__(self,
n_mels: int = 128,
sample_rate: int = 16000,
f_min: float = 0.,
f_max: Optional[float] = None,
n_stft: Optional[int] = None,
fb_norm: Optional[str] = None) -> None:
super(MyMelScale, self).__init__()
self.n_mels = n_mels
self.sample_rate = sample_rate
self.f_max = f_max if f_max is not None else float(sample_rate // 2)
self.f_min = f_min
self.fb_norm = fb_norm
assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)
if fb_norm == 'librosa_slaney':
mel_kernel_librosa = librosa.filters.mel(
self.sample_rate, int(2 * (n_stft - 1)), n_mels=self.n_mels,
fmin=self.f_min, fmax=self.f_max, norm='slaney'
)
fb = torch.Tensor(mel_kernel_librosa.T)
else:
fb = torch.empty(0) if n_stft is None else F.create_fb_matrix(
n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate,
norm=self.fb_norm
)
self.register_buffer('fb', fb)
def forward(self, specgram: Tensor) -> Tensor:
r"""
Args:
specgram (Tensor): A spectrogram STFT of dimension (..., freq, time).
Returns:
Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
"""
# pack batch
shape = specgram.size()
specgram = specgram.reshape(-1, shape[-2], shape[-1])
if self.fb.numel() == 0:
if fb_norm == 'librosa_slaney':
mel_kernel_librosa = librosa.filters.mel(
sr, int(2 * (spectram.size(1) - 1)), n_mels=self.n_mels,
fmin=self.f_min, fmax=self.f_max, norm='slaney'
)
fb = torch.Tensor(mel_kernel_librosa.T)
else:
fb = torch.empty(0) if n_stft is None else F.create_fb_matrix(
spectram.size(1), self.f_min, self.f_max, self.n_mels, self.sample_rate,
norm=self.fb_norm
)
# Attributes cannot be reassigned outside __init__ so workaround
self.fb.resize_(tmp_fb.size())
self.fb.copy_(tmp_fb)
# (channel, frequency, time).transpose(...) dot (frequency, n_mels)
# -> (channel, time, n_mels).transpose(...)
mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2)
# unpack batch
mel_specgram = mel_specgram.reshape(shape[:-2] + mel_specgram.shape[-2:])
return mel_specgram
class MyMelSpectrogram(torch.nn.Module):
r"""Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram
and MelScale.
Sources
* https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe
* https://timsainb.github.io/spectrograms-mfccs-and-inversion-in-python.html
* http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html
Args:
sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
win_length (int or None, optional): Window size. (Default: ``n_fft``)
hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
f_min (float, optional): Minimum frequency. (Default: ``0.``)
f_max (float or None, optional): Maximum frequency. (Default: ``None``)
pad (int, optional): Two sided padding of signal. (Default: ``0``)
n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
window_fn (Callable[..., Tensor], optional): A function to create a window tensor
that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``)
Example
>>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
>>> mel_specgram = transforms.MelSpectrogram(sample_rate)(waveform) # (channel, n_mels, time)
"""
__constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad', 'n_mels', 'f_min']
def __init__(self,
sample_rate: int = 16000,
n_fft: int = 400,
win_length: Optional[int] = None,
hop_length: Optional[int] = None,
f_min: float = 0.,
f_max: Optional[float] = None,
pad: int = 0,
n_mels: int = 128,
window_fn: Callable[..., Tensor] = torch.hann_window,
power: Optional[float] = 2.,
normalized: bool = False,
fb_norm: Optional[str] = None,
wkwargs: Optional[dict] = None) -> None:
super(MyMelSpectrogram, self).__init__()
self.sample_rate = sample_rate
self.n_fft = n_fft
self.win_length = win_length if win_length is not None else n_fft
self.hop_length = hop_length if hop_length is not None else self.win_length // 2
self.pad = pad
self.power = power
self.normalized = normalized
self.n_mels = n_mels # number of mel frequency bins
self.f_max = f_max
self.f_min = f_min
self.fb_norm = fb_norm
self.spectrogram = T.Spectrogram(n_fft=self.n_fft, win_length=self.win_length,
hop_length=self.hop_length,
pad=self.pad, window_fn=window_fn, power=self.power,
normalized=self.normalized, wkwargs=wkwargs)
self.mel_scale = MyMelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max,
self.n_fft // 2 + 1, self.fb_norm)
def forward(self, waveform: Tensor) -> Tensor:
r"""
Args:
waveform (Tensor): Tensor of audio of dimension (..., time).
Returns:
Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
"""
specgram = self.spectrogram(waveform)
mel_specgram = self.mel_scale(specgram)
return mel_specgram
import torch
import torchaudio
import librosa
import matplotlib.pyplot as plt
from torchaudio import functional as Fa
n_fft = 2048
sr = 6000
mel_kernel_torchaudio = Fa.create_fb_matrix(
int(n_fft // 2 + 1),
n_mels=128,
f_min=0.,
f_max=sr/2.,
sample_rate=sr,
norm=None
)
mel_kernel_torchaudio_slaney = Fa.create_fb_matrix(
int(n_fft // 2 + 1),
n_mels=128,
f_min=0.,
f_max=sr/2.,
sample_rate=sr,
norm='slaney'
)
mel_kernel_librosa_htk = librosa.filters.mel(
sr,
n_fft,
n_mels=128,
fmin=0.,
fmax=sr/2.,
norm='slaney',
htk=True,
)
mel_kernel_librosa_slaney = librosa.filters.mel(
sr,
n_fft,
n_mels=128,
fmin=0.,
fmax=sr/2.,
norm='slaney',
htk=False,
)
fig, axs = plt.subplots(2, 2, figsize=(10, 10))
fig.suptitle('mel-Kernel')
axs[0][0].set_title('torchaudio[None]')
axs[0][0].imshow(mel_kernel_torchaudio, aspect='auto')
axs[0][0].set_ylabel('frequency bin')
axs[0][0].set_xlabel('mel bin')
axs[0][1].set_title('torchaudio[slaney]')
axs[0][1].imshow(mel_kernel_torchaudio_slaney, aspect='auto')
axs[0][1].set_xlabel('mel bin')
axs[1][0].set_title('librosa[htk + slaney]')
axs[1][0].imshow(mel_kernel_librosa_htk.T, aspect='auto')
axs[1][0].set_xlabel('mel bin')
axs[1][1].set_title('librosa[audiotory_toolbox + slaney]')
axs[1][1].imshow(mel_kernel_librosa_slaney.T, aspect='auto')
axs[1][1].set_xlabel('mel bin')
plt.show()
import torch
import torchaudio
import librosa
import matplotlib.pyplot as plt
from torchaudio import functional as Fa
# assume MyMelSpectrogram is declared somewhere
from my_melspec import MyMelSpectrogram
# some variables
fn = 'test/torchaudio_unittest/assets/steam-train-whistle-daniel_simon.wav'
sr = 6000 # sampling rate
# librosa default
n_fft = 2048
win_len = None
hop_len = 512
# # with torchaudio.load
# waveform, sample_rate = torchaudio.load(
# '4538556.clip.mp3',
# normalization=True
# )
# waveform = waveform.mean(0)
# With librosa.load
waveform, sample_rate = librosa.load(fn, sr=sr)
waveform = torch.Tensor(waveform)
melspecs = {}
for fb_norm in [None, 'slaney', 'librosa_slaney']:
# instantiate mel-spectrogram transform
torch_melspectrogram = MyMelSpectrogram(
sample_rate,
n_fft=n_fft,
win_length=win_len,
hop_length=hop_len,
fb_norm=fb_norm
)
# compute
X1 = torch_melspectrogram(waveform)
X2 = librosa.feature.melspectrogram(waveform.numpy(),
sr=sample_rate,
n_fft=n_fft,
hop_length=hop_len,
win_length=win_len)
# for plot
melspecs[fb_norm] = (X1, X2)
# compute error
mse = ((X1 - X2)**2).mean()
# log error
print(f'Mean Squared Error[fb:{fb_norm}]:\t{mse:.4f}')
fig, axs = plt.subplots(1, 4, figsize=(20, 5))
axs[0].set_title('torchaudio / filter bank [None]')
axs[0].set_ylabel('mel bin')
axs[0].set_xlabel('frame')
axs[0].imshow(librosa.power_to_db(melspecs[None][0]), aspect='auto')
axs[1].set_title('torchaudio / filter bank [slaney]')
axs[1].set_xlabel('frame')
axs[1].imshow(librosa.power_to_db(melspecs['slaney'][0]), aspect='auto')
axs[2].set_title('torchaudio / filter bank [librosa_slaney]')
axs[2].set_xlabel('frame')
axs[2].imshow(librosa.power_to_db(melspecs['librosa_slaney'][0]), aspect='auto')
axs[3].set_title('librosa')
axs[3].set_xlabel('frame')
axs[3].imshow(librosa.power_to_db(melspecs[None][1]), aspect='auto')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment