Skip to content

Instantly share code, notes, and snippets.

@manofstick
Created April 5, 2023 04:59
Show Gist options
  • Save manofstick/bff14a76a00cc8a067ae04772cca3c60 to your computer and use it in GitHub Desktop.
Save manofstick/bff14a76a00cc8a067ae04772cca3c60 to your computer and use it in GitHub Desktop.
Matrix Mult
class Matrix<T>
{
[MethodImpl(MethodImplOptions.AggressiveOptimization)]
public static void InnerLoop4_3(Span<Vector<T>> ARowVector0, Span<Vector<T>> ARowVector1, Span<Vector<T>> ARowVector2, Span<Vector<T>> ARowVector3, Span<Vector<T>> BTRowVector0, Span<Vector<T>> BTRowVector1, Span<Vector<T>> BTRowVector2, out Tmp4_3 tmp)
{
var tmpVector0_0 = Vector<T>.Zero; var tmpVector0_1 = Vector<T>.Zero; var tmpVector0_2 = Vector<T>.Zero;
var tmpVector1_0 = Vector<T>.Zero; var tmpVector1_1 = Vector<T>.Zero; var tmpVector1_2 = Vector<T>.Zero;
var tmpVector2_0 = Vector<T>.Zero; var tmpVector2_1 = Vector<T>.Zero; var tmpVector2_2 = Vector<T>.Zero;
var tmpVector3_0 = Vector<T>.Zero; var tmpVector3_1 = Vector<T>.Zero; var tmpVector3_2 = Vector<T>.Zero;
for (var i = ARowVector0.Length-1; i >= 0; --i)
{
Vector<T> // use as temp to stop JIT generating code that causes register spill
x = ARowVector0[i]; x *= BTRowVector0[i]; tmpVector0_0 += x; x = ARowVector0[i]; x *= BTRowVector1[i]; tmpVector0_1 += x; x = ARowVector0[i]; x *= BTRowVector2[i]; tmpVector0_2 += x;
x = ARowVector1[i]; x *= BTRowVector0[i]; tmpVector1_0 += x; x = ARowVector1[i]; x *= BTRowVector1[i]; tmpVector1_1 += x; x = ARowVector1[i]; x *= BTRowVector2[i]; tmpVector1_2 += x;
x = ARowVector2[i]; x *= BTRowVector0[i]; tmpVector2_0 += x; x = ARowVector2[i]; x *= BTRowVector1[i]; tmpVector2_1 += x; x = ARowVector2[i]; x *= BTRowVector2[i]; tmpVector2_2 += x;
x = ARowVector3[i]; x *= BTRowVector0[i]; tmpVector3_0 += x; x = ARowVector3[i]; x *= BTRowVector1[i]; tmpVector3_1 += x; x = ARowVector3[i]; x *= BTRowVector2[i]; tmpVector3_2 += x;
}
tmp.tmp0_0 = T.Zero; tmp.tmp0_1 = T.Zero; tmp.tmp0_2 = T.Zero;
tmp.tmp1_0 = T.Zero; tmp.tmp1_1 = T.Zero; tmp.tmp1_2 = T.Zero;
tmp.tmp2_0 = T.Zero; tmp.tmp2_1 = T.Zero; tmp.tmp2_2 = T.Zero;
tmp.tmp3_0 = T.Zero; tmp.tmp3_1 = T.Zero; tmp.tmp3_2 = T.Zero;
for (var i = 0; i < Vector<T>.Count; ++i)
{
tmp.tmp0_0 += tmpVector0_0[i]; tmp.tmp0_1 += tmpVector0_1[i]; tmp.tmp0_2 += tmpVector0_2[i];
tmp.tmp1_0 += tmpVector1_0[i]; tmp.tmp1_1 += tmpVector1_1[i]; tmp.tmp1_2 += tmpVector1_2[i];
tmp.tmp2_0 += tmpVector2_0[i]; tmp.tmp2_1 += tmpVector2_1[i]; tmp.tmp2_2 += tmpVector2_2[i];
tmp.tmp3_0 += tmpVector3_0[i]; tmp.tmp3_1 += tmpVector3_1[i]; tmp.tmp3_2 += tmpVector3_2[i];
}
}
public static FMatrix<T> Multiple_SIMD4_3P(FMatrix<T> A, FMatrix<T> B)
{
if (A.Columns != B.Rows)
throw new ArgumentException("(A.Columns != B.Rows)");
var BT = Transpose(B);
var rows = A.Rows;
var columns = B.Columns;
var result = new FMatrix<T>(rows, columns);
var processors = Math.Min(Environment.ProcessorCount, A.Rows / 4);
var batchSize = rows / (processors * 4);
var options = new ParallelOptions();
Parallel.For(0, processors, processorIdx =>
{
var startIdx = processorIdx * batchSize * 4;
var endIdx = (processorIdx < (processors - 1)) ? (processorIdx + 1) * batchSize * 4 : rows;
var row = startIdx;
for (; row < endIdx - 4 + 1; row += 4)
{
var ARow0 = A.values[row + 0].AsSpan();
var ARow1 = A.values[row + 1].AsSpan();
var ARow2 = A.values[row + 2].AsSpan();
var ARow3 = A.values[row + 3].AsSpan();
var ARowVector0 = MemoryMarshal.Cast<T, Vector<T>>(ARow0);
var ARowVector1 = MemoryMarshal.Cast<T, Vector<T>>(ARow1);
var ARowVector2 = MemoryMarshal.Cast<T, Vector<T>>(ARow2);
var ARowVector3 = MemoryMarshal.Cast<T, Vector<T>>(ARow3);
var column = 0;
for (; column < columns - 3 + 1; column += 3)
{
var BTRow0 = BT.values[column + 0].AsSpan();
var BTRow1 = BT.values[column + 1].AsSpan();
var BTRow2 = BT.values[column + 2].AsSpan();
var BTRowVector0 = MemoryMarshal.Cast<T, Vector<T>>(BTRow0);
var BTRowVector1 = MemoryMarshal.Cast<T, Vector<T>>(BTRow1);
var BTRowVector2 = MemoryMarshal.Cast<T, Vector<T>>(BTRow2);
InnerLoop4_3(ARowVector0, ARowVector1, ARowVector2, ARowVector3, BTRowVector0, BTRowVector1, BTRowVector2, out var tmp);
for (var i = ARowVector0.Length * Vector<T>.Count; i < ARow0.Length; ++i)
{
tmp.tmp0_0 += ARow0[i] * BTRow0[i]; tmp.tmp0_1 += ARow0[i] * BTRow1[i]; tmp.tmp0_2 += ARow0[i] * BTRow2[i];
tmp.tmp1_0 += ARow1[i] * BTRow0[i]; tmp.tmp1_1 += ARow1[i] * BTRow1[i]; tmp.tmp1_2 += ARow1[i] * BTRow2[i];
tmp.tmp2_0 += ARow2[i] * BTRow0[i]; tmp.tmp2_1 += ARow2[i] * BTRow1[i]; tmp.tmp2_2 += ARow2[i] * BTRow2[i];
tmp.tmp3_0 += ARow3[i] * BTRow0[i]; tmp.tmp3_1 += ARow3[i] * BTRow1[i]; tmp.tmp3_2 += ARow3[i] * BTRow2[i];
}
result.values[row + 0][column + 0] = tmp.tmp0_0; result.values[row + 0][column + 1] = tmp.tmp0_1; result.values[row + 0][column + 2] = tmp.tmp0_2;
result.values[row + 1][column + 0] = tmp.tmp1_0; result.values[row + 1][column + 1] = tmp.tmp1_1; result.values[row + 1][column + 2] = tmp.tmp1_2;
result.values[row + 2][column + 0] = tmp.tmp2_0; result.values[row + 2][column + 1] = tmp.tmp2_1; result.values[row + 2][column + 2] = tmp.tmp2_2;
result.values[row + 3][column + 0] = tmp.tmp3_0; result.values[row + 3][column + 1] = tmp.tmp3_1; result.values[row + 3][column + 2] = tmp.tmp3_2;
}
for (; column < columns; column++)
{
var BTRow = BT.values[column].AsSpan();
var BTRowVector = MemoryMarshal.Cast<T, Vector<T>>(BTRow);
var tmpVector0 = Vector<T>.Zero;
var tmpVector1 = Vector<T>.Zero;
var tmpVector2 = Vector<T>.Zero;
var tmpVector3 = Vector<T>.Zero;
for (var i = 0; i < ARowVector0.Length; ++i)
{
tmpVector0 += ARowVector0[i] * BTRowVector[i];
tmpVector1 += ARowVector1[i] * BTRowVector[i];
tmpVector2 += ARowVector2[i] * BTRowVector[i];
tmpVector3 += ARowVector3[i] * BTRowVector[i];
}
var tmp0 = T.Zero;
var tmp1 = T.Zero;
var tmp2 = T.Zero;
var tmp3 = T.Zero;
for (var i = 0; i < Vector<T>.Count; ++i)
{
tmp0 += tmpVector0[i];
tmp1 += tmpVector1[i];
tmp2 += tmpVector2[i];
tmp3 += tmpVector3[i];
}
for (var i = ARowVector0.Length * Vector<T>.Count; i < ARow0.Length; ++i)
{
tmp0 += ARow0[i] * BTRow[i];
tmp1 += ARow1[i] * BTRow[i];
tmp2 += ARow2[i] * BTRow[i];
tmp3 += ARow3[i] * BTRow[i];
}
result.values[row + 0][column] = tmp0;
result.values[row + 1][column] = tmp1;
result.values[row + 2][column] = tmp2;
result.values[row + 3][column] = tmp3;
}
}
for (; row < endIdx; ++row)
{
var ARow = A.values[row].AsSpan();
var ARowVector = MemoryMarshal.Cast<T, Vector<T>>(ARow);
for (var column = 0; column < columns; column++)
{
var BTRow = BT.values[column].AsSpan();
var BTRowVector = MemoryMarshal.Cast<T, Vector<T>>(BTRow);
var tmpVector = Vector<T>.Zero;
for (var i = 0; i < ARowVector.Length; ++i)
tmpVector += ARowVector[i] * BTRowVector[i];
var tmp = T.Zero;
for (var i = 0; i < Vector<T>.Count; ++i)
tmp += tmpVector[i];
for (var i = ARowVector.Length * Vector<T>.Count; i < ARow.Length; ++i)
tmp += ARow[i] * BTRow[i];
result.values[row][column] = tmp;
}
}
});
return result;
}
public static FMatrix<T> Multiple_SIMD4_3PP(FMatrix<T> A, FMatrix<T> B)
{
if (A.Columns != B.Rows)
throw new ArgumentException("(A.Columns != B.Rows)");
var BT = Transpose(B);
var rows = A.Rows;
var columns = B.Columns;
var result = new FMatrix<T>(rows, columns);
Parallel.For(0, rows / 4 + 1, batchRowIdx =>
{
var row = batchRowIdx * 4;
if (row < rows - 4 + 1)
{
var ARow0 = A.values[row + 0].AsSpan();
var ARow1 = A.values[row + 1].AsSpan();
var ARow2 = A.values[row + 2].AsSpan();
var ARow3 = A.values[row + 3].AsSpan();
var ARowVector0 = MemoryMarshal.Cast<T, Vector<T>>(ARow0);
var ARowVector1 = MemoryMarshal.Cast<T, Vector<T>>(ARow1);
var ARowVector2 = MemoryMarshal.Cast<T, Vector<T>>(ARow2);
var ARowVector3 = MemoryMarshal.Cast<T, Vector<T>>(ARow3);
var column = 0;
for (; column < columns - 3 + 1; column += 3)
{
var BTRow0 = BT.values[column + 0].AsSpan();
var BTRow1 = BT.values[column + 1].AsSpan();
var BTRow2 = BT.values[column + 2].AsSpan();
var BTRowVector0 = MemoryMarshal.Cast<T, Vector<T>>(BTRow0);
var BTRowVector1 = MemoryMarshal.Cast<T, Vector<T>>(BTRow1);
var BTRowVector2 = MemoryMarshal.Cast<T, Vector<T>>(BTRow2);
InnerLoop4_3(ARowVector0, ARowVector1, ARowVector2, ARowVector3, BTRowVector0, BTRowVector1, BTRowVector2, out var tmp);
for (var i = ARowVector0.Length * Vector<T>.Count; i < ARow0.Length; ++i)
{
tmp.tmp0_0 += ARow0[i] * BTRow0[i]; tmp.tmp0_1 += ARow0[i] * BTRow1[i]; tmp.tmp0_2 += ARow0[i] * BTRow2[i];
tmp.tmp1_0 += ARow1[i] * BTRow0[i]; tmp.tmp1_1 += ARow1[i] * BTRow1[i]; tmp.tmp1_2 += ARow1[i] * BTRow2[i];
tmp.tmp2_0 += ARow2[i] * BTRow0[i]; tmp.tmp2_1 += ARow2[i] * BTRow1[i]; tmp.tmp2_2 += ARow2[i] * BTRow2[i];
tmp.tmp3_0 += ARow3[i] * BTRow0[i]; tmp.tmp3_1 += ARow3[i] * BTRow1[i]; tmp.tmp3_2 += ARow3[i] * BTRow2[i];
}
result.values[row + 0][column + 0] = tmp.tmp0_0; result.values[row + 0][column + 1] = tmp.tmp0_1; result.values[row + 0][column + 2] = tmp.tmp0_2;
result.values[row + 1][column + 0] = tmp.tmp1_0; result.values[row + 1][column + 1] = tmp.tmp1_1; result.values[row + 1][column + 2] = tmp.tmp1_2;
result.values[row + 2][column + 0] = tmp.tmp2_0; result.values[row + 2][column + 1] = tmp.tmp2_1; result.values[row + 2][column + 2] = tmp.tmp2_2;
result.values[row + 3][column + 0] = tmp.tmp3_0; result.values[row + 3][column + 1] = tmp.tmp3_1; result.values[row + 3][column + 2] = tmp.tmp3_2;
}
for (; column < columns; column++)
{
var BTRow = BT.values[column].AsSpan();
var BTRowVector = MemoryMarshal.Cast<T, Vector<T>>(BTRow);
var tmpVector0 = Vector<T>.Zero;
var tmpVector1 = Vector<T>.Zero;
var tmpVector2 = Vector<T>.Zero;
var tmpVector3 = Vector<T>.Zero;
for (var i = 0; i < ARowVector0.Length; ++i)
{
tmpVector0 += ARowVector0[i] * BTRowVector[i];
tmpVector1 += ARowVector1[i] * BTRowVector[i];
tmpVector2 += ARowVector2[i] * BTRowVector[i];
tmpVector3 += ARowVector3[i] * BTRowVector[i];
}
var tmp0 = T.Zero;
var tmp1 = T.Zero;
var tmp2 = T.Zero;
var tmp3 = T.Zero;
for (var i = 0; i < Vector<T>.Count; ++i)
{
tmp0 += tmpVector0[i];
tmp1 += tmpVector1[i];
tmp2 += tmpVector2[i];
tmp3 += tmpVector3[i];
}
for (var i = ARowVector0.Length * Vector<T>.Count; i < ARow0.Length; ++i)
{
tmp0 += ARow0[i] * BTRow[i];
tmp1 += ARow1[i] * BTRow[i];
tmp2 += ARow2[i] * BTRow[i];
tmp3 += ARow3[i] * BTRow[i];
}
result.values[row + 0][column] = tmp0;
result.values[row + 1][column] = tmp1;
result.values[row + 2][column] = tmp2;
result.values[row + 3][column] = tmp3;
}
}
else
{
for (; row < rows; ++row)
{
var ARow = A.values[row].AsSpan();
var ARowVector = MemoryMarshal.Cast<T, Vector<T>>(ARow);
for (var column = 0; column < columns; column++)
{
var BTRow = BT.values[column].AsSpan();
var BTRowVector = MemoryMarshal.Cast<T, Vector<T>>(BTRow);
var tmpVector = Vector<T>.Zero;
for (var i = 0; i < ARowVector.Length; ++i)
tmpVector += ARowVector[i] * BTRowVector[i];
var tmp = T.Zero;
for (var i = 0; i < Vector<T>.Count; ++i)
tmp += tmpVector[i];
for (var i = ARowVector.Length * Vector<T>.Count; i < ARow.Length; ++i)
tmp += ARow[i] * BTRow[i];
result.values[row][column] = tmp;
}
}
}
});
return result;
}
public static FMatrix<T> Multiple_SIMD4_3(FMatrix<T> A, FMatrix<T> B)
{
if (A.Columns != B.Rows)
throw new ArgumentException("(A.Columns != B.Rows)");
var BT = Transpose(B);
var rows = A.Rows;
var columns = B.Columns;
var result = new FMatrix<T>(rows, columns);
var row = 0;
for (; row < rows - 4 + 1; row += 4)
{
var ARow0 = A.values[row + 0].AsSpan();
var ARow1 = A.values[row + 1].AsSpan();
var ARow2 = A.values[row + 2].AsSpan();
var ARow3 = A.values[row + 3].AsSpan();
var ARowVector0 = MemoryMarshal.Cast<T, Vector<T>>(ARow0);
var ARowVector1 = MemoryMarshal.Cast<T, Vector<T>>(ARow1);
var ARowVector2 = MemoryMarshal.Cast<T, Vector<T>>(ARow2);
var ARowVector3 = MemoryMarshal.Cast<T, Vector<T>>(ARow3);
var column = 0;
for (; column < columns - 3 + 1; column += 3)
{
var BTRow0 = BT.values[column + 0].AsSpan();
var BTRow1 = BT.values[column + 1].AsSpan();
var BTRow2 = BT.values[column + 2].AsSpan();
var BTRowVector0 = MemoryMarshal.Cast<T, Vector<T>>(BTRow0);
var BTRowVector1 = MemoryMarshal.Cast<T, Vector<T>>(BTRow1);
var BTRowVector2 = MemoryMarshal.Cast<T, Vector<T>>(BTRow2);
InnerLoop4_3(ARowVector0, ARowVector1, ARowVector2, ARowVector3, BTRowVector0, BTRowVector1, BTRowVector2, out var tmp);
for (var i = ARowVector0.Length * Vector<T>.Count; i < ARow0.Length; ++i)
{
tmp.tmp0_0 += ARow0[i] * BTRow0[i]; tmp.tmp0_1 += ARow0[i] * BTRow1[i]; tmp.tmp0_2 += ARow0[i] * BTRow2[i];
tmp.tmp1_0 += ARow1[i] * BTRow0[i]; tmp.tmp1_1 += ARow1[i] * BTRow1[i]; tmp.tmp1_2 += ARow1[i] * BTRow2[i];
tmp.tmp2_0 += ARow2[i] * BTRow0[i]; tmp.tmp2_1 += ARow2[i] * BTRow1[i]; tmp.tmp2_2 += ARow2[i] * BTRow2[i];
tmp.tmp3_0 += ARow3[i] * BTRow0[i]; tmp.tmp3_1 += ARow3[i] * BTRow1[i]; tmp.tmp3_2 += ARow3[i] * BTRow2[i];
}
result.values[row + 0][column + 0] = tmp.tmp0_0; result.values[row + 0][column + 1] = tmp.tmp0_1; result.values[row + 0][column + 2] = tmp.tmp0_2;
result.values[row + 1][column + 0] = tmp.tmp1_0; result.values[row + 1][column + 1] = tmp.tmp1_1; result.values[row + 1][column + 2] = tmp.tmp1_2;
result.values[row + 2][column + 0] = tmp.tmp2_0; result.values[row + 2][column + 1] = tmp.tmp2_1; result.values[row + 2][column + 2] = tmp.tmp2_2;
result.values[row + 3][column + 0] = tmp.tmp3_0; result.values[row + 3][column + 1] = tmp.tmp3_1; result.values[row + 3][column + 2] = tmp.tmp3_2;
}
for (; column < columns; column++)
{
var BTRow = BT.values[column].AsSpan();
var BTRowVector = MemoryMarshal.Cast<T, Vector<T>>(BTRow);
var tmpVector0 = Vector<T>.Zero;
var tmpVector1 = Vector<T>.Zero;
var tmpVector2 = Vector<T>.Zero;
var tmpVector3 = Vector<T>.Zero;
for (var i = 0; i < ARowVector0.Length; ++i)
{
tmpVector0 += ARowVector0[i] * BTRowVector[i];
tmpVector1 += ARowVector1[i] * BTRowVector[i];
tmpVector2 += ARowVector2[i] * BTRowVector[i];
tmpVector3 += ARowVector3[i] * BTRowVector[i];
}
var tmp0 = T.Zero;
var tmp1 = T.Zero;
var tmp2 = T.Zero;
var tmp3 = T.Zero;
for (var i = 0; i < Vector<T>.Count; ++i)
{
tmp0 += tmpVector0[i];
tmp1 += tmpVector1[i];
tmp2 += tmpVector2[i];
tmp3 += tmpVector3[i];
}
for (var i = ARowVector0.Length * Vector<T>.Count; i < ARow0.Length; ++i)
{
tmp0 += ARow0[i] * BTRow[i];
tmp1 += ARow1[i] * BTRow[i];
tmp2 += ARow2[i] * BTRow[i];
tmp3 += ARow3[i] * BTRow[i];
}
result.values[row + 0][column] = tmp0;
result.values[row + 1][column] = tmp1;
result.values[row + 2][column] = tmp2;
result.values[row + 3][column] = tmp3;
}
}
for (; row < rows; ++row)
{
var ARow = A.values[row].AsSpan();
var ARowVector = MemoryMarshal.Cast<T, Vector<T>>(ARow);
for (var column = 0; column < columns; column++)
{
var BTRow = BT.values[column].AsSpan();
var BTRowVector = MemoryMarshal.Cast<T, Vector<T>>(BTRow);
var tmpVector = Vector<T>.Zero;
for (var i = 0; i < ARowVector.Length; ++i)
tmpVector += ARowVector[i] * BTRowVector[i];
var tmp = T.Zero;
for (var i = 0; i < Vector<T>.Count; ++i)
tmp += tmpVector[i];
for (var i = ARowVector.Length * Vector<T>.Count; i < ARow.Length; ++i)
tmp += ARow[i] * BTRow[i];
result.values[row][column] = tmp;
}
}
return result;
}
}
@manofstick
Copy link
Author

    private static T[][] StripedA(T[][] A)
    {
        var rows = A.Length;
        var columns = A[0].Length;

        var stripedRows = (rows + Vector<T>.Count - 1) / Vector<T>.Count;
        var stripedColumns = ((columns * Vector<T>.Count) + Vector<T>.Count - 1) / Vector<T>.Count * Vector<T>.Count;

        var striped = new T[stripedRows][];
        for (var stripedRowIdx = 0; stripedRowIdx < stripedRows; ++stripedRowIdx)
        {
            var rowIdx = stripedRowIdx * Vector<T>.Count;
            var stripedRow = new T[stripedColumns];
            striped[stripedRowIdx] = stripedRow;
            for (var columnIdx = 0; columnIdx < columns; ++columnIdx)
            {
                for (var i = 0; i < Vector<T>.Count; ++i)
                {
                    if (rowIdx + i < rows)
                        stripedRow[(columnIdx * Vector<T>.Count) + i] = A[rowIdx + i][columnIdx];
                }
            }
        }
        return striped;
    }

    private static T[][] StripedB(T[][] B)
    {
        var rows = B.Length;
        var columns = B[0].Length;

        var stripedRows = (columns + Vector<T>.Count - 1) / Vector<T>.Count;
        var stripedColumns = ((rows * Vector<T>.Count) + Vector<T>.Count - 1) / Vector<T>.Count * Vector<T>.Count;

        var striped = new T[stripedRows][];
        for (var stripedRowIdx = 0; stripedRowIdx < stripedRows; ++stripedRowIdx)
        {
            var columnIdx = stripedRowIdx * Vector<T>.Count;
            var stripedRow = new T[stripedColumns];
            striped[stripedRowIdx] = stripedRow;
            for (var rowIdx = 0; rowIdx < rows; ++rowIdx)
            {
                for (var i = 0; i < Vector<T>.Count; ++i)
                {
                    if (columnIdx + i < columns)
                        stripedRow[(rowIdx * Vector<T>.Count) + i] = B[rowIdx][columnIdx + i];
                }
            }
        }
        return striped;
    }

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment