Skip to content

Instantly share code, notes, and snippets.

@drisspg
Created July 10, 2024 19:24
Show Gist options
  • Save drisspg/d7ae2134fbb6ca369c4817853c3352fa to your computer and use it in GitHub Desktop.
Save drisspg/d7ae2134fbb6ca369c4817853c3352fa to your computer and use it in GitHub Desktop.
import torch
import copy
import torch.nn as nn
import torch.nn.functional as F
from float8_experimental.inference import quantize_to_float8, ActivationCasting, QuantConfig, ScalingGranularity
from float8_experimental.float8_utils import compute_error
from transformer_nuggets.utils import benchmark_cuda_function_in_microseconds, profiler
from pathlib import Path
from tqdm import tqdm
from tabulate import tabulate
torch._dynamo.config.automatic_dynamic_shapes = False
# Needed since changing args to function causes recompiles
torch._dynamo.config.cache_size_limit = 1000
class FeedForward(nn.Module):
def __init__(self) -> None:
super().__init__()
self.w1 = nn.Linear(4096, 14336, bias=False)
self.w3 = nn.Linear(4096, 14336, bias=False)
self.w2 = nn.Linear(14336, 4096, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))
def setup_model(original_model, quant_config, scaling_granularity=None):
model = copy.deepcopy(original_model)
if quant_config:
quantize_to_float8(model, quant_config, scaling_granularity=scaling_granularity)
return torch.compile(model, backend=compile_backend)
def run_benchmark(model, input_tensor, name, num_warmup=10, profile=False):
with torch.no_grad():
for _ in range(num_warmup):
model(input_tensor)
if profile:
with profiler(Path(f"/home/drisspg/meta/scripts/fp8/data/{name}")):
model(input_tensor)
time = benchmark_cuda_function_in_microseconds(model, input_tensor)
output = model(input_tensor)
return time, output
def run_sweep(original_mlp, variants, input_sizes, profile=False):
results = []
for batch_size, num_tokens in tqdm(input_sizes):
input_tensor = torch.rand(batch_size, num_tokens, 4096, device="cuda", dtype=torch.bfloat16) * 5
variant_results = []
outputs = {}
for name, quant_config, scaling_granularity in variants:
if name == "FP8_Static_AxisWise":
# Update the static quantization scale for AxisWise
quant_config = QuantConfig(ActivationCasting.STATIC, torch.full((num_tokens*batch_size, 1), 1.0, device="cuda", dtype=torch.float32))
model = setup_model(original_mlp, quant_config, scaling_granularity)
time, output = run_benchmark(model, input_tensor, f"{name}_{batch_size}_{num_tokens}", profile=profile)
variant_results.append([name, f"{time:.2f}"])
outputs[name] = output
bf16_output = outputs["BF16"]
bf16_time = float(variant_results[0][1]) # Assuming BF16 is the first variant
comparison_results = [
[row[0], row[1], f"{bf16_time / float(row[1]):.2f}x", f"{compute_error(output, bf16_output):.6e}"]
for row, (name, output) in zip(variant_results, outputs.items())
]
results.append((batch_size, num_tokens, comparison_results))
return results
if __name__ == "__main__":
profile = False
compile_backend = "inductor"
original_mlp = FeedForward().to("cuda").to(torch.bfloat16)
variants = [
("BF16", None, None),
("FP8_Dynamic_TensorWise", QuantConfig(ActivationCasting.DYNAMIC), ScalingGranularity.TensorWise),
("FP8_Static_TensorWise", QuantConfig(ActivationCasting.STATIC, torch.tensor([1.0], device="cuda", dtype=torch.float32)), ScalingGranularity.TensorWise),
("FP8_Weight_Only_TensorWise", QuantConfig(ActivationCasting.WEIGHT_ONLY), ScalingGranularity.TensorWise),
("FP8_Dynamic_AxisWise", QuantConfig(ActivationCasting.DYNAMIC), ScalingGranularity.AxisWise),
("FP8_Static_AxisWise", None, ScalingGranularity.AxisWise), # We'll update this in run_sweep
("FP8_Weight_Only_AxisWise", QuantConfig(ActivationCasting.WEIGHT_ONLY), ScalingGranularity.AxisWise),
]
input_sizes = [
(1, 128),
(1, 1024),
(32, 128),
(32, 1024),
(64, 2048),
]
sweep_results = run_sweep(original_mlp, variants, input_sizes, profile=profile)
# Print results
headers = ["Variant", "Time (μs)", "Speedup vs BF16", "SQNR vs BF16"]
for batch_size, num_tokens, comparison_results in sweep_results:
print(f"\nResults for batch_size={batch_size}, num_tokens={num_tokens}:")
print(tabulate(comparison_results, headers=headers, tablefmt="grid"))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment