Skip to content

Instantly share code, notes, and snippets.

@robertknight
Last active December 23, 2023 23:16
Show Gist options
  • Save robertknight/20acb0496803ac5e4bfba9cdfb2bedfd to your computer and use it in GitHub Desktop.
Save robertknight/20acb0496803ac5e4bfba9cdfb2bedfd to your computer and use it in GitHub Desktop.
//! Port of https://github.com/danieldk/gemm-benchmark/ for comparison of
//! matrix multiplication performance against other popular libraries.
use std::time::Duration;
use wasnn::gemm;
use wasnn_tensor::prelude::*;
use wasnn_tensor::NdTensor;
use rayon::iter::IntoParallelIterator;
use rayon::iter::ParallelIterator;
struct BenchmarkStats {
elapsed: Duration,
flops: usize,
}
fn gemm_benchmark(dim: usize, iterations: usize, threads: usize) -> BenchmarkStats {
let one = 1.0;
let two = one + one;
let point_five = one / two;
let matrix_a = NdTensor::full([dim, dim], two);
let matrix_b = NdTensor::full([dim, dim], point_five);
let c_matrices: Vec<_> = std::iter::repeat(NdTensor::full([dim, dim], one))
.take(threads)
.collect::<Vec<_>>();
let start = std::time::Instant::now();
c_matrices.into_par_iter().for_each(|mut matrix_c| {
for _ in 0..iterations {
let row_stride = matrix_c.stride(0);
gemm(
matrix_c.data_mut().unwrap(),
row_stride,
matrix_a.view(),
matrix_b.view(),
1.,
1.,
);
}
});
let elapsed = start.elapsed();
BenchmarkStats {
elapsed,
flops: (dim.pow(3) * 2 * iterations * threads) + (dim.pow(2) * 2 * iterations * threads),
}
}
fn main() {
let threads = 4;
let iterations = 2000;
let dim = 1024;
rayon::ThreadPoolBuilder::new()
.num_threads(threads)
.build_global()
.unwrap();
println!("Threads: {}", threads);
println!("Iterations per thread: {}", iterations);
println!("Matrix shape: {} x {}", dim, dim);
let stats = gemm_benchmark(dim, iterations, threads);
println!(
"GFLOPS: {:.2}",
(stats.flops as f64 / stats.elapsed.as_secs_f64()) / 1000_000_000.
);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment