diff --git a/burn-autodiff/src/ops/tensor.rs b/burn-autodiff/src/ops/tensor.rs index 4d8e01477..24417bce1 100644 --- a/burn-autodiff/src/ops/tensor.rs +++ b/burn-autodiff/src/ops/tensor.rs @@ -512,7 +512,7 @@ impl TensorOps> for ADBackendDecorator { fn backward(self, ops: Ops, grads: &mut Gradients) { unary::(ops.parents, ops.node, grads, |grad| { - B::mask_fill(grad, ops.state, 0.to_elem()) + B::mask_fill(grad, ops.state, 0.elem()) }); } } @@ -529,8 +529,8 @@ impl TensorOps> for ADBackendDecorator { B::equal(lhs.primitive, rhs.primitive) } - fn equal_scalar(lhs: ADTensor, rhs: FloatElem) -> BoolTensor { - B::equal_scalar(lhs.primitive, rhs) + fn equal_elem(lhs: ADTensor, rhs: FloatElem) -> BoolTensor { + B::equal_elem(lhs.primitive, rhs) } fn greater(lhs: ADTensor, rhs: ADTensor) -> BoolTensor { @@ -596,7 +596,7 @@ impl TensorOps> for ADBackendDecorator { let shape = ops.state; let val = 1_f64 / shape.num_elements() as f64; let ones = B::ones(shape, &B::device(&grad)); - let val = B::mul_scalar(ones, val.to_elem()); + let val = B::mul_scalar(ones, val.elem()); let grad: Tensor = Tensor::from_primitive(grad); let val: Tensor = Tensor::from_primitive(val); @@ -821,7 +821,7 @@ impl TensorOps> for ADBackendDecorator { unary::(ops.parents, ops.node, grads, |grad| { let input = ops.state; let ones = B::ones(B::shape(&input), &B::device(&input)); - let value = B::div(ones, B::add_scalar(input, 1.to_elem())); + let value = B::div(ones, B::add_scalar(input, 1.elem())); B::mul(grad, value) }); @@ -848,7 +848,7 @@ impl TensorOps> for ADBackendDecorator { unary::(ops.parents, ops.node, grads, |grad| { let tmp = B::powf(tensor, value - 1.0); - let value = B::mul_scalar(tmp, value.to_elem()); + let value = B::mul_scalar(tmp, value.elem()); B::mul(grad, value) }); @@ -874,7 +874,7 @@ impl TensorOps> for ADBackendDecorator { fn backward(self, ops: Ops, grads: &mut Gradients) { unary::(ops.parents, ops.node, grads, |grad| { let input = ops.state; - let value = B::div_scalar(B::powf(input, -0.5), 2.to_elem()); + let value = B::div_scalar(B::powf(input, -0.5), 2.elem()); B::mul(grad, value) }); @@ -946,7 +946,7 @@ impl TensorOps> for ADBackendDecorator { fn backward(self, ops: Ops, grads: &mut Gradients) { unary::(ops.parents, ops.node, grads, |grad| { - let value = B::add_scalar(B::neg(B::powf(ops.state, 2.0)), 1.to_elem()); + let value = B::add_scalar(B::neg(B::powf(ops.state, 2.0)), 1.elem()); B::mul(grad, value) }); } @@ -971,8 +971,8 @@ impl TensorOps> for ADBackendDecorator { fn backward(self, ops: Ops, grads: &mut Gradients) { unary::(ops.parents, ops.node, grads, |grad| { let exponent = B::neg(B::powf(ops.state, 2.0)); - let numerator = B::mul_scalar(B::exp(exponent), 2.0.to_elem()); - let denominator = std::f64::consts::PI.sqrt().to_elem(); + let numerator = B::mul_scalar(B::exp(exponent), 2.0.elem()); + let denominator = std::f64::consts::PI.sqrt().elem(); let value = B::div_scalar(numerator, denominator); B::mul(grad, value) @@ -1055,7 +1055,7 @@ impl TensorOps> for ADBackendDecorator { fn backward(self, ops: Ops, grads: &mut Gradients) { unary::(ops.parents, ops.node, grads, |grad| { - let zero = 0.to_elem(); + let zero = 0.elem(); let mask = B::lower_equal_scalar(ops.state, zero); B::mask_fill(grad, mask, zero) }); diff --git a/burn-core/src/nn/attention/mask.rs b/burn-core/src/nn/attention/mask.rs index e3a61cf7d..6af73f54c 100644 --- a/burn-core/src/nn/attention/mask.rs +++ b/burn-core/src/nn/attention/mask.rs @@ -19,7 +19,7 @@ pub fn generate_autoregressive_mask( mask = mask.to_device(device).repeat(0, batch_size); - mask.equal_scalar(1_i64) + mask.equal_elem(1_i64) } pub struct GeneratePaddingMask { @@ -67,7 +67,7 @@ pub fn generate_padding_mask( tensor = tensor.index_assign( [index..index + 1, 0..tokens.len()], Tensor::from_data(Data::new( - tokens.into_iter().map(|e| (e as i64).to_elem()).collect(), + tokens.into_iter().map(|e| (e as i64).elem()).collect(), Shape::new([1, seq_length]), )), ); @@ -75,7 +75,7 @@ pub fn generate_padding_mask( let mask = tensor .clone() - .equal_scalar(pad_token as i64) + .equal_elem(pad_token as i64) .to_device(device); let tensor = tensor.to_device(device); diff --git a/burn-core/src/nn/attention/mha.rs b/burn-core/src/nn/attention/mha.rs index 3f2c5791e..0257fa75e 100644 --- a/burn-core/src/nn/attention/mha.rs +++ b/burn-core/src/nn/attention/mha.rs @@ -325,7 +325,7 @@ mod tests { [0..batch_size, seq_length - num_padded..seq_length], Tensor::ones([batch_size, num_padded]), ); - let mask_pad = mask_pad.equal_scalar(1); + let mask_pad = mask_pad.equal_elem(1); let tensor_1 = Tensor::::random( [batch_size, seq_length, d_model], diff --git a/burn-core/src/nn/conv/conv1d.rs b/burn-core/src/nn/conv/conv1d.rs index e5625125b..69ee7ac4d 100644 --- a/burn-core/src/nn/conv/conv1d.rs +++ b/burn-core/src/nn/conv/conv1d.rs @@ -64,8 +64,8 @@ impl Conv1d { let k = (config.channels_in * config.kernel_size) as f64; let k = sqrt(1.0 / k); - let k1: B::FloatElem = (-k).to_elem(); - let k2: B::FloatElem = k.to_elem(); + let k1: B::FloatElem = (-k).elem(); + let k2: B::FloatElem = k.elem(); let weight = Tensor::random( [config.channels_out, config.channels_in, config.kernel_size], diff --git a/burn-core/src/nn/conv/conv2d.rs b/burn-core/src/nn/conv/conv2d.rs index 456e792a8..c2544bcf0 100644 --- a/burn-core/src/nn/conv/conv2d.rs +++ b/burn-core/src/nn/conv/conv2d.rs @@ -64,8 +64,8 @@ impl Conv2d { let k = (config.channels[0] * config.kernel_size[0] * config.kernel_size[1]) as f64; let k = sqrt(1.0 / k); - let k1: B::FloatElem = (-k).to_elem(); - let k2: B::FloatElem = k.to_elem(); + let k1: B::FloatElem = (-k).elem(); + let k2: B::FloatElem = k.elem(); let weight = Tensor::random( [ diff --git a/burn-core/src/nn/linear.rs b/burn-core/src/nn/linear.rs index 41623a0f5..a611b21e0 100644 --- a/burn-core/src/nn/linear.rs +++ b/burn-core/src/nn/linear.rs @@ -43,7 +43,7 @@ impl Linear { /// Create the module from the given configuration. pub fn new(config: &LinearConfig) -> Self { let k = sqrt(1.0 / config.d_input as f64); - let distribution = Distribution::Uniform((-1.0 * k).to_elem(), k.to_elem()); + let distribution = Distribution::Uniform((-1.0 * k).elem(), k.elem()); let weight = Tensor::random([config.d_input, config.d_output], distribution); let bias = match config.bias { diff --git a/burn-core/src/optim/adam.rs b/burn-core/src/optim/adam.rs index 3c7a9d32d..87c277db6 100644 --- a/burn-core/src/optim/adam.rs +++ b/burn-core/src/optim/adam.rs @@ -37,7 +37,7 @@ pub struct Adam { impl Adam { pub fn new(config: &AdamConfig) -> Self { Self { - learning_rate: config.learning_rate.to_elem(), + learning_rate: config.learning_rate.elem(), momentum: AdaptiveMomentum { beta_1: config.beta_1, beta_2: config.beta_2, @@ -139,7 +139,7 @@ impl AdaptiveMomentum { self.moment_2.register(id.clone(), moment_2.clone()); self.time.register(id.clone(), time.clone()); - let time = time.single_value().to_elem(); + let time = time.single_value().elem(); let moment_1_corrected = moment_1.div_scalar(1f32 - self.beta_1.powf(time)); let moment_2_corrected = moment_2.div_scalar(1f32 - self.beta_2.powf(time)); diff --git a/burn-core/src/optim/decay.rs b/burn-core/src/optim/decay.rs index 112876dbe..08ad0adfd 100644 --- a/burn-core/src/optim/decay.rs +++ b/burn-core/src/optim/decay.rs @@ -23,7 +23,7 @@ pub struct WeightDecay { impl WeightDecay { pub fn new(config: &WeightDecayConfig) -> Self { Self { - penalty: config.penalty.to_elem(), + penalty: config.penalty.elem(), gradients: GradientsParams::new(), } } diff --git a/burn-core/src/optim/momentum.rs b/burn-core/src/optim/momentum.rs index ea583421c..8618c757c 100644 --- a/burn-core/src/optim/momentum.rs +++ b/burn-core/src/optim/momentum.rs @@ -31,7 +31,7 @@ pub struct Momentum { impl Momentum { pub fn new(config: &MomentumConfig) -> Self { Self { - momentum: config.momentum.to_elem(), + momentum: config.momentum.elem(), dampening: config.dampening, velocity: GradientsParams::new(), nesterov: config.nesterov, diff --git a/burn-core/src/optim/sgd.rs b/burn-core/src/optim/sgd.rs index a901292a2..7fb2fdb10 100644 --- a/burn-core/src/optim/sgd.rs +++ b/burn-core/src/optim/sgd.rs @@ -30,7 +30,7 @@ pub struct Sgd { impl Sgd { pub fn new(config: &SgdConfig) -> Self { - let learning_rate = config.learning_rate.to_elem(); + let learning_rate = config.learning_rate.elem(); let momentum = config.momentum.as_ref().map(|config| Momentum::new(config)); let weight_decay = config .weight_decay diff --git a/burn-ndarray/src/element.rs b/burn-ndarray/src/element.rs index 274c9816d..1df5ca30f 100644 --- a/burn-ndarray/src/element.rs +++ b/burn-ndarray/src/element.rs @@ -3,7 +3,13 @@ use libm::{exp, log, log1p, pow, sqrt}; use libm::{expf, log1pf, logf, powf, sqrtf}; pub(crate) trait NdArrayElement: - Element + ndarray::LinalgScalar + ndarray::ScalarOperand + ExpElement + num_traits::FromPrimitive + Element + + ndarray::LinalgScalar + + ndarray::ScalarOperand + + ExpElement + + num_traits::FromPrimitive + + core::cmp::PartialEq + + core::cmp::PartialOrd { } diff --git a/burn-ndarray/src/ops/tensor.rs b/burn-ndarray/src/ops/tensor.rs index 1c09c97cd..cf0520d1c 100644 --- a/burn-ndarray/src/ops/tensor.rs +++ b/burn-ndarray/src/ops/tensor.rs @@ -132,7 +132,7 @@ impl TensorOps> for NdArrayBackend { } fn neg(tensor: NdArrayTensor) -> NdArrayTensor { - Self::mul_scalar(tensor, (-1f32).to_elem::()) + Self::mul_scalar(tensor, (-1f32).elem::()) } fn swap_dims( @@ -174,12 +174,12 @@ impl TensorOps> for NdArrayBackend { value: E, ) -> NdArrayTensor { let mask_mul = mask.array.mapv(|x| match x { - true => 0.to_elem(), - false => 1.to_elem(), + true => 0.elem(), + false => 1.elem(), }); let mask_add = mask.array.mapv(|x| match x { true => value, - false => 0.to_elem(), + false => 0.elem(), }); let array = (tensor.array * mask_mul) + mask_add; @@ -191,12 +191,12 @@ impl TensorOps> for NdArrayBackend { rhs: NdArrayTensor, ) -> NdArrayTensor { let tensor = NdArrayBackend::::sub(lhs, rhs); - let zero = 0.to_elem(); + let zero = 0.elem(); - Self::equal_scalar(tensor, zero) + Self::equal_elem(tensor, zero) } - fn equal_scalar(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { + fn equal_elem(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { let array = lhs.array.mapv(|a| a == rhs).into_shared(); NdArrayTensor { array } @@ -207,7 +207,7 @@ impl TensorOps> for NdArrayBackend { rhs: NdArrayTensor, ) -> NdArrayTensor { let tensor = NdArrayBackend::::sub(lhs, rhs); - let zero = 0.to_elem(); + let zero = 0.elem(); Self::greater_scalar(tensor, zero) } @@ -222,7 +222,7 @@ impl TensorOps> for NdArrayBackend { rhs: NdArrayTensor, ) -> NdArrayTensor { let tensor = NdArrayBackend::::sub(lhs, rhs); - let zero = 0.to_elem(); + let zero = 0.elem(); Self::greater_equal_scalar(tensor, zero) } @@ -240,7 +240,7 @@ impl TensorOps> for NdArrayBackend { rhs: NdArrayTensor, ) -> NdArrayTensor { let tensor = NdArrayBackend::::sub(lhs, rhs); - let zero = 0.to_elem(); + let zero = 0.elem(); Self::lower_scalar(tensor, zero) } @@ -255,7 +255,7 @@ impl TensorOps> for NdArrayBackend { rhs: NdArrayTensor, ) -> NdArrayTensor { let tensor = NdArrayBackend::::sub(lhs, rhs); - let zero = 0.to_elem(); + let zero = 0.elem(); Self::lower_equal_scalar(tensor, zero) } @@ -289,13 +289,13 @@ impl TensorOps> for NdArrayBackend { } fn to_full_precision(tensor: &NdArrayTensor) -> NdArrayTensor { - let array = tensor.array.mapv(|a| a.to_elem()).into_shared(); + let array = tensor.array.mapv(|a| a.elem()).into_shared(); NdArrayTensor { array } } fn from_full_precision(tensor: NdArrayTensor) -> NdArrayTensor { - let array = tensor.array.mapv(|a| a.to_elem()).into_shared(); + let array = tensor.array.mapv(|a| a.elem()).into_shared(); NdArrayTensor { array } } @@ -341,7 +341,7 @@ impl TensorOps> for NdArrayBackend { fn cos(tensor: NdArrayTensor) -> NdArrayTensor { let array = tensor .array - .mapv_into(|a| cos(a.to_f64().unwrap()).to_elem()) + .mapv_into(|a| cos(a.to_f64().unwrap()).elem()) .into_shared(); NdArrayTensor { array } @@ -350,7 +350,7 @@ impl TensorOps> for NdArrayBackend { fn sin(tensor: NdArrayTensor) -> NdArrayTensor { let array = tensor .array - .mapv_into(|a| sin(a.to_f64().unwrap()).to_elem()) + .mapv_into(|a| sin(a.to_f64().unwrap()).elem()) .into_shared(); NdArrayTensor { array } @@ -359,7 +359,7 @@ impl TensorOps> for NdArrayBackend { fn tanh(tensor: NdArrayTensor) -> NdArrayTensor { let array = tensor .array - .mapv_into(|a| tanh(a.to_f64().unwrap()).to_elem()) + .mapv_into(|a| tanh(a.to_f64().unwrap()).elem()) .into_shared(); NdArrayTensor { array } @@ -368,7 +368,7 @@ impl TensorOps> for NdArrayBackend { fn erf(tensor: NdArrayTensor) -> NdArrayTensor { let array = tensor .array - .mapv_into(|a| erf(a.to_f64().unwrap()).to_elem()) + .mapv_into(|a| erf(a.to_f64().unwrap()).elem()) .into_shared(); NdArrayTensor { array } @@ -379,11 +379,11 @@ impl TensorOps> for NdArrayBackend { } fn relu(tensor: NdArrayTensor) -> NdArrayTensor { - let zero = 0.to_elem(); + let zero = 0.elem(); let array = tensor .array .mapv_into(|elem| match elem < zero { - true => 0.0.to_elem(), + true => 0.0.elem(), false => elem, }) .into_shared(); @@ -409,7 +409,7 @@ where while end <= data.value.len() { let data_dim = &mut data.value[start..end]; - let mut sorted: Vec = data_dim.iter().map(|a| a.to_elem()).collect(); + let mut sorted: Vec = data_dim.iter().map(|a| a.elem()).collect(); sorted.sort_by(&cmp); let max = sorted[0]; @@ -417,7 +417,7 @@ where let data_dim = &mut data.value[start..end]; let mut index: i64 = 0; for elem in data_dim { - let as_float: f64 = elem.to_elem(); + let as_float: f64 = elem.elem(); if as_float == max { break; } diff --git a/burn-tch/src/element.rs b/burn-tch/src/element.rs index 9d52d4401..df128c859 100644 --- a/burn-tch/src/element.rs +++ b/burn-tch/src/element.rs @@ -1,24 +1,7 @@ use burn_tensor::Element; use half::f16; -pub trait IsInt { - fn is_int(&self) -> bool; -} -pub trait TchElement: Element + tch::kind::Element + IsInt {} - -macro_rules! make_element { - ( - $ty:ident, - $bool:expr - - ) => { - impl IsInt for $ty { - fn is_int(&self) -> bool { - $bool - } - } - }; -} +pub trait TchElement: Element + tch::kind::Element {} impl TchElement for f64 {} impl TchElement for f32 {} @@ -29,11 +12,3 @@ impl TchElement for i32 {} impl TchElement for i16 {} impl TchElement for u8 {} - -make_element!(f64, false); -make_element!(f32, false); -make_element!(f16, false); -make_element!(i64, true); -make_element!(i32, true); -make_element!(i16, true); -make_element!(u8, false); diff --git a/burn-tch/src/ops/tensor.rs b/burn-tch/src/ops/tensor.rs index 40b55dca6..b47f3aa22 100644 --- a/burn-tch/src/ops/tensor.rs +++ b/burn-tch/src/ops/tensor.rs @@ -110,7 +110,7 @@ impl TensorOps> for TchBackend { } fn add_scalar(lhs: TchTensor, rhs: E) -> TchTensor { - let rhs: f64 = rhs.to_elem(); + let rhs: f64 = rhs.elem(); let tensor = lhs.unary_ops( |mut tensor| tensor.f_add_scalar_(rhs).unwrap(), @@ -125,7 +125,7 @@ impl TensorOps> for TchBackend { } fn sub_scalar(lhs: TchTensor, rhs: E) -> TchTensor { - let rhs: f64 = rhs.to_elem(); + let rhs: f64 = rhs.elem(); let tensor = lhs.unary_ops( |mut tensor| tensor.f_sub_scalar_(rhs).unwrap(), |tensor| tensor.f_sub_scalar(rhs).unwrap(), @@ -139,7 +139,7 @@ impl TensorOps> for TchBackend { } fn mul_scalar(lhs: TchTensor, rhs: E) -> TchTensor { - let rhs: f64 = rhs.to_elem(); + let rhs: f64 = rhs.elem(); let tensor = lhs.unary_ops( |mut tensor| tensor.f_mul_scalar_(rhs).unwrap(), |tensor| tensor.f_mul_scalar(rhs).unwrap(), @@ -153,7 +153,7 @@ impl TensorOps> for TchBackend { } fn div_scalar(lhs: TchTensor, rhs: E) -> TchTensor { - let rhs: f64 = rhs.to_elem(); + let rhs: f64 = rhs.elem(); let tensor = lhs.unary_ops( |mut tensor| tensor.f_div_scalar_(rhs).unwrap(), |tensor| tensor.f_div_scalar(rhs).unwrap(), @@ -168,7 +168,7 @@ impl TensorOps> for TchBackend { } fn neg(tensor: TchTensor) -> TchTensor { - Self::mul_scalar(tensor, (-1f32).to_elem::()) + Self::mul_scalar(tensor, (-1f32).elem::()) } fn swap_dims( @@ -210,7 +210,7 @@ impl TensorOps> for TchBackend { mask: TchTensor, value: E, ) -> TchTensor { - let value: f64 = value.to_elem(); + let value: f64 = value.elem(); let tensor = tensor.unary_ops( |mut tensor| tensor.f_masked_fill_(&mask.tensor, value).unwrap(), |tensor| tensor.f_masked_fill(&mask.tensor, value).unwrap(), @@ -223,8 +223,8 @@ impl TensorOps> for TchBackend { TchOps::equal(lhs, rhs) } - fn equal_scalar(lhs: TchTensor, rhs: E) -> TchTensor { - let rhs: f64 = rhs.to_elem(); + fn equal_elem(lhs: TchTensor, rhs: E) -> TchTensor { + let rhs: f64 = rhs.elem(); let tensor = lhs.unary_ops( |mut tensor| tensor.eq_(rhs).to_kind(tch::Kind::Bool), |tensor| tensor.eq(rhs), @@ -245,7 +245,7 @@ impl TensorOps> for TchBackend { } fn greater_scalar(lhs: TchTensor, rhs: E) -> TchTensor { - let rhs: f64 = rhs.to_elem(); + let rhs: f64 = rhs.elem(); let tensor = lhs.unary_ops( |mut tensor| tensor.greater_(rhs).to_kind(tch::Kind::Bool), |tensor| tensor.greater(rhs), @@ -269,7 +269,7 @@ impl TensorOps> for TchBackend { } fn greater_equal_scalar(lhs: TchTensor, rhs: E) -> TchTensor { - let rhs: f64 = rhs.to_elem(); + let rhs: f64 = rhs.elem(); let tensor = lhs.unary_ops( |mut tensor| tensor.greater_equal_(rhs).to_kind(tch::Kind::Bool), |tensor| tensor.greater_equal(rhs), @@ -290,7 +290,7 @@ impl TensorOps> for TchBackend { } fn lower_scalar(lhs: TchTensor, rhs: E) -> TchTensor { - let rhs: f64 = rhs.to_elem(); + let rhs: f64 = rhs.elem(); let tensor = lhs.unary_ops( |mut tensor| tensor.less_(rhs).to_kind(tch::Kind::Bool), |tensor| tensor.less(rhs), @@ -314,7 +314,7 @@ impl TensorOps> for TchBackend { } fn lower_equal_scalar(lhs: TchTensor, rhs: E) -> TchTensor { - let rhs: f64 = rhs.to_elem(); + let rhs: f64 = rhs.elem(); let tensor = lhs.unary_ops( |mut tensor| tensor.less_equal_(rhs).to_kind(tch::Kind::Bool), |tensor| tensor.less_equal(rhs), diff --git a/burn-tensor/src/tensor/api/base.rs b/burn-tensor/src/tensor/api/base.rs index f528027fa..447ea66d8 100644 --- a/burn-tensor/src/tensor/api/base.rs +++ b/burn-tensor/src/tensor/api/base.rs @@ -150,9 +150,9 @@ where } /// Applies element wise equal comparison and returns a boolean tensor. - pub fn equal_scalar>(self, other: E) -> Tensor { + pub fn equal_elem>(self, other: E) -> Tensor { let elem: K::Elem = other.into(); - K::equal_scalar::(self.primitive, elem) + K::equal_elem::(self.primitive, elem) } /// Concatenates all tensors into a new one along the given dimension. @@ -206,13 +206,12 @@ pub trait BasicOps: TensorKind { dim: usize, times: usize, ) -> Self::Primitive; + fn cat(vectors: Vec>, dim: usize) -> Self::Primitive; fn equal( lhs: Self::Primitive, rhs: Self::Primitive, ) -> Tensor; - fn equal_scalar(lhs: Self::Primitive, rhs: Self::Elem) - -> Tensor; - fn cat(vectors: Vec>, dim: usize) -> Self::Primitive; + fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor; } impl BasicOps for Float { @@ -277,6 +276,10 @@ impl BasicOps for Float { B::repeat(tensor, dim, times) } + fn cat(vectors: Vec>, dim: usize) -> Self::Primitive { + B::cat(vectors, dim) + } + fn equal( lhs: Self::Primitive, rhs: Self::Primitive, @@ -284,15 +287,8 @@ impl BasicOps for Float { Tensor::new(B::equal(lhs, rhs)) } - fn equal_scalar( - lhs: Self::Primitive, - rhs: Self::Elem, - ) -> Tensor { - Tensor::new(B::equal_scalar(lhs, rhs)) - } - - fn cat(vectors: Vec>, dim: usize) -> Self::Primitive { - B::cat(vectors, dim) + fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor { + Tensor::new(B::equal_elem(lhs, rhs)) } } @@ -365,10 +361,7 @@ impl BasicOps for Int { Tensor::new(B::int_equal(lhs, rhs)) } - fn equal_scalar( - lhs: Self::Primitive, - rhs: Self::Elem, - ) -> Tensor { + fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor { Tensor::new(B::int_equal_elem(lhs, rhs)) } @@ -446,10 +439,7 @@ impl BasicOps for Bool { Tensor::new(B::bool_equal(lhs, rhs)) } - fn equal_scalar( - lhs: Self::Primitive, - rhs: Self::Elem, - ) -> Tensor { + fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor { Tensor::new(B::bool_equal_elem(lhs, rhs)) } diff --git a/burn-tensor/src/tensor/api/float.rs b/burn-tensor/src/tensor/api/float.rs index a4518fa1d..a1504991a 100644 --- a/burn-tensor/src/tensor/api/float.rs +++ b/burn-tensor/src/tensor/api/float.rs @@ -272,22 +272,22 @@ where /// Applies element wise greater comparison and returns a boolean tensor. pub fn greater_scalar(self, other: E) -> Tensor { - Tensor::new(B::greater_scalar(self.primitive, other.to_elem())) + Tensor::new(B::greater_scalar(self.primitive, other.elem())) } /// Applies element wise greater-equal comparison and returns a boolean tensor. pub fn greater_equal_scalar(self, other: E) -> Tensor { - Tensor::new(B::greater_equal_scalar(self.primitive, other.to_elem())) + Tensor::new(B::greater_equal_scalar(self.primitive, other.elem())) } /// Applies element wise lower comparison and returns a boolean tensor. pub fn lower_scalar(self, other: E) -> Tensor { - Tensor::new(B::lower_scalar(self.primitive, other.to_elem())) + Tensor::new(B::lower_scalar(self.primitive, other.elem())) } /// Applies element wise lower-equal comparison and returns a boolean tensor. pub fn lower_equal_scalar(self, other: E) -> Tensor { - Tensor::new(B::lower_equal_scalar(self.primitive, other.to_elem())) + Tensor::new(B::lower_equal_scalar(self.primitive, other.elem())) } /// Create a random tensor of the given shape where each element is sampled from the given @@ -299,11 +299,7 @@ where /// Fill each element with the given value based on the given mask. pub fn mask_fill(self, mask: Tensor, value: E) -> Self { - Self::new(B::mask_fill( - self.primitive, - mask.primitive, - value.to_elem(), - )) + Self::new(B::mask_fill(self.primitive, mask.primitive, value.elem())) } /// Returns a tensor with full precision based on the selected backend. diff --git a/burn-tensor/src/tensor/api/numeric.rs b/burn-tensor/src/tensor/api/numeric.rs index d1d7c1dad..9d9a274df 100644 --- a/burn-tensor/src/tensor/api/numeric.rs +++ b/burn-tensor/src/tensor/api/numeric.rs @@ -160,7 +160,7 @@ impl Numeric for Int { lhs: Self::Primitive, rhs: E, ) -> Self::Primitive { - B::int_add_scalar(lhs, rhs.to_elem()) + B::int_add_scalar(lhs, rhs.elem()) } fn sub( lhs: Self::Primitive, @@ -172,7 +172,7 @@ impl Numeric for Int { lhs: Self::Primitive, rhs: E, ) -> Self::Primitive { - B::int_sub_scalar(lhs, rhs.to_elem()) + B::int_sub_scalar(lhs, rhs.elem()) } fn div( lhs: Self::Primitive, @@ -184,7 +184,7 @@ impl Numeric for Int { lhs: Self::Primitive, rhs: E, ) -> Self::Primitive { - B::int_div_scalar(lhs, rhs.to_elem()) + B::int_div_scalar(lhs, rhs.elem()) } fn mul( lhs: Self::Primitive, @@ -196,7 +196,7 @@ impl Numeric for Int { lhs: Self::Primitive, rhs: E, ) -> Self::Primitive { - B::int_mul_scalar(lhs, rhs.to_elem()) + B::int_mul_scalar(lhs, rhs.elem()) } fn neg(tensor: Self::Primitive) -> Self::Primitive { B::int_neg(tensor) @@ -232,7 +232,7 @@ impl Numeric for Float { lhs: Self::Primitive, rhs: E, ) -> Self::Primitive { - B::add_scalar(lhs, rhs.to_elem()) + B::add_scalar(lhs, rhs.elem()) } fn sub( lhs: Self::Primitive, @@ -244,7 +244,7 @@ impl Numeric for Float { lhs: Self::Primitive, rhs: E, ) -> Self::Primitive { - B::sub_scalar(lhs, rhs.to_elem()) + B::sub_scalar(lhs, rhs.elem()) } fn div( lhs: Self::Primitive, @@ -256,7 +256,7 @@ impl Numeric for Float { lhs: Self::Primitive, rhs: E, ) -> Self::Primitive { - B::div_scalar(lhs, rhs.to_elem()) + B::div_scalar(lhs, rhs.elem()) } fn mul( lhs: Self::Primitive, @@ -268,7 +268,7 @@ impl Numeric for Float { lhs: Self::Primitive, rhs: E, ) -> Self::Primitive { - B::mul_scalar(lhs, rhs.to_elem()) + B::mul_scalar(lhs, rhs.elem()) } fn neg(tensor: Self::Primitive) -> Self::Primitive { B::neg(tensor) diff --git a/burn-tensor/src/tensor/data.rs b/burn-tensor/src/tensor/data.rs index 1fb51e7cf..12e2c36f0 100644 --- a/burn-tensor/src/tensor/data.rs +++ b/burn-tensor/src/tensor/data.rs @@ -61,14 +61,12 @@ where DistributionSamplerKind::Uniform(distribution) => self.rng.sample(distribution), DistributionSamplerKind::Bernoulli(distribution) => { if self.rng.sample(distribution) { - 1.to_elem() + 1.elem() } else { - 0.to_elem() + 0.elem() } } - DistributionSamplerKind::Normal(distribution) => { - self.rng.sample(distribution).to_elem() - } + DistributionSamplerKind::Normal(distribution) => self.rng.sample(distribution).elem(), } } } @@ -114,7 +112,7 @@ where impl Data { pub fn convert(self) -> Data { - let value: Vec = self.value.into_iter().map(|a| a.to_elem()).collect(); + let value: Vec = self.value.into_iter().map(|a| a.elem()).collect(); Data { value, @@ -125,7 +123,7 @@ impl Data { impl DataSerialize

{ pub fn convert(self) -> DataSerialize { - let value: Vec = self.value.into_iter().map(|a| a.to_elem()).collect(); + let value: Vec = self.value.into_iter().map(|a| a.elem()).collect(); DataSerialize { value, @@ -136,11 +134,7 @@ impl DataSerialize

{ impl Data { pub fn convert(self) -> Data { - let value: Vec = self - .value - .into_iter() - .map(|a| (a as i64).to_elem()) - .collect(); + let value: Vec = self.value.into_iter().map(|a| (a as i64).elem()).collect(); Data { value, @@ -170,7 +164,7 @@ where let mut data = Vec::with_capacity(num_elements); for _ in 0..num_elements { - data.push(0.to_elem()); + data.push(0.elem()); } Data::new(data, shape) @@ -189,7 +183,7 @@ where let mut data = Vec::with_capacity(num_elements); for _ in 0..num_elements { - data.push(1.to_elem()); + data.push(1.elem()); } Data::new(data, shape) diff --git a/burn-tensor/src/tensor/element.rs b/burn-tensor/src/tensor/element.rs index de9b40cdf..d3117cc6c 100644 --- a/burn-tensor/src/tensor/element.rs +++ b/burn-tensor/src/tensor/element.rs @@ -8,21 +8,18 @@ pub trait Element: + ElementRandom + ElementConversion + ElementPrecision - + ElementValue - + core::ops::Mul + core::fmt::Debug + Default + Send + Sync + Copy - + core::cmp::PartialOrd + 'static { } pub trait ElementConversion { fn from_elem(elem: E) -> Self; - fn to_elem(&self) -> E; + fn elem(self) -> E; } pub trait ElementRandom { @@ -31,14 +28,6 @@ pub trait ElementRandom { Self: Sized; } -pub trait ElementValue { - fn inf() -> Self; - fn inf_neg() -> Self; - fn nan() -> Self; - fn zero() -> Self; - fn one() -> Self; -} - #[derive(Clone, PartialEq, Eq, Copy, Debug)] pub enum Precision { Double, @@ -55,8 +44,6 @@ pub trait ElementPrecision { macro_rules! make_element { ( ty $type:ident $precision:expr, - zero $zero:expr, - one $one:expr, convert $convert:expr, random $random:expr @@ -67,26 +54,8 @@ macro_rules! make_element { fn from_elem(elem: E) -> Self { $convert(&elem) } - fn to_elem(&self) -> E { - E::from_elem(*self) - } - } - - impl ElementValue for $type { - fn inf() -> Self { - Self::from_elem(f64::INFINITY) - } - fn inf_neg() -> Self { - Self::from_elem(core::ops::Neg::neg(f64::INFINITY)) - } - fn nan() -> Self { - Self::from_elem(f64::NAN) - } - fn zero() -> Self { - $zero - } - fn one() -> Self { - $one + fn elem(self) -> E { + E::from_elem(self) } } @@ -101,78 +70,53 @@ macro_rules! make_element { $random(distribution, rng) } } - - }; - ( - float $float:ident $precision:expr, - convert $convert:expr, - random $random:expr - ) => { - make_element!( - ty $float $precision, - zero 0.0, - one 1.0, - convert $convert, - random $random - ); - }; - ( - int $int:ident $precision:expr, - convert $convert:expr, - random $random:expr - ) => { - make_element!( - ty $int $precision, - zero 0, - one 1, - convert $convert, - random $random - ); }; } make_element!( - float f64 Precision::Double, + ty f64 Precision::Double, convert |elem: &dyn ToPrimitive| elem.to_f64().unwrap(), random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample() ); make_element!( - float f32 Precision::Full, + ty f32 Precision::Full, convert |elem: &dyn ToPrimitive| elem.to_f32().unwrap(), random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample() ); make_element!( - int i64 Precision::Double, + ty i64 Precision::Double, convert |elem: &dyn ToPrimitive| elem.to_i64().unwrap(), random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample() ); + make_element!( - int i32 Precision::Full, + ty i32 Precision::Full, convert |elem: &dyn ToPrimitive| elem.to_i32().unwrap(), random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample() ); + make_element!( - int i16 Precision::Half, + ty i16 Precision::Half, convert |elem: &dyn ToPrimitive| elem.to_i16().unwrap(), random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample() ); + make_element!( - int i8 Precision::Other, + ty i8 Precision::Other, convert |elem: &dyn ToPrimitive| elem.to_i8().unwrap(), random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample() ); make_element!( - int u8 Precision::Other, + ty u8 Precision::Other, convert |elem: &dyn ToPrimitive| elem.to_u8().unwrap(), random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample() ); + make_element!( ty f16 Precision::Half, - zero ::zero(), - one ::one(), convert |elem: &dyn ToPrimitive| f16::from_f32(elem.to_f32().unwrap()), random |distribution: Distribution, rng: &mut R| { let distribution: Distribution = distribution.convert(); diff --git a/burn-tensor/src/tensor/ops/modules/conv.rs b/burn-tensor/src/tensor/ops/modules/conv.rs index 9fea288df..11230d16f 100644 --- a/burn-tensor/src/tensor/ops/modules/conv.rs +++ b/burn-tensor/src/tensor/ops/modules/conv.rs @@ -72,7 +72,7 @@ pub(crate) fn conv1d_backward( weight_grad, bias.map(|b| { let elem = batch_size * length_out; - let elem = (elem as i32).to_elem(); + let elem = (elem as i32).elem(); let b = B::zeros(B::shape(&b), &B::device(&b)); @@ -127,7 +127,7 @@ pub(crate) fn conv2d_backward( weight_grad, bias.map(|b| { let elem = batch_size * width_out * height_out; - let elem = (elem as i32).to_elem(); + let elem = (elem as i32).elem(); let b = B::zeros(B::shape(&b), &B::device(&b)); diff --git a/burn-tensor/src/tensor/ops/tensor.rs b/burn-tensor/src/tensor/ops/tensor.rs index 08fdb9e95..d81148d3c 100644 --- a/burn-tensor/src/tensor/ops/tensor.rs +++ b/burn-tensor/src/tensor/ops/tensor.rs @@ -31,7 +31,7 @@ pub trait TensorOps { let shape = Shape::new([range.end - range.start]); let value = range .into_iter() - .map(|i| (i as i64).to_elem()) + .map(|i| (i as i64).elem()) .collect::>(); let data = Data::new(value, shape); B::int_from_data(data, device) @@ -132,7 +132,7 @@ pub trait TensorOps { lhs: B::TensorPrimitive, rhs: B::TensorPrimitive, ) -> B::BoolTensorPrimitive; - fn equal_scalar( + fn equal_elem( lhs: B::TensorPrimitive, rhs: B::FloatElem, ) -> B::BoolTensorPrimitive; diff --git a/examples/mnist/src/data.rs b/examples/mnist/src/data.rs index f670bcbef..dea223762 100644 --- a/examples/mnist/src/data.rs +++ b/examples/mnist/src/data.rs @@ -31,7 +31,7 @@ impl Batcher> for MNISTBatcher { let targets = items .iter() - .map(|item| Tensor::::from_data(Data::from([(item.label as i64).to_elem()]))) + .map(|item| Tensor::::from_data(Data::from([(item.label as i64).elem()]))) .collect(); let images = Tensor::cat(images, 0).to_device(&self.device); diff --git a/examples/text-classification/src/data/batcher.rs b/examples/text-classification/src/data/batcher.rs index 1518cc16d..1f81d004e 100644 --- a/examples/text-classification/src/data/batcher.rs +++ b/examples/text-classification/src/data/batcher.rs @@ -29,9 +29,7 @@ impl Batcher> for item in items { tokens_list.push(self.tokenizer.encode(&item.text)); - labels_list.push(Tensor::from_data(Data::from([ - (item.label as i64).to_elem() - ]))); + labels_list.push(Tensor::from_data(Data::from([(item.label as i64).elem()]))); } let mask = generate_padding_mask(