-
-
Save drisspg/d7ae2134fbb6ca369c4817853c3352fa to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import copy | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from float8_experimental.inference import quantize_to_float8, ActivationCasting, QuantConfig, ScalingGranularity | |
from float8_experimental.float8_utils import compute_error | |
from transformer_nuggets.utils import benchmark_cuda_function_in_microseconds, profiler | |
from pathlib import Path | |
from tqdm import tqdm | |
from tabulate import tabulate | |
torch._dynamo.config.automatic_dynamic_shapes = False | |
# Needed since changing args to function causes recompiles | |
torch._dynamo.config.cache_size_limit = 1000 | |
class FeedForward(nn.Module): | |
def __init__(self) -> None: | |
super().__init__() | |
self.w1 = nn.Linear(4096, 14336, bias=False) | |
self.w3 = nn.Linear(4096, 14336, bias=False) | |
self.w2 = nn.Linear(14336, 4096, bias=False) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
def setup_model(original_model, quant_config, scaling_granularity=None): | |
model = copy.deepcopy(original_model) | |
if quant_config: | |
quantize_to_float8(model, quant_config, scaling_granularity=scaling_granularity) | |
return torch.compile(model, backend=compile_backend) | |
def run_benchmark(model, input_tensor, name, num_warmup=10, profile=False): | |
with torch.no_grad(): | |
for _ in range(num_warmup): | |
model(input_tensor) | |
if profile: | |
with profiler(Path(f"/home/drisspg/meta/scripts/fp8/data/{name}")): | |
model(input_tensor) | |
time = benchmark_cuda_function_in_microseconds(model, input_tensor) | |
output = model(input_tensor) | |
return time, output | |
def run_sweep(original_mlp, variants, input_sizes, profile=False): | |
results = [] | |
for batch_size, num_tokens in tqdm(input_sizes): | |
input_tensor = torch.rand(batch_size, num_tokens, 4096, device="cuda", dtype=torch.bfloat16) * 5 | |
variant_results = [] | |
outputs = {} | |
for name, quant_config, scaling_granularity in variants: | |
if name == "FP8_Static_AxisWise": | |
# Update the static quantization scale for AxisWise | |
quant_config = QuantConfig(ActivationCasting.STATIC, torch.full((num_tokens*batch_size, 1), 1.0, device="cuda", dtype=torch.float32)) | |
model = setup_model(original_mlp, quant_config, scaling_granularity) | |
time, output = run_benchmark(model, input_tensor, f"{name}_{batch_size}_{num_tokens}", profile=profile) | |
variant_results.append([name, f"{time:.2f}"]) | |
outputs[name] = output | |
bf16_output = outputs["BF16"] | |
bf16_time = float(variant_results[0][1]) # Assuming BF16 is the first variant | |
comparison_results = [ | |
[row[0], row[1], f"{bf16_time / float(row[1]):.2f}x", f"{compute_error(output, bf16_output):.6e}"] | |
for row, (name, output) in zip(variant_results, outputs.items()) | |
] | |
results.append((batch_size, num_tokens, comparison_results)) | |
return results | |
if __name__ == "__main__": | |
profile = False | |
compile_backend = "inductor" | |
original_mlp = FeedForward().to("cuda").to(torch.bfloat16) | |
variants = [ | |
("BF16", None, None), | |
("FP8_Dynamic_TensorWise", QuantConfig(ActivationCasting.DYNAMIC), ScalingGranularity.TensorWise), | |
("FP8_Static_TensorWise", QuantConfig(ActivationCasting.STATIC, torch.tensor([1.0], device="cuda", dtype=torch.float32)), ScalingGranularity.TensorWise), | |
("FP8_Weight_Only_TensorWise", QuantConfig(ActivationCasting.WEIGHT_ONLY), ScalingGranularity.TensorWise), | |
("FP8_Dynamic_AxisWise", QuantConfig(ActivationCasting.DYNAMIC), ScalingGranularity.AxisWise), | |
("FP8_Static_AxisWise", None, ScalingGranularity.AxisWise), # We'll update this in run_sweep | |
("FP8_Weight_Only_AxisWise", QuantConfig(ActivationCasting.WEIGHT_ONLY), ScalingGranularity.AxisWise), | |
] | |
input_sizes = [ | |
(1, 128), | |
(1, 1024), | |
(32, 128), | |
(32, 1024), | |
(64, 2048), | |
] | |
sweep_results = run_sweep(original_mlp, variants, input_sizes, profile=profile) | |
# Print results | |
headers = ["Variant", "Time (μs)", "Speedup vs BF16", "SQNR vs BF16"] | |
for batch_size, num_tokens, comparison_results in sweep_results: | |
print(f"\nResults for batch_size={batch_size}, num_tokens={num_tokens}:") | |
print(tabulate(comparison_results, headers=headers, tablefmt="grid")) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment