mirror of https://github.com/tracel-ai/burn.git
refactor: device functions (#157)
This commit is contained in:
parent
2d4e514b41
commit
c7963d8485
|
@ -19,14 +19,14 @@ impl<B: Backend, const D: usize> std::ops::Add<ADTensor<D, B>> for ADTensor<D, B
|
|||
}
|
||||
|
||||
impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
||||
fn from_data<const D: usize>(data: Data<B::Elem, D>, device: B::Device) -> ADTensor<D, B> {
|
||||
fn from_data<const D: usize>(data: Data<B::Elem, D>, device: &B::Device) -> ADTensor<D, B> {
|
||||
let tensor = B::from_data(data, device);
|
||||
ADTensor::from_tensor(tensor)
|
||||
}
|
||||
|
||||
fn from_data_bool<const D: usize>(
|
||||
data: Data<bool, D>,
|
||||
device: B::Device,
|
||||
device: &B::Device,
|
||||
) -> B::BoolTensorPrimitive<D> {
|
||||
B::from_data_bool(data, device)
|
||||
}
|
||||
|
@ -34,16 +34,16 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
fn random<const D: usize>(
|
||||
shape: Shape<D>,
|
||||
distribution: Distribution<B::Elem>,
|
||||
device: B::Device,
|
||||
device: &B::Device,
|
||||
) -> ADTensor<D, B> {
|
||||
ADTensor::from_tensor(B::random(shape, distribution, device))
|
||||
}
|
||||
|
||||
fn zeros<const D: usize>(shape: Shape<D>, device: B::Device) -> ADTensor<D, B> {
|
||||
fn zeros<const D: usize>(shape: Shape<D>, device: &B::Device) -> ADTensor<D, B> {
|
||||
ADTensor::from_tensor(B::zeros(shape, device))
|
||||
}
|
||||
|
||||
fn ones<const D: usize>(shape: Shape<D>, device: B::Device) -> ADTensor<D, B> {
|
||||
fn ones<const D: usize>(shape: Shape<D>, device: &B::Device) -> ADTensor<D, B> {
|
||||
ADTensor::from_tensor(B::ones(shape, device))
|
||||
}
|
||||
|
||||
|
@ -99,7 +99,7 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
|
||||
fn bool_to_device<const D: usize>(
|
||||
tensor: &<ADBackendDecorator<B> as Backend>::BoolTensorPrimitive<D>,
|
||||
device: <ADBackendDecorator<B> as Backend>::Device,
|
||||
device: &<ADBackendDecorator<B> as Backend>::Device,
|
||||
) -> <ADBackendDecorator<B> as Backend>::BoolTensorPrimitive<D> {
|
||||
B::bool_to_device(tensor, device)
|
||||
}
|
||||
|
@ -112,7 +112,7 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
|
||||
fn to_device<const D: usize>(
|
||||
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
|
||||
device: <ADBackendDecorator<B> as Backend>::Device,
|
||||
device: &<ADBackendDecorator<B> as Backend>::Device,
|
||||
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
|
||||
#[derive(new, Debug)]
|
||||
struct ToDeviceBackward<B: Backend, const D: usize> {
|
||||
|
@ -126,7 +126,7 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
&self,
|
||||
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>,
|
||||
) -> B::TensorPrimitive<D> {
|
||||
B::to_device(&state.output.grad(), self.device)
|
||||
B::to_device(&state.output.grad(), &self.device)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -140,7 +140,7 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
|
||||
fn empty<const D: usize>(
|
||||
shape: Shape<D>,
|
||||
device: <ADBackendDecorator<B> as Backend>::Device,
|
||||
device: &<ADBackendDecorator<B> as Backend>::Device,
|
||||
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
|
||||
ADTensor::from_tensor(B::empty(shape, device))
|
||||
}
|
||||
|
@ -775,7 +775,7 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<1>>,
|
||||
) -> B::TensorPrimitive<D> {
|
||||
let grad = state.output.grad();
|
||||
let ones = B::ones(self.shape.clone(), B::device(&grad));
|
||||
let ones = B::ones(self.shape.clone(), &B::device(&grad));
|
||||
|
||||
let grad: Tensor<B, 1> = Tensor::from_primitive(grad);
|
||||
let val = 1_f64 / self.shape.num_elements() as f64;
|
||||
|
@ -809,7 +809,7 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<1>>,
|
||||
) -> B::TensorPrimitive<D> {
|
||||
let grad = state.output.grad();
|
||||
let ones = B::ones(self.shape.clone(), B::device(&grad));
|
||||
let ones = B::ones(self.shape.clone(), &B::device(&grad));
|
||||
|
||||
let grad: Tensor<B, 1> = Tensor::from_primitive(grad);
|
||||
let ones: Tensor<B, D> = Tensor::from_primitive(ones);
|
||||
|
@ -844,7 +844,7 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>,
|
||||
) -> B::TensorPrimitive<D> {
|
||||
let grad = B::sum_dim(&state.output.grad(), self.dim);
|
||||
let ones = B::ones(self.shape.clone(), B::device(&grad));
|
||||
let ones = B::ones(self.shape.clone(), &B::device(&grad));
|
||||
|
||||
let val = 1_f64 / self.shape.dims[self.dim] as f64;
|
||||
let ones = B::mul_scalar(&ones, &B::Elem::from_elem(val));
|
||||
|
@ -879,7 +879,7 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>,
|
||||
) -> B::TensorPrimitive<D> {
|
||||
let grad = B::sum_dim(&state.output.grad(), self.dim);
|
||||
let ones = B::ones(self.shape.clone(), B::device(&grad));
|
||||
let ones = B::ones(self.shape.clone(), &B::device(&grad));
|
||||
|
||||
B::mul(&ones, &grad)
|
||||
}
|
||||
|
|
|
@ -41,7 +41,7 @@ pub trait Module: Send + Sync + std::fmt::Debug + std::fmt::Display {
|
|||
/// Get the device list of the module and all of its sub-modules.
|
||||
fn devices(&self) -> Vec<<Self::Backend as Backend>::Device>;
|
||||
/// Move the module and all of its sub-modules to the given device.
|
||||
fn to_device(&mut self, device: <Self::Backend as Backend>::Device);
|
||||
fn to_device(&mut self, device: &<Self::Backend as Backend>::Device);
|
||||
/// Load the module state.
|
||||
fn load(&mut self, state: &State<<Self::Backend as Backend>::Elem>)
|
||||
-> Result<(), LoadingError>;
|
||||
|
|
|
@ -15,7 +15,7 @@ impl<M: Module> Module for Param<M> {
|
|||
self.value.devices()
|
||||
}
|
||||
|
||||
fn to_device(&mut self, device: <Self::Backend as Backend>::Device) {
|
||||
fn to_device(&mut self, device: &<Self::Backend as Backend>::Device) {
|
||||
self.value.to_device(device)
|
||||
}
|
||||
|
||||
|
@ -65,7 +65,7 @@ impl<M: Module> Module for Param<Vec<M>> {
|
|||
devices
|
||||
}
|
||||
|
||||
fn to_device(&mut self, device: <M::Backend as Backend>::Device) {
|
||||
fn to_device(&mut self, device: &<M::Backend as Backend>::Device) {
|
||||
for module in self.value.iter_mut() {
|
||||
module.to_device(device);
|
||||
}
|
||||
|
|
|
@ -16,7 +16,7 @@ impl<const D: usize, B: Backend> Module for Param<Tensor<B, D>> {
|
|||
vec![self.value.device()]
|
||||
}
|
||||
|
||||
fn to_device(&mut self, device: B::Device) {
|
||||
fn to_device(&mut self, device: &B::Device) {
|
||||
self.value = self.value.to_device(device);
|
||||
}
|
||||
|
||||
|
@ -32,7 +32,7 @@ impl<const D: usize, B: Backend> Module for Param<Tensor<B, D>> {
|
|||
|
||||
match state {
|
||||
State::Data(data) => {
|
||||
self.value = Tensor::from_data_device(Data::from(data), self.value.device());
|
||||
self.value = Tensor::from_data_device(Data::from(data), &self.value.device());
|
||||
}
|
||||
_ => return Err(LoadingError::new("Can't load tensor".to_string())),
|
||||
};
|
||||
|
@ -72,7 +72,7 @@ impl<const D: usize, B: Backend> Module for Param<Option<Tensor<B, D>>> {
|
|||
vec![]
|
||||
}
|
||||
|
||||
fn to_device(&mut self, device: B::Device) {
|
||||
fn to_device(&mut self, device: &B::Device) {
|
||||
if let Some(value) = &self.value {
|
||||
self.value = Some(value.to_device(device));
|
||||
}
|
||||
|
@ -101,7 +101,7 @@ impl<const D: usize, B: Backend> Module for Param<Option<Tensor<B, D>>> {
|
|||
};
|
||||
|
||||
if let Some(value) = &self.value {
|
||||
self.value = Some(Tensor::from_data_device(Data::from(data), value.device()));
|
||||
self.value = Some(Tensor::from_data_device(Data::from(data), &value.device()));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
|
|
@ -6,7 +6,7 @@ use burn_tensor::{backend::Backend, BoolTensor, Data, ElementConversion, Shape,
|
|||
pub fn generate_autoregressive_mask<B: Backend>(
|
||||
batch_size: usize,
|
||||
seq_length: usize,
|
||||
device: B::Device,
|
||||
device: &B::Device,
|
||||
) -> BoolTensor<B, 3> {
|
||||
let mut mask = Tensor::<B::IntegerBackend, 3>::zeros([1, seq_length, seq_length]);
|
||||
|
||||
|
@ -30,7 +30,7 @@ pub fn generate_padding_mask<B: Backend>(
|
|||
pad_token: usize,
|
||||
tokens_list: Vec<Vec<usize>>,
|
||||
max_seq_lenght: Option<usize>,
|
||||
device: B::Device,
|
||||
device: &B::Device,
|
||||
) -> GeneratePaddingMask<B> {
|
||||
let mut max_size = 0;
|
||||
let batch_size = tokens_list.len();
|
||||
|
@ -88,7 +88,7 @@ mod tests {
|
|||
fn test_generate_autoregressive_mask() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
|
||||
let mask = generate_autoregressive_mask::<TestBackend>(2, 3, device);
|
||||
let mask = generate_autoregressive_mask::<TestBackend>(2, 3, &device);
|
||||
|
||||
assert_eq!(
|
||||
mask.into_data(),
|
||||
|
@ -117,7 +117,7 @@ mod tests {
|
|||
vec![3, 3, 3, 4, 10, 15],
|
||||
];
|
||||
|
||||
let mask = generate_padding_mask::<TestBackend>(0, tokens, None, device);
|
||||
let mask = generate_padding_mask::<TestBackend>(0, tokens, None, &device);
|
||||
|
||||
assert_eq!(
|
||||
mask.mask.into_data(),
|
||||
|
|
|
@ -365,7 +365,7 @@ mod tests {
|
|||
[batch_size, seq_length, d_model],
|
||||
Distribution::Standard,
|
||||
);
|
||||
let mask_attn = generate_autoregressive_mask(batch_size, seq_length, tensor.device());
|
||||
let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &tensor.device());
|
||||
let input = MhaInput::self_attn(tensor.clone()).mask_attn(mask_attn);
|
||||
|
||||
let output_1 = mha.forward(input);
|
||||
|
|
|
@ -38,7 +38,7 @@ impl<B: Backend> CrossEntropyLoss<B> {
|
|||
let indexes = targets.to_data();
|
||||
|
||||
let mut targets_logits =
|
||||
Tensor::<B, 2>::zeros_device([batch_size, self.num_targets], device);
|
||||
Tensor::<B, 2>::zeros_device([batch_size, self.num_targets], &device);
|
||||
|
||||
for b in 0..batch_size {
|
||||
let index = indexes.value[b] as usize;
|
||||
|
@ -50,7 +50,7 @@ impl<B: Backend> CrossEntropyLoss<B> {
|
|||
|
||||
targets_logits = targets_logits.index_assign(
|
||||
[b..b + 1, index..index + 1],
|
||||
&Tensor::ones_device([1, 1], device),
|
||||
&Tensor::ones_device([1, 1], &device),
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
@ -276,7 +276,7 @@ mod tests {
|
|||
[batch_size, seq_length, d_model],
|
||||
Distribution::Standard,
|
||||
);
|
||||
let mask_attn = generate_autoregressive_mask(batch_size, seq_length, tensor.device());
|
||||
let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &tensor.device());
|
||||
let input = TransformerEncoderInput::new(tensor.clone()).mask_attn(mask_attn);
|
||||
|
||||
let output_1 = transformer.forward(input);
|
||||
|
|
|
@ -111,7 +111,7 @@ pub(super) fn load_state_gradients<const D: usize, B: ADBackend, F: Fn(&ParamId)
|
|||
device: &B::Device,
|
||||
) {
|
||||
if let Some(State::Data(data)) = state.get(id_to_key(id).as_str()) {
|
||||
let tensor = Tensor::<B::InnerBackend, D>::from_data_device(Data::from(data), *device);
|
||||
let tensor = Tensor::<B::InnerBackend, D>::from_data_device(Data::from(data), device);
|
||||
grads.register(id.clone(), tensor);
|
||||
};
|
||||
}
|
||||
|
|
|
@ -73,7 +73,8 @@ impl<'a, B: ADBackend> ModuleVisitor<B> for GradientsParamsConverter<'a, B> {
|
|||
impl<'a, B: ADBackend> ModuleVisitor<B> for GradientsParamsChangeDevice<'a, B> {
|
||||
fn visit<const D: usize>(&mut self, id: &ParamId, _tensor: &Tensor<B, D>) {
|
||||
if let Some(grad) = self.grads.remove::<D>(id) {
|
||||
self.grads.register(id.clone(), grad.to_device(self.device));
|
||||
self.grads
|
||||
.register(id.clone(), grad.to_device(&self.device));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -115,7 +116,7 @@ mod tests {
|
|||
fn test_convert_grads() {
|
||||
let layer_1 = layer();
|
||||
let mut layer_2 = layer_1.clone();
|
||||
layer_2.to_device(<TestADBackend as Backend>::Device::default());
|
||||
layer_2.to_device(&<TestADBackend as Backend>::Device::default());
|
||||
layer_2.detach();
|
||||
let loss_1 = layer_1.forward(random_tensor());
|
||||
let loss_2 = layer_2.forward(random_tensor());
|
||||
|
|
|
@ -111,7 +111,7 @@ impl Param {
|
|||
}
|
||||
|
||||
quote! {
|
||||
fn to_device(&mut self, device: B::Device) {
|
||||
fn to_device(&mut self, device: &B::Device) {
|
||||
#body
|
||||
}
|
||||
}
|
||||
|
|
|
@ -81,7 +81,8 @@ fn conv2d_with_kernel<E: NdArrayElement>(
|
|||
|
||||
let heigth_new = f32::ceil((heigth - k1 + 1) as f32 / stride[0] as f32) as usize;
|
||||
let width_new = f32::ceil((width - k2 + 1) as f32 / stride[1] as f32) as usize;
|
||||
let mut output = NdArrayBackend::empty(Shape::new([heigth_new, width_new]), NdArrayDevice::Cpu);
|
||||
let mut output =
|
||||
NdArrayBackend::empty(Shape::new([heigth_new, width_new]), &NdArrayDevice::Cpu);
|
||||
|
||||
for i in 0..heigth_new {
|
||||
for j in 0..width_new {
|
||||
|
|
|
@ -65,7 +65,7 @@ pub(crate) fn max_pool2d_backward_naive<E: NdArrayElement>(
|
|||
NdArrayBackend::reshape(indexes, Shape::new([batch_size, channels, heigth * width]));
|
||||
let mut output_flatten = NdArrayBackend::zeros(
|
||||
Shape::new([batch_size, channels, heigth_x * width_x]),
|
||||
NdArrayDevice::Cpu,
|
||||
&NdArrayDevice::Cpu,
|
||||
);
|
||||
|
||||
for b in 0..batch_size {
|
||||
|
@ -104,9 +104,10 @@ fn max_pool2d_with_kernel<E: NdArrayElement>(
|
|||
|
||||
let heigth_new = f32::ceil((heigth - k1 + 1) as f32 / stride[0] as f32) as usize;
|
||||
let width_new = f32::ceil((width - k2 + 1) as f32 / stride[1] as f32) as usize;
|
||||
let mut output = NdArrayBackend::empty(Shape::new([heigth_new, width_new]), NdArrayDevice::Cpu);
|
||||
let mut output =
|
||||
NdArrayBackend::empty(Shape::new([heigth_new, width_new]), &NdArrayDevice::Cpu);
|
||||
let mut indexes =
|
||||
NdArrayBackend::empty(Shape::new([heigth_new, width_new]), NdArrayDevice::Cpu);
|
||||
NdArrayBackend::empty(Shape::new([heigth_new, width_new]), &NdArrayDevice::Cpu);
|
||||
|
||||
for i in 0..heigth_new {
|
||||
for j in 0..width_new {
|
||||
|
@ -120,7 +121,7 @@ fn max_pool2d_with_kernel<E: NdArrayElement>(
|
|||
let value = NdArrayBackend::into_data(x_flatten).value[index as usize];
|
||||
let value = NdArrayBackend::from_data(
|
||||
Data::new(vec![value], Shape::new([1, 1])),
|
||||
NdArrayDevice::Cpu,
|
||||
&NdArrayDevice::Cpu,
|
||||
);
|
||||
|
||||
let index_i = index / k2 as i64;
|
||||
|
@ -132,7 +133,7 @@ fn max_pool2d_with_kernel<E: NdArrayElement>(
|
|||
|
||||
let index = NdArrayBackend::from_data(
|
||||
Data::new(vec![index], Shape::new([1, 1])),
|
||||
NdArrayDevice::Cpu,
|
||||
&NdArrayDevice::Cpu,
|
||||
);
|
||||
|
||||
indexes = NdArrayBackend::index_assign(&indexes, [i..i + 1, j..j + 1], &index);
|
||||
|
|
|
@ -10,7 +10,7 @@ pub(crate) fn apply_padding2d<E: NdArrayElement>(
|
|||
let heigth_new = heigth + (2 * padding[0]);
|
||||
let width_new = width + (2 * padding[1]);
|
||||
|
||||
let mut x_new = NdArrayBackend::zeros(Shape::new([heigth_new, width_new]), NdArrayDevice::Cpu);
|
||||
let mut x_new = NdArrayBackend::zeros(Shape::new([heigth_new, width_new]), &NdArrayDevice::Cpu);
|
||||
x_new = NdArrayBackend::index_assign(
|
||||
&x_new,
|
||||
[
|
||||
|
|
|
@ -35,13 +35,13 @@ macro_rules! keepdim {
|
|||
}
|
||||
|
||||
impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
||||
fn from_data<const D: usize>(data: Data<E, D>, _device: NdArrayDevice) -> NdArrayTensor<E, D> {
|
||||
fn from_data<const D: usize>(data: Data<E, D>, _device: &NdArrayDevice) -> NdArrayTensor<E, D> {
|
||||
NdArrayTensor::from_data(data)
|
||||
}
|
||||
|
||||
fn from_data_bool<const D: usize>(
|
||||
data: Data<bool, D>,
|
||||
_device: NdArrayDevice,
|
||||
_device: &NdArrayDevice,
|
||||
) -> NdArrayTensor<bool, D> {
|
||||
NdArrayTensor::from_data(data)
|
||||
}
|
||||
|
@ -49,7 +49,7 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
|||
fn random<const D: usize>(
|
||||
shape: Shape<D>,
|
||||
distribution: Distribution<E>,
|
||||
device: NdArrayDevice,
|
||||
device: &NdArrayDevice,
|
||||
) -> NdArrayTensor<E, D> {
|
||||
let mut seed = SEED.lock().unwrap();
|
||||
let mut rng: StdRng = match seed.as_ref() {
|
||||
|
@ -105,7 +105,7 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
|||
|
||||
fn bool_to_device<const D: usize>(
|
||||
tensor: &NdArrayTensor<bool, D>,
|
||||
_device: NdArrayDevice,
|
||||
_device: &NdArrayDevice,
|
||||
) -> NdArrayTensor<bool, D> {
|
||||
tensor.clone()
|
||||
}
|
||||
|
@ -131,14 +131,14 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
|||
|
||||
fn to_device<const D: usize>(
|
||||
tensor: &NdArrayTensor<E, D>,
|
||||
_device: NdArrayDevice,
|
||||
_device: &NdArrayDevice,
|
||||
) -> NdArrayTensor<E, D> {
|
||||
tensor.clone()
|
||||
}
|
||||
|
||||
fn empty<const D: usize>(
|
||||
shape: Shape<D>,
|
||||
device: <NdArrayBackend<E> as Backend>::Device,
|
||||
device: &<NdArrayBackend<E> as Backend>::Device,
|
||||
) -> <NdArrayBackend<E> as Backend>::TensorPrimitive<D> {
|
||||
NdArrayBackend::<E>::zeros(shape, device)
|
||||
}
|
||||
|
|
|
@ -3,56 +3,56 @@ use burn_tensor::{backend::Backend, ops::TensorOps, Data, Distribution, ElementC
|
|||
use std::ops::{Add, Div, Mul, Range, Sub};
|
||||
|
||||
impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
||||
fn from_data<const D: usize>(data: Data<E, D>, device: TchDevice) -> TchTensor<E, D> {
|
||||
TchTensor::from_data(data, device.into())
|
||||
fn from_data<const D: usize>(data: Data<E, D>, device: &TchDevice) -> TchTensor<E, D> {
|
||||
TchTensor::from_data(data, (*device).into())
|
||||
}
|
||||
|
||||
fn from_data_bool<const D: usize>(
|
||||
data: Data<bool, D>,
|
||||
device: TchDevice,
|
||||
device: &TchDevice,
|
||||
) -> TchTensor<bool, D> {
|
||||
TchTensor::from_data(data, device.into())
|
||||
TchTensor::from_data(data, (*device).into())
|
||||
}
|
||||
|
||||
fn random<const D: usize>(
|
||||
shape: Shape<D>,
|
||||
distribution: Distribution<E>,
|
||||
device: TchDevice,
|
||||
device: &TchDevice,
|
||||
) -> TchTensor<E, D> {
|
||||
match distribution {
|
||||
Distribution::Standard => {
|
||||
let mut tensor = TchTensor::<E, D>::empty(shape, device);
|
||||
let mut tensor = TchTensor::<E, D>::empty(shape, *device);
|
||||
tensor.tensor = tensor.tensor.normal_(0.0, 1.0);
|
||||
tensor
|
||||
}
|
||||
Distribution::Bernoulli(prob) => {
|
||||
let mut tensor = TchTensor::<E, D>::empty(shape, device);
|
||||
let mut tensor = TchTensor::<E, D>::empty(shape, *device);
|
||||
tensor.tensor = tensor.tensor.f_bernoulli_float_(prob).unwrap();
|
||||
tensor
|
||||
}
|
||||
Distribution::Uniform(from, to) => {
|
||||
let mut tensor = TchTensor::<E, D>::empty(shape, device);
|
||||
let mut tensor = TchTensor::<E, D>::empty(shape, *device);
|
||||
tensor.tensor = tensor
|
||||
.tensor
|
||||
.uniform_(from.to_f64().unwrap(), to.to_f64().unwrap());
|
||||
tensor
|
||||
}
|
||||
Distribution::Normal(mean, std) => {
|
||||
let mut tensor = TchTensor::<E, D>::empty(shape, device);
|
||||
let mut tensor = TchTensor::<E, D>::empty(shape, *device);
|
||||
tensor.tensor = tensor.tensor.normal_(mean, std);
|
||||
tensor
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn zeros<const D: usize>(shape: Shape<D>, device: TchDevice) -> TchTensor<E, D> {
|
||||
let mut tensor = TchTensor::<E, D>::empty(shape, device);
|
||||
fn zeros<const D: usize>(shape: Shape<D>, device: &TchDevice) -> TchTensor<E, D> {
|
||||
let mut tensor = TchTensor::<E, D>::empty(shape, *device);
|
||||
tensor.tensor = tensor.tensor.zero_();
|
||||
tensor
|
||||
}
|
||||
|
||||
fn ones<const D: usize>(shape: Shape<D>, device: TchDevice) -> TchTensor<E, D> {
|
||||
let mut tensor = TchTensor::<E, D>::empty(shape, device);
|
||||
fn ones<const D: usize>(shape: Shape<D>, device: &TchDevice) -> TchTensor<E, D> {
|
||||
let mut tensor = TchTensor::<E, D>::empty(shape, *device);
|
||||
tensor.tensor = tensor.tensor.ones_like();
|
||||
tensor
|
||||
}
|
||||
|
@ -99,11 +99,11 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
|
||||
fn bool_to_device<const D: usize>(
|
||||
tensor: &TchTensor<bool, D>,
|
||||
device: TchDevice,
|
||||
device: &TchDevice,
|
||||
) -> TchTensor<bool, D> {
|
||||
TchTensor {
|
||||
kind: tensor.kind,
|
||||
tensor: tensor.tensor.to(device.into()),
|
||||
tensor: tensor.tensor.to((*device).into()),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -124,20 +124,22 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
tensor.tensor.device().into()
|
||||
}
|
||||
|
||||
fn to_device<const D: usize>(tensor: &TchTensor<E, D>, device: TchDevice) -> TchTensor<E, D> {
|
||||
fn to_device<const D: usize>(tensor: &TchTensor<E, D>, device: &TchDevice) -> TchTensor<E, D> {
|
||||
TchTensor {
|
||||
kind: tensor.kind,
|
||||
tensor: tensor.tensor.to(device.into()),
|
||||
tensor: tensor.tensor.to((*device).into()),
|
||||
}
|
||||
}
|
||||
|
||||
fn empty<const D: usize>(
|
||||
shape: Shape<D>,
|
||||
device: <TchBackend<E> as Backend>::Device,
|
||||
device: &<TchBackend<E> as Backend>::Device,
|
||||
) -> <TchBackend<E> as Backend>::TensorPrimitive<D> {
|
||||
let kind = TchKind::<E>::new();
|
||||
let tensor =
|
||||
tch::Tensor::empty(&shape.dims.map(|a| a as i64), (kind.kind(), device.into()));
|
||||
let tensor = tch::Tensor::empty(
|
||||
&shape.dims.map(|a| a as i64),
|
||||
(kind.kind(), (*device).into()),
|
||||
);
|
||||
|
||||
to_tensor(tensor)
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@ pub trait Backend:
|
|||
+ std::fmt::Debug
|
||||
+ 'static
|
||||
{
|
||||
type Device: Copy + Clone + Default + std::fmt::Debug + Send + Sync;
|
||||
type Device: Clone + Default + std::fmt::Debug + Send + Sync;
|
||||
type Elem: Element;
|
||||
type FullPrecisionElem: Element;
|
||||
type FullPrecisionBackend: Backend<Elem = Self::FullPrecisionElem, Device = Self::Device>;
|
||||
|
|
|
@ -19,10 +19,10 @@ where
|
|||
{
|
||||
/// Returns a new integer tensor on the default device which values are generated from the given range.
|
||||
pub fn arange(range: Range<usize>) -> Tensor<B::IntegerBackend, 1> {
|
||||
Tensor::new(B::arange(range, B::Device::default()))
|
||||
Tensor::new(B::arange(range, &B::Device::default()))
|
||||
}
|
||||
/// Returns a new integer tensor on the specified device which values are generated from the given range.
|
||||
pub fn arange_device(range: Range<usize>, device: B::Device) -> Tensor<B::IntegerBackend, 1> {
|
||||
pub fn arange_device(range: Range<usize>, device: &B::Device) -> Tensor<B::IntegerBackend, 1> {
|
||||
Tensor::new(B::arange(range, device))
|
||||
}
|
||||
}
|
||||
|
@ -63,7 +63,7 @@ where
|
|||
}
|
||||
|
||||
/// Returns a new tensor on the given device.
|
||||
pub fn to_device(&self, device: B::Device) -> Self {
|
||||
pub fn to_device(&self, device: &B::Device) -> Self {
|
||||
Self::new(B::to_device(&self.value, device))
|
||||
}
|
||||
|
||||
|
@ -144,12 +144,12 @@ where
|
|||
|
||||
/// Create a tensor from the given data.
|
||||
pub fn from_data(data: Data<B::Elem, D>) -> Self {
|
||||
let tensor = B::from_data(data, B::Device::default());
|
||||
let tensor = B::from_data(data, &B::Device::default());
|
||||
Tensor::new(tensor)
|
||||
}
|
||||
|
||||
/// Create a tensor from the given data on the given device.
|
||||
pub fn from_data_device(data: Data<B::Elem, D>, device: B::Device) -> Self {
|
||||
pub fn from_data_device(data: Data<B::Elem, D>, device: &B::Device) -> Self {
|
||||
let tensor = B::from_data(data, device);
|
||||
Tensor::new(tensor)
|
||||
}
|
||||
|
@ -173,18 +173,18 @@ where
|
|||
|
||||
/// Returns a new tensor with the same shape and device as the current tensor filled with zeros.
|
||||
pub fn zeros_like(&self) -> Self {
|
||||
Tensor::new(B::zeros(self.shape(), self.device()))
|
||||
Tensor::new(B::zeros(self.shape(), &self.device()))
|
||||
}
|
||||
|
||||
/// Returns a new tensor with the same shape and device as the current tensor filled with ones.
|
||||
pub fn ones_like(&self) -> Self {
|
||||
Tensor::new(B::ones(self.shape(), self.device()))
|
||||
Tensor::new(B::ones(self.shape(), &self.device()))
|
||||
}
|
||||
|
||||
/// Returns a new tensor with the same shape and device as the current tensor filled random
|
||||
/// values sampled from the given distribution.
|
||||
pub fn random_like(&self, distribution: Distribution<B::Elem>) -> Self {
|
||||
Tensor::new(B::random(self.shape(), distribution, self.device()))
|
||||
Tensor::new(B::random(self.shape(), distribution, &self.device()))
|
||||
}
|
||||
|
||||
/// Create a one hot tensor.
|
||||
|
@ -424,30 +424,30 @@ where
|
|||
/// Create a random tensor of the given shape where each element is sampled from the given
|
||||
/// distribution.
|
||||
pub fn random<S: Into<Shape<D>>>(shape: S, distribution: Distribution<B::Elem>) -> Self {
|
||||
let tensor = B::random(shape.into(), distribution, B::Device::default());
|
||||
let tensor = B::random(shape.into(), distribution, &B::Device::default());
|
||||
Self::new(tensor)
|
||||
}
|
||||
|
||||
/// Create a tensor of the given shape where each element is zero.
|
||||
pub fn zeros<S: Into<Shape<D>>>(shape: S) -> Self {
|
||||
let tensor = B::zeros(shape.into(), B::Device::default());
|
||||
let tensor = B::zeros(shape.into(), &B::Device::default());
|
||||
Self::new(tensor)
|
||||
}
|
||||
|
||||
/// Create a tensor of the given shape where each element is zero.
|
||||
pub fn zeros_device<S: Into<Shape<D>>>(shape: S, device: B::Device) -> Self {
|
||||
pub fn zeros_device<S: Into<Shape<D>>>(shape: S, device: &B::Device) -> Self {
|
||||
let tensor = B::zeros(shape.into(), device);
|
||||
Self::new(tensor)
|
||||
}
|
||||
|
||||
/// Create a tensor of the given shape where each element is one.
|
||||
pub fn ones<S: Into<Shape<D>>>(shape: S) -> Self {
|
||||
let tensor = B::ones(shape.into(), B::Device::default());
|
||||
let tensor = B::ones(shape.into(), &B::Device::default());
|
||||
Self::new(tensor)
|
||||
}
|
||||
|
||||
/// Create a tensor of the given shape where each element is one.
|
||||
pub fn ones_device<S: Into<Shape<D>>>(shape: S, device: B::Device) -> Self {
|
||||
pub fn ones_device<S: Into<Shape<D>>>(shape: S, device: &B::Device) -> Self {
|
||||
let tensor = B::ones(shape.into(), device);
|
||||
Self::new(tensor)
|
||||
}
|
||||
|
|
|
@ -19,7 +19,7 @@ where
|
|||
B::bool_shape(&self.value)
|
||||
}
|
||||
|
||||
pub fn to_device(&self, device: B::Device) -> Self {
|
||||
pub fn to_device(&self, device: &B::Device) -> Self {
|
||||
Self::new(B::bool_to_device(&self.value, device))
|
||||
}
|
||||
|
||||
|
@ -39,7 +39,7 @@ where
|
|||
}
|
||||
|
||||
pub fn from_data(data: Data<bool, D>) -> Self {
|
||||
let value = B::from_data_bool(data, B::Device::default());
|
||||
let value = B::from_data_bool(data, &B::Device::default());
|
||||
Self::new(value)
|
||||
}
|
||||
|
||||
|
|
|
@ -73,7 +73,7 @@ pub(crate) fn conv1d_backward<B: Backend>(
|
|||
let elem = batch_size * length_out;
|
||||
let elem = (elem as i32).to_elem();
|
||||
|
||||
let b = B::zeros(B::shape(b), B::device(b));
|
||||
let b = B::zeros(B::shape(b), &B::device(b));
|
||||
|
||||
B::add_scalar(&b, &elem)
|
||||
}),
|
||||
|
@ -128,7 +128,7 @@ pub(crate) fn conv2d_backward<B: Backend>(
|
|||
let elem = batch_size * width_out * height_out;
|
||||
let elem = (elem as i32).to_elem();
|
||||
|
||||
let b = B::zeros(B::shape(b), B::device(b));
|
||||
let b = B::zeros(B::shape(b), &B::device(b));
|
||||
|
||||
B::add_scalar(&b, &elem)
|
||||
}),
|
||||
|
|
|
@ -4,21 +4,21 @@ use std::ops::Range;
|
|||
pub trait TensorOps<B: Backend> {
|
||||
fn from_data<const D: usize>(
|
||||
data: Data<B::Elem, D>,
|
||||
device: B::Device,
|
||||
device: &B::Device,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
fn from_data_bool<const D: usize>(
|
||||
data: Data<bool, D>,
|
||||
device: B::Device,
|
||||
device: &B::Device,
|
||||
) -> B::BoolTensorPrimitive<D>;
|
||||
fn random<const D: usize>(
|
||||
shape: Shape<D>,
|
||||
distribution: Distribution<B::Elem>,
|
||||
device: B::Device,
|
||||
device: &B::Device,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
fn zeros<const D: usize>(shape: Shape<D>, device: B::Device) -> B::TensorPrimitive<D> {
|
||||
fn zeros<const D: usize>(shape: Shape<D>, device: &B::Device) -> B::TensorPrimitive<D> {
|
||||
Self::from_data(Data::zeros(shape), device)
|
||||
}
|
||||
fn ones<const D: usize>(shape: Shape<D>, device: B::Device) -> B::TensorPrimitive<D> {
|
||||
fn ones<const D: usize>(shape: Shape<D>, device: &B::Device) -> B::TensorPrimitive<D> {
|
||||
Self::from_data(Data::ones(shape), device)
|
||||
}
|
||||
fn shape<const D: usize>(tensor: &B::TensorPrimitive<D>) -> Shape<D>;
|
||||
|
@ -29,7 +29,7 @@ pub trait TensorOps<B: Backend> {
|
|||
fn bool_into_data<const D: usize>(tensor: B::BoolTensorPrimitive<D>) -> Data<bool, D>;
|
||||
fn bool_to_device<const D: usize>(
|
||||
tensor: &B::BoolTensorPrimitive<D>,
|
||||
device: B::Device,
|
||||
device: &B::Device,
|
||||
) -> B::BoolTensorPrimitive<D>;
|
||||
fn bool_reshape<const D1: usize, const D2: usize>(
|
||||
tensor: &B::BoolTensorPrimitive<D1>,
|
||||
|
@ -42,11 +42,11 @@ pub trait TensorOps<B: Backend> {
|
|||
fn device<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::Device;
|
||||
fn to_device<const D: usize>(
|
||||
tensor: &B::TensorPrimitive<D>,
|
||||
device: B::Device,
|
||||
device: &B::Device,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
fn arange(
|
||||
range: Range<usize>,
|
||||
device: B::Device,
|
||||
device: &B::Device,
|
||||
) -> <B::IntegerBackend as Backend>::TensorPrimitive<1> {
|
||||
let shape = Shape::new([range.end - range.start]);
|
||||
let value = range
|
||||
|
@ -56,7 +56,7 @@ pub trait TensorOps<B: Backend> {
|
|||
let data = Data::new(value, shape);
|
||||
<B::IntegerBackend as TensorOps<B::IntegerBackend>>::from_data(data, device)
|
||||
}
|
||||
fn empty<const D: usize>(shape: Shape<D>, device: B::Device) -> B::TensorPrimitive<D>;
|
||||
fn empty<const D: usize>(shape: Shape<D>, device: &B::Device) -> B::TensorPrimitive<D>;
|
||||
fn repeat<const D: usize>(
|
||||
tensor: &B::TensorPrimitive<D>,
|
||||
dim: usize,
|
||||
|
@ -76,7 +76,7 @@ pub trait TensorOps<B: Backend> {
|
|||
start..end
|
||||
});
|
||||
|
||||
let mut tensor_output = B::empty(shape, B::device(tensor));
|
||||
let mut tensor_output = B::empty(shape, &B::device(tensor));
|
||||
for i in 0..times {
|
||||
let mut indexes = indexes_select_all.clone();
|
||||
indexes[dim] = i..i + 1;
|
||||
|
|
|
@ -41,13 +41,13 @@ where
|
|||
TO: Send + 'static,
|
||||
M: TrainStep<B, TI, TO> + Send + 'static,
|
||||
{
|
||||
let device = self.device;
|
||||
let device = self.device.clone();
|
||||
|
||||
spawn(move || loop {
|
||||
match receiver_input.recv() {
|
||||
Ok(item) => {
|
||||
let mut step = item.model;
|
||||
step.to_device(device);
|
||||
step.to_device(&device);
|
||||
step.detach();
|
||||
|
||||
let output = step.step(item.item);
|
||||
|
@ -80,7 +80,7 @@ where
|
|||
let (sender_input, receiver_input) = std::sync::mpsc::channel();
|
||||
let worker = Worker {
|
||||
sender_input,
|
||||
device: *device,
|
||||
device: device.clone(),
|
||||
};
|
||||
|
||||
worker.start(sender_output.clone(), receiver_input);
|
||||
|
|
|
@ -60,7 +60,7 @@ where
|
|||
|
||||
// The reference model is always on the first device provided.
|
||||
if let Some(device) = self.devices.get(0) {
|
||||
self.model.to_device(*device);
|
||||
self.model.to_device(device);
|
||||
self.model.detach();
|
||||
}
|
||||
|
||||
|
@ -102,7 +102,7 @@ where
|
|||
let step = MultiDevicesTrainStep::new(&self.devices);
|
||||
|
||||
// The main device is always the first in the list.
|
||||
let device_main = *self.devices.get(0).unwrap();
|
||||
let device_main = self.devices.get(0).unwrap().clone();
|
||||
|
||||
loop {
|
||||
let items = step.step(&mut iterator, &self.model);
|
||||
|
@ -114,7 +114,7 @@ where
|
|||
iteration += 1;
|
||||
let progress = iterator.progress();
|
||||
|
||||
to_device_grads(&mut item.grads, device_main, &self.model);
|
||||
to_device_grads(&mut item.grads, device_main.clone(), &self.model);
|
||||
log::info!("Updated device");
|
||||
accumulator.accumulate(&self.model, item.grads);
|
||||
accumulation_current += 1;
|
||||
|
|
|
@ -31,11 +31,11 @@ impl<B: Backend> Metric for AccuracyMetric<B> {
|
|||
fn update(&mut self, input: &AccuracyInput<B>) -> MetricEntry {
|
||||
let [batch_size, _n_classes] = input.outputs.dims();
|
||||
|
||||
let targets = input.targets.to_device(B::Device::default());
|
||||
let targets = input.targets.to_device(&B::Device::default());
|
||||
let outputs = input
|
||||
.outputs
|
||||
.argmax(1)
|
||||
.to_device(B::Device::default())
|
||||
.to_device(&B::Device::default())
|
||||
.reshape([batch_size]);
|
||||
|
||||
let total_current = outputs.equal(&targets).to_int().sum().to_data().value[0] as usize;
|
||||
|
|
|
@ -34,8 +34,8 @@ impl<B: Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
|
|||
.map(|item| Tensor::<B::IntegerBackend, 1>::from_data(Data::from([item.label as i64])))
|
||||
.collect();
|
||||
|
||||
let images = Tensor::cat(images, 0).to_device(self.device).detach();
|
||||
let targets = Tensor::cat(targets, 0).to_device(self.device).detach();
|
||||
let images = Tensor::cat(images, 0).to_device(&self.device).detach();
|
||||
let targets = Tensor::cat(targets, 0).to_device(&self.device).detach();
|
||||
|
||||
MNISTBatch { images, targets }
|
||||
}
|
||||
|
|
|
@ -25,8 +25,8 @@ pub fn run<B: ADBackend>(device: B::Device) {
|
|||
B::seed(config.seed);
|
||||
|
||||
// Data
|
||||
let batcher_train = Arc::new(MNISTBatcher::<B>::new(device));
|
||||
let batcher_valid = Arc::new(MNISTBatcher::<B::InnerBackend>::new(device));
|
||||
let batcher_train = Arc::new(MNISTBatcher::<B>::new(device.clone()));
|
||||
let batcher_valid = Arc::new(MNISTBatcher::<B::InnerBackend>::new(device.clone()));
|
||||
let dataloader_train = DataLoaderBuilder::new(batcher_train)
|
||||
.batch_size(config.batch_size)
|
||||
.shuffle(config.seed)
|
||||
|
|
|
@ -36,13 +36,13 @@ impl<B: Backend> Batcher<TextClassificationItem, TextClassificationBatch<B>>
|
|||
self.tokenizer.pad_token(),
|
||||
tokens_list,
|
||||
Some(self.max_seq_lenght),
|
||||
B::Device::default(),
|
||||
&B::Device::default(),
|
||||
);
|
||||
|
||||
TextClassificationBatch {
|
||||
tokens: mask.tensor.to_device(self.device).detach(),
|
||||
labels: Tensor::cat(labels_list, 0).to_device(self.device).detach(),
|
||||
mask_pad: mask.mask.to_device(self.device),
|
||||
tokens: mask.tensor.to_device(&self.device).detach(),
|
||||
labels: Tensor::cat(labels_list, 0).to_device(&self.device).detach(),
|
||||
mask_pad: mask.mask.to_device(&self.device),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -55,7 +55,7 @@ impl<B: Backend> TextClassificationModel<B> {
|
|||
|
||||
pub fn forward(&self, item: TextClassificationBatch<B>) -> ClassificationOutput<B> {
|
||||
let [batch_size, seq_length] = item.tokens.dims();
|
||||
let device = self.embedding_token.devices()[0];
|
||||
let device = &self.embedding_token.devices()[0];
|
||||
|
||||
let tokens = item.tokens.to_device(device).detach();
|
||||
let labels = item.labels.to_device(device).detach();
|
||||
|
|
|
@ -42,12 +42,12 @@ pub fn train<B: ADBackend, D: TextClassificationDataset + 'static>(
|
|||
let tokenizer = Arc::new(BertCasedTokenizer::default());
|
||||
let batcher_train = Arc::new(TextClassificationBatcher::<B>::new(
|
||||
tokenizer.clone(),
|
||||
device,
|
||||
device.clone(),
|
||||
config.max_seq_length,
|
||||
));
|
||||
let batcher_test = Arc::new(TextClassificationBatcher::<B::InnerBackend>::new(
|
||||
tokenizer.clone(),
|
||||
device,
|
||||
device.clone(),
|
||||
config.max_seq_length,
|
||||
));
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ impl<B: Backend> Batcher<TextGenerationItem, TextGenerationBatch<B>> for TextGen
|
|||
self.tokenizer.pad_token(),
|
||||
tokens_list,
|
||||
Some(self.max_seq_lenght),
|
||||
B::Device::default(),
|
||||
&B::Device::default(),
|
||||
);
|
||||
|
||||
TextGenerationBatch {
|
||||
|
|
|
@ -61,7 +61,7 @@ impl<B: Backend> TextClassificationModel<B> {
|
|||
item: TrainingTextGenerationBatch<B>,
|
||||
) -> ClassificationOutput<B> {
|
||||
let [batch_size, seq_length] = item.tokens_inputs.dims();
|
||||
let device = self.embedding_token.devices()[0];
|
||||
let device = &self.embedding_token.devices()[0];
|
||||
|
||||
let inputs = item.tokens_inputs.to_device(device).detach();
|
||||
let mask_pad = item.mask_pad.to_device(device);
|
||||
|
@ -75,7 +75,7 @@ impl<B: Backend> TextClassificationModel<B> {
|
|||
let embedding = (embedding_positions + embedding_tokens) / 2;
|
||||
|
||||
let mask_attn =
|
||||
generate_autoregressive_mask::<B>(batch_size, seq_length, embedding.device());
|
||||
generate_autoregressive_mask::<B>(batch_size, seq_length, &embedding.device());
|
||||
|
||||
let encoded = self.transformer.forward(
|
||||
TransformerEncoderInput::new(embedding)
|
||||
|
|
Loading…
Reference in New Issue