refactor: device functions (#157)

This commit is contained in:
Nathaniel Simard 2023-01-27 18:37:21 -05:00 committed by GitHub
parent 2d4e514b41
commit c7963d8485
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 120 additions and 115 deletions

View File

@ -19,14 +19,14 @@ impl<B: Backend, const D: usize> std::ops::Add<ADTensor<D, B>> for ADTensor<D, B
}
impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
fn from_data<const D: usize>(data: Data<B::Elem, D>, device: B::Device) -> ADTensor<D, B> {
fn from_data<const D: usize>(data: Data<B::Elem, D>, device: &B::Device) -> ADTensor<D, B> {
let tensor = B::from_data(data, device);
ADTensor::from_tensor(tensor)
}
fn from_data_bool<const D: usize>(
data: Data<bool, D>,
device: B::Device,
device: &B::Device,
) -> B::BoolTensorPrimitive<D> {
B::from_data_bool(data, device)
}
@ -34,16 +34,16 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
fn random<const D: usize>(
shape: Shape<D>,
distribution: Distribution<B::Elem>,
device: B::Device,
device: &B::Device,
) -> ADTensor<D, B> {
ADTensor::from_tensor(B::random(shape, distribution, device))
}
fn zeros<const D: usize>(shape: Shape<D>, device: B::Device) -> ADTensor<D, B> {
fn zeros<const D: usize>(shape: Shape<D>, device: &B::Device) -> ADTensor<D, B> {
ADTensor::from_tensor(B::zeros(shape, device))
}
fn ones<const D: usize>(shape: Shape<D>, device: B::Device) -> ADTensor<D, B> {
fn ones<const D: usize>(shape: Shape<D>, device: &B::Device) -> ADTensor<D, B> {
ADTensor::from_tensor(B::ones(shape, device))
}
@ -99,7 +99,7 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
fn bool_to_device<const D: usize>(
tensor: &<ADBackendDecorator<B> as Backend>::BoolTensorPrimitive<D>,
device: <ADBackendDecorator<B> as Backend>::Device,
device: &<ADBackendDecorator<B> as Backend>::Device,
) -> <ADBackendDecorator<B> as Backend>::BoolTensorPrimitive<D> {
B::bool_to_device(tensor, device)
}
@ -112,7 +112,7 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
fn to_device<const D: usize>(
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
device: <ADBackendDecorator<B> as Backend>::Device,
device: &<ADBackendDecorator<B> as Backend>::Device,
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
#[derive(new, Debug)]
struct ToDeviceBackward<B: Backend, const D: usize> {
@ -126,7 +126,7 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
&self,
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>,
) -> B::TensorPrimitive<D> {
B::to_device(&state.output.grad(), self.device)
B::to_device(&state.output.grad(), &self.device)
}
}
@ -140,7 +140,7 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
fn empty<const D: usize>(
shape: Shape<D>,
device: <ADBackendDecorator<B> as Backend>::Device,
device: &<ADBackendDecorator<B> as Backend>::Device,
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
ADTensor::from_tensor(B::empty(shape, device))
}
@ -775,7 +775,7 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<1>>,
) -> B::TensorPrimitive<D> {
let grad = state.output.grad();
let ones = B::ones(self.shape.clone(), B::device(&grad));
let ones = B::ones(self.shape.clone(), &B::device(&grad));
let grad: Tensor<B, 1> = Tensor::from_primitive(grad);
let val = 1_f64 / self.shape.num_elements() as f64;
@ -809,7 +809,7 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<1>>,
) -> B::TensorPrimitive<D> {
let grad = state.output.grad();
let ones = B::ones(self.shape.clone(), B::device(&grad));
let ones = B::ones(self.shape.clone(), &B::device(&grad));
let grad: Tensor<B, 1> = Tensor::from_primitive(grad);
let ones: Tensor<B, D> = Tensor::from_primitive(ones);
@ -844,7 +844,7 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>,
) -> B::TensorPrimitive<D> {
let grad = B::sum_dim(&state.output.grad(), self.dim);
let ones = B::ones(self.shape.clone(), B::device(&grad));
let ones = B::ones(self.shape.clone(), &B::device(&grad));
let val = 1_f64 / self.shape.dims[self.dim] as f64;
let ones = B::mul_scalar(&ones, &B::Elem::from_elem(val));
@ -879,7 +879,7 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>,
) -> B::TensorPrimitive<D> {
let grad = B::sum_dim(&state.output.grad(), self.dim);
let ones = B::ones(self.shape.clone(), B::device(&grad));
let ones = B::ones(self.shape.clone(), &B::device(&grad));
B::mul(&ones, &grad)
}

View File

@ -41,7 +41,7 @@ pub trait Module: Send + Sync + std::fmt::Debug + std::fmt::Display {
/// Get the device list of the module and all of its sub-modules.
fn devices(&self) -> Vec<<Self::Backend as Backend>::Device>;
/// Move the module and all of its sub-modules to the given device.
fn to_device(&mut self, device: <Self::Backend as Backend>::Device);
fn to_device(&mut self, device: &<Self::Backend as Backend>::Device);
/// Load the module state.
fn load(&mut self, state: &State<<Self::Backend as Backend>::Elem>)
-> Result<(), LoadingError>;

View File

@ -15,7 +15,7 @@ impl<M: Module> Module for Param<M> {
self.value.devices()
}
fn to_device(&mut self, device: <Self::Backend as Backend>::Device) {
fn to_device(&mut self, device: &<Self::Backend as Backend>::Device) {
self.value.to_device(device)
}
@ -65,7 +65,7 @@ impl<M: Module> Module for Param<Vec<M>> {
devices
}
fn to_device(&mut self, device: <M::Backend as Backend>::Device) {
fn to_device(&mut self, device: &<M::Backend as Backend>::Device) {
for module in self.value.iter_mut() {
module.to_device(device);
}

View File

@ -16,7 +16,7 @@ impl<const D: usize, B: Backend> Module for Param<Tensor<B, D>> {
vec![self.value.device()]
}
fn to_device(&mut self, device: B::Device) {
fn to_device(&mut self, device: &B::Device) {
self.value = self.value.to_device(device);
}
@ -32,7 +32,7 @@ impl<const D: usize, B: Backend> Module for Param<Tensor<B, D>> {
match state {
State::Data(data) => {
self.value = Tensor::from_data_device(Data::from(data), self.value.device());
self.value = Tensor::from_data_device(Data::from(data), &self.value.device());
}
_ => return Err(LoadingError::new("Can't load tensor".to_string())),
};
@ -72,7 +72,7 @@ impl<const D: usize, B: Backend> Module for Param<Option<Tensor<B, D>>> {
vec![]
}
fn to_device(&mut self, device: B::Device) {
fn to_device(&mut self, device: &B::Device) {
if let Some(value) = &self.value {
self.value = Some(value.to_device(device));
}
@ -101,7 +101,7 @@ impl<const D: usize, B: Backend> Module for Param<Option<Tensor<B, D>>> {
};
if let Some(value) = &self.value {
self.value = Some(Tensor::from_data_device(Data::from(data), value.device()));
self.value = Some(Tensor::from_data_device(Data::from(data), &value.device()));
}
Ok(())

View File

@ -6,7 +6,7 @@ use burn_tensor::{backend::Backend, BoolTensor, Data, ElementConversion, Shape,
pub fn generate_autoregressive_mask<B: Backend>(
batch_size: usize,
seq_length: usize,
device: B::Device,
device: &B::Device,
) -> BoolTensor<B, 3> {
let mut mask = Tensor::<B::IntegerBackend, 3>::zeros([1, seq_length, seq_length]);
@ -30,7 +30,7 @@ pub fn generate_padding_mask<B: Backend>(
pad_token: usize,
tokens_list: Vec<Vec<usize>>,
max_seq_lenght: Option<usize>,
device: B::Device,
device: &B::Device,
) -> GeneratePaddingMask<B> {
let mut max_size = 0;
let batch_size = tokens_list.len();
@ -88,7 +88,7 @@ mod tests {
fn test_generate_autoregressive_mask() {
let device = <TestBackend as Backend>::Device::default();
let mask = generate_autoregressive_mask::<TestBackend>(2, 3, device);
let mask = generate_autoregressive_mask::<TestBackend>(2, 3, &device);
assert_eq!(
mask.into_data(),
@ -117,7 +117,7 @@ mod tests {
vec![3, 3, 3, 4, 10, 15],
];
let mask = generate_padding_mask::<TestBackend>(0, tokens, None, device);
let mask = generate_padding_mask::<TestBackend>(0, tokens, None, &device);
assert_eq!(
mask.mask.into_data(),

View File

@ -365,7 +365,7 @@ mod tests {
[batch_size, seq_length, d_model],
Distribution::Standard,
);
let mask_attn = generate_autoregressive_mask(batch_size, seq_length, tensor.device());
let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &tensor.device());
let input = MhaInput::self_attn(tensor.clone()).mask_attn(mask_attn);
let output_1 = mha.forward(input);

View File

@ -38,7 +38,7 @@ impl<B: Backend> CrossEntropyLoss<B> {
let indexes = targets.to_data();
let mut targets_logits =
Tensor::<B, 2>::zeros_device([batch_size, self.num_targets], device);
Tensor::<B, 2>::zeros_device([batch_size, self.num_targets], &device);
for b in 0..batch_size {
let index = indexes.value[b] as usize;
@ -50,7 +50,7 @@ impl<B: Backend> CrossEntropyLoss<B> {
targets_logits = targets_logits.index_assign(
[b..b + 1, index..index + 1],
&Tensor::ones_device([1, 1], device),
&Tensor::ones_device([1, 1], &device),
);
}

View File

@ -276,7 +276,7 @@ mod tests {
[batch_size, seq_length, d_model],
Distribution::Standard,
);
let mask_attn = generate_autoregressive_mask(batch_size, seq_length, tensor.device());
let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &tensor.device());
let input = TransformerEncoderInput::new(tensor.clone()).mask_attn(mask_attn);
let output_1 = transformer.forward(input);

View File

@ -111,7 +111,7 @@ pub(super) fn load_state_gradients<const D: usize, B: ADBackend, F: Fn(&ParamId)
device: &B::Device,
) {
if let Some(State::Data(data)) = state.get(id_to_key(id).as_str()) {
let tensor = Tensor::<B::InnerBackend, D>::from_data_device(Data::from(data), *device);
let tensor = Tensor::<B::InnerBackend, D>::from_data_device(Data::from(data), device);
grads.register(id.clone(), tensor);
};
}

View File

@ -73,7 +73,8 @@ impl<'a, B: ADBackend> ModuleVisitor<B> for GradientsParamsConverter<'a, B> {
impl<'a, B: ADBackend> ModuleVisitor<B> for GradientsParamsChangeDevice<'a, B> {
fn visit<const D: usize>(&mut self, id: &ParamId, _tensor: &Tensor<B, D>) {
if let Some(grad) = self.grads.remove::<D>(id) {
self.grads.register(id.clone(), grad.to_device(self.device));
self.grads
.register(id.clone(), grad.to_device(&self.device));
}
}
}
@ -115,7 +116,7 @@ mod tests {
fn test_convert_grads() {
let layer_1 = layer();
let mut layer_2 = layer_1.clone();
layer_2.to_device(<TestADBackend as Backend>::Device::default());
layer_2.to_device(&<TestADBackend as Backend>::Device::default());
layer_2.detach();
let loss_1 = layer_1.forward(random_tensor());
let loss_2 = layer_2.forward(random_tensor());

View File

@ -111,7 +111,7 @@ impl Param {
}
quote! {
fn to_device(&mut self, device: B::Device) {
fn to_device(&mut self, device: &B::Device) {
#body
}
}

View File

@ -81,7 +81,8 @@ fn conv2d_with_kernel<E: NdArrayElement>(
let heigth_new = f32::ceil((heigth - k1 + 1) as f32 / stride[0] as f32) as usize;
let width_new = f32::ceil((width - k2 + 1) as f32 / stride[1] as f32) as usize;
let mut output = NdArrayBackend::empty(Shape::new([heigth_new, width_new]), NdArrayDevice::Cpu);
let mut output =
NdArrayBackend::empty(Shape::new([heigth_new, width_new]), &NdArrayDevice::Cpu);
for i in 0..heigth_new {
for j in 0..width_new {

View File

@ -65,7 +65,7 @@ pub(crate) fn max_pool2d_backward_naive<E: NdArrayElement>(
NdArrayBackend::reshape(indexes, Shape::new([batch_size, channels, heigth * width]));
let mut output_flatten = NdArrayBackend::zeros(
Shape::new([batch_size, channels, heigth_x * width_x]),
NdArrayDevice::Cpu,
&NdArrayDevice::Cpu,
);
for b in 0..batch_size {
@ -104,9 +104,10 @@ fn max_pool2d_with_kernel<E: NdArrayElement>(
let heigth_new = f32::ceil((heigth - k1 + 1) as f32 / stride[0] as f32) as usize;
let width_new = f32::ceil((width - k2 + 1) as f32 / stride[1] as f32) as usize;
let mut output = NdArrayBackend::empty(Shape::new([heigth_new, width_new]), NdArrayDevice::Cpu);
let mut output =
NdArrayBackend::empty(Shape::new([heigth_new, width_new]), &NdArrayDevice::Cpu);
let mut indexes =
NdArrayBackend::empty(Shape::new([heigth_new, width_new]), NdArrayDevice::Cpu);
NdArrayBackend::empty(Shape::new([heigth_new, width_new]), &NdArrayDevice::Cpu);
for i in 0..heigth_new {
for j in 0..width_new {
@ -120,7 +121,7 @@ fn max_pool2d_with_kernel<E: NdArrayElement>(
let value = NdArrayBackend::into_data(x_flatten).value[index as usize];
let value = NdArrayBackend::from_data(
Data::new(vec![value], Shape::new([1, 1])),
NdArrayDevice::Cpu,
&NdArrayDevice::Cpu,
);
let index_i = index / k2 as i64;
@ -132,7 +133,7 @@ fn max_pool2d_with_kernel<E: NdArrayElement>(
let index = NdArrayBackend::from_data(
Data::new(vec![index], Shape::new([1, 1])),
NdArrayDevice::Cpu,
&NdArrayDevice::Cpu,
);
indexes = NdArrayBackend::index_assign(&indexes, [i..i + 1, j..j + 1], &index);

View File

@ -10,7 +10,7 @@ pub(crate) fn apply_padding2d<E: NdArrayElement>(
let heigth_new = heigth + (2 * padding[0]);
let width_new = width + (2 * padding[1]);
let mut x_new = NdArrayBackend::zeros(Shape::new([heigth_new, width_new]), NdArrayDevice::Cpu);
let mut x_new = NdArrayBackend::zeros(Shape::new([heigth_new, width_new]), &NdArrayDevice::Cpu);
x_new = NdArrayBackend::index_assign(
&x_new,
[

View File

@ -35,13 +35,13 @@ macro_rules! keepdim {
}
impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
fn from_data<const D: usize>(data: Data<E, D>, _device: NdArrayDevice) -> NdArrayTensor<E, D> {
fn from_data<const D: usize>(data: Data<E, D>, _device: &NdArrayDevice) -> NdArrayTensor<E, D> {
NdArrayTensor::from_data(data)
}
fn from_data_bool<const D: usize>(
data: Data<bool, D>,
_device: NdArrayDevice,
_device: &NdArrayDevice,
) -> NdArrayTensor<bool, D> {
NdArrayTensor::from_data(data)
}
@ -49,7 +49,7 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
fn random<const D: usize>(
shape: Shape<D>,
distribution: Distribution<E>,
device: NdArrayDevice,
device: &NdArrayDevice,
) -> NdArrayTensor<E, D> {
let mut seed = SEED.lock().unwrap();
let mut rng: StdRng = match seed.as_ref() {
@ -105,7 +105,7 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
fn bool_to_device<const D: usize>(
tensor: &NdArrayTensor<bool, D>,
_device: NdArrayDevice,
_device: &NdArrayDevice,
) -> NdArrayTensor<bool, D> {
tensor.clone()
}
@ -131,14 +131,14 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
fn to_device<const D: usize>(
tensor: &NdArrayTensor<E, D>,
_device: NdArrayDevice,
_device: &NdArrayDevice,
) -> NdArrayTensor<E, D> {
tensor.clone()
}
fn empty<const D: usize>(
shape: Shape<D>,
device: <NdArrayBackend<E> as Backend>::Device,
device: &<NdArrayBackend<E> as Backend>::Device,
) -> <NdArrayBackend<E> as Backend>::TensorPrimitive<D> {
NdArrayBackend::<E>::zeros(shape, device)
}

View File

@ -3,56 +3,56 @@ use burn_tensor::{backend::Backend, ops::TensorOps, Data, Distribution, ElementC
use std::ops::{Add, Div, Mul, Range, Sub};
impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
fn from_data<const D: usize>(data: Data<E, D>, device: TchDevice) -> TchTensor<E, D> {
TchTensor::from_data(data, device.into())
fn from_data<const D: usize>(data: Data<E, D>, device: &TchDevice) -> TchTensor<E, D> {
TchTensor::from_data(data, (*device).into())
}
fn from_data_bool<const D: usize>(
data: Data<bool, D>,
device: TchDevice,
device: &TchDevice,
) -> TchTensor<bool, D> {
TchTensor::from_data(data, device.into())
TchTensor::from_data(data, (*device).into())
}
fn random<const D: usize>(
shape: Shape<D>,
distribution: Distribution<E>,
device: TchDevice,
device: &TchDevice,
) -> TchTensor<E, D> {
match distribution {
Distribution::Standard => {
let mut tensor = TchTensor::<E, D>::empty(shape, device);
let mut tensor = TchTensor::<E, D>::empty(shape, *device);
tensor.tensor = tensor.tensor.normal_(0.0, 1.0);
tensor
}
Distribution::Bernoulli(prob) => {
let mut tensor = TchTensor::<E, D>::empty(shape, device);
let mut tensor = TchTensor::<E, D>::empty(shape, *device);
tensor.tensor = tensor.tensor.f_bernoulli_float_(prob).unwrap();
tensor
}
Distribution::Uniform(from, to) => {
let mut tensor = TchTensor::<E, D>::empty(shape, device);
let mut tensor = TchTensor::<E, D>::empty(shape, *device);
tensor.tensor = tensor
.tensor
.uniform_(from.to_f64().unwrap(), to.to_f64().unwrap());
tensor
}
Distribution::Normal(mean, std) => {
let mut tensor = TchTensor::<E, D>::empty(shape, device);
let mut tensor = TchTensor::<E, D>::empty(shape, *device);
tensor.tensor = tensor.tensor.normal_(mean, std);
tensor
}
}
}
fn zeros<const D: usize>(shape: Shape<D>, device: TchDevice) -> TchTensor<E, D> {
let mut tensor = TchTensor::<E, D>::empty(shape, device);
fn zeros<const D: usize>(shape: Shape<D>, device: &TchDevice) -> TchTensor<E, D> {
let mut tensor = TchTensor::<E, D>::empty(shape, *device);
tensor.tensor = tensor.tensor.zero_();
tensor
}
fn ones<const D: usize>(shape: Shape<D>, device: TchDevice) -> TchTensor<E, D> {
let mut tensor = TchTensor::<E, D>::empty(shape, device);
fn ones<const D: usize>(shape: Shape<D>, device: &TchDevice) -> TchTensor<E, D> {
let mut tensor = TchTensor::<E, D>::empty(shape, *device);
tensor.tensor = tensor.tensor.ones_like();
tensor
}
@ -99,11 +99,11 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
fn bool_to_device<const D: usize>(
tensor: &TchTensor<bool, D>,
device: TchDevice,
device: &TchDevice,
) -> TchTensor<bool, D> {
TchTensor {
kind: tensor.kind,
tensor: tensor.tensor.to(device.into()),
tensor: tensor.tensor.to((*device).into()),
}
}
@ -124,20 +124,22 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
tensor.tensor.device().into()
}
fn to_device<const D: usize>(tensor: &TchTensor<E, D>, device: TchDevice) -> TchTensor<E, D> {
fn to_device<const D: usize>(tensor: &TchTensor<E, D>, device: &TchDevice) -> TchTensor<E, D> {
TchTensor {
kind: tensor.kind,
tensor: tensor.tensor.to(device.into()),
tensor: tensor.tensor.to((*device).into()),
}
}
fn empty<const D: usize>(
shape: Shape<D>,
device: <TchBackend<E> as Backend>::Device,
device: &<TchBackend<E> as Backend>::Device,
) -> <TchBackend<E> as Backend>::TensorPrimitive<D> {
let kind = TchKind::<E>::new();
let tensor =
tch::Tensor::empty(&shape.dims.map(|a| a as i64), (kind.kind(), device.into()));
let tensor = tch::Tensor::empty(
&shape.dims.map(|a| a as i64),
(kind.kind(), (*device).into()),
);
to_tensor(tensor)
}

View File

@ -12,7 +12,7 @@ pub trait Backend:
+ std::fmt::Debug
+ 'static
{
type Device: Copy + Clone + Default + std::fmt::Debug + Send + Sync;
type Device: Clone + Default + std::fmt::Debug + Send + Sync;
type Elem: Element;
type FullPrecisionElem: Element;
type FullPrecisionBackend: Backend<Elem = Self::FullPrecisionElem, Device = Self::Device>;

View File

@ -19,10 +19,10 @@ where
{
/// Returns a new integer tensor on the default device which values are generated from the given range.
pub fn arange(range: Range<usize>) -> Tensor<B::IntegerBackend, 1> {
Tensor::new(B::arange(range, B::Device::default()))
Tensor::new(B::arange(range, &B::Device::default()))
}
/// Returns a new integer tensor on the specified device which values are generated from the given range.
pub fn arange_device(range: Range<usize>, device: B::Device) -> Tensor<B::IntegerBackend, 1> {
pub fn arange_device(range: Range<usize>, device: &B::Device) -> Tensor<B::IntegerBackend, 1> {
Tensor::new(B::arange(range, device))
}
}
@ -63,7 +63,7 @@ where
}
/// Returns a new tensor on the given device.
pub fn to_device(&self, device: B::Device) -> Self {
pub fn to_device(&self, device: &B::Device) -> Self {
Self::new(B::to_device(&self.value, device))
}
@ -144,12 +144,12 @@ where
/// Create a tensor from the given data.
pub fn from_data(data: Data<B::Elem, D>) -> Self {
let tensor = B::from_data(data, B::Device::default());
let tensor = B::from_data(data, &B::Device::default());
Tensor::new(tensor)
}
/// Create a tensor from the given data on the given device.
pub fn from_data_device(data: Data<B::Elem, D>, device: B::Device) -> Self {
pub fn from_data_device(data: Data<B::Elem, D>, device: &B::Device) -> Self {
let tensor = B::from_data(data, device);
Tensor::new(tensor)
}
@ -173,18 +173,18 @@ where
/// Returns a new tensor with the same shape and device as the current tensor filled with zeros.
pub fn zeros_like(&self) -> Self {
Tensor::new(B::zeros(self.shape(), self.device()))
Tensor::new(B::zeros(self.shape(), &self.device()))
}
/// Returns a new tensor with the same shape and device as the current tensor filled with ones.
pub fn ones_like(&self) -> Self {
Tensor::new(B::ones(self.shape(), self.device()))
Tensor::new(B::ones(self.shape(), &self.device()))
}
/// Returns a new tensor with the same shape and device as the current tensor filled random
/// values sampled from the given distribution.
pub fn random_like(&self, distribution: Distribution<B::Elem>) -> Self {
Tensor::new(B::random(self.shape(), distribution, self.device()))
Tensor::new(B::random(self.shape(), distribution, &self.device()))
}
/// Create a one hot tensor.
@ -424,30 +424,30 @@ where
/// Create a random tensor of the given shape where each element is sampled from the given
/// distribution.
pub fn random<S: Into<Shape<D>>>(shape: S, distribution: Distribution<B::Elem>) -> Self {
let tensor = B::random(shape.into(), distribution, B::Device::default());
let tensor = B::random(shape.into(), distribution, &B::Device::default());
Self::new(tensor)
}
/// Create a tensor of the given shape where each element is zero.
pub fn zeros<S: Into<Shape<D>>>(shape: S) -> Self {
let tensor = B::zeros(shape.into(), B::Device::default());
let tensor = B::zeros(shape.into(), &B::Device::default());
Self::new(tensor)
}
/// Create a tensor of the given shape where each element is zero.
pub fn zeros_device<S: Into<Shape<D>>>(shape: S, device: B::Device) -> Self {
pub fn zeros_device<S: Into<Shape<D>>>(shape: S, device: &B::Device) -> Self {
let tensor = B::zeros(shape.into(), device);
Self::new(tensor)
}
/// Create a tensor of the given shape where each element is one.
pub fn ones<S: Into<Shape<D>>>(shape: S) -> Self {
let tensor = B::ones(shape.into(), B::Device::default());
let tensor = B::ones(shape.into(), &B::Device::default());
Self::new(tensor)
}
/// Create a tensor of the given shape where each element is one.
pub fn ones_device<S: Into<Shape<D>>>(shape: S, device: B::Device) -> Self {
pub fn ones_device<S: Into<Shape<D>>>(shape: S, device: &B::Device) -> Self {
let tensor = B::ones(shape.into(), device);
Self::new(tensor)
}

View File

@ -19,7 +19,7 @@ where
B::bool_shape(&self.value)
}
pub fn to_device(&self, device: B::Device) -> Self {
pub fn to_device(&self, device: &B::Device) -> Self {
Self::new(B::bool_to_device(&self.value, device))
}
@ -39,7 +39,7 @@ where
}
pub fn from_data(data: Data<bool, D>) -> Self {
let value = B::from_data_bool(data, B::Device::default());
let value = B::from_data_bool(data, &B::Device::default());
Self::new(value)
}

View File

@ -73,7 +73,7 @@ pub(crate) fn conv1d_backward<B: Backend>(
let elem = batch_size * length_out;
let elem = (elem as i32).to_elem();
let b = B::zeros(B::shape(b), B::device(b));
let b = B::zeros(B::shape(b), &B::device(b));
B::add_scalar(&b, &elem)
}),
@ -128,7 +128,7 @@ pub(crate) fn conv2d_backward<B: Backend>(
let elem = batch_size * width_out * height_out;
let elem = (elem as i32).to_elem();
let b = B::zeros(B::shape(b), B::device(b));
let b = B::zeros(B::shape(b), &B::device(b));
B::add_scalar(&b, &elem)
}),

View File

@ -4,21 +4,21 @@ use std::ops::Range;
pub trait TensorOps<B: Backend> {
fn from_data<const D: usize>(
data: Data<B::Elem, D>,
device: B::Device,
device: &B::Device,
) -> B::TensorPrimitive<D>;
fn from_data_bool<const D: usize>(
data: Data<bool, D>,
device: B::Device,
device: &B::Device,
) -> B::BoolTensorPrimitive<D>;
fn random<const D: usize>(
shape: Shape<D>,
distribution: Distribution<B::Elem>,
device: B::Device,
device: &B::Device,
) -> B::TensorPrimitive<D>;
fn zeros<const D: usize>(shape: Shape<D>, device: B::Device) -> B::TensorPrimitive<D> {
fn zeros<const D: usize>(shape: Shape<D>, device: &B::Device) -> B::TensorPrimitive<D> {
Self::from_data(Data::zeros(shape), device)
}
fn ones<const D: usize>(shape: Shape<D>, device: B::Device) -> B::TensorPrimitive<D> {
fn ones<const D: usize>(shape: Shape<D>, device: &B::Device) -> B::TensorPrimitive<D> {
Self::from_data(Data::ones(shape), device)
}
fn shape<const D: usize>(tensor: &B::TensorPrimitive<D>) -> Shape<D>;
@ -29,7 +29,7 @@ pub trait TensorOps<B: Backend> {
fn bool_into_data<const D: usize>(tensor: B::BoolTensorPrimitive<D>) -> Data<bool, D>;
fn bool_to_device<const D: usize>(
tensor: &B::BoolTensorPrimitive<D>,
device: B::Device,
device: &B::Device,
) -> B::BoolTensorPrimitive<D>;
fn bool_reshape<const D1: usize, const D2: usize>(
tensor: &B::BoolTensorPrimitive<D1>,
@ -42,11 +42,11 @@ pub trait TensorOps<B: Backend> {
fn device<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::Device;
fn to_device<const D: usize>(
tensor: &B::TensorPrimitive<D>,
device: B::Device,
device: &B::Device,
) -> B::TensorPrimitive<D>;
fn arange(
range: Range<usize>,
device: B::Device,
device: &B::Device,
) -> <B::IntegerBackend as Backend>::TensorPrimitive<1> {
let shape = Shape::new([range.end - range.start]);
let value = range
@ -56,7 +56,7 @@ pub trait TensorOps<B: Backend> {
let data = Data::new(value, shape);
<B::IntegerBackend as TensorOps<B::IntegerBackend>>::from_data(data, device)
}
fn empty<const D: usize>(shape: Shape<D>, device: B::Device) -> B::TensorPrimitive<D>;
fn empty<const D: usize>(shape: Shape<D>, device: &B::Device) -> B::TensorPrimitive<D>;
fn repeat<const D: usize>(
tensor: &B::TensorPrimitive<D>,
dim: usize,
@ -76,7 +76,7 @@ pub trait TensorOps<B: Backend> {
start..end
});
let mut tensor_output = B::empty(shape, B::device(tensor));
let mut tensor_output = B::empty(shape, &B::device(tensor));
for i in 0..times {
let mut indexes = indexes_select_all.clone();
indexes[dim] = i..i + 1;

View File

@ -41,13 +41,13 @@ where
TO: Send + 'static,
M: TrainStep<B, TI, TO> + Send + 'static,
{
let device = self.device;
let device = self.device.clone();
spawn(move || loop {
match receiver_input.recv() {
Ok(item) => {
let mut step = item.model;
step.to_device(device);
step.to_device(&device);
step.detach();
let output = step.step(item.item);
@ -80,7 +80,7 @@ where
let (sender_input, receiver_input) = std::sync::mpsc::channel();
let worker = Worker {
sender_input,
device: *device,
device: device.clone(),
};
worker.start(sender_output.clone(), receiver_input);

View File

@ -60,7 +60,7 @@ where
// The reference model is always on the first device provided.
if let Some(device) = self.devices.get(0) {
self.model.to_device(*device);
self.model.to_device(device);
self.model.detach();
}
@ -102,7 +102,7 @@ where
let step = MultiDevicesTrainStep::new(&self.devices);
// The main device is always the first in the list.
let device_main = *self.devices.get(0).unwrap();
let device_main = self.devices.get(0).unwrap().clone();
loop {
let items = step.step(&mut iterator, &self.model);
@ -114,7 +114,7 @@ where
iteration += 1;
let progress = iterator.progress();
to_device_grads(&mut item.grads, device_main, &self.model);
to_device_grads(&mut item.grads, device_main.clone(), &self.model);
log::info!("Updated device");
accumulator.accumulate(&self.model, item.grads);
accumulation_current += 1;

View File

@ -31,11 +31,11 @@ impl<B: Backend> Metric for AccuracyMetric<B> {
fn update(&mut self, input: &AccuracyInput<B>) -> MetricEntry {
let [batch_size, _n_classes] = input.outputs.dims();
let targets = input.targets.to_device(B::Device::default());
let targets = input.targets.to_device(&B::Device::default());
let outputs = input
.outputs
.argmax(1)
.to_device(B::Device::default())
.to_device(&B::Device::default())
.reshape([batch_size]);
let total_current = outputs.equal(&targets).to_int().sum().to_data().value[0] as usize;

View File

@ -34,8 +34,8 @@ impl<B: Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
.map(|item| Tensor::<B::IntegerBackend, 1>::from_data(Data::from([item.label as i64])))
.collect();
let images = Tensor::cat(images, 0).to_device(self.device).detach();
let targets = Tensor::cat(targets, 0).to_device(self.device).detach();
let images = Tensor::cat(images, 0).to_device(&self.device).detach();
let targets = Tensor::cat(targets, 0).to_device(&self.device).detach();
MNISTBatch { images, targets }
}

View File

@ -25,8 +25,8 @@ pub fn run<B: ADBackend>(device: B::Device) {
B::seed(config.seed);
// Data
let batcher_train = Arc::new(MNISTBatcher::<B>::new(device));
let batcher_valid = Arc::new(MNISTBatcher::<B::InnerBackend>::new(device));
let batcher_train = Arc::new(MNISTBatcher::<B>::new(device.clone()));
let batcher_valid = Arc::new(MNISTBatcher::<B::InnerBackend>::new(device.clone()));
let dataloader_train = DataLoaderBuilder::new(batcher_train)
.batch_size(config.batch_size)
.shuffle(config.seed)

View File

@ -36,13 +36,13 @@ impl<B: Backend> Batcher<TextClassificationItem, TextClassificationBatch<B>>
self.tokenizer.pad_token(),
tokens_list,
Some(self.max_seq_lenght),
B::Device::default(),
&B::Device::default(),
);
TextClassificationBatch {
tokens: mask.tensor.to_device(self.device).detach(),
labels: Tensor::cat(labels_list, 0).to_device(self.device).detach(),
mask_pad: mask.mask.to_device(self.device),
tokens: mask.tensor.to_device(&self.device).detach(),
labels: Tensor::cat(labels_list, 0).to_device(&self.device).detach(),
mask_pad: mask.mask.to_device(&self.device),
}
}
}

View File

@ -55,7 +55,7 @@ impl<B: Backend> TextClassificationModel<B> {
pub fn forward(&self, item: TextClassificationBatch<B>) -> ClassificationOutput<B> {
let [batch_size, seq_length] = item.tokens.dims();
let device = self.embedding_token.devices()[0];
let device = &self.embedding_token.devices()[0];
let tokens = item.tokens.to_device(device).detach();
let labels = item.labels.to_device(device).detach();

View File

@ -42,12 +42,12 @@ pub fn train<B: ADBackend, D: TextClassificationDataset + 'static>(
let tokenizer = Arc::new(BertCasedTokenizer::default());
let batcher_train = Arc::new(TextClassificationBatcher::<B>::new(
tokenizer.clone(),
device,
device.clone(),
config.max_seq_length,
));
let batcher_test = Arc::new(TextClassificationBatcher::<B::InnerBackend>::new(
tokenizer.clone(),
device,
device.clone(),
config.max_seq_length,
));

View File

@ -37,7 +37,7 @@ impl<B: Backend> Batcher<TextGenerationItem, TextGenerationBatch<B>> for TextGen
self.tokenizer.pad_token(),
tokens_list,
Some(self.max_seq_lenght),
B::Device::default(),
&B::Device::default(),
);
TextGenerationBatch {

View File

@ -61,7 +61,7 @@ impl<B: Backend> TextClassificationModel<B> {
item: TrainingTextGenerationBatch<B>,
) -> ClassificationOutput<B> {
let [batch_size, seq_length] = item.tokens_inputs.dims();
let device = self.embedding_token.devices()[0];
let device = &self.embedding_token.devices()[0];
let inputs = item.tokens_inputs.to_device(device).detach();
let mask_pad = item.mask_pad.to_device(device);
@ -75,7 +75,7 @@ impl<B: Backend> TextClassificationModel<B> {
let embedding = (embedding_positions + embedding_tokens) / 2;
let mask_attn =
generate_autoregressive_mask::<B>(batch_size, seq_length, embedding.device());
generate_autoregressive_mask::<B>(batch_size, seq_length, &embedding.device());
let encoded = self.transformer.forward(
TransformerEncoderInput::new(embedding)