Created
January 1, 2025 17:56
-
-
Save lucasnewman/44b9bab11d1b448896933c9a62e703f1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from functools import reduce | |
from typing import List | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn.utils.rnn import pad_sequence | |
import einx | |
from x_transformers import Encoder, Decoder | |
from x_transformers.x_transformers import ScaledSinusoidalEmbedding | |
from vector_quantize_pytorch import FSQ | |
def lens_to_mask(t, length): | |
if length is None: | |
length = t.amax() | |
seq = torch.arange(length, device=t.device) | |
return einx.less("n, b -> b n", seq, t) | |
class SemanticEncoder(nn.Module): | |
def __init__( | |
self, | |
encoder_dim: int = 512, | |
encoder_depth: int = 8, | |
encoder_heads: int = 8, | |
ctc_dropout: float = 0.1, | |
num_mels: int = 100, | |
num_codebooks: int = 1, | |
codebook_size: int = 1024, | |
codebook_levels: List[int] = [8, 5, 5, 5], | |
sample_rate: int = 24_000, | |
num_text_embeds: int = 256, | |
use_rmsnorm: bool = True, | |
): | |
super().__init__() | |
self.encoder_dim = encoder_dim | |
self.sample_rate = sample_rate | |
self.conv1 = nn.Conv1d(num_mels, encoder_dim, kernel_size=3, padding=1) | |
self.conv2 = nn.Conv1d( | |
encoder_dim, encoder_dim, kernel_size=3, stride=2, padding=1 | |
) | |
self.pos_emb = ScaledSinusoidalEmbedding(encoder_dim) | |
self.encoder = Encoder( | |
dim=encoder_dim, | |
depth=encoder_depth, | |
heads=encoder_heads, | |
use_rmsnorm=True, | |
disable_abs_pos_emb=True, | |
attn_flash=torch.cuda.is_available(), | |
) | |
self.ctc_head = nn.Sequential( | |
nn.Dropout(ctc_dropout), nn.Linear(encoder_dim, num_text_embeds) | |
) | |
self.num_codebooks = num_codebooks | |
self.codebook_size = codebook_size | |
self.codebook_levels = codebook_levels | |
effective_codebook_size = reduce(lambda x, y: x * y, codebook_levels) | |
assert ( | |
effective_codebook_size <= codebook_size | |
), "product of codebook levels must be less than or equal to codebook size" | |
self.quantizer = FSQ( | |
levels=codebook_levels, | |
dim=encoder_dim, | |
num_codebooks=num_codebooks, | |
) | |
def forward(self, x, lens): | |
x = x.transpose(1, 2) | |
x = F.gelu(self.conv1(x)) | |
x = F.gelu(self.conv2(x)) | |
x = x.transpose(1, 2) | |
# adjust lens for conv downsample | |
if lens is not None: | |
lens = ((lens - 1) // 2) + 1 | |
mask = lens_to_mask(lens, x.shape[1]) if lens is not None else None | |
x = x + self.pos_emb(x) | |
x = self.encoder(x, mask=mask) | |
ctc_logits = self.ctc_head(x) | |
quantized, indices = self.quantizer(x) | |
quantization_error = F.mse_loss(x, quantized, reduction="mean").detach() | |
return indices, ctc_logits, lens, quantization_error | |
class ASRDecoder(nn.Module): | |
def __init__( | |
self, | |
dim: int = 512, | |
depth: int = 8, | |
heads: int = 8, | |
num_speech_embeds: int = 1024, | |
num_text_embeds: int = 256, | |
): | |
super().__init__() | |
self.speech_emb = nn.Embedding( | |
num_embeddings=num_speech_embeds, embedding_dim=dim | |
) | |
self.text_emb = nn.Embedding( | |
num_embeddings=num_text_embeds, embedding_dim=dim, padding_idx=0 | |
) | |
self.decoder = Decoder( | |
dim=dim, | |
depth=depth, | |
heads=heads, | |
rotary_pos_emb=True, | |
cross_attend=True, | |
attn_flash=torch.cuda.is_available(), | |
) | |
self.to_logits = nn.Linear(dim, num_text_embeds, bias=False) | |
nn.init.zeros_(self.to_logits.weight) | |
def forward( | |
self, | |
text, | |
speech, | |
mask=None, | |
speech_mask=None, | |
): | |
device = text.device | |
speech_emb = self.speech_emb(speech) | |
speech_emb_pos = torch.arange(speech_emb.shape[1], device=device) | |
text_emb = self.text_emb(text) | |
decoded = self.decoder(text_emb, context=speech_emb, context_pos=speech_emb_pos) | |
logits = self.to_logits(decoded) | |
target = text[:, 1:] | |
loss = F.cross_entropy(logits[:, :-1].transpose(1, 2), target, ignore_index=0) | |
predicted = logits[:, :-1].detach().log_softmax(dim=-1).argmax(dim=-1) | |
return loss, target, predicted | |
class SemanticCodec(nn.Module): | |
def __init__( | |
self, | |
encoder_dim: int = 512, | |
encoder_depth: int = 8, | |
encoder_heads: int = 8, | |
decoder_dim: int = 512, | |
decoder_depth: int = 8, | |
decoder_heads: int = 8, | |
num_codebooks: int = 1, | |
codebook_size: int = 1024, | |
codebook_levels: List[int] = [8, 5, 5, 5], | |
num_mels: int = 100, | |
num_text_embeds: int = 256, | |
sample_rate: int = 24_000, | |
use_rmsnorm: bool = True, | |
): | |
super().__init__() | |
self.bos_token = num_text_embeds | |
self.eos_token = num_text_embeds + 1 | |
self.num_text_embeds = num_text_embeds + 2 | |
self.sample_rate = sample_rate | |
self.encoder = SemanticEncoder( | |
encoder_dim=encoder_dim, | |
encoder_depth=encoder_depth, | |
encoder_heads=encoder_heads, | |
num_mels=num_mels, | |
num_codebooks=num_codebooks, | |
codebook_size=codebook_size, | |
codebook_levels=codebook_levels, | |
sample_rate=sample_rate, | |
num_text_embeds=num_text_embeds, | |
use_rmsnorm=use_rmsnorm, | |
) | |
self.decoder = ASRDecoder( | |
dim=decoder_dim, | |
depth=decoder_depth, | |
heads=decoder_heads, | |
num_speech_embeds=codebook_size, | |
num_text_embeds=self.num_text_embeds, | |
) | |
def normalize_text(self, texts: list[str]) -> list[str]: | |
import re | |
def normalize(text: str) -> str: | |
# replace sentence boundaries with spaces | |
text = re.sub(r"[.!?]+", " ", text) | |
# keep alphanumeric, spaces, apostrophes | |
text = re.sub(r"[^a-zA-Z0-9\s']", " ", text) | |
# convert multiple spaces to single space | |
text = re.sub(r"\s+", " ", text) | |
return text.lower().strip() | |
return [normalize(text) for text in texts] | |
def tokenize(self, text: str) -> torch.Tensor: | |
return [torch.tensor([*bytes(t, "UTF-8")]) for t in text] | |
def add_bos_and_eos(self, text_tensors: list[torch.Tensor]) -> list[torch.Tensor]: | |
return [ | |
torch.cat( | |
[torch.tensor([self.bos_token]), t, torch.tensor([self.eos_token])] | |
) | |
for t in text_tensors | |
] | |
def forward(self, x, text, lens): | |
device = x.device | |
# tokenize | |
ctc_text = self.tokenize(self.normalize_text(text)) | |
ctc_text = pad_sequence(ctc_text, padding_value=0, batch_first=True) | |
ctc_text = ctc_text.to(device) | |
ctc_text_lengths = torch.tensor([len(t) for t in ctc_text], device=device) | |
asr_text = self.tokenize(text) | |
asr_text = self.add_bos_and_eos(asr_text) | |
asr_text = pad_sequence(asr_text, padding_value=0, batch_first=True) | |
asr_text = asr_text.to(device) | |
# encoder | |
semantic_tokens, ctc_logits, enc_lens, quantization_error = self.encoder( | |
x, lens | |
) | |
# decoder | |
ce_loss, target, predicted = self.decoder(asr_text, semantic_tokens) | |
# ctc loss | |
ctc_logprobs = ctc_logits.log_softmax(dim=-1) | |
ctc_loss = F.ctc_loss( | |
ctc_logprobs.transpose(0, 1), | |
ctc_text, | |
enc_lens, | |
ctc_text_lengths, | |
blank=0, | |
reduction="mean", | |
zero_infinity=True, | |
) | |
# losses | |
loss = ce_loss + ctc_loss | |
return ( | |
loss, | |
target, | |
predicted, | |
ctc_loss.detach(), | |
ce_loss.detach(), | |
quantization_error, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment