$ python bench_linear.py --bs 1
BS: 1, Latency: 0.389 ms, IC: 4096, OC: 11008, Samples: 100, Warmup: 10
$ python bench_linear.py --bs 128
BS: 128, Latency: 3.640 ms, IC: 4096, OC: 11008, Samples: 100, Warmup: 10
$ python bench_linear.py --bs 1024
BS: 1024, Latency: 41.244 ms, IC: 4096, OC: 11008, Samples: 100, Warmup: 10
$ python bench_linear.py --bs 1024 --ic 11008 --oc 4096
BS: 1024, Latency: 58.562 ms, IC: 11008, OC: 4096, Samples: 100, Warmup: 10
import torch
import torch.nn as nn
import argparse
import time
from functools import wraps
# Define the Linear layer
class SimpleLinear(nn.Module):
def __init__(self, input_size, output_size):
super(SimpleLinear, self).__init__()
self.linear = nn.Linear(input_size, output_size, bias=False)
def forward(self, x):
return self.linear(x)
# Function to benchmark the Linear layer
def benchmark_linear_layer(input_size, output_size, batch_size, warmup, num_samples, device='cpu'):
layer = SimpleLinear(input_size, output_size).to(device)
# Generate random input data
input_data = torch.randn(batch_size, input_size).to(device)
# Warm-up
for _ in range(warmup):
_ = layer(input_data)
# Timing the forward pass
total_time_ms = 0.0
for _ in range(num_samples):
start_time = time.time()
_ = layer(input_data)
end_time = time.time()
total_time_ms += (end_time - start_time) * 1000
avg_time_ms = total_time_ms / num_samples
return avg_time_ms
# Main function to parse arguments and run the benchmark
def main():
parser = argparse.ArgumentParser(
description="Benchmark PyTorch Linear layer latency."
)
parser.add_argument(
"--ic", type=int, default=4096, help="input channel dimension."
)
parser.add_argument(
"--oc", type=int, default=11008, help="output channel dimension."
)
parser.add_argument(
"--bs", type=int, default=1, help="batch size for the input of linear layer."
)
parser.add_argument(
"--warmup", type=int, default=10, help="number of warm-up iterations."
)
parser.add_argument(
"--samples", type=int, default=100, help="number of samples to collect for latency measurement."
)
args = parser.parse_args()
latency = benchmark_linear_layer(args.ic, args.oc, args.bs, args.warmup, args.samples)
print(f"BS: {args.bs}, Latency: {latency:.3f} ms, IC: {args.ic}, OC: {args.oc}, Samples: {args.samples}, Warmup: {args.warmup}")
if __name__ == "__main__":
main()
# TODO
# extend for xpu
# extend for torch.float16, torch.bfloat16