Skip to content

Instantly share code, notes, and snippets.

@lucasnewman
Created January 1, 2025 17:56
Show Gist options
  • Save lucasnewman/44b9bab11d1b448896933c9a62e703f1 to your computer and use it in GitHub Desktop.
Save lucasnewman/44b9bab11d1b448896933c9a62e703f1 to your computer and use it in GitHub Desktop.
from __future__ import annotations
from functools import reduce
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
import einx
from x_transformers import Encoder, Decoder
from x_transformers.x_transformers import ScaledSinusoidalEmbedding
from vector_quantize_pytorch import FSQ
def lens_to_mask(t, length):
if length is None:
length = t.amax()
seq = torch.arange(length, device=t.device)
return einx.less("n, b -> b n", seq, t)
class SemanticEncoder(nn.Module):
def __init__(
self,
encoder_dim: int = 512,
encoder_depth: int = 8,
encoder_heads: int = 8,
ctc_dropout: float = 0.1,
num_mels: int = 100,
num_codebooks: int = 1,
codebook_size: int = 1024,
codebook_levels: List[int] = [8, 5, 5, 5],
sample_rate: int = 24_000,
num_text_embeds: int = 256,
use_rmsnorm: bool = True,
):
super().__init__()
self.encoder_dim = encoder_dim
self.sample_rate = sample_rate
self.conv1 = nn.Conv1d(num_mels, encoder_dim, kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(
encoder_dim, encoder_dim, kernel_size=3, stride=2, padding=1
)
self.pos_emb = ScaledSinusoidalEmbedding(encoder_dim)
self.encoder = Encoder(
dim=encoder_dim,
depth=encoder_depth,
heads=encoder_heads,
use_rmsnorm=True,
disable_abs_pos_emb=True,
attn_flash=torch.cuda.is_available(),
)
self.ctc_head = nn.Sequential(
nn.Dropout(ctc_dropout), nn.Linear(encoder_dim, num_text_embeds)
)
self.num_codebooks = num_codebooks
self.codebook_size = codebook_size
self.codebook_levels = codebook_levels
effective_codebook_size = reduce(lambda x, y: x * y, codebook_levels)
assert (
effective_codebook_size <= codebook_size
), "product of codebook levels must be less than or equal to codebook size"
self.quantizer = FSQ(
levels=codebook_levels,
dim=encoder_dim,
num_codebooks=num_codebooks,
)
def forward(self, x, lens):
x = x.transpose(1, 2)
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
x = x.transpose(1, 2)
# adjust lens for conv downsample
if lens is not None:
lens = ((lens - 1) // 2) + 1
mask = lens_to_mask(lens, x.shape[1]) if lens is not None else None
x = x + self.pos_emb(x)
x = self.encoder(x, mask=mask)
ctc_logits = self.ctc_head(x)
quantized, indices = self.quantizer(x)
quantization_error = F.mse_loss(x, quantized, reduction="mean").detach()
return indices, ctc_logits, lens, quantization_error
class ASRDecoder(nn.Module):
def __init__(
self,
dim: int = 512,
depth: int = 8,
heads: int = 8,
num_speech_embeds: int = 1024,
num_text_embeds: int = 256,
):
super().__init__()
self.speech_emb = nn.Embedding(
num_embeddings=num_speech_embeds, embedding_dim=dim
)
self.text_emb = nn.Embedding(
num_embeddings=num_text_embeds, embedding_dim=dim, padding_idx=0
)
self.decoder = Decoder(
dim=dim,
depth=depth,
heads=heads,
rotary_pos_emb=True,
cross_attend=True,
attn_flash=torch.cuda.is_available(),
)
self.to_logits = nn.Linear(dim, num_text_embeds, bias=False)
nn.init.zeros_(self.to_logits.weight)
def forward(
self,
text,
speech,
mask=None,
speech_mask=None,
):
device = text.device
speech_emb = self.speech_emb(speech)
speech_emb_pos = torch.arange(speech_emb.shape[1], device=device)
text_emb = self.text_emb(text)
decoded = self.decoder(text_emb, context=speech_emb, context_pos=speech_emb_pos)
logits = self.to_logits(decoded)
target = text[:, 1:]
loss = F.cross_entropy(logits[:, :-1].transpose(1, 2), target, ignore_index=0)
predicted = logits[:, :-1].detach().log_softmax(dim=-1).argmax(dim=-1)
return loss, target, predicted
class SemanticCodec(nn.Module):
def __init__(
self,
encoder_dim: int = 512,
encoder_depth: int = 8,
encoder_heads: int = 8,
decoder_dim: int = 512,
decoder_depth: int = 8,
decoder_heads: int = 8,
num_codebooks: int = 1,
codebook_size: int = 1024,
codebook_levels: List[int] = [8, 5, 5, 5],
num_mels: int = 100,
num_text_embeds: int = 256,
sample_rate: int = 24_000,
use_rmsnorm: bool = True,
):
super().__init__()
self.bos_token = num_text_embeds
self.eos_token = num_text_embeds + 1
self.num_text_embeds = num_text_embeds + 2
self.sample_rate = sample_rate
self.encoder = SemanticEncoder(
encoder_dim=encoder_dim,
encoder_depth=encoder_depth,
encoder_heads=encoder_heads,
num_mels=num_mels,
num_codebooks=num_codebooks,
codebook_size=codebook_size,
codebook_levels=codebook_levels,
sample_rate=sample_rate,
num_text_embeds=num_text_embeds,
use_rmsnorm=use_rmsnorm,
)
self.decoder = ASRDecoder(
dim=decoder_dim,
depth=decoder_depth,
heads=decoder_heads,
num_speech_embeds=codebook_size,
num_text_embeds=self.num_text_embeds,
)
def normalize_text(self, texts: list[str]) -> list[str]:
import re
def normalize(text: str) -> str:
# replace sentence boundaries with spaces
text = re.sub(r"[.!?]+", " ", text)
# keep alphanumeric, spaces, apostrophes
text = re.sub(r"[^a-zA-Z0-9\s']", " ", text)
# convert multiple spaces to single space
text = re.sub(r"\s+", " ", text)
return text.lower().strip()
return [normalize(text) for text in texts]
def tokenize(self, text: str) -> torch.Tensor:
return [torch.tensor([*bytes(t, "UTF-8")]) for t in text]
def add_bos_and_eos(self, text_tensors: list[torch.Tensor]) -> list[torch.Tensor]:
return [
torch.cat(
[torch.tensor([self.bos_token]), t, torch.tensor([self.eos_token])]
)
for t in text_tensors
]
def forward(self, x, text, lens):
device = x.device
# tokenize
ctc_text = self.tokenize(self.normalize_text(text))
ctc_text = pad_sequence(ctc_text, padding_value=0, batch_first=True)
ctc_text = ctc_text.to(device)
ctc_text_lengths = torch.tensor([len(t) for t in ctc_text], device=device)
asr_text = self.tokenize(text)
asr_text = self.add_bos_and_eos(asr_text)
asr_text = pad_sequence(asr_text, padding_value=0, batch_first=True)
asr_text = asr_text.to(device)
# encoder
semantic_tokens, ctc_logits, enc_lens, quantization_error = self.encoder(
x, lens
)
# decoder
ce_loss, target, predicted = self.decoder(asr_text, semantic_tokens)
# ctc loss
ctc_logprobs = ctc_logits.log_softmax(dim=-1)
ctc_loss = F.ctc_loss(
ctc_logprobs.transpose(0, 1),
ctc_text,
enc_lens,
ctc_text_lengths,
blank=0,
reduction="mean",
zero_infinity=True,
)
# losses
loss = ce_loss + ctc_loss
return (
loss,
target,
predicted,
ctc_loss.detach(),
ce_loss.detach(),
quantization_error,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment