Skip to content

Instantly share code, notes, and snippets.

@rygorous
Last active November 15, 2024 06:58
Show Gist options
  • Save rygorous/991df2d8dae81ba1a844a2b0f99559ea to your computer and use it in GitHub Desktop.
Save rygorous/991df2d8dae81ba1a844a2b0f99559ea to your computer and use it in GitHub Desktop.
template <int t_nbits>
static inline void quant_endpoint_with_pbit(U8 *deq0, U8 *deq1, int val)
{
const int expanded_nbits = t_nbits + 1;
const U32 range = 1u << expanded_nbits;
const U32 recip255 = 0x8081; // enough bits for our value range
const int postscale = (0x10000 >> t_nbits) + (0x10000 >> (t_nbits*2 + 1));
// The reconstruction here adds the pbit as the lowest bit and then reconstructs
// it as a (nbits+1)-bit value to float, i.e. (quant*2 + pbit) / (range - 1).
// Consider the two cases separately:
// pbit=0 reconstructs (quant*2) / (range-1) = quant / ((range-1) / 2)
// pbit=1 reconstructs (quant*2+1) / (range-1) = quant / ((range-1) / 2) + 1/(range-1)
//
// the former is a uniform quantizer with a step size of 0.5 / (range - 1)
// -> quantize for that with the usual 1/2 rounding bias (see above in quant_endpoint).
//
// the latter is biased by 1/(range-1) which works out to needing a 0 rounding bias
// (i.e. truncating).
//
// "quant" here is t_nbits wide; we then expand with the p-bit value in the
// right place.
// The math for quant here is
// quantP = (val * (range - 1) + (p == 0 ? 254 : 0)) / 510
// except we use a sufficient-precision reciprocal (using that val*(range-1) + bias
// fits in 16 bits). In this scalar version we fuse the mul by (range-1)*recip255
// into one larger constant, in the SIMD version we keep them separate since two
// 16x16 multiplies (one low half, one high half) are much cheaper than going to
// 32 bits.
U32 prescaled = val * ((range - 1) * recip255);
U32 quant0 = (prescaled + 254*recip255) >> 24; // quant for pbit=0
U32 quant1 = prescaled >> 24; // quant for pbit=1
// dequantize back to 8 bits
*deq0 = static_cast<U8>((quant0 * postscale) >> 8);
*deq1 = static_cast<U8>(((quant1 * postscale) >> 8) | (128 >> t_nbits));
}
// ...
const Vec128_U16 recip255 { 0x8081 }; // ceil(2^16 * 128 / 255); accurate enough for a full 16 bits
const int cpb = 256 >> (cb + 1);
const int apb = 256 >> (ab + 1);
// Follow the explanation in quant_endpoint_with_pbit above
const U16 quant_cb = (2 << cb) - 1;
const U16 quant_ab = (2 << ab) - 1;
const U16 dequant_cb = ((0x10000 >> cb) + (0x10000 >> (2*cb + 1)));
const U16 dequant_ab = (ab != 0) ? (((0x10000 >> ab) + (0x10000 >> (2*ab + 1))) & 0xffff) : 0; // & 0xffff to fix Clang warning in the ab=0 case
const Vec128_U16 quant_scale = Vec128_U16::repeat4(quant_cb, quant_cb, quant_cb, quant_ab);
const Vec128_U16 dequant_scale = Vec128_U16::repeat4(dequant_cb,dequant_cb,dequant_cb,dequant_ab);
const Vec128_U16 pbit_value = Vec128_U16::repeat4(cpb,cpb,cpb,apb);
const Vec128_U16 himask = Vec128_U16(0xff00);
// Quantize two ways, once assuming pbit=0 and once assuming pbit=1
Vec128_U16 endpoint1_prediv = endpoints16 * quant_scale; // pbit=1 value has bias of 0
Vec128_U16 endpoint0_prediv = endpoint1_prediv + Vec128_U16(254); // pbit=0 value has bias of 254
Vec128_U16 quant0 = endpoint0_prediv.mulhi(recip255) & himask;
Vec128_U16 quant1 = endpoint1_prediv.mulhi(recip255) & himask;
// quantX is now 256 * (endpointX_prediv / 510)
// Dequant, add the pbit in
Vec128_U16 dequant0 = quant0.mulhi(dequant_scale);
Vec128_U16 dequant1 = quant1.mulhi(dequant_scale) | pbit_value;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment