Created
June 16, 2020 13:36
-
-
Save YashasSamaga/c3ee66732ff3c2b07cd48ea5bd7fb4e1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <cuda_runtime.h> | |
#include <random> | |
#include <iostream> | |
struct relu_grad | |
{ | |
__device__ float operator()(float x) { return x > 0; } | |
}; | |
struct mish_grad_dn | |
{ | |
__device__ float softplus_kernel(float x, float threshold = 20) | |
{ | |
if (x > threshold) return x; | |
else if (x < -threshold) return expf(x); | |
return log1pf(expf(x)); | |
} | |
__device__ float operator()(float x) | |
{ | |
const float MISH_THRESHOLD = 20.0f; | |
const float sp = softplus_kernel(x, MISH_THRESHOLD); | |
const float grad_sp = -expm1f(-sp); | |
const float tsp = tanh(sp); | |
const float grad_tsp = (1 - tsp*tsp) * grad_sp; | |
const float grad = x * grad_tsp + tsp; | |
return grad; | |
} | |
}; | |
struct mish_grad_tb | |
{ | |
__device__ float operator()(float x) | |
{ | |
const float THRESHOLD = 20.0f; | |
const float sp = x < THRESHOLD ? log1p(expf(x)) : x; | |
const float grad_sp = 1 - exp(-sp); | |
const float tsp = tanh(sp); | |
const float grad_tsp = (1 - tsp*tsp) * grad_sp; | |
const float grad = x * grad_tsp + tsp; | |
return grad; | |
} | |
}; | |
struct mish_grad_tb_expm1 | |
{ | |
__device__ float operator()(float x) | |
{ | |
const float THRESHOLD = 20.0f; | |
const float sp = x < THRESHOLD ? log1p(expf(x)) : x; | |
const float grad_sp = -expm1(-sp); | |
const float tsp = tanh(sp); | |
const float grad_tsp = (1 - tsp*tsp) * grad_sp; | |
const float grad = x * grad_tsp + tsp; | |
return grad; | |
} | |
}; | |
struct mish_grad_fast | |
{ | |
__device__ float operator()(float x) | |
{ | |
auto e = __expf(x); | |
auto n = e * e + 2 * e; | |
float tsp; | |
if (x <= -0.6f) | |
tsp = __fdividef(n, n + 2); | |
else | |
tsp = 1 - __fdividef(2, n + 2); | |
const float grad_sp = __fdividef(e, e + 1); | |
const float grad_tsp = (1 - tsp*tsp) * grad_sp; | |
const float grad = x * grad_tsp + tsp; | |
return x > 10.5f ? 1 : grad; | |
} | |
}; | |
struct mish_grad_double | |
{ | |
__device__ float operator()(float x) | |
{ | |
const double sp = log1p(exp(x)); | |
const double grad_sp = -expm1(-sp); | |
const double tsp = tanh(sp); | |
const double grad_tsp = (1 - tsp*tsp) * grad_sp; | |
const double grad = x * grad_tsp + tsp; | |
return grad; | |
} | |
}; | |
template <class GradientFunc> | |
__global__ void grad_vec1(float* __restrict__ dz, const float* __restrict__ input, int n) | |
{ | |
GradientFunc grad; | |
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) | |
{ | |
dz[i] *= grad(input[i]); | |
} | |
} | |
template <class GradientFunc> | |
__global__ void grad_vec4(float4* __restrict__ dz, const float4* __restrict__ input, int n) | |
{ | |
GradientFunc grad; | |
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) | |
{ | |
float4 temp = input[i]; | |
float4 dy = dz[i]; | |
dy.w *= grad(temp.w); | |
dy.x *= grad(temp.x); | |
dy.y *= grad(temp.y); | |
dy.z *= grad(temp.z); | |
dz[i] = dy; | |
} | |
} | |
__global__ void limit_2L1S_v1(float * __restrict__ dz, const float * __restrict__ input, int n) | |
{ | |
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) | |
dz[i] += input[i]; | |
} | |
__global__ void limit_2L1S_v4(float4 * __restrict__ dz, const float4 * __restrict__ input, int n) | |
{ | |
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) | |
{ | |
auto dy = dz[i]; | |
auto inp = input[i]; | |
dy.w += inp.w; | |
dy.x += inp.x; | |
dy.y += inp.y; | |
dy.z += inp.z; | |
dz[i] = dy; | |
} | |
} | |
template <class GradientFunc> | |
__global__ void dump() | |
{ | |
GradientFunc grad; | |
for (float x = -100; x < 20; x += 0.0001) | |
printf("%.7f %.7e\n", x, grad(x)); | |
} | |
int main () | |
{ | |
if (1) | |
{ | |
dump<mish_grad_tb><<<1, 1>>>(); | |
cudaDeviceSynchronize(); | |
return 0; | |
} | |
constexpr int N = 1024 * 1024 * 16; | |
float *input_activation; | |
float *grad; | |
cudaMalloc(&input_activation, N * sizeof(float)); | |
cudaMalloc(&grad, N * sizeof(float)); | |
float *input_activation_h = new float[N]; | |
float *grad_h = new float[N]; | |
float *output_h = new float[N]; | |
float *output_ref = new float[N]; | |
std::random_device rd; | |
std::mt19937 gen(rd()); | |
std::uniform_real_distribution<float> in_dis(-50, 50); | |
for (int i = 0; i < N; i++) | |
{ | |
long double a = in_dis(gen); | |
input_activation_h[i] = a; | |
long double dy = 1.0; | |
grad_h[i] = dy; | |
const long double sp = std::log1p(std::exp(a)); | |
const long double grad_sp = -std::expm1(-sp); | |
const long double tsp = std::tanh(sp); | |
const long double grad_tsp = (1 - tsp * tsp) * grad_sp; | |
const long double grad = a * grad_tsp + tsp; | |
output_ref[i] = dy * grad; | |
} | |
auto lInorm = [&] (float* x, float* y, int n) { | |
float max = 0; | |
for (int i = 0; i < n; i++) | |
max = std::max(max, std::abs(y[i] - x[i])); | |
return max; | |
}; | |
auto l2norm = [] (float* x, float* y, int n) { | |
std::vector<double> diff(n); | |
for (int i = 0; i < n; i++) | |
diff[i] = y[i] - x[i]; | |
auto sqr_sum = std::accumulate(std::begin(diff), std::end(diff), 0.0, [](auto lhs, auto rhs) { return lhs + rhs * rhs; }); | |
return std::sqrt(sqr_sum); | |
}; | |
auto grad4 = reinterpret_cast<float4*>(grad); | |
auto input_activation4 = reinterpret_cast<float4*>(input_activation); | |
// vec1 | |
cudaMemcpy(input_activation, input_activation_h, N * sizeof(float), cudaMemcpyHostToDevice); | |
cudaMemcpy(grad, grad_h, N * sizeof(float), cudaMemcpyHostToDevice); | |
grad_vec1<relu_grad><<<10, 1024>>>(grad, input_activation, N); | |
cudaDeviceSynchronize(); | |
cudaMemcpy(input_activation, input_activation_h, N * sizeof(float), cudaMemcpyHostToDevice); | |
cudaMemcpy(grad, grad_h, N * sizeof(float), cudaMemcpyHostToDevice); | |
grad_vec4<relu_grad><<<10, 1024>>>(grad4, input_activation4, N / 4); | |
cudaDeviceSynchronize(); | |
cudaMemcpy(input_activation, input_activation_h, N * sizeof(float), cudaMemcpyHostToDevice); | |
cudaMemcpy(grad, grad_h, N * sizeof(float), cudaMemcpyHostToDevice); | |
limit_2L1S_v1<<<10, 1024>>>(grad, input_activation, N); | |
cudaDeviceSynchronize(); | |
cudaMemcpy(input_activation, input_activation_h, N * sizeof(float), cudaMemcpyHostToDevice); | |
cudaMemcpy(grad, grad_h, N * sizeof(float), cudaMemcpyHostToDevice); | |
limit_2L1S_v4<<<10, 1024>>>(grad4, input_activation4, N / 4); | |
cudaDeviceSynchronize(); | |
cudaMemcpy(input_activation, input_activation_h, N * sizeof(float), cudaMemcpyHostToDevice); | |
cudaMemcpy(grad, grad_h, N * sizeof(float), cudaMemcpyHostToDevice); | |
grad_vec1<mish_grad_dn><<<10, 1024>>>(grad, input_activation, N); | |
cudaDeviceSynchronize(); | |
cudaMemcpy(output_h, grad, N * sizeof(float), cudaMemcpyDeviceToHost); | |
std::cout << "[vec1] mish_grad_dn: " << l2norm(output_ref, output_h, N) << ' ' << lInorm(output_ref, output_h, N) << '\n'; | |
cudaMemcpy(input_activation, input_activation_h, N * sizeof(float), cudaMemcpyHostToDevice); | |
cudaMemcpy(grad, grad_h, N * sizeof(float), cudaMemcpyHostToDevice); | |
grad_vec4<mish_grad_dn><<<10, 1024>>>(grad4, input_activation4, N / 4); | |
cudaDeviceSynchronize(); | |
cudaMemcpy(output_h, grad, N * sizeof(float), cudaMemcpyDeviceToHost); | |
std::cout << "[vec1] mish_grad_dn: " << l2norm(output_ref, output_h, N) << ' ' << lInorm(output_ref, output_h, N) << '\n'; | |
cudaMemcpy(input_activation, input_activation_h, N * sizeof(float), cudaMemcpyHostToDevice); | |
cudaMemcpy(grad, grad_h, N * sizeof(float), cudaMemcpyHostToDevice); | |
grad_vec1<mish_grad_tb><<<10, 1024>>>(grad, input_activation, N); | |
cudaDeviceSynchronize(); | |
cudaMemcpy(output_h, grad, N * sizeof(float), cudaMemcpyDeviceToHost); | |
std::cout << "[vec1] mish_grad_tb: " << l2norm(output_ref, output_h, N) << ' ' << lInorm(output_ref, output_h, N) << '\n'; | |
cudaMemcpy(input_activation, input_activation_h, N * sizeof(float), cudaMemcpyHostToDevice); | |
cudaMemcpy(grad, grad_h, N * sizeof(float), cudaMemcpyHostToDevice); | |
grad_vec4<mish_grad_tb><<<10, 1024>>>(grad4, input_activation4, N / 4); | |
cudaDeviceSynchronize(); | |
cudaMemcpy(output_h, grad, N * sizeof(float), cudaMemcpyDeviceToHost); | |
std::cout << "[vec1] mish_grad_tb: " << l2norm(output_ref, output_h, N) << ' ' << lInorm(output_ref, output_h, N) << '\n'; | |
cudaMemcpy(input_activation, input_activation_h, N * sizeof(float), cudaMemcpyHostToDevice); | |
cudaMemcpy(grad, grad_h, N * sizeof(float), cudaMemcpyHostToDevice); | |
grad_vec1<mish_grad_tb_expm1><<<10, 1024>>>(grad, input_activation, N); | |
cudaDeviceSynchronize(); | |
cudaMemcpy(output_h, grad, N * sizeof(float), cudaMemcpyDeviceToHost); | |
std::cout << "[vec1] mish_grad_tb_expm1: " << l2norm(output_ref, output_h, N) << ' ' << lInorm(output_ref, output_h, N) << '\n'; | |
cudaMemcpy(input_activation, input_activation_h, N * sizeof(float), cudaMemcpyHostToDevice); | |
cudaMemcpy(grad, grad_h, N * sizeof(float), cudaMemcpyHostToDevice); | |
grad_vec4<mish_grad_tb_expm1><<<10, 1024>>>(grad4, input_activation4, N / 4); | |
cudaDeviceSynchronize(); | |
cudaMemcpy(output_h, grad, N * sizeof(float), cudaMemcpyDeviceToHost); | |
std::cout << "[vec1] mish_grad_tb_expm1: " << l2norm(output_ref, output_h, N) << ' ' << lInorm(output_ref, output_h, N) << '\n'; | |
cudaMemcpy(input_activation, input_activation_h, N * sizeof(float), cudaMemcpyHostToDevice); | |
cudaMemcpy(grad, grad_h, N * sizeof(float), cudaMemcpyHostToDevice); | |
grad_vec1<mish_grad_fast><<<10, 1024>>>(grad, input_activation, N); | |
cudaDeviceSynchronize(); | |
cudaMemcpy(output_h, grad, N * sizeof(float), cudaMemcpyDeviceToHost); | |
std::cout << "[vec1] mish_grad_fast: " << l2norm(output_ref, output_h, N) << ' ' << lInorm(output_ref, output_h, N) << '\n'; | |
cudaMemcpy(input_activation, input_activation_h, N * sizeof(float), cudaMemcpyHostToDevice); | |
cudaMemcpy(grad, grad_h, N * sizeof(float), cudaMemcpyHostToDevice); | |
grad_vec4<mish_grad_fast><<<10, 1024>>>(grad4, input_activation4, N / 4); | |
cudaDeviceSynchronize(); | |
cudaMemcpy(output_h, grad, N * sizeof(float), cudaMemcpyDeviceToHost); | |
std::cout << "[vec1] mish_grad_fast: " << l2norm(output_ref, output_h, N) << ' ' << lInorm(output_ref, output_h, N) << '\n'; | |
cudaMemcpy(input_activation, input_activation_h, N * sizeof(float), cudaMemcpyHostToDevice); | |
cudaMemcpy(grad, grad_h, N * sizeof(float), cudaMemcpyHostToDevice); | |
grad_vec1<mish_grad_double><<<10, 1024>>>(grad, input_activation, N); | |
cudaDeviceSynchronize(); | |
cudaMemcpy(output_h, grad, N * sizeof(float), cudaMemcpyDeviceToHost); | |
std::cout << "[vec1] mish_grad_double: " << l2norm(output_ref, output_h, N) << ' ' << lInorm(output_ref, output_h, N) << '\n'; | |
cudaMemcpy(input_activation, input_activation_h, N * sizeof(float), cudaMemcpyHostToDevice); | |
cudaMemcpy(grad, grad_h, N * sizeof(float), cudaMemcpyHostToDevice); | |
grad_vec4<mish_grad_double><<<10, 1024>>>(grad4, input_activation4, N / 4); | |
cudaDeviceSynchronize(); | |
cudaMemcpy(output_h, grad, N * sizeof(float), cudaMemcpyDeviceToHost); | |
std::cout << "[vec1] mish_grad_double: " << l2norm(output_ref, output_h, N) << ' ' << lInorm(output_ref, output_h, N) << '\n'; | |
return 0; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
SMOOTHING_STEP_SIZE = 1000 | |
LEFT_X_CUTOFF = -100 | |
RIGHT_X_CUTOFF = 100 | |
def ref_mish(x): | |
return x * np.tanh(np.log1p(np.exp(x))) | |
def ref_grad(x): | |
sp = np.log1p(np.exp(x)) | |
grad_sp = -np.expm1(-sp) | |
tsp = np.tanh(sp) | |
grad_tsp = (1 - tsp * tsp) * grad_sp | |
return x * grad_tsp + tsp | |
def generate_stats(src): | |
x_list = [] | |
y_list = [] | |
with open(src, "r") as f: | |
for line in f.readlines(): | |
x, y = [float(field.strip()) for field in line.split(' ')] | |
if LEFT_X_CUTOFF < x and x < RIGHT_X_CUTOFF: | |
x_list.append(x) | |
y_list.append(y) | |
rel_error_log10 = [] | |
abs_diff_err = [] | |
for x, y in zip(x_list, y_list): | |
x128 = np.float128(x) | |
y128 = np.float128(y) | |
ref = ref_grad(x128) | |
diff = np.abs(y128 - ref) | |
rerr = -np.Inf if diff == 0 else np.log10(np.abs(diff / ref)) | |
log_diff = 0 if diff == 0 else np.log10(diff) | |
rel_error_log10.append(float(rerr)) | |
abs_diff_err.append(float(diff)) | |
# smoothing | |
x_final = [] | |
rel_error_log10_final = [] | |
abs_diff_err_final = [] | |
for step in range(len(x_list) // SMOOTHING_STEP_SIZE): | |
ibegin = step * SMOOTHING_STEP_SIZE | |
iend = ibegin + SMOOTHING_STEP_SIZE | |
avg_x = np.mean(x_list[ibegin : iend]) | |
max_rel_err_log10 = np.max(rel_error_log10[ibegin : iend]) | |
max_diff_err = np.max(abs_diff_err[ibegin : iend]) | |
x_final.append(avg_x) | |
rel_error_log10_final.append(max_rel_err_log10) | |
abs_diff_err_final.append(max_diff_err) | |
return x_final, rel_error_log10_final, abs_diff_err_final | |
x_double, re_double, ad_double = generate_stats("dump_fast_grad") | |
x_ocv, re_ocv, ad_ocv = generate_stats("dump_dn_grad") | |
x_tb, re_tb, ad_tb = generate_stats("dump_tb_grad") | |
import matplotlib.pyplot as plt | |
linewidth = 0.5 | |
fig, ax = plt.subplots(1, 3) | |
ax[0].plot(x_double, re_double, linewidth = linewidth, c = 'g', label = "fast grad") | |
ax[0].plot(x_ocv, re_ocv, linewidth = linewidth, c = 'r', label = "darknet") | |
ax[0].plot(x_tb, re_tb, linewidth = linewidth, c = 'b', label = "tb") | |
ax[0].set_title("relative error (log10)") | |
ax[0].legend() | |
ax[1].plot(x_double, ad_double, linewidth = linewidth, c = 'g', label = "fast grad") | |
ax[1].plot(x_ocv, ad_ocv, linewidth = linewidth, c = 'r', label = "darknet") | |
ax[1].plot(x_tb, ad_tb, linewidth = linewidth, c = 'b', label = "tb") | |
ax[1].set_title("abs(diff)") | |
ax[1].legend() | |
ax[2].plot(x_double, [np.log10(a) for a in ad_double], linewidth = linewidth, c = 'g', label = "fast grad") | |
ax[2].plot(x_ocv, [np.log10(a) for a in ad_ocv], linewidth = linewidth, c = 'r', label = "darknet") | |
ax[2].plot(x_tb, [np.log10(a) for a in ad_tb], linewidth = linewidth, c = 'b', label = "tb") | |
ax[2].set_title("log10(abs(diff))") | |
ax[2].legend() | |
plt.show() | |
print(np.max(re_ocv), np.max(ad_ocv)) | |
print(x_ocv[np.argmax(re_ocv)], x_ocv[np.argmax(ad_ocv)]) | |
print(np.max(re_tb), np.max(ad_tb)) | |
print(x_tb[np.argmax(re_tb)], x_tb[np.argmax(ad_tb)]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment