Created
March 4, 2020 10:59
-
-
Save espdev/d278962828ce7360d653b8679d7bb511 to your computer and use it in GitHub Desktop.
Rust/Python sparse matrices multiplication
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import numpy as np | |
from scipy.sparse import diags | |
def make_matrix(pcount): | |
x = np.linspace(0., 10., pcount) | |
dx = np.diff(x) | |
odx = 1. / dx | |
diags_qt = np.vstack((odx[:-1], -(odx[1:] + odx[:-1]), odx[1:])) | |
qt = diags(diags_qt, [0, 1, 2], (pcount - 2, pcount)) | |
return qt | |
def main(): | |
pcounts = [ | |
100, 500, 1000, 2000, 3000, 4000, 5000, 7500, 10000, 20000, 30000, 40000, 50000, | |
60000, 70000, 80000, 90000, 100000, 125000, 150000, 175000, 200000 | |
] | |
elapseds = [] | |
nnzs = [] | |
for pcount in pcounts: | |
qt = make_matrix(pcount) | |
print('\n--- pcount: {} ---------------'.format(pcount)) | |
print('`qt` shape={}, nnz={}', qt.shape, qt.nnz) | |
t0 = time.monotonic() | |
_ = qt @ qt.T | |
t1 = time.monotonic() - t0 | |
print(" -> qt * qt.T: {:.2f} msec", round(t1*1000)) | |
elapseds.append(round(t1*1000)) | |
nnzs.append(qt.nnz) | |
print("\n`qt * qt.T` nnz: {}".format(nnzs)) | |
print("\n`qt * qt.T` elapsed msec: {}".format(elapseds)) | |
if __name__ == '__main__': | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::time::Instant; | |
use num_traits::NumOps; | |
use ndarray::{NdFloat, ScalarOperand, Dimension, Array, Array1, Array2, AsArray, Axis, Slice, s, stack}; | |
use sprs; | |
use sprs::{CsMat, Shape}; | |
pub fn diff<'a, T: 'a, D, V>(data: V, axis: Option<Axis>) -> Array<T, D> | |
where T: NumOps + ScalarOperand, D: Dimension, V: AsArray<'a, T, D> | |
{ | |
let data_view = data.into(); | |
let axis = axis.unwrap_or(Axis(data_view.ndim() - 1)); | |
let head = data_view.slice_axis(axis, Slice::from(..-1)); | |
let tail = data_view.slice_axis(axis, Slice::from(1..)); | |
&tail - &head | |
} | |
pub fn spdiags<T>(diags: Array2<T>, offsets: &[isize], shape: Shape) -> CsMat<T> | |
where T: NdFloat | |
{ | |
let (rows, cols) = shape; | |
let numel_and_indices = |offset: isize| { | |
let mut i: usize = 0; | |
let mut j: usize = 0; | |
if offset < 0 { | |
i = offset.abs() as usize; | |
} else { | |
j = offset as usize; | |
} | |
((rows - i).min(cols - j), i, j) | |
}; | |
let mut mat = sprs::TriMat::<T>::new(shape); | |
for (k, &offset) in offsets.iter().enumerate() { | |
let (n, i, j) = numel_and_indices(offset); | |
let diag_row = diags.row(k); | |
let row_head = || diag_row.slice(s![..n]); | |
let row_tail = || diag_row.slice(s![-(n as isize)..]); | |
let diag = if offset < 0 { | |
if rows >= cols { | |
row_head() | |
} else { | |
row_tail() | |
} | |
} else { | |
if rows >= cols { | |
row_tail() | |
} else { | |
row_head() | |
} | |
}; | |
for l in 0..n { | |
mat.add_triplet(l + i, l + j, diag[l]); | |
} | |
} | |
mat.to_csr() | |
} | |
fn make_matrix(pcount: usize) -> CsMat<f64> { | |
let x = Array1::<f64>::linspace(0.0, 10.0, pcount); | |
let dx = diff(&x, None); | |
let ones = |n| Array1::<f64>::ones((n, )); | |
let odx = ones(pcount - 1) / &dx; | |
let odx_head = odx.slice(s![..-1]).insert_axis(Axis(0)).into_owned(); | |
let odx_tail = odx.slice(s![1..]).insert_axis(Axis(0)).into_owned(); | |
let odx_body = -(&odx_tail + &odx_head); | |
let diags_qt = stack![Axis(0), odx_head, odx_body, odx_tail]; | |
spdiags(diags_qt, &[0, 1, 2], (pcount - 2, pcount)) | |
} | |
fn main() { | |
let pcounts: &[usize] = &[ | |
100, 500, 1000, 2000, 3000, 4000, 5000, 7500, 10000, 20000, 30000, 40000, 50000, | |
60000, 70000, 80000, 90000, 100000, 125000, 150000, 175000, 200000 | |
]; | |
let mut elapseds = Vec::<u128>::new(); | |
let mut nnzs = Vec::<usize>::new(); | |
for &pcount in pcounts { | |
let qt = make_matrix(pcount); | |
println!("\n--- pcount: {} ---------------", pcount); | |
println!("`qt` shape={:?}, nnz={}", qt.shape(), qt.nnz()); | |
let now = Instant::now(); | |
let _ = &qt * &qt.transpose_view(); | |
let elapsed = now.elapsed().as_millis(); | |
println!(" -> qt * qt.T: {} msec", elapsed); | |
elapseds.push(elapsed); | |
nnzs.push(qt.nnz()); | |
} | |
println!("\n`qt * qt.T` nnz: {:?}", nnzs); | |
println!("\n`qt * qt.T` elapsed msec: {:?}", elapseds); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment