Last active
November 15, 2024 06:58
-
-
Save rygorous/991df2d8dae81ba1a844a2b0f99559ea to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
template <int t_nbits> | |
static inline void quant_endpoint_with_pbit(U8 *deq0, U8 *deq1, int val) | |
{ | |
const int expanded_nbits = t_nbits + 1; | |
const U32 range = 1u << expanded_nbits; | |
const U32 recip255 = 0x8081; // enough bits for our value range | |
const int postscale = (0x10000 >> t_nbits) + (0x10000 >> (t_nbits*2 + 1)); | |
// The reconstruction here adds the pbit as the lowest bit and then reconstructs | |
// it as a (nbits+1)-bit value to float, i.e. (quant*2 + pbit) / (range - 1). | |
// Consider the two cases separately: | |
// pbit=0 reconstructs (quant*2) / (range-1) = quant / ((range-1) / 2) | |
// pbit=1 reconstructs (quant*2+1) / (range-1) = quant / ((range-1) / 2) + 1/(range-1) | |
// | |
// the former is a uniform quantizer with a step size of 0.5 / (range - 1) | |
// -> quantize for that with the usual 1/2 rounding bias (see above in quant_endpoint). | |
// | |
// the latter is biased by 1/(range-1) which works out to needing a 0 rounding bias | |
// (i.e. truncating). | |
// | |
// "quant" here is t_nbits wide; we then expand with the p-bit value in the | |
// right place. | |
// The math for quant here is | |
// quantP = (val * (range - 1) + (p == 0 ? 254 : 0)) / 510 | |
// except we use a sufficient-precision reciprocal (using that val*(range-1) + bias | |
// fits in 16 bits). In this scalar version we fuse the mul by (range-1)*recip255 | |
// into one larger constant, in the SIMD version we keep them separate since two | |
// 16x16 multiplies (one low half, one high half) are much cheaper than going to | |
// 32 bits. | |
U32 prescaled = val * ((range - 1) * recip255); | |
U32 quant0 = (prescaled + 254*recip255) >> 24; // quant for pbit=0 | |
U32 quant1 = prescaled >> 24; // quant for pbit=1 | |
// dequantize back to 8 bits | |
*deq0 = static_cast<U8>((quant0 * postscale) >> 8); | |
*deq1 = static_cast<U8>(((quant1 * postscale) >> 8) | (128 >> t_nbits)); | |
} | |
// ... | |
const Vec128_U16 recip255 { 0x8081 }; // ceil(2^16 * 128 / 255); accurate enough for a full 16 bits | |
const int cpb = 256 >> (cb + 1); | |
const int apb = 256 >> (ab + 1); | |
// Follow the explanation in quant_endpoint_with_pbit above | |
const U16 quant_cb = (2 << cb) - 1; | |
const U16 quant_ab = (2 << ab) - 1; | |
const U16 dequant_cb = ((0x10000 >> cb) + (0x10000 >> (2*cb + 1))); | |
const U16 dequant_ab = (ab != 0) ? (((0x10000 >> ab) + (0x10000 >> (2*ab + 1))) & 0xffff) : 0; // & 0xffff to fix Clang warning in the ab=0 case | |
const Vec128_U16 quant_scale = Vec128_U16::repeat4(quant_cb, quant_cb, quant_cb, quant_ab); | |
const Vec128_U16 dequant_scale = Vec128_U16::repeat4(dequant_cb,dequant_cb,dequant_cb,dequant_ab); | |
const Vec128_U16 pbit_value = Vec128_U16::repeat4(cpb,cpb,cpb,apb); | |
const Vec128_U16 himask = Vec128_U16(0xff00); | |
// Quantize two ways, once assuming pbit=0 and once assuming pbit=1 | |
Vec128_U16 endpoint1_prediv = endpoints16 * quant_scale; // pbit=1 value has bias of 0 | |
Vec128_U16 endpoint0_prediv = endpoint1_prediv + Vec128_U16(254); // pbit=0 value has bias of 254 | |
Vec128_U16 quant0 = endpoint0_prediv.mulhi(recip255) & himask; | |
Vec128_U16 quant1 = endpoint1_prediv.mulhi(recip255) & himask; | |
// quantX is now 256 * (endpointX_prediv / 510) | |
// Dequant, add the pbit in | |
Vec128_U16 dequant0 = quant0.mulhi(dequant_scale); | |
Vec128_U16 dequant1 = quant1.mulhi(dequant_scale) | pbit_value; | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment