-
-
Save mthrok/844f86a6855e1414a35deeb94274383a to your computer and use it in GitHub Desktop.
Quick digging in what makes the mel-spectrum discrepancy between torch audio and librosa
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
from typing import Callable, Optional | |
from warnings import warn | |
import librosa | |
import torch | |
from torch import Tensor | |
from torchaudio import functional as F | |
from torchaudio import transforms as T | |
from torchaudio.compliance import kaldi | |
class MyMelScale(torch.nn.Module): | |
r"""Turn a normal STFT into a mel frequency STFT, using a conversion | |
matrix. This uses triangular filter banks. | |
User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)). | |
Args: | |
n_mels (int, optional): Number of mel filterbanks. (Default: ``128``) | |
sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``) | |
f_min (float, optional): Minimum frequency. (Default: ``0.``) | |
f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``) | |
n_stft (int, optional): Number of bins in STFT. Calculated from first input | |
if None is given. See ``n_fft`` in :class:`Spectrogram`. (Default: ``None``) | |
""" | |
__constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max'] | |
def __init__(self, | |
n_mels: int = 128, | |
sample_rate: int = 16000, | |
f_min: float = 0., | |
f_max: Optional[float] = None, | |
n_stft: Optional[int] = None, | |
fb_norm: Optional[str] = None) -> None: | |
super(MyMelScale, self).__init__() | |
self.n_mels = n_mels | |
self.sample_rate = sample_rate | |
self.f_max = f_max if f_max is not None else float(sample_rate // 2) | |
self.f_min = f_min | |
self.fb_norm = fb_norm | |
assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max) | |
if fb_norm == 'librosa_slaney': | |
mel_kernel_librosa = librosa.filters.mel( | |
self.sample_rate, int(2 * (n_stft - 1)), n_mels=self.n_mels, | |
fmin=self.f_min, fmax=self.f_max, norm='slaney' | |
) | |
fb = torch.Tensor(mel_kernel_librosa.T) | |
else: | |
fb = torch.empty(0) if n_stft is None else F.create_fb_matrix( | |
n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, | |
norm=self.fb_norm | |
) | |
self.register_buffer('fb', fb) | |
def forward(self, specgram: Tensor) -> Tensor: | |
r""" | |
Args: | |
specgram (Tensor): A spectrogram STFT of dimension (..., freq, time). | |
Returns: | |
Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time). | |
""" | |
# pack batch | |
shape = specgram.size() | |
specgram = specgram.reshape(-1, shape[-2], shape[-1]) | |
if self.fb.numel() == 0: | |
if fb_norm == 'librosa_slaney': | |
mel_kernel_librosa = librosa.filters.mel( | |
sr, int(2 * (spectram.size(1) - 1)), n_mels=self.n_mels, | |
fmin=self.f_min, fmax=self.f_max, norm='slaney' | |
) | |
fb = torch.Tensor(mel_kernel_librosa.T) | |
else: | |
fb = torch.empty(0) if n_stft is None else F.create_fb_matrix( | |
spectram.size(1), self.f_min, self.f_max, self.n_mels, self.sample_rate, | |
norm=self.fb_norm | |
) | |
# Attributes cannot be reassigned outside __init__ so workaround | |
self.fb.resize_(tmp_fb.size()) | |
self.fb.copy_(tmp_fb) | |
# (channel, frequency, time).transpose(...) dot (frequency, n_mels) | |
# -> (channel, time, n_mels).transpose(...) | |
mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2) | |
# unpack batch | |
mel_specgram = mel_specgram.reshape(shape[:-2] + mel_specgram.shape[-2:]) | |
return mel_specgram | |
class MyMelSpectrogram(torch.nn.Module): | |
r"""Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram | |
and MelScale. | |
Sources | |
* https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe | |
* https://timsainb.github.io/spectrograms-mfccs-and-inversion-in-python.html | |
* http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html | |
Args: | |
sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``) | |
win_length (int or None, optional): Window size. (Default: ``n_fft``) | |
hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``) | |
n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``) | |
f_min (float, optional): Minimum frequency. (Default: ``0.``) | |
f_max (float or None, optional): Maximum frequency. (Default: ``None``) | |
pad (int, optional): Two sided padding of signal. (Default: ``0``) | |
n_mels (int, optional): Number of mel filterbanks. (Default: ``128``) | |
window_fn (Callable[..., Tensor], optional): A function to create a window tensor | |
that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``) | |
wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``) | |
Example | |
>>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True) | |
>>> mel_specgram = transforms.MelSpectrogram(sample_rate)(waveform) # (channel, n_mels, time) | |
""" | |
__constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad', 'n_mels', 'f_min'] | |
def __init__(self, | |
sample_rate: int = 16000, | |
n_fft: int = 400, | |
win_length: Optional[int] = None, | |
hop_length: Optional[int] = None, | |
f_min: float = 0., | |
f_max: Optional[float] = None, | |
pad: int = 0, | |
n_mels: int = 128, | |
window_fn: Callable[..., Tensor] = torch.hann_window, | |
power: Optional[float] = 2., | |
normalized: bool = False, | |
fb_norm: Optional[str] = None, | |
wkwargs: Optional[dict] = None) -> None: | |
super(MyMelSpectrogram, self).__init__() | |
self.sample_rate = sample_rate | |
self.n_fft = n_fft | |
self.win_length = win_length if win_length is not None else n_fft | |
self.hop_length = hop_length if hop_length is not None else self.win_length // 2 | |
self.pad = pad | |
self.power = power | |
self.normalized = normalized | |
self.n_mels = n_mels # number of mel frequency bins | |
self.f_max = f_max | |
self.f_min = f_min | |
self.fb_norm = fb_norm | |
self.spectrogram = T.Spectrogram(n_fft=self.n_fft, win_length=self.win_length, | |
hop_length=self.hop_length, | |
pad=self.pad, window_fn=window_fn, power=self.power, | |
normalized=self.normalized, wkwargs=wkwargs) | |
self.mel_scale = MyMelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max, | |
self.n_fft // 2 + 1, self.fb_norm) | |
def forward(self, waveform: Tensor) -> Tensor: | |
r""" | |
Args: | |
waveform (Tensor): Tensor of audio of dimension (..., time). | |
Returns: | |
Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time). | |
""" | |
specgram = self.spectrogram(waveform) | |
mel_specgram = self.mel_scale(specgram) | |
return mel_specgram |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torchaudio | |
import librosa | |
import matplotlib.pyplot as plt | |
from torchaudio import functional as Fa | |
n_fft = 2048 | |
sr = 6000 | |
mel_kernel_torchaudio = Fa.create_fb_matrix( | |
int(n_fft // 2 + 1), | |
n_mels=128, | |
f_min=0., | |
f_max=sr/2., | |
sample_rate=sr, | |
norm=None | |
) | |
mel_kernel_torchaudio_slaney = Fa.create_fb_matrix( | |
int(n_fft // 2 + 1), | |
n_mels=128, | |
f_min=0., | |
f_max=sr/2., | |
sample_rate=sr, | |
norm='slaney' | |
) | |
mel_kernel_librosa_htk = librosa.filters.mel( | |
sr, | |
n_fft, | |
n_mels=128, | |
fmin=0., | |
fmax=sr/2., | |
norm='slaney', | |
htk=True, | |
) | |
mel_kernel_librosa_slaney = librosa.filters.mel( | |
sr, | |
n_fft, | |
n_mels=128, | |
fmin=0., | |
fmax=sr/2., | |
norm='slaney', | |
htk=False, | |
) | |
fig, axs = plt.subplots(2, 2, figsize=(10, 10)) | |
fig.suptitle('mel-Kernel') | |
axs[0][0].set_title('torchaudio[None]') | |
axs[0][0].imshow(mel_kernel_torchaudio, aspect='auto') | |
axs[0][0].set_ylabel('frequency bin') | |
axs[0][0].set_xlabel('mel bin') | |
axs[0][1].set_title('torchaudio[slaney]') | |
axs[0][1].imshow(mel_kernel_torchaudio_slaney, aspect='auto') | |
axs[0][1].set_xlabel('mel bin') | |
axs[1][0].set_title('librosa[htk + slaney]') | |
axs[1][0].imshow(mel_kernel_librosa_htk.T, aspect='auto') | |
axs[1][0].set_xlabel('mel bin') | |
axs[1][1].set_title('librosa[audiotory_toolbox + slaney]') | |
axs[1][1].imshow(mel_kernel_librosa_slaney.T, aspect='auto') | |
axs[1][1].set_xlabel('mel bin') | |
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torchaudio | |
import librosa | |
import matplotlib.pyplot as plt | |
from torchaudio import functional as Fa | |
# assume MyMelSpectrogram is declared somewhere | |
from my_melspec import MyMelSpectrogram | |
# some variables | |
fn = 'test/torchaudio_unittest/assets/steam-train-whistle-daniel_simon.wav' | |
sr = 6000 # sampling rate | |
# librosa default | |
n_fft = 2048 | |
win_len = None | |
hop_len = 512 | |
# # with torchaudio.load | |
# waveform, sample_rate = torchaudio.load( | |
# '4538556.clip.mp3', | |
# normalization=True | |
# ) | |
# waveform = waveform.mean(0) | |
# With librosa.load | |
waveform, sample_rate = librosa.load(fn, sr=sr) | |
waveform = torch.Tensor(waveform) | |
melspecs = {} | |
for fb_norm in [None, 'slaney', 'librosa_slaney']: | |
# instantiate mel-spectrogram transform | |
torch_melspectrogram = MyMelSpectrogram( | |
sample_rate, | |
n_fft=n_fft, | |
win_length=win_len, | |
hop_length=hop_len, | |
fb_norm=fb_norm | |
) | |
# compute | |
X1 = torch_melspectrogram(waveform) | |
X2 = librosa.feature.melspectrogram(waveform.numpy(), | |
sr=sample_rate, | |
n_fft=n_fft, | |
hop_length=hop_len, | |
win_length=win_len) | |
# for plot | |
melspecs[fb_norm] = (X1, X2) | |
# compute error | |
mse = ((X1 - X2)**2).mean() | |
# log error | |
print(f'Mean Squared Error[fb:{fb_norm}]:\t{mse:.4f}') | |
fig, axs = plt.subplots(1, 4, figsize=(20, 5)) | |
axs[0].set_title('torchaudio / filter bank [None]') | |
axs[0].set_ylabel('mel bin') | |
axs[0].set_xlabel('frame') | |
axs[0].imshow(librosa.power_to_db(melspecs[None][0]), aspect='auto') | |
axs[1].set_title('torchaudio / filter bank [slaney]') | |
axs[1].set_xlabel('frame') | |
axs[1].imshow(librosa.power_to_db(melspecs['slaney'][0]), aspect='auto') | |
axs[2].set_title('torchaudio / filter bank [librosa_slaney]') | |
axs[2].set_xlabel('frame') | |
axs[2].imshow(librosa.power_to_db(melspecs['librosa_slaney'][0]), aspect='auto') | |
axs[3].set_title('librosa') | |
axs[3].set_xlabel('frame') | |
axs[3].imshow(librosa.power_to_db(melspecs[None][1]), aspect='auto') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment