mirror of https://github.com/tracel-ai/burn.git
aa79e36a8d
* Add cubecl quantization kernels and QTensorOps for burn-jit * Fix typo * Fix output vec factor * Fix output dtype size_of * Remove unused code in dequantize test * Fix dequantize vectorization * Handle tensors when number of elems is not a multiple of 4 * Support quantize for tensors with less than 4 elems (no vectorization) * Fix equal 0 test * Add quantize/dequantize tests * Add q_to_device * Refactor kernels for latest cubecl * intermediate i32 cast * Fix size_of output type * Use strict=false to ignore floating point precision issues with qparams equality * Only check that lhs & rhs strategies match (but not strict on qparams values) * Use assert_approx_eq on dequant values * Reduce precision for flaky test * Remove todo comment * Add comment for cast to unsigned * More comment --------- Co-authored-by: louisfd <louisfd94@gmail.com> |
||
---|---|---|
.. | ||
burn | ||
burn-autodiff | ||
burn-candle | ||
burn-common | ||
burn-core | ||
burn-cuda | ||
burn-dataset | ||
burn-derive | ||
burn-fusion | ||
burn-import | ||
burn-jit | ||
burn-ndarray | ||
burn-no-std-tests | ||
burn-tch | ||
burn-tensor | ||
burn-tensor-testgen | ||
burn-train | ||
burn-wgpu | ||
onnx-ir |