mirror of https://github.com/tracel-ai/burn.git
fix: mix precision training on tch backend (#209)
This commit is contained in:
parent
be96160065
commit
cf7847acb5
|
@ -343,7 +343,7 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
}
|
||||
|
||||
fn to_full_precision<const D: usize>(tensor: &TchTensor<E, D>) -> TchTensor<f32, D> {
|
||||
let tensor = tensor.tensor.to_kind(E::KIND);
|
||||
let tensor = tensor.tensor.to_kind(tch::Kind::Float);
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue