renaming FloatTensor Ops, Primitives, and maybe functions (#1174)

This commit is contained in:
Joshua Ferguson 2024-01-27 09:04:50 -06:00 committed by GitHub
parent 3814c4c9fb
commit 4a70a0f8bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
49 changed files with 1551 additions and 1294 deletions

View File

@ -20,7 +20,7 @@ impl<B: Backend, const D: usize> Benchmark for BinaryBenchmark<B, D> {
fn execute(&self, (lhs, rhs): Self::Args) {
// Choice of add is arbitrary
B::add(lhs.clone().into_primitive(), rhs.clone().into_primitive());
B::float_add(lhs.clone().into_primitive(), rhs.clone().into_primitive());
}
fn prepare(&self) -> Self::Args {

View File

@ -22,7 +22,7 @@ impl<B: Backend, const D: usize> Benchmark for UnaryBenchmark<B, D> {
fn execute(&self, args: Self::Args) {
// Choice of tanh is arbitrary
B::tanh(args.clone().into_primitive());
B::float_tanh(args.clone().into_primitive());
}
fn prepare(&self) -> Self::Args {

View File

@ -17,7 +17,7 @@ impl<B: Backend> Backend for Autodiff<B> {
type FullPrecisionElem = B::FullPrecisionElem;
type FullPrecisionBackend = Autodiff<B::FullPrecisionBackend>;
type TensorPrimitive<const D: usize> = AutodiffTensor<B, D>;
type FloatTensorPrimitive<const D: usize> = AutodiffTensor<B, D>;
type FloatElem = B::FloatElem;
type IntTensorPrimitive<const D: usize> = B::IntTensorPrimitive<D>;
@ -53,28 +53,28 @@ impl<B: Backend> AutodiffBackend for Autodiff<B> {
fn grad<const D: usize>(
tensor: &AutodiffTensor<B, D>,
grads: &Gradients,
) -> Option<B::TensorPrimitive<D>> {
) -> Option<B::FloatTensorPrimitive<D>> {
grads.get(tensor)
}
fn grad_remove<const D: usize>(
tensor: &AutodiffTensor<B, D>,
grads: &mut Gradients,
) -> Option<B::TensorPrimitive<D>> {
) -> Option<B::FloatTensorPrimitive<D>> {
grads.remove(tensor)
}
fn inner<const D: usize>(tensor: AutodiffTensor<B, D>) -> B::TensorPrimitive<D> {
fn inner<const D: usize>(tensor: AutodiffTensor<B, D>) -> B::FloatTensorPrimitive<D> {
tensor.primitive
}
fn from_inner<const D: usize>(tensor: B::TensorPrimitive<D>) -> AutodiffTensor<B, D> {
fn from_inner<const D: usize>(tensor: B::FloatTensorPrimitive<D>) -> AutodiffTensor<B, D> {
AutodiffTensor::new(tensor)
}
fn grad_replace<const D: usize>(
tensor: &AutodiffTensor<B, D>,
grads: &mut Self::Gradients,
grad: B::TensorPrimitive<D>,
grad: B::FloatTensorPrimitive<D>,
) {
grads.remove(tensor);
grads.register::<B, D>(tensor.node.clone(), grad);

View File

@ -13,7 +13,7 @@ pub struct Gradients {
container: TensorContainer<GradID>,
}
type TensorPrimitive<B, const D: usize> = <B as Backend>::TensorPrimitive<D>;
type TensorPrimitive<B, const D: usize> = <B as Backend>::FloatTensorPrimitive<D>;
impl Gradients {
/// Creates a new gradients container.
@ -26,7 +26,7 @@ impl Gradients {
};
gradients.register::<B, D>(
root_node,
B::ones(B::shape(&root_tensor), &B::device(&root_tensor)),
B::float_ones(B::float_shape(&root_tensor), &B::float_device(&root_tensor)),
);
gradients
}

View File

@ -14,7 +14,7 @@ impl<B: Backend> ActivationOps<Autodiff<B>> for Autodiff<B> {
struct Gelu<const D: usize>;
impl<const D: usize, B: Backend> Backward<B, D, 1> for Gelu<D> {
type State = B::TensorPrimitive<D>;
type State = B::FloatTensorPrimitive<D>;
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
let input = ops.state;
@ -39,7 +39,7 @@ impl<B: Backend> ActivationOps<Autodiff<B>> for Autodiff<B> {
struct Relu;
impl<B: Backend, const D: usize> Backward<B, D, 1> for Relu {
type State = B::TensorPrimitive<D>;
type State = B::FloatTensorPrimitive<D>;
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
@ -60,7 +60,7 @@ impl<B: Backend> ActivationOps<Autodiff<B>> for Autodiff<B> {
struct Sigmoid;
impl<B: Backend, const D: usize> Backward<B, D, 1> for Sigmoid {
type State = B::TensorPrimitive<D>;
type State = B::FloatTensorPrimitive<D>;
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {

View File

@ -44,8 +44,8 @@ pub fn binary<B, const D_OUT: usize, const D_LHS: usize, const D_RHS: usize, FLh
func_rhs: FRhs,
) where
B: Backend,
FLhs: FnOnce(B::TensorPrimitive<D_OUT>) -> B::TensorPrimitive<D_LHS>,
FRhs: FnOnce(B::TensorPrimitive<D_OUT>) -> B::TensorPrimitive<D_RHS>,
FLhs: FnOnce(B::FloatTensorPrimitive<D_OUT>) -> B::FloatTensorPrimitive<D_LHS>,
FRhs: FnOnce(B::FloatTensorPrimitive<D_OUT>) -> B::FloatTensorPrimitive<D_RHS>,
{
let [grad_4lhs, grad_4rhs] = duplicate(&parents, Some(grads.consume::<B, D_OUT>(&node)));
let [node_lhs, node_rhs] = parents;
@ -69,7 +69,7 @@ pub fn unary<B, const D_OUT: usize, const D_IN: usize, F>(
func: F,
) where
B: Backend,
F: FnOnce(B::TensorPrimitive<D_OUT>) -> B::TensorPrimitive<D_IN>,
F: FnOnce(B::FloatTensorPrimitive<D_OUT>) -> B::FloatTensorPrimitive<D_IN>,
{
let [parent_node] = parents;
let grad = grads.consume::<B, D_OUT>(&node);
@ -90,7 +90,7 @@ pub fn unary_different_backend<BIn, BOut, const D_OUT: usize, const D_IN: usize,
) where
BIn: Backend,
BOut: Backend,
F: FnOnce(BOut::TensorPrimitive<D_OUT>) -> BIn::TensorPrimitive<D_IN>,
F: FnOnce(BOut::FloatTensorPrimitive<D_OUT>) -> BIn::FloatTensorPrimitive<D_IN>,
{
let [parent_node] = parents;
let grad = grads.consume::<BOut, D_OUT>(&node);

View File

@ -37,7 +37,10 @@ where
BO: Backward<B, D, N, State = ()>,
{
/// Prepare a stateless operation.
pub fn stateless(self, output: <B as Backend>::TensorPrimitive<D>) -> AutodiffTensor<B, D> {
pub fn stateless(
self,
output: <B as Backend>::FloatTensorPrimitive<D>,
) -> AutodiffTensor<B, D> {
match self.stateful() {
OpsKind::Tracked(prep) => prep.finish((), output),
OpsKind::UnTracked(prep) => prep.finish(output),
@ -77,7 +80,7 @@ where
BO: Backward<B, D, N, State = S>,
{
/// Finish the preparation of an untracked operation and returns the output tensor.
pub fn finish(self, output: <B as Backend>::TensorPrimitive<D>) -> AutodiffTensor<B, D> {
pub fn finish(self, output: <B as Backend>::FloatTensorPrimitive<D>) -> AutodiffTensor<B, D> {
AutodiffTensor::from_parents(
output,
&self.nodes,
@ -97,7 +100,7 @@ where
pub fn finish(
self,
state: S,
output: <B as Backend>::TensorPrimitive<D>,
output: <B as Backend>::FloatTensorPrimitive<D>,
) -> AutodiffTensor<B, D> {
let output = AutodiffTensor::from_parents(
output,
@ -164,10 +167,10 @@ where
/// If broadcasting happened during the forward pass, the gradients will be sum along the
/// broadcasted dimension.
pub fn broadcast_shape<B: Backend, const D: usize>(
mut grad: B::TensorPrimitive<D>,
mut grad: B::FloatTensorPrimitive<D>,
shape: &Shape<D>,
) -> B::TensorPrimitive<D> {
let shape_grad = B::shape(&grad);
) -> B::FloatTensorPrimitive<D> {
let shape_grad = B::float_shape(&grad);
for i in 0..D {
if shape_grad.dims[i] != shape.dims[i] {
@ -177,7 +180,7 @@ pub fn broadcast_shape<B: Backend, const D: usize>(
shape.dims, shape_grad.dims, "Expected the shape of the next grad to be 1."
);
}
grad = B::sum_dim(grad, i);
grad = B::float_sum_dim(grad, i);
}
}

View File

@ -81,7 +81,7 @@ impl<B: Backend> BoolTensorOps<Self> for Autodiff<B> {
fn bool_into_float<const D: usize>(
tensor: BoolTensor<B, D>,
) -> <Autodiff<B> as Backend>::TensorPrimitive<D> {
) -> <Autodiff<B> as Backend>::FloatTensorPrimitive<D> {
AutodiffTensor::new(B::bool_into_float(tensor))
}

View File

@ -305,7 +305,7 @@ impl<B: Backend> IntTensorOps<Autodiff<B>> for Autodiff<B> {
}
fn int_into_float<const D: usize>(
tensor: <Autodiff<B> as Backend>::IntTensorPrimitive<D>,
) -> <Autodiff<B> as Backend>::TensorPrimitive<D> {
) -> <Autodiff<B> as Backend>::FloatTensorPrimitive<D> {
AutodiffTensor::new(B::int_into_float(tensor))
}

View File

@ -11,10 +11,10 @@ impl<B: Backend, const D: usize> Backward<B, D, 1> for MaxMinDim {
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
let (indices, shape) = ops.state;
let device = B::device(&grad);
let zeros = B::zeros(shape, &device);
let device = B::float_device(&grad);
let zeros = B::float_zeros(shape, &device);
B::scatter(D - 1, zeros, indices, grad)
B::float_scatter(D - 1, zeros, indices, grad)
});
}
}

View File

@ -14,7 +14,7 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
struct Embedding;
impl<B: Backend> Backward<B, 3, 1> for Embedding {
type State = (B::TensorPrimitive<2>, IntTensor<B, 2>);
type State = (B::FloatTensorPrimitive<2>, IntTensor<B, 2>);
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
let (weights, indices) = ops.state;
@ -58,9 +58,9 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
impl<B: Backend> Backward<B, 4, 3> for Conv2DWithBias {
type State = (
B::TensorPrimitive<4>,
B::TensorPrimitive<4>,
B::TensorPrimitive<1>,
B::FloatTensorPrimitive<4>,
B::FloatTensorPrimitive<4>,
B::FloatTensorPrimitive<1>,
ConvOptions<2>,
);
@ -84,7 +84,11 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
}
impl<B: Backend> Backward<B, 4, 2> for Conv2DNoBias {
type State = (B::TensorPrimitive<4>, B::TensorPrimitive<4>, ConvOptions<2>);
type State = (
B::FloatTensorPrimitive<4>,
B::FloatTensorPrimitive<4>,
ConvOptions<2>,
);
fn backward(self, ops: Ops<Self::State, 2>, grads: &mut Gradients) {
let [node_x, node_weight] = ops.parents;
@ -158,9 +162,9 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
impl<B: Backend> Backward<B, 4, 3> for ConvTranspose2DWithBias {
type State = (
B::TensorPrimitive<4>,
B::TensorPrimitive<4>,
B::TensorPrimitive<1>,
B::FloatTensorPrimitive<4>,
B::FloatTensorPrimitive<4>,
B::FloatTensorPrimitive<1>,
ConvTransposeOptions<2>,
);
@ -185,8 +189,8 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
impl<B: Backend> Backward<B, 4, 2> for ConvTranspose2DNoBias {
type State = (
B::TensorPrimitive<4>,
B::TensorPrimitive<4>,
B::FloatTensorPrimitive<4>,
B::FloatTensorPrimitive<4>,
ConvTransposeOptions<2>,
);
@ -270,9 +274,9 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
impl<B: Backend> Backward<B, 3, 3> for Conv1DWithBias {
type State = (
B::TensorPrimitive<3>,
B::TensorPrimitive<3>,
B::TensorPrimitive<1>,
B::FloatTensorPrimitive<3>,
B::FloatTensorPrimitive<3>,
B::FloatTensorPrimitive<1>,
ConvOptions<1>,
);
@ -296,7 +300,11 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
}
impl<B: Backend> Backward<B, 3, 2> for Conv1DNoBias {
type State = (B::TensorPrimitive<3>, B::TensorPrimitive<3>, ConvOptions<1>);
type State = (
B::FloatTensorPrimitive<3>,
B::FloatTensorPrimitive<3>,
ConvOptions<1>,
);
fn backward(self, ops: Ops<Self::State, 2>, grads: &mut Gradients) {
let [node_x, node_weight] = ops.parents;
@ -369,9 +377,9 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
impl<B: Backend> Backward<B, 3, 3> for ConvTranspose1DWithBias {
type State = (
B::TensorPrimitive<3>,
B::TensorPrimitive<3>,
B::TensorPrimitive<1>,
B::FloatTensorPrimitive<3>,
B::FloatTensorPrimitive<3>,
B::FloatTensorPrimitive<1>,
ConvTransposeOptions<1>,
);
@ -396,8 +404,8 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
impl<B: Backend> Backward<B, 3, 2> for ConvTranspose1DNoBias {
type State = (
B::TensorPrimitive<3>,
B::TensorPrimitive<3>,
B::FloatTensorPrimitive<3>,
B::FloatTensorPrimitive<3>,
ConvTransposeOptions<1>,
);
@ -494,7 +502,7 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
struct AvgPool1D;
impl<B: Backend> Backward<B, 3, 1> for AvgPool1D {
type State = (B::TensorPrimitive<3>, usize, usize, usize, bool);
type State = (B::FloatTensorPrimitive<3>, usize, usize, usize, bool);
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
let [node_parent] = ops.parents;
@ -551,7 +559,7 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
impl<B: Backend> Backward<B, 4, 1> for AvgPool2D {
type State = (
B::TensorPrimitive<4>,
B::FloatTensorPrimitive<4>,
[usize; 2],
[usize; 2],
[usize; 2],
@ -807,7 +815,7 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
struct AdaptiveAvgPool1D;
impl<B: Backend> Backward<B, 3, 1> for AdaptiveAvgPool1D {
type State = B::TensorPrimitive<3>;
type State = B::FloatTensorPrimitive<3>;
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
let [node_parent] = ops.parents;
@ -839,7 +847,7 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
struct AdaptiveAvgPool2D;
impl<B: Backend> Backward<B, 4, 1> for AdaptiveAvgPool2D {
type State = B::TensorPrimitive<4>;
type State = B::FloatTensorPrimitive<4>;
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
let [node_parent] = ops.parents;
@ -866,7 +874,7 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
fn adaptive_avg_pool2d_backward(
_x: AutodiffTensor<B, 4>,
_grad: AutodiffTensor<B, 4>,
) -> <Autodiff<B> as Backend>::TensorPrimitive<4> {
) -> <Autodiff<B> as Backend>::FloatTensorPrimitive<4> {
panic!("Can't differentiate adaptive avg pool2d backward.");
}
}
@ -876,7 +884,7 @@ struct MaxPool1D;
impl<B: Backend> Backward<B, 3, 1> for MaxPool1D {
type State = (
B::TensorPrimitive<3>,
B::FloatTensorPrimitive<3>,
IntTensor<B, 3>,
usize,
usize,
@ -910,7 +918,7 @@ struct MaxPool2D;
impl<B: Backend> Backward<B, 4, 1> for MaxPool2D {
type State = (
B::TensorPrimitive<4>,
B::FloatTensorPrimitive<4>,
IntTensor<B, 4>,
[usize; 2],
[usize; 2],

File diff suppressed because it is too large Load Diff

View File

@ -9,7 +9,7 @@ use crate::{
#[derive(Debug, Clone)]
pub struct AutodiffTensor<B: Backend, const D: usize> {
pub primitive: B::TensorPrimitive<D>,
pub primitive: B::FloatTensorPrimitive<D>,
pub node: NodeRef,
pub graph: Graph,
}
@ -31,7 +31,7 @@ impl Step for RootStep {
impl<B: Backend, const D: usize> AutodiffTensor<B, D> {
/// Create a new leaf tensor.
pub fn new(primitive: B::TensorPrimitive<D>) -> Self {
pub fn new(primitive: B::FloatTensorPrimitive<D>) -> Self {
let id = NodeID::new();
let node = Node::new(vec![], 0, id, Requirement::None);
@ -68,7 +68,7 @@ impl<B: Backend, const D: usize> AutodiffTensor<B, D> {
/// Create a tensor from parent infos.
pub fn from_parents<I: Iterator<Item = Graph>>(
output: B::TensorPrimitive<D>,
output: B::FloatTensorPrimitive<D>,
parent_nodes: &[NodeRef],
parent_graphs: I,
requirement: Requirement,

View File

@ -72,7 +72,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> Backend for Candle<F, I> {
type FullPrecisionBackend = Candle<Self::FullPrecisionElem, Self::IntElem>;
type FullPrecisionElem = f32;
type TensorPrimitive<const D: usize> = CandleTensor<Self::FloatElem, D>;
type FloatTensorPrimitive<const D: usize> = CandleTensor<Self::FloatElem, D>;
type FloatElem = F;
type IntTensorPrimitive<const D: usize> = CandleTensor<Self::IntElem, D>;

View File

@ -1,7 +1,7 @@
use std::borrow::Borrow;
use burn_tensor::{
ops::{BoolTensor, FloatElem, FloatTensor, FullPrecisionBackend, IntTensor, TensorOps},
ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, FullPrecisionBackend, IntTensor},
Data, Device, Distribution, ElementConversion, Reader, Shape,
};
use candle_core::{backend::BackendStorage, shape, Tensor};
@ -11,12 +11,15 @@ use crate::{
Candle, CandleTensor,
};
impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<Self> for Candle<F, I> {
fn from_data<const D: usize>(data: Data<F, D>, device: &Device<Self>) -> CandleTensor<F, D> {
impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle<F, I> {
fn float_from_data<const D: usize>(
data: Data<F, D>,
device: &Device<Self>,
) -> CandleTensor<F, D> {
CandleTensor::from_data(data, *device)
}
fn random<const D: usize>(
fn float_random<const D: usize>(
shape: Shape<D>,
distribution: Distribution,
device: &Device<Self>,
@ -50,90 +53,90 @@ impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<Self> for Candle<F, I
}
}
fn shape<const D: usize>(tensor: &CandleTensor<F, D>) -> Shape<D> {
fn float_shape<const D: usize>(tensor: &CandleTensor<F, D>) -> Shape<D> {
super::base::shape(tensor)
}
fn into_data<const D: usize>(tensor: CandleTensor<F, D>) -> Reader<Data<F, D>> {
fn float_into_data<const D: usize>(tensor: CandleTensor<F, D>) -> Reader<Data<F, D>> {
Reader::Concrete(super::base::into_data(tensor))
}
fn device<const D: usize>(tensor: &CandleTensor<F, D>) -> Device<Self> {
fn float_device<const D: usize>(tensor: &CandleTensor<F, D>) -> Device<Self> {
super::base::device(tensor)
}
fn to_device<const D: usize>(
fn float_to_device<const D: usize>(
tensor: CandleTensor<F, D>,
device: &Device<Self>,
) -> CandleTensor<F, D> {
super::base::to_device(tensor, device)
}
fn into_int<const D: usize>(tensor: CandleTensor<F, D>) -> IntTensor<Self, D> {
fn float_into_int<const D: usize>(tensor: CandleTensor<F, D>) -> IntTensor<Self, D> {
CandleTensor::new(tensor.tensor.to_dtype(I::DTYPE).unwrap())
}
fn empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> FloatTensor<Self, D> {
fn float_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> FloatTensor<Self, D> {
super::base::empty(shape, device)
}
fn add<const D: usize>(
fn float_add<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {
CandleTensor::new(lhs.tensor.broadcast_add(&rhs.tensor).unwrap())
}
fn add_scalar<const D: usize>(
fn float_add_scalar<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> FloatTensor<Self, D> {
CandleTensor::new((lhs.tensor + rhs.elem::<f64>()).unwrap())
}
fn sub<const D: usize>(
fn float_sub<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {
CandleTensor::new(lhs.tensor.broadcast_sub(&rhs.tensor).unwrap())
}
fn sub_scalar<const D: usize>(
fn float_sub_scalar<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> FloatTensor<Self, D> {
CandleTensor::new((lhs.tensor - rhs.elem::<f64>()).unwrap())
}
fn mul<const D: usize>(
fn float_mul<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {
CandleTensor::new(lhs.tensor.broadcast_mul(&rhs.tensor).unwrap())
}
fn mul_scalar<const D: usize>(
fn float_mul_scalar<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> FloatTensor<Self, D> {
CandleTensor::new((lhs.tensor * rhs.elem::<f64>()).unwrap())
}
fn div<const D: usize>(
fn float_div<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {
CandleTensor::new(lhs.tensor.broadcast_div(&rhs.tensor).unwrap())
}
fn div_scalar<const D: usize>(
fn float_div_scalar<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> FloatTensor<Self, D> {
CandleTensor::new((lhs.tensor / rhs.elem::<f64>()).unwrap())
}
fn matmul<const D: usize>(
fn float_matmul<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {
@ -150,7 +153,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<Self> for Candle<F, I
CandleTensor::new(lhs_contiguous.broadcast_matmul(&rhs_contiguous).unwrap())
}
fn swap_dims<const D: usize>(
fn float_swap_dims<const D: usize>(
tensor: FloatTensor<Self, D>,
dim1: usize,
dim2: usize,
@ -158,14 +161,14 @@ impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<Self> for Candle<F, I
super::base::swap_dims(tensor, dim1, dim2)
}
fn reshape<const D1: usize, const D2: usize>(
fn float_reshape<const D1: usize, const D2: usize>(
tensor: FloatTensor<Self, D1>,
shape: Shape<D2>,
) -> FloatTensor<Self, D2> {
super::base::reshape(tensor, shape)
}
fn gather<const D: usize>(
fn float_gather<const D: usize>(
dim: usize,
tensor: FloatTensor<Self, D>,
indices: IntTensor<Self, D>,
@ -173,7 +176,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<Self> for Candle<F, I
CandleTensor::new(tensor.tensor.gather(&indices.tensor, dim).unwrap())
}
fn scatter<const D: usize>(
fn float_scatter<const D: usize>(
dim: usize,
tensor: FloatTensor<Self, D>,
indices: IntTensor<Self, D>,
@ -187,7 +190,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<Self> for Candle<F, I
)
}
fn select<const D: usize>(
fn float_select<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
indices: IntTensor<Self, 1>,
@ -195,7 +198,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<Self> for Candle<F, I
CandleTensor::new(tensor.tensor.index_select(&indices.tensor, dim).unwrap())
}
fn select_assign<const D: usize>(
fn float_select_assign<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
indices: IntTensor<Self, 1>,
@ -209,14 +212,14 @@ impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<Self> for Candle<F, I
)
}
fn slice<const D1: usize, const D2: usize>(
fn float_slice<const D1: usize, const D2: usize>(
tensor: FloatTensor<Self, D1>,
ranges: [std::ops::Range<usize>; D2],
) -> FloatTensor<Self, D1> {
super::base::slice(tensor, ranges)
}
fn slice_assign<const D1: usize, const D2: usize>(
fn float_slice_assign<const D1: usize, const D2: usize>(
tensor: FloatTensor<Self, D1>,
ranges: [std::ops::Range<usize>; D2],
value: FloatTensor<Self, D1>,
@ -224,7 +227,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<Self> for Candle<F, I
super::base::slice_assign(tensor, ranges, value)
}
fn mask_where<const D: usize>(
fn float_mask_where<const D: usize>(
tensor: FloatTensor<Self, D>,
mask: BoolTensor<Self, D>,
value: FloatTensor<Self, D>,
@ -236,7 +239,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<Self> for Candle<F, I
)
}
fn mask_fill<const D: usize>(
fn float_mask_fill<const D: usize>(
tensor: FloatTensor<Self, D>,
mask: BoolTensor<Self, D>,
value: FloatElem<Self>,
@ -251,14 +254,14 @@ impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<Self> for Candle<F, I
)
}
fn equal<const D: usize>(
fn float_equal<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> BoolTensor<Self, D> {
CandleTensor::new(lhs.tensor.eq(&rhs.tensor).unwrap())
}
fn equal_elem<const D: usize>(
fn float_equal_elem<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> BoolTensor<Self, D> {
@ -269,14 +272,14 @@ impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<Self> for Candle<F, I
)
}
fn greater<const D: usize>(
fn float_greater<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> BoolTensor<Self, D> {
CandleTensor::new(lhs.tensor.gt(&rhs.tensor).unwrap())
}
fn greater_elem<const D: usize>(
fn float_greater_elem<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> BoolTensor<Self, D> {
@ -287,14 +290,14 @@ impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<Self> for Candle<F, I
)
}
fn greater_equal<const D: usize>(
fn float_greater_equal<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> BoolTensor<Self, D> {
CandleTensor::new(lhs.tensor.ge(&rhs.tensor).unwrap())
}
fn greater_equal_elem<const D: usize>(
fn float_greater_equal_elem<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> BoolTensor<Self, D> {
@ -305,14 +308,14 @@ impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<Self> for Candle<F, I
)
}
fn lower<const D: usize>(
fn float_lower<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> BoolTensor<Self, D> {
CandleTensor::new(lhs.tensor.lt(&rhs.tensor).unwrap())
}
fn lower_elem<const D: usize>(
fn float_lower_elem<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> BoolTensor<Self, D> {
@ -323,14 +326,14 @@ impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<Self> for Candle<F, I
)
}
fn lower_equal<const D: usize>(
fn float_lower_equal<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> BoolTensor<Self, D> {
CandleTensor::new(lhs.tensor.le(&rhs.tensor).unwrap())
}
fn lower_equal_elem<const D: usize>(
fn float_lower_equal_elem<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> BoolTensor<Self, D> {
@ -341,79 +344,94 @@ impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<Self> for Candle<F, I
)
}
fn sum<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, 1> {
fn float_sum<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, 1> {
let sum = tensor.tensor.sum_all().unwrap().to_scalar::<F>().unwrap();
CandleTensor::from_data(Data::new([sum].into(), [1].into()), Self::device(&tensor))
CandleTensor::from_data(
Data::new([sum].into(), [1].into()),
Self::float_device(&tensor),
)
}
fn sum_dim<const D: usize>(tensor: FloatTensor<Self, D>, dim: usize) -> FloatTensor<Self, D> {
fn float_sum_dim<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
) -> FloatTensor<Self, D> {
CandleTensor::new(tensor.tensor.sum_keepdim(dim).unwrap())
}
fn mean_dim<const D: usize>(tensor: FloatTensor<Self, D>, dim: usize) -> FloatTensor<Self, D> {
fn float_mean_dim<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
) -> FloatTensor<Self, D> {
CandleTensor::new(tensor.tensor.mean_keepdim(dim).unwrap())
}
fn to_full_precision<const D: usize>(
fn float_to_full_precision<const D: usize>(
tensor: &FloatTensor<Self, D>,
) -> FloatTensor<FullPrecisionBackend<Self>, D> {
CandleTensor::new(tensor.tensor.to_dtype(candle_core::DType::F32).unwrap())
}
fn from_full_precision<const D: usize>(
fn float_from_full_precision<const D: usize>(
tensor: FloatTensor<FullPrecisionBackend<Self>, D>,
) -> FloatTensor<Self, D> {
CandleTensor::new(tensor.tensor.to_dtype(F::DTYPE).unwrap())
}
fn exp<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
fn float_exp<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
CandleTensor::new(tensor.tensor.exp().unwrap())
}
fn log<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
fn float_log<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
CandleTensor::new(tensor.tensor.log().unwrap())
}
fn log1p<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
fn float_log1p<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
CandleTensor::new((tensor.tensor + 1.).unwrap().log().unwrap())
}
fn powf_scalar<const D: usize>(
fn float_powf_scalar<const D: usize>(
tensor: FloatTensor<Self, D>,
value: f32,
) -> FloatTensor<Self, D> {
CandleTensor::new(tensor.tensor.powf(value.elem::<f64>()).unwrap())
}
fn sqrt<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
fn float_sqrt<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
CandleTensor::new(tensor.tensor.sqrt().unwrap())
}
fn abs<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
fn float_abs<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
CandleTensor::new(tensor.tensor.abs().unwrap())
}
fn cos<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
fn float_cos<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
CandleTensor::new(tensor.tensor.cos().unwrap())
}
fn sin<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
fn float_sin<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
CandleTensor::new(tensor.tensor.sin().unwrap())
}
fn tanh<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
fn float_tanh<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
CandleTensor::new(tensor.tensor.tanh().unwrap())
}
fn erf<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
fn float_erf<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
CandleTensor::new(tensor.tensor.erf().unwrap())
}
fn cat<const D: usize>(tensors: Vec<FloatTensor<Self, D>>, dim: usize) -> FloatTensor<Self, D> {
fn float_cat<const D: usize>(
tensors: Vec<FloatTensor<Self, D>>,
dim: usize,
) -> FloatTensor<Self, D> {
super::base::cat(tensors, dim)
}
fn argmax<const D: usize>(tensor: FloatTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
fn float_argmax<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
) -> IntTensor<Self, D> {
CandleTensor::new(
tensor
.tensor
@ -424,7 +442,10 @@ impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<Self> for Candle<F, I
)
}
fn argmin<const D: usize>(tensor: FloatTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
fn float_argmin<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
) -> IntTensor<Self, D> {
CandleTensor::new(
tensor
.tensor
@ -435,21 +456,21 @@ impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<Self> for Candle<F, I
)
}
fn clamp_max<const D: usize>(
fn float_clamp_max<const D: usize>(
tensor: FloatTensor<Self, D>,
max: FloatElem<Self>,
) -> FloatTensor<Self, D> {
CandleTensor::new(tensor.tensor.minimum(max).unwrap())
}
fn clamp_min<const D: usize>(
fn float_clamp_min<const D: usize>(
tensor: FloatTensor<Self, D>,
min: FloatElem<Self>,
) -> FloatTensor<Self, D> {
CandleTensor::new(tensor.tensor.maximum(min).unwrap())
}
fn clamp<const D: usize>(
fn float_clamp<const D: usize>(
tensor: FloatTensor<Self, D>,
min: FloatElem<Self>,
max: FloatElem<Self>,
@ -457,11 +478,11 @@ impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<Self> for Candle<F, I
CandleTensor::new(tensor.tensor.clamp(min, max).unwrap())
}
fn recip<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
fn float_recip<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
CandleTensor::new(tensor.tensor.recip().unwrap())
}
fn narrow<const D: usize>(
fn float_narrow<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
start: usize,
@ -470,7 +491,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<Self> for Candle<F, I
super::base::narrow(tensor, dim, start, length)
}
fn chunk<const D: usize>(
fn float_chunk<const D: usize>(
tensor: FloatTensor<Self, D>,
chunks: usize,
dim: usize,
@ -478,7 +499,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<Self> for Candle<F, I
super::base::chunk(tensor, chunks, dim)
}
fn powf<const D: usize>(
fn float_powf<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {

View File

@ -26,7 +26,7 @@ impl<B: FusionBackend> Backend for Fusion<B> {
type FullPrecisionBackend = Self;
type FullPrecisionElem = B::FloatElem;
type TensorPrimitive<const D: usize> = FusionTensor<B::FusionClient>;
type FloatTensorPrimitive<const D: usize> = FusionTensor<B::FusionClient>;
type FloatElem = B::FloatElem;
@ -152,11 +152,11 @@ pub trait FusionBackend: Backend {
fn optimizations(device: Device<Self>)
-> Vec<Box<dyn OptimizationBuilder<Self::Optimization>>>;
/// Convert a [handle](FusionBackend::Handle) to a [float tensor](Backend::TensorPrimitive).
/// Convert a [handle](FusionBackend::Handle) to a [float tensor](Backend::FloatTensorPrimitive).
fn float_tensor<const D: usize>(
handle: Self::Handle,
shape: Shape<D>,
) -> Self::TensorPrimitive<D>;
) -> Self::FloatTensorPrimitive<D>;
/// Convert a [handle](FusionBackend::Handle) to an [int tensor](Backend::IntTensorPrimitive).
fn int_tensor<const D: usize>(
handle: Self::Handle,
@ -168,8 +168,8 @@ pub trait FusionBackend: Backend {
shape: Shape<D>,
) -> Self::BoolTensorPrimitive<D>;
/// Convert a [float tensor](Backend::TensorPrimitive) to a [handle](FusionBackend::Handle).
fn float_tensor_handle<const D: usize>(tensor: Self::TensorPrimitive<D>) -> Self::Handle;
/// Convert a [float tensor](Backend::FloatTensorPrimitive) to a [handle](FusionBackend::Handle).
fn float_tensor_handle<const D: usize>(tensor: Self::FloatTensorPrimitive<D>) -> Self::Handle;
/// Convert an [int tensor](Backend::IntTensorPrimitive) to a [handle](FusionBackend::Handle).
fn int_tensor_handle<const D: usize>(tensor: Self::IntTensorPrimitive<D>) -> Self::Handle;
/// Convert a [bool tensor](Backend::BoolTensorPrimitive) to a [handle](FusionBackend::Handle).

View File

@ -59,12 +59,12 @@ impl<B: FusionBackend> HandleContainer<B> {
}
}
/// Get the [float tensor](burn_tensor::backend::Backend::TensorPrimitive) corresponding to the
/// Get the [float tensor](burn_tensor::backend::Backend::FloatTensorPrimitive) corresponding to the
/// given [tensor description](TensorDescription).
pub fn get_float_tensor<const D: usize>(
&mut self,
tensor: &TensorDescription,
) -> B::TensorPrimitive<D> {
) -> B::FloatTensorPrimitive<D> {
B::float_tensor(
self.get_handle(&tensor.id, &tensor.status),
Shape::from(&tensor.shape),
@ -95,11 +95,11 @@ impl<B: FusionBackend> HandleContainer<B> {
)
}
/// Register a new [float tensor](burn_tensor::backend::Backend::TensorPrimitive) with the corresponding [tensor id](TensorId).
/// Register a new [float tensor](burn_tensor::backend::Backend::FloatTensorPrimitive) with the corresponding [tensor id](TensorId).
pub fn register_float_tensor<const D: usize>(
&mut self,
id: &TensorId,
tensor: B::TensorPrimitive<D>,
tensor: B::FloatTensorPrimitive<D>,
) {
let handle = B::float_tensor_handle(tensor);
self.handles.insert(*id, Handle::Existing(handle));

View File

@ -17,19 +17,19 @@ use crate::{
unary_float_ops, Fusion, FusionBackend, TensorDescription,
};
use burn_tensor::{
ops::{BoolTensor, FloatElem, FloatTensor, FullPrecisionBackend, IntTensor, TensorOps},
ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, FullPrecisionBackend, IntTensor},
Data, Device, Distribution, ElementConversion, Reader, Shape,
};
use std::ops::Range;
impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
fn from_data<const D: usize>(
impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
fn float_from_data<const D: usize>(
data: Data<FloatElem<Self>, D>,
device: &Device<Self>,
) -> FloatTensor<Self, D> {
let client = get_client::<B>(&device.clone().into());
let tensor = B::from_data(data, device);
let shape = B::shape(&tensor);
let tensor = B::float_from_data(data, device);
let shape = B::float_shape(&tensor);
client.register_tensor(
B::float_tensor_handle(tensor),
@ -38,7 +38,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
)
}
fn random<const D: usize>(
fn float_random<const D: usize>(
shape: Shape<D>,
distribution: Distribution,
device: &Device<Self>,
@ -51,8 +51,8 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
impl<const D: usize, B: FusionBackend> Operation<B> for RandomOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let shape = Shape::from(self.desc.out.shape.clone());
let output: B::TensorPrimitive<D> =
B::random(shape, self.desc.distribution, &handles.device);
let output: B::FloatTensorPrimitive<D> =
B::float_random(shape, self.desc.distribution, &handles.device);
handles.register_float_tensor(&self.desc.out.id, output);
}
}
@ -75,7 +75,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn zeros<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> FloatTensor<Self, D> {
fn float_zeros<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> FloatTensor<Self, D> {
#[derive(new)]
struct ZerosOps<const D: usize> {
out: TensorDescription,
@ -84,7 +84,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
impl<const D: usize, B: FusionBackend> Operation<B> for ZerosOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let shape = Shape::from(self.out.shape.clone());
let output = B::zeros::<D>(shape, &handles.device);
let output = B::float_zeros::<D>(shape, &handles.device);
handles.register_float_tensor(&self.out.id, output);
}
}
@ -104,7 +104,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn ones<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> FloatTensor<Self, D> {
fn float_ones<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> FloatTensor<Self, D> {
#[derive(new)]
struct OnesOps<const D: usize> {
out: TensorDescription,
@ -113,7 +113,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
impl<const D: usize, B: FusionBackend> Operation<B> for OnesOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let shape = Shape::from(self.out.shape.clone());
let output = B::ones::<D>(shape, &handles.device);
let output = B::float_ones::<D>(shape, &handles.device);
handles.register_float_tensor(&self.out.id, output);
}
}
@ -133,7 +133,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn full<const D: usize>(
fn float_full<const D: usize>(
shape: Shape<D>,
fill_value: FloatElem<Self>,
device: &Device<Self>,
@ -147,8 +147,8 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
impl<const D: usize, B: FusionBackend> Operation<B> for FullOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let shape = Shape::from(self.out.shape.clone());
let output: B::TensorPrimitive<D> =
B::full(shape, self.elem.elem(), &handles.device);
let output: B::FloatTensorPrimitive<D> =
B::float_full(shape, self.elem.elem(), &handles.device);
handles.register_float_tensor(&self.out.id, output);
}
}
@ -168,19 +168,21 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn shape<const D: usize>(tensor: &FloatTensor<Self, D>) -> Shape<D> {
fn float_shape<const D: usize>(tensor: &FloatTensor<Self, D>) -> Shape<D> {
tensor.shape()
}
fn into_data<const D: usize>(tensor: FloatTensor<Self, D>) -> Reader<Data<FloatElem<Self>, D>> {
fn float_into_data<const D: usize>(
tensor: FloatTensor<Self, D>,
) -> Reader<Data<FloatElem<Self>, D>> {
tensor.into_data()
}
fn device<const D: usize>(tensor: &FloatTensor<Self, D>) -> Device<Self> {
fn float_device<const D: usize>(tensor: &FloatTensor<Self, D>) -> Device<Self> {
tensor.client.device().clone().into()
}
fn to_device<const D: usize>(
fn float_to_device<const D: usize>(
tensor: FloatTensor<Self, D>,
device: &Device<Self>,
) -> FloatTensor<Self, D> {
@ -202,7 +204,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
)
}
fn into_int<const D: usize>(tensor: FloatTensor<Self, D>) -> IntTensor<Self, D> {
fn float_into_int<const D: usize>(tensor: FloatTensor<Self, D>) -> IntTensor<Self, D> {
#[derive(new)]
struct IntoIntOps<const D: usize> {
desc: UnaryOperationDescription,
@ -211,7 +213,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
impl<const D: usize, B: FusionBackend> Operation<B> for IntoIntOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let input = handles.get_float_tensor::<D>(&self.desc.input);
let output = B::into_int(input);
let output = B::float_into_int(input);
handles.register_int_tensor(&self.desc.out.id, output);
}
@ -233,19 +235,19 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> FloatTensor<Self, D> {
fn float_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> FloatTensor<Self, D> {
let client = get_client::<B>(&device.clone().into());
let stream = StreamId::current();
let tensor = B::empty(shape.clone(), device);
let tensor = B::float_empty(shape.clone(), device);
client.register_tensor(B::float_tensor_handle(tensor), shape.dims.into(), stream)
}
fn add<const D: usize>(
fn float_add<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {
binary_float_ops!(AddOps, B::add);
binary_float_ops!(AddOps, B::float_add);
let stream_1 = lhs.stream;
let stream_2 = rhs.stream;
@ -268,11 +270,11 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn add_scalar<const D: usize>(
fn float_add_scalar<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> FloatTensor<Self, D> {
scalar_float_ops!(AddOps, B::add_scalar);
scalar_float_ops!(AddOps, B::float_add_scalar);
let stream = lhs.stream;
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
@ -293,7 +295,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn clamp<const D: usize>(
fn float_clamp<const D: usize>(
tensor: FloatTensor<Self, D>,
min: FloatElem<Self>,
max: FloatElem<Self>,
@ -306,7 +308,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
impl<const D: usize, B: FusionBackend> Operation<B> for ClampOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let input = handles.get_float_tensor::<D>(&self.desc.tensor);
let output = B::clamp(input, self.desc.min.elem(), self.desc.max.elem());
let output = B::float_clamp(input, self.desc.min.elem(), self.desc.max.elem());
handles.register_float_tensor(&self.desc.out.id, output);
}
@ -330,11 +332,11 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn sub<const D: usize>(
fn float_sub<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {
binary_float_ops!(SubOps, B::sub);
binary_float_ops!(SubOps, B::float_sub);
let stream_1 = lhs.stream;
let stream_2 = rhs.stream;
@ -356,11 +358,11 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn sub_scalar<const D: usize>(
fn float_sub_scalar<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> FloatTensor<Self, D> {
scalar_float_ops!(SubOps, B::sub_scalar);
scalar_float_ops!(SubOps, B::float_sub_scalar);
let stream = lhs.stream;
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
@ -381,11 +383,11 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn mul<const D: usize>(
fn float_mul<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {
binary_float_ops!(MulOps, B::mul);
binary_float_ops!(MulOps, B::float_mul);
let stream_1 = lhs.stream;
let stream_2 = rhs.stream;
@ -407,11 +409,11 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn mul_scalar<const D: usize>(
fn float_mul_scalar<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> FloatTensor<Self, D> {
scalar_float_ops!(MulOps, B::mul_scalar);
scalar_float_ops!(MulOps, B::float_mul_scalar);
let stream = lhs.stream;
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
@ -432,11 +434,11 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn div<const D: usize>(
fn float_div<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {
binary_float_ops!(DivOps, B::div);
binary_float_ops!(DivOps, B::float_div);
let stream_1 = lhs.stream;
let stream_2 = rhs.stream;
@ -458,11 +460,11 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn div_scalar<const D: usize>(
fn float_div_scalar<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> FloatTensor<Self, D> {
scalar_float_ops!(DivOps, B::div_scalar);
scalar_float_ops!(DivOps, B::float_div_scalar);
let stream = lhs.stream;
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
@ -483,11 +485,11 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn matmul<const D: usize>(
fn float_matmul<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {
binary_float_ops!(MatmulOps, B::matmul);
binary_float_ops!(MatmulOps, B::float_matmul);
let stream_1 = lhs.stream;
let stream_2 = rhs.stream;
@ -512,7 +514,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn swap_dims<const D: usize>(
fn float_swap_dims<const D: usize>(
tensor: FloatTensor<Self, D>,
dim1: usize,
dim2: usize,
@ -525,7 +527,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
impl<const D: usize, B: FusionBackend> Operation<B> for SwapDimsOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let input = handles.get_float_tensor::<D>(&self.desc.input);
let output = B::swap_dims(input, self.desc.dim1, self.desc.dim2);
let output = B::float_swap_dims(input, self.desc.dim1, self.desc.dim2);
handles.register_float_tensor(&self.desc.out.id, output);
}
}
@ -553,7 +555,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn reshape<const D1: usize, const D2: usize>(
fn float_reshape<const D1: usize, const D2: usize>(
tensor: FloatTensor<Self, D1>,
shape: Shape<D2>,
) -> FloatTensor<Self, D2> {
@ -565,7 +567,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
impl<const D1: usize, const D2: usize, B: FusionBackend> Operation<B> for ReshapeDimsOps<D1, D2> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let input = handles.get_float_tensor::<D1>(&self.desc.input);
let output = B::reshape::<D1, D2>(input, Shape::from(&self.desc.out.shape));
let output = B::float_reshape::<D1, D2>(input, Shape::from(&self.desc.out.shape));
handles.register_float_tensor(&self.desc.out.id, output);
}
}
@ -587,7 +589,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn gather<const D: usize>(
fn float_gather<const D: usize>(
dim: usize,
tensor: FloatTensor<Self, D>,
indices: IntTensor<Self, D>,
@ -602,7 +604,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
let tensor = handles.get_float_tensor::<D>(&self.desc.tensor);
let indices = handles.get_int_tensor(&self.desc.indices);
let output = B::gather(self.desc.dim, tensor, indices);
let output = B::float_gather(self.desc.dim, tensor, indices);
handles.register_float_tensor(&self.desc.out.id, output);
}
}
@ -627,7 +629,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn scatter<const D: usize>(
fn float_scatter<const D: usize>(
dim: usize,
tensor: FloatTensor<Self, D>,
indices: IntTensor<Self, D>,
@ -644,7 +646,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
let indices = handles.get_int_tensor(&self.desc.indices);
let value = handles.get_float_tensor(&self.desc.value);
let output = B::scatter(self.desc.dim, tensor, indices, value);
let output = B::float_scatter(self.desc.dim, tensor, indices, value);
handles.register_float_tensor(&self.desc.out.id, output);
}
@ -673,7 +675,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn select<const D: usize>(
fn float_select<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
indices: IntTensor<Self, 1>,
@ -688,7 +690,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
let tensor = handles.get_float_tensor::<D>(&self.desc.tensor);
let indices = handles.get_int_tensor(&self.desc.indices);
let output = B::select(tensor, self.desc.dim, indices);
let output = B::float_select(tensor, self.desc.dim, indices);
handles.register_float_tensor(&self.desc.out.id, output);
}
@ -714,7 +716,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn select_assign<const D: usize>(
fn float_select_assign<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
indices: IntTensor<Self, 1>,
@ -731,7 +733,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
let indices = handles.get_int_tensor(&self.desc.indices);
let value = handles.get_float_tensor(&self.desc.value);
let output = B::select_assign(tensor, self.desc.dim, indices, value);
let output = B::float_select_assign(tensor, self.desc.dim, indices, value);
handles.register_float_tensor(&self.desc.out.id, output);
}
@ -761,7 +763,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn slice<const D1: usize, const D2: usize>(
fn float_slice<const D1: usize, const D2: usize>(
tensor: FloatTensor<Self, D1>,
ranges: [Range<usize>; D2],
) -> FloatTensor<Self, D1> {
@ -775,7 +777,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
let tensor = handles.get_float_tensor::<D1>(&self.desc.tensor);
let output =
B::slice::<D1, D2>(tensor, self.desc.ranges.clone().try_into().unwrap());
B::float_slice::<D1, D2>(tensor, self.desc.ranges.clone().try_into().unwrap());
handles.register_float_tensor(&self.desc.out.id, output);
}
@ -803,7 +805,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn slice_assign<const D1: usize, const D2: usize>(
fn float_slice_assign<const D1: usize, const D2: usize>(
tensor: FloatTensor<Self, D1>,
ranges: [Range<usize>; D2],
value: FloatTensor<Self, D1>,
@ -818,7 +820,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
let tensor = handles.get_float_tensor::<D1>(&self.desc.tensor);
let value = handles.get_float_tensor::<D1>(&self.desc.value);
let output = B::slice_assign::<D1, D2>(
let output = B::float_slice_assign::<D1, D2>(
tensor,
self.desc.ranges.clone().try_into().unwrap(),
value,
@ -848,7 +850,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn mask_where<const D: usize>(
fn float_mask_where<const D: usize>(
tensor: FloatTensor<Self, D>,
mask: BoolTensor<Self, D>,
value: FloatTensor<Self, D>,
@ -864,7 +866,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
let value = handles.get_float_tensor(&self.desc.value);
let mask = handles.get_bool_tensor(&self.desc.mask);
let output = B::mask_where(tensor, mask, value);
let output = B::float_mask_where(tensor, mask, value);
handles.register_float_tensor(&self.desc.out.id, output);
}
@ -893,7 +895,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn mask_fill<const D: usize>(
fn float_mask_fill<const D: usize>(
tensor: FloatTensor<Self, D>,
mask: BoolTensor<Self, D>,
value: FloatElem<Self>,
@ -908,7 +910,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
let tensor = handles.get_float_tensor::<D>(&self.desc.tensor);
let mask = handles.get_bool_tensor(&self.desc.mask);
let output = B::mask_fill(tensor, mask, self.desc.value.elem());
let output = B::float_mask_fill(tensor, mask, self.desc.value.elem());
handles.register_float_tensor(&self.desc.out.id, output);
}
@ -933,11 +935,11 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn equal<const D: usize>(
fn float_equal<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> BoolTensor<Self, D> {
binary_float_cmp_ops!(EqualOps, B::equal);
binary_float_cmp_ops!(EqualOps, B::float_equal);
let stream_1 = lhs.stream;
let stream_2 = rhs.stream;
@ -959,11 +961,11 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn equal_elem<const D: usize>(
fn float_equal_elem<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> BoolTensor<Self, D> {
scalar_float_cmp_ops!(EqualElemOps, B::equal_elem);
scalar_float_cmp_ops!(EqualElemOps, B::float_equal_elem);
let stream = lhs.stream;
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
@ -984,11 +986,11 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn greater<const D: usize>(
fn float_greater<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> BoolTensor<Self, D> {
binary_float_cmp_ops!(GreaterOps, B::greater);
binary_float_cmp_ops!(GreaterOps, B::float_greater);
let stream_1 = lhs.stream;
let stream_2 = rhs.stream;
@ -1010,11 +1012,11 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn greater_elem<const D: usize>(
fn float_greater_elem<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> BoolTensor<Self, D> {
scalar_float_cmp_ops!(GreaterElemOps, B::greater_elem);
scalar_float_cmp_ops!(GreaterElemOps, B::float_greater_elem);
let stream = lhs.stream;
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
@ -1035,11 +1037,11 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn greater_equal<const D: usize>(
fn float_greater_equal<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> BoolTensor<Self, D> {
binary_float_cmp_ops!(GreaterEqualOps, B::greater_equal);
binary_float_cmp_ops!(GreaterEqualOps, B::float_greater_equal);
let stream_1 = lhs.stream;
let stream_2 = rhs.stream;
@ -1063,11 +1065,11 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn greater_equal_elem<const D: usize>(
fn float_greater_equal_elem<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> BoolTensor<Self, D> {
scalar_float_cmp_ops!(GreaterEqualElemOps, B::greater_equal_elem);
scalar_float_cmp_ops!(GreaterEqualElemOps, B::float_greater_equal_elem);
let stream = lhs.stream;
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
@ -1088,11 +1090,11 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn lower<const D: usize>(
fn float_lower<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> BoolTensor<Self, D> {
binary_float_cmp_ops!(LowerOps, B::lower);
binary_float_cmp_ops!(LowerOps, B::float_lower);
let stream_1 = lhs.stream;
let stream_2 = rhs.stream;
@ -1114,11 +1116,11 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn lower_elem<const D: usize>(
fn float_lower_elem<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> BoolTensor<Self, D> {
scalar_float_cmp_ops!(LowerElemOps, B::lower_elem);
scalar_float_cmp_ops!(LowerElemOps, B::float_lower_elem);
let stream = lhs.stream;
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
@ -1139,11 +1141,11 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn lower_equal<const D: usize>(
fn float_lower_equal<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> BoolTensor<Self, D> {
binary_float_cmp_ops!(LowerEqualOps, B::lower_equal);
binary_float_cmp_ops!(LowerEqualOps, B::float_lower_equal);
let stream_1 = lhs.stream;
let stream_2 = rhs.stream;
@ -1167,11 +1169,11 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn lower_equal_elem<const D: usize>(
fn float_lower_equal_elem<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> BoolTensor<Self, D> {
scalar_float_cmp_ops!(LowerEqualElemOps, B::lower_equal_elem);
scalar_float_cmp_ops!(LowerEqualElemOps, B::float_lower_equal_elem);
let stream = lhs.stream;
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
@ -1192,8 +1194,8 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn sum<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, 1> {
unary_float_ops!(SumOps, B::sum);
fn float_sum<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, 1> {
unary_float_ops!(SumOps, B::float_sum);
let stream = tensor.stream;
let out = tensor.client.tensor_uninitialized(vec![1]);
@ -1211,8 +1213,11 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn sum_dim<const D: usize>(tensor: FloatTensor<Self, D>, dim: usize) -> FloatTensor<Self, D> {
scalar_float_ops!(SumDimOps, B::sum_dim, usize, noconvert);
fn float_sum_dim<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
) -> FloatTensor<Self, D> {
scalar_float_ops!(SumDimOps, B::float_sum_dim, usize, noconvert);
let stream = tensor.stream;
let mut shape = tensor.shape.clone();
@ -1233,8 +1238,8 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn mean<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, 1> {
unary_float_ops!(MeanOps, B::mean);
fn float_mean<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, 1> {
unary_float_ops!(MeanOps, B::float_mean);
let stream = tensor.stream;
let out = tensor.client.tensor_uninitialized(vec![1]);
@ -1252,8 +1257,11 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn mean_dim<const D: usize>(tensor: FloatTensor<Self, D>, dim: usize) -> FloatTensor<Self, D> {
scalar_float_ops!(MeanDimOps, B::mean_dim, usize, noconvert);
fn float_mean_dim<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
) -> FloatTensor<Self, D> {
scalar_float_ops!(MeanDimOps, B::float_mean_dim, usize, noconvert);
let stream = tensor.stream;
let mut shape = tensor.shape.clone();
@ -1274,20 +1282,20 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn to_full_precision<const D: usize>(
fn float_to_full_precision<const D: usize>(
tensor: &FloatTensor<Self, D>,
) -> FloatTensor<FullPrecisionBackend<Self>, D> {
tensor.clone()
}
fn from_full_precision<const D: usize>(
fn float_from_full_precision<const D: usize>(
tensor: FloatTensor<FullPrecisionBackend<Self>, D>,
) -> FloatTensor<Self, D> {
tensor
}
fn exp<const D: usize>(lhs: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary_float_ops!(ExpOps, B::exp);
fn float_exp<const D: usize>(lhs: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary_float_ops!(ExpOps, B::float_exp);
let stream = lhs.stream;
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
@ -1305,8 +1313,8 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn log<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary_float_ops!(LogOps, B::log);
fn float_log<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary_float_ops!(LogOps, B::float_log);
let stream = tensor.stream;
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
@ -1324,8 +1332,8 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn log1p<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary_float_ops!(Log1pOps, B::log1p);
fn float_log1p<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary_float_ops!(Log1pOps, B::float_log1p);
let stream = tensor.stream;
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
@ -1343,8 +1351,11 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn powf_scalar<const D: usize>(lhs: FloatTensor<Self, D>, rhs: f32) -> FloatTensor<Self, D> {
scalar_float_ops!(PowfOps, B::powf_scalar, f32);
fn float_powf_scalar<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: f32,
) -> FloatTensor<Self, D> {
scalar_float_ops!(PowfOps, B::float_powf_scalar, f32);
let stream = lhs.stream;
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
@ -1363,8 +1374,8 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn sqrt<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary_float_ops!(SqrtOps, B::sqrt);
fn float_sqrt<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary_float_ops!(SqrtOps, B::float_sqrt);
let stream = tensor.stream;
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
@ -1382,8 +1393,8 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn abs<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary_float_ops!(AbsOps, B::abs);
fn float_abs<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary_float_ops!(AbsOps, B::float_abs);
let stream = tensor.stream;
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
@ -1401,8 +1412,8 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn cos<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary_float_ops!(CosOps, B::cos);
fn float_cos<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary_float_ops!(CosOps, B::float_cos);
let stream = tensor.stream;
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
@ -1420,8 +1431,8 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn sin<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary_float_ops!(SinOps, B::sin);
fn float_sin<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary_float_ops!(SinOps, B::float_sin);
let stream = tensor.stream;
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
@ -1439,8 +1450,8 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn tanh<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary_float_ops!(TanhOps, B::tanh);
fn float_tanh<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary_float_ops!(TanhOps, B::float_tanh);
let stream = tensor.stream;
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
@ -1458,8 +1469,8 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn recip<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary_float_ops!(Recip, B::recip);
fn float_recip<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary_float_ops!(Recip, B::float_recip);
let stream = tensor.stream;
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
@ -1476,8 +1487,8 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn erf<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary_float_ops!(TanhOps, B::erf);
fn float_erf<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary_float_ops!(TanhOps, B::float_erf);
let stream = tensor.stream;
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
@ -1495,7 +1506,10 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn cat<const D: usize>(tensors: Vec<FloatTensor<Self, D>>, dim: usize) -> FloatTensor<Self, D> {
fn float_cat<const D: usize>(
tensors: Vec<FloatTensor<Self, D>>,
dim: usize,
) -> FloatTensor<Self, D> {
#[derive(new)]
struct CatOps<const D: usize> {
desc: CatOperationDescription,
@ -1510,7 +1524,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
.map(|tensor| handles.get_float_tensor(tensor))
.collect();
let output = B::cat::<D>(tensors, self.desc.dim);
let output = B::float_cat::<D>(tensors, self.desc.dim);
handles.register_float_tensor(&self.desc.out.id, output);
}
@ -1543,8 +1557,11 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn argmax<const D: usize>(tensor: FloatTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
scalar_float2int_ops!(ArgMaxOps, B::argmax, usize);
fn float_argmax<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
) -> IntTensor<Self, D> {
scalar_float2int_ops!(ArgMaxOps, B::float_argmax, usize);
let stream = tensor.stream;
let mut shape = tensor.shape.clone();
@ -1565,8 +1582,11 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn argmin<const D: usize>(tensor: FloatTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
scalar_float2int_ops!(ArgMinOps, B::argmin, usize);
fn float_argmin<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
) -> IntTensor<Self, D> {
scalar_float2int_ops!(ArgMinOps, B::float_argmin, usize);
let stream = tensor.stream;
let mut shape = tensor.shape.clone();
@ -1587,8 +1607,8 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn max<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, 1> {
unary_float_ops!(MaxOps, B::max);
fn float_max<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, 1> {
unary_float_ops!(MaxOps, B::float_max);
let stream = tensor.stream;
let out = tensor.client.tensor_uninitialized(vec![1]);
@ -1606,8 +1626,11 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn max_dim<const D: usize>(tensor: FloatTensor<Self, D>, dim: usize) -> FloatTensor<Self, D> {
scalar_float_ops!(MaxDimOps, B::max_dim, usize, noconvert);
fn float_max_dim<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
) -> FloatTensor<Self, D> {
scalar_float_ops!(MaxDimOps, B::float_max_dim, usize, noconvert);
let stream = tensor.stream;
let mut shape = tensor.shape.clone();
@ -1628,7 +1651,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn max_dim_with_indices<const D: usize>(
fn float_max_dim_with_indices<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
) -> (FloatTensor<Self, D>, IntTensor<Self, D>) {
@ -1640,7 +1663,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
impl<const D: usize, B: FusionBackend> Operation<B> for MaxDimWithIndicesOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let tensor = handles.get_float_tensor::<D>(&self.desc.tensor);
let (output, indices) = B::max_dim_with_indices(tensor, self.desc.dim);
let (output, indices) = B::float_max_dim_with_indices(tensor, self.desc.dim);
handles.register_float_tensor(&self.desc.out.id, output);
handles.register_int_tensor(&self.desc.out_indices.id, indices);
@ -1671,8 +1694,8 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
(out, out_indices)
}
fn min<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, 1> {
unary_float_ops!(MinOps, B::min);
fn float_min<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, 1> {
unary_float_ops!(MinOps, B::float_min);
let stream = tensor.stream;
let out = tensor.client.tensor_uninitialized(vec![1]);
@ -1690,8 +1713,11 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn min_dim<const D: usize>(tensor: FloatTensor<Self, D>, dim: usize) -> FloatTensor<Self, D> {
scalar_float_ops!(MinDimOps, B::min_dim, usize, noconvert);
fn float_min_dim<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
) -> FloatTensor<Self, D> {
scalar_float_ops!(MinDimOps, B::float_min_dim, usize, noconvert);
let stream = tensor.stream;
let mut shape = tensor.shape.clone();
@ -1712,7 +1738,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn min_dim_with_indices<const D: usize>(
fn float_min_dim_with_indices<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
) -> (FloatTensor<Self, D>, IntTensor<Self, D>) {
@ -1724,7 +1750,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
impl<const D: usize, B: FusionBackend> Operation<B> for MinDimWithIndicesOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let tensor = handles.get_float_tensor::<D>(&self.desc.tensor);
let (output, indices) = B::min_dim_with_indices(tensor, self.desc.dim);
let (output, indices) = B::float_min_dim_with_indices(tensor, self.desc.dim);
handles.register_float_tensor(&self.desc.out.id, output);
handles.register_int_tensor(&self.desc.out_indices.id, indices);
@ -1755,11 +1781,11 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
(out, out_indices)
}
fn powf<const D: usize>(
fn float_powf<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {
binary_float_ops!(PowOps, B::powf);
binary_float_ops!(PowOps, B::float_powf);
let stream_1 = lhs.stream;
let stream_2 = rhs.stream;

View File

@ -54,7 +54,7 @@ where
self.drain_stream(id);
let tensor = self.handles.get_float_tensor(&tensor);
B::into_data(tensor)
B::float_into_data(tensor)
}
pub fn read_int<const D: usize>(
@ -90,7 +90,7 @@ where
server_device: &mut Self,
) -> Arc<TensorId> {
let tensor = self.handles.get_float_tensor::<D>(tensor);
let tensor = B::to_device(tensor, device);
let tensor = B::float_to_device(tensor, device);
let id = server_device.create_empty_handle();
server_device

View File

@ -37,31 +37,31 @@ pub enum OperationDescription {
/// Operation description specific to a float tensor.
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub enum FloatOperationDescription {
/// Operation corresponding to [exp](burn_tensor::ops::TensorOps::exp).
/// Operation corresponding to [exp](burn_tensor::ops::FloatTensorOps::float_exp).
Exp(UnaryOperationDescription),
/// Operation corresponding to [log](burn_tensor::ops::TensorOps::log).
/// Operation corresponding to [log](burn_tensor::ops::FloatTensorOps::float_log).
Log(UnaryOperationDescription),
/// Operation corresponding to [log1p](burn_tensor::ops::TensorOps::log1p).
/// Operation corresponding to [log1p](burn_tensor::ops::FloatTensorOps::float_log1p).
Log1p(UnaryOperationDescription),
/// Operation corresponding to [erf](burn_tensor::ops::TensorOps::erf).
/// Operation corresponding to [erf](burn_tensor::ops::FloatTensorOps::float_erf).
Erf(UnaryOperationDescription),
/// Operation corresponding to [powf_scalar](burn_tensor::ops::TensorOps::powf_scalar).
/// Operation corresponding to [powf_scalar](burn_tensor::ops::FloatTensorOps::float_powf_scalar).
PowfScalar(ScalarOperationDescription<f32>),
/// Operation corresponding to [sqrt](burn_tensor::ops::TensorOps::sqrt).
/// Operation corresponding to [sqrt](burn_tensor::ops::FloatTensorOps::float_sqrt).
Sqrt(UnaryOperationDescription),
/// Operation corresponding to [cos](burn_tensor::ops::TensorOps::cos).
/// Operation corresponding to [cos](burn_tensor::ops::FloatTensorOps::float_cos).
Cos(UnaryOperationDescription),
/// Operation corresponding to [sin](burn_tensor::ops::TensorOps::sin).
/// Operation corresponding to [sin](burn_tensor::ops::FloatTensorOps::float_sin).
Sin(UnaryOperationDescription),
/// Operation corresponding to [tanh](burn_tensor::ops::TensorOps::tanh).
/// Operation corresponding to [tanh](burn_tensor::ops::FloatTensorOps::float_tanh).
Tanh(UnaryOperationDescription),
/// Operation corresponding to [into_int](burn_tensor::ops::TensorOps::into_int).
/// Operation corresponding to [into_int](burn_tensor::ops::FloatTensorOps::float_into_int).
IntoInt(UnaryOperationDescription),
/// Operation corresponding to [matmul](burn_tensor::ops::TensorOps::matmul).
/// Operation corresponding to [matmul](burn_tensor::ops::FloatTensorOps::float_matmul).
Matmul(BinaryOperationDescription),
/// Operation corresponding to [random](burn_tensor::ops::TensorOps::random).
/// Operation corresponding to [random](burn_tensor::ops::FloatTensorOps::float_random).
Random(RandomOperationDescription),
/// Operation corresponding to [recip](burn_tensor::ops::TensorOps::recip).
/// Operation corresponding to [recip](burn_tensor::ops::FloatTensorOps::float_recip).
Recip(UnaryOperationDescription),
}
@ -127,49 +127,49 @@ pub enum ModuleOperationDescription {
pub enum BaseOperationDescription {
/// Operation corresponding to:
///
/// Float => [to device](burn_tensor::ops::TensorOps::to_device).
/// Float => [to device](burn_tensor::ops::FloatTensorOps::float_to_device).
/// Int => [to device](burn_tensor::ops::IntTensorOps::int_to_device).
/// Bool => [to device](burn_tensor::ops::BoolTensorOps::bool_to_device).
ToDevice(TensorDescription),
/// Operation corresponding to:
///
/// Float => [reshape](burn_tensor::ops::TensorOps::reshape).
/// Float => [reshape](burn_tensor::ops::FloatTensorOps::float_reshape).
/// Int => [reshape](burn_tensor::ops::IntTensorOps::int_reshape).
/// Bool => [reshape](burn_tensor::ops::BoolTensorOps::bool_reshape).
Reshape(ReshapeDescription),
/// Operation corresponding to:
///
/// Float => [swap_dims](burn_tensor::ops::TensorOps::swap_dims).
/// Float => [swap_dims](burn_tensor::ops::FloatTensorOps::float_swap_dims).
/// Int => [swap_dims](burn_tensor::ops::IntTensorOps::int_swap_dims).
/// Bool => [swap_dims](burn_tensor::ops::BoolTensorOps::bool_swap_dims).
SwapDims(SwapDimsDescription),
/// Operation corresponding to:
///
/// Float => [slice](burn_tensor::ops::TensorOps::slice).
/// Float => [slice](burn_tensor::ops::FloatTensorOps::float_slice).
/// Int => [slice](burn_tensor::ops::IntTensorOps::int_slice).
/// Bool => [slice](burn_tensor::ops::BoolTensorOps::bool_slice).
Slice(SliceOperationDescription),
/// Operation corresponding to:
///
/// Float => [slice assign](burn_tensor::ops::TensorOps::slice_assign).
/// Float => [slice assign](burn_tensor::ops::FloatTensorOps::float_slice_assign).
/// Int => [slice assign](burn_tensor::ops::IntTensorOps::int_slice_assign).
/// Bool => [slice assign](burn_tensor::ops::BoolTensorOps::bool_slice_assign).
SliceAssign(SliceAssignOperationDescription),
/// Operation corresponding to:
///
/// Float => [equal](burn_tensor::ops::TensorOps::equal).
/// Float => [equal](burn_tensor::ops::FloatTensorOps::float_equal).
/// Int => [equal](burn_tensor::ops::IntTensorOps::int_equal).
/// Bool => [equal](burn_tensor::ops::BoolTensorOps::bool_equal).
Equal(BinaryOperationDescription),
/// Operation corresponding to:
///
/// Float => [repeat](burn_tensor::ops::TensorOps::repeat).
/// Float => [repeat](burn_tensor::ops::FloatTensorOps::float_repeat).
/// Int => [repeat](burn_tensor::ops::IntTensorOps::int_repeat).
/// Bool => [repeat](burn_tensor::ops::BoolTensorOps::bool_repeat).
Repeat(RepeatOperationDescription),
/// Operation corresponding to:
///
/// Float => [cat](burn_tensor::ops::TensorOps::cat).
/// Float => [cat](burn_tensor::ops::FloatTensorOps::float_cat).
/// Int => [cat](burn_tensor::ops::IntTensorOps::int_cat).
/// Bool => [cat](burn_tensor::ops::BoolTensorOps::bool_cat).
Cat(CatOperationDescription),
@ -180,207 +180,207 @@ pub enum BaseOperationDescription {
pub enum NumericOperationDescription<E> {
/// Operation corresponding to:
///
/// Float => [add](burn_tensor::ops::TensorOps::add).
/// Float => [add](burn_tensor::ops::FloatTensorOps::float_add).
/// Int => [add](burn_tensor::ops::IntTensorOps::int_add).
Add(BinaryOperationDescription),
/// Operation corresponding to:
///
/// Float => [add scalar](burn_tensor::ops::TensorOps::add_scalar).
/// Float => [add scalar](burn_tensor::ops::FloatTensorOps::float_add_scalar).
/// Int => [add scalar](burn_tensor::ops::IntTensorOps::int_add_scalar).
AddScalar(ScalarOperationDescription<E>),
/// Operation corresponding to:
///
/// Float => [sub](burn_tensor::ops::TensorOps::sub).
/// Float => [sub](burn_tensor::ops::FloatTensorOps::float_sub).
/// Int => [sub](burn_tensor::ops::IntTensorOps::int_sub).
Sub(BinaryOperationDescription),
/// Operation corresponding to:
///
/// Float => [sub scalar](burn_tensor::ops::TensorOps::sub_scalar).
/// Float => [sub scalar](burn_tensor::ops::FloatTensorOps::float_sub_scalar).
/// Int => [sub scalar](burn_tensor::ops::IntTensorOps::int_sub_scalar).
SubScalar(ScalarOperationDescription<E>),
/// Operation corresponding to:
///
/// Float => [div](burn_tensor::ops::TensorOps::div).
/// Float => [div](burn_tensor::ops::FloatTensorOps::float_div).
/// Int => [div](burn_tensor::ops::IntTensorOps::int_div).
Div(BinaryOperationDescription),
/// Operation corresponding to:
///
/// Float => [div scalar](burn_tensor::ops::TensorOps::div_scalar).
/// Float => [div scalar](burn_tensor::ops::FloatTensorOps::float_div_scalar).
/// Int => [div scalar](burn_tensor::ops::IntTensorOps::int_div_scalar).
DivScalar(ScalarOperationDescription<E>),
/// Operation corresponding to:
///
/// Float => [mul](burn_tensor::ops::TensorOps::mul).
/// Float => [mul](burn_tensor::ops::FloatTensorOps::float_mul).
/// Int => [mul](burn_tensor::ops::IntTensorOps::int_mul).
Mul(BinaryOperationDescription),
/// Operation corresponding to:
///
/// Float => [mul scalar](burn_tensor::ops::TensorOps::mul_scalar).
/// Float => [mul scalar](burn_tensor::ops::FloatTensorOps::float_mul_scalar).
/// Int => [mul scalar](burn_tensor::ops::IntTensorOps::int_mul_scalar).
MulScalar(ScalarOperationDescription<E>),
/// Operation corresponding to:
///
/// Float => [abs](burn_tensor::ops::TensorOps::abs).
/// Float => [abs](burn_tensor::ops::FloatTensorOps::float_abs).
/// Int => [abs](burn_tensor::ops::IntTensorOps::int_abs).
Abs(UnaryOperationDescription),
/// Operation corresponding to:
///
/// Float => [ones](burn_tensor::ops::TensorOps::ones).
/// Float => [ones](burn_tensor::ops::FloatTensorOps::float_ones).
/// Int => [ones](burn_tensor::ops::IntTensorOps::int_ones).
Ones(TensorDescription),
/// Operation corresponding to:
///
/// Float => [zeros](burn_tensor::ops::TensorOps::zeros).
/// Float => [zeros](burn_tensor::ops::FloatTensorOps::float_zeros).
/// Int => [zeros](burn_tensor::ops::IntTensorOps::int_zeros).
Zeros(TensorDescription),
/// Operation corresponding to:
///
/// Float => [full](burn_tensor::ops::TensorOps::full).
/// Float => [full](burn_tensor::ops::FloatTensorOps::float_full).
/// Int => [full](burn_tensor::ops::IntTensorOps::int_full).
Full((TensorDescription, E)),
/// Operation corresponding to:
///
/// Float => [gather](burn_tensor::ops::TensorOps::gather).
/// Float => [gather](burn_tensor::ops::FloatTensorOps::float_gather).
/// Int => [gather](burn_tensor::ops::IntTensorOps::int_gather).
Gather(GatherOperationDescription),
/// Operation corresponding to:
///
/// Float => [scatter](burn_tensor::ops::TensorOps::scatter).
/// Float => [scatter](burn_tensor::ops::FloatTensorOps::float_scatter).
/// Int => [scatter](burn_tensor::ops::IntTensorOps::int_scatter).
Scatter(ScatterOperationDescription),
/// Operation corresponding to:
///
/// Float => [select](burn_tensor::ops::TensorOps::select).
/// Float => [select](burn_tensor::ops::FloatTensorOps::float_select).
/// Int => [select](burn_tensor::ops::IntTensorOps::int_select).
Select(SelectOperationDescription),
/// Operation corresponding to:
///
/// Float => [select assign](burn_tensor::ops::TensorOps::select_assign).
/// Float => [select assign](burn_tensor::ops::FloatTensorOps::float_select_assign).
/// Int => [select assign](burn_tensor::ops::IntTensorOps::int_select_assign).
SelectAssign(SelectAssignOperationDescription),
/// Operation corresponding to:
///
/// Float => [mask where](burn_tensor::ops::TensorOps::mask_where).
/// Float => [mask where](burn_tensor::ops::FloatTensorOps::float_mask_where).
/// Int => [mask where](burn_tensor::ops::IntTensorOps::int_mask_where).
MaskWhere(MaskWhereOperationDescription),
/// Operation corresponding to:
///
/// Float => [mask fill](burn_tensor::ops::TensorOps::mask_fill).
/// Float => [mask fill](burn_tensor::ops::FloatTensorOps::float_mask_fill).
/// Int => [mask fill](burn_tensor::ops::IntTensorOps::int_mask_fill).
MaskFill(MaskFillOperationDescription<E>),
/// Operation corresponding to:
///
/// Float => [mean dim](burn_tensor::ops::TensorOps::mean_dim).
/// Float => [mean dim](burn_tensor::ops::FloatTensorOps::float_mean_dim).
/// Int => [mean dim](burn_tensor::ops::IntTensorOps::int_mean_dim).
MeanDim(ScalarOperationDescription<usize>),
/// Operation corresponding to:
///
/// Float => [mean](burn_tensor::ops::TensorOps::mean).
/// Float => [mean](burn_tensor::ops::FloatTensorOps::float_mean).
/// Int => [mean](burn_tensor::ops::IntTensorOps::int_mean).
Mean(UnaryOperationDescription),
/// Operation corresponding to:
///
/// Float => [sum](burn_tensor::ops::TensorOps::sum).
/// Float => [sum](burn_tensor::ops::FloatTensorOps::float_sum).
/// Int => [sum](burn_tensor::ops::IntTensorOps::int_sum).
Sum(UnaryOperationDescription),
/// Operation corresponding to:
///
/// Float => [sum dim](burn_tensor::ops::TensorOps::sum_dim).
/// Float => [sum dim](burn_tensor::ops::FloatTensorOps::float_sum_dim).
/// Int => [sum dim](burn_tensor::ops::IntTensorOps::int_sum_dim).
SumDim(ScalarOperationDescription<usize>),
/// Operation corresponding to:
///
/// Float => [equal elem](burn_tensor::ops::TensorOps::equal_elem).
/// Float => [equal elem](burn_tensor::ops::FloatTensorOps::float_equal_elem).
/// Int => [equal elem](burn_tensor::ops::IntTensorOps::int_equal_elem).
EqualElem(ScalarOperationDescription<E>),
/// Operation corresponding to:
///
/// Float => [greater](burn_tensor::ops::TensorOps::greater).
/// Float => [greater](burn_tensor::ops::FloatTensorOps::float_greater).
/// Int => [greater](burn_tensor::ops::IntTensorOps::int_greater).
Greater(BinaryOperationDescription),
/// Operation corresponding to:
///
/// Float => [greater elem](burn_tensor::ops::TensorOps::greater_elem).
/// Float => [greater elem](burn_tensor::ops::FloatTensorOps::float_greater_elem).
/// Int => [greater elem](burn_tensor::ops::IntTensorOps::int_greater_elem).
GreaterElem(ScalarOperationDescription<E>),
/// Operation corresponding to:
///
/// Float => [greater equal](burn_tensor::ops::TensorOps::greater_elem).
/// Float => [greater equal](burn_tensor::ops::FloatTensorOps::float_greater_elem).
/// Int => [greater elem](burn_tensor::ops::IntTensorOps::int_greater_elem).
GreaterEqual(BinaryOperationDescription),
/// Operation corresponding to:
///
/// Float => [greater equal elem](burn_tensor::ops::TensorOps::greater_equal_elem).
/// Float => [greater equal elem](burn_tensor::ops::FloatTensorOps::float_greater_equal_elem).
/// Int => [greater equal elem](burn_tensor::ops::IntTensorOps::int_greater_equal_elem).
GreaterEqualElem(ScalarOperationDescription<E>),
/// Operation corresponding to:
///
/// Float => [lower](burn_tensor::ops::TensorOps::lower).
/// Float => [lower](burn_tensor::ops::FloatTensorOps::float_lower).
/// Int => [lower](burn_tensor::ops::IntTensorOps::int_lower).
Lower(BinaryOperationDescription),
/// Operation corresponding to:
///
/// Float => [lower elem](burn_tensor::ops::TensorOps::lower_elem).
/// Float => [lower elem](burn_tensor::ops::FloatTensorOps::float_lower_elem).
/// Int => [lower elem](burn_tensor::ops::IntTensorOps::int_lower_elem).
LowerElem(ScalarOperationDescription<E>),
/// Operation corresponding to:
///
/// Float => [lower equal](burn_tensor::ops::TensorOps::lower_equal).
/// Float => [lower equal](burn_tensor::ops::FloatTensorOps::float_lower_equal).
/// Int => [lower equal](burn_tensor::ops::IntTensorOps::int_lower_equal).
LowerEqual(BinaryOperationDescription),
/// Operation corresponding to:
///
/// Float => [lower equal elem](burn_tensor::ops::TensorOps::lower_equal_elem).
/// Float => [lower equal elem](burn_tensor::ops::FloatTensorOps::float_lower_equal_elem).
/// Int => [lower equal elem](burn_tensor::ops::IntTensorOps::int_lower_equal_elem).
LowerEqualElem(ScalarOperationDescription<E>),
/// Operation corresponding to:
///
/// Float => [argmax](burn_tensor::ops::TensorOps::argmax).
/// Float => [argmax](burn_tensor::ops::FloatTensorOps::float_argmax).
/// Int => [argmax](burn_tensor::ops::IntTensorOps::int_argmax).
ArgMax(ScalarOperationDescription<usize>),
/// Operation corresponding to:
///
/// Float => [argmin](burn_tensor::ops::TensorOps::argmin).
/// Float => [argmin](burn_tensor::ops::FloatTensorOps::float_argmin).
/// Int => [argmin](burn_tensor::ops::IntTensorOps::int_argmin).
ArgMin(ScalarOperationDescription<usize>),
/// Operation corresponding to:
///
/// Float => [max](burn_tensor::ops::TensorOps::max).
/// Float => [max](burn_tensor::ops::FloatTensorOps::float_max).
/// Int => [max](burn_tensor::ops::IntTensorOps::int_max).
Max(UnaryOperationDescription),
/// Operation corresponding to:
///
/// Float => [max dim with indices](burn_tensor::ops::TensorOps::max_dim_with_indices).
/// Float => [max dim with indices](burn_tensor::ops::FloatTensorOps::float_max_dim_with_indices).
/// Int => [max dim with indices](burn_tensor::ops::IntTensorOps::int_max_dim_with_indices).
MaxDimWithIndices(ReduceDimWithIndicesDescription),
/// Operation corresponding to:
///
/// Float => [min dim with indices](burn_tensor::ops::TensorOps::min_dim_with_indices).
/// Float => [min dim with indices](burn_tensor::ops::FloatTensorOps::float_min_dim_with_indices).
/// Int => [min dim with indices](burn_tensor::ops::IntTensorOps::int_min_dim_with_indices).
MinDimWithIndices(ReduceDimWithIndicesDescription),
/// Operation corresponding to:
///
/// Float => [min](burn_tensor::ops::TensorOps::min).
/// Float => [min](burn_tensor::ops::FloatTensorOps::float_min).
/// Int => [min](burn_tensor::ops::IntTensorOps::int_min).
Min(UnaryOperationDescription),
/// Operation corresponding to:
///
/// Float => [max dim](burn_tensor::ops::TensorOps::max_dim).
/// Float => [max dim](burn_tensor::ops::FloatTensorOps::float_max_dim).
/// Int => [max dim](burn_tensor::ops::IntTensorOps::int_max_dim).
MaxDim(ScalarOperationDescription<usize>),
/// Operation corresponding to:
///
/// Float => [min dim](burn_tensor::ops::TensorOps::min_dim).
/// Float => [min dim](burn_tensor::ops::FloatTensorOps::float_min_dim).
/// Int => [min dim](burn_tensor::ops::IntTensorOps::int_min_dim).
MinDim(ScalarOperationDescription<usize>),
/// Operation corresponding to:
///
/// Float => [clamp](burn_tensor::ops::TensorOps::clamp).
/// Float => [clamp](burn_tensor::ops::FloatTensorOps::float_clamp).
/// Int => [clamp](burn_tensor::ops::IntTensorOps::int_clamp).
Clamp(ClampOperationDescription<E>),
/// Operation corresponding to:
///
/// Float => [powf](burn_tensor::ops::TensorOps::powf).
/// Float => [powf](burn_tensor::ops::FloatTensorOps::float_powf).
/// Int => [powf](burn_tensor::ops::IntTensorOps::int_powf).
Powf(BinaryOperationDescription),
}

View File

@ -35,7 +35,7 @@ impl<E: FloatNdArrayElement> Backend for NdArray<E> {
type FullPrecisionElem = f32;
type FullPrecisionBackend = NdArray<f32>;
type TensorPrimitive<const D: usize> = NdArrayTensor<E, D>;
type FloatTensorPrimitive<const D: usize> = NdArrayTensor<E, D>;
type FloatElem = E;
type IntTensorPrimitive<const D: usize> = NdArrayTensor<i64, D>;

View File

@ -116,7 +116,7 @@ impl<E: FloatNdArrayElement> BoolTensorOps<Self> for NdArray<E> {
fn bool_into_float<const D: usize>(
tensor: <NdArray<E> as Backend>::BoolTensorPrimitive<D>,
) -> <NdArray<E> as Backend>::TensorPrimitive<D> {
) -> <NdArray<E> as Backend>::FloatTensorPrimitive<D> {
let array = tensor.array.mapv(|a| (a as i32).elem()).into_shared();
NdArrayTensor { array }
}

View File

@ -365,7 +365,7 @@ impl<E: FloatNdArrayElement> IntTensorOps<Self> for NdArray<E> {
fn int_into_float<const D: usize>(
tensor: <NdArray<E> as Backend>::IntTensorPrimitive<D>,
) -> <NdArray<E> as Backend>::TensorPrimitive<D> {
) -> <NdArray<E> as Backend>::FloatTensorPrimitive<D> {
let array = tensor.array.mapv(|a| a.elem()).into_shared();
NdArrayTensor { array }
}

View File

@ -1,7 +1,7 @@
use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArray};
use crate::{iter_range_par, run_par, UnsafeSharedRef};
use burn_tensor::ElementConversion;
use burn_tensor::{ops::TensorOps, Shape};
use burn_tensor::{ops::FloatTensorOps, Shape};
use ndarray::s;
pub(crate) fn matmul<E, const D: usize>(
@ -29,7 +29,7 @@ where
let out = general_matmul(lhs, rhs);
NdArray::<E>::reshape(out, shape_out)
NdArray::<E>::float_reshape(out, shape_out)
}
fn general_matmul<E: FloatNdArrayElement>(
@ -91,13 +91,13 @@ fn reshape<E: FloatNdArrayElement, const D: usize>(
let shape = tensor.shape();
if D < 2 {
NdArray::<E>::reshape(tensor, Shape::new([1, 1, shape.dims[0]]))
NdArray::<E>::float_reshape(tensor, Shape::new([1, 1, shape.dims[0]]))
} else {
let batch_size = batch_size(&shape);
let size0 = shape.dims[D - 2];
let size1 = shape.dims[D - 1];
NdArray::<E>::reshape(tensor, Shape::new([batch_size, size0, size1]))
NdArray::<E>::float_reshape(tensor, Shape::new([batch_size, size0, size1]))
}
}

View File

@ -1,5 +1,5 @@
use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArray};
use burn_tensor::ops::TensorOps;
use burn_tensor::ops::FloatTensorOps;
use ndarray::Array4;
pub(crate) fn apply_padding_4d<E: FloatNdArrayElement>(
@ -18,7 +18,7 @@ pub(crate) fn apply_padding_4d<E: FloatNdArrayElement>(
);
let mut x_new = NdArrayTensor::new(x_new.into_shared().into_dyn());
x_new = NdArray::slice_assign(
x_new = NdArray::float_slice_assign(
x_new,
[
0..batch_size,

View File

@ -10,7 +10,7 @@ use crate::{NdArrayDevice, SEED};
// Workspace crates
use burn_common::rand::get_seeded_rng;
use burn_tensor::{backend::Backend, ops::TensorOps, Data, ElementConversion, Shape};
use burn_tensor::{backend::Backend, ops::FloatTensorOps, Data, ElementConversion, Shape};
use burn_tensor::{Distribution, Reader};
// External crates
@ -20,12 +20,15 @@ use libm::{cos, erf, sin, tanh};
#[allow(unused_imports)]
use num_traits::Float;
impl<E: FloatNdArrayElement> TensorOps<Self> for NdArray<E> {
fn from_data<const D: usize>(data: Data<E, D>, _device: &NdArrayDevice) -> NdArrayTensor<E, D> {
impl<E: FloatNdArrayElement> FloatTensorOps<Self> for NdArray<E> {
fn float_from_data<const D: usize>(
data: Data<E, D>,
_device: &NdArrayDevice,
) -> NdArrayTensor<E, D> {
NdArrayTensor::from_data(data)
}
fn random<const D: usize>(
fn float_random<const D: usize>(
shape: Shape<D>,
distribution: Distribution,
device: &NdArrayDevice,
@ -36,16 +39,16 @@ impl<E: FloatNdArrayElement> TensorOps<Self> for NdArray<E> {
} else {
get_seeded_rng()
};
let tensor = Self::from_data(Data::random(shape, distribution, &mut rng), device);
let tensor = Self::float_from_data(Data::random(shape, distribution, &mut rng), device);
*seed = Some(rng);
tensor
}
fn shape<const D: usize>(tensor: &NdArrayTensor<E, D>) -> Shape<D> {
fn float_shape<const D: usize>(tensor: &NdArrayTensor<E, D>) -> Shape<D> {
tensor.shape()
}
fn into_data<const D: usize>(
fn float_into_data<const D: usize>(
tensor: NdArrayTensor<E, D>,
) -> Reader<Data<<NdArray<E> as Backend>::FloatElem, D>> {
let shape = tensor.shape();
@ -54,84 +57,84 @@ impl<E: FloatNdArrayElement> TensorOps<Self> for NdArray<E> {
Reader::Concrete(Data::new(values, shape))
}
fn device<const D: usize>(_tensor: &NdArrayTensor<E, D>) -> NdArrayDevice {
fn float_device<const D: usize>(_tensor: &NdArrayTensor<E, D>) -> NdArrayDevice {
NdArrayDevice::Cpu
}
fn to_device<const D: usize>(
fn float_to_device<const D: usize>(
tensor: NdArrayTensor<E, D>,
_device: &NdArrayDevice,
) -> NdArrayTensor<E, D> {
tensor
}
fn empty<const D: usize>(
fn float_empty<const D: usize>(
shape: Shape<D>,
device: &<NdArray<E> as Backend>::Device,
) -> NdArrayTensor<E, D> {
NdArray::<E>::zeros(shape, device)
NdArray::<E>::float_zeros(shape, device)
}
fn add<const D: usize>(
fn float_add<const D: usize>(
lhs: NdArrayTensor<E, D>,
rhs: NdArrayTensor<E, D>,
) -> NdArrayTensor<E, D> {
NdArrayMathOps::add(lhs, rhs)
}
fn add_scalar<const D: usize>(lhs: NdArrayTensor<E, D>, rhs: E) -> NdArrayTensor<E, D> {
fn float_add_scalar<const D: usize>(lhs: NdArrayTensor<E, D>, rhs: E) -> NdArrayTensor<E, D> {
NdArrayMathOps::add_scalar(lhs, rhs)
}
fn sub<const D: usize>(
fn float_sub<const D: usize>(
lhs: NdArrayTensor<E, D>,
rhs: NdArrayTensor<E, D>,
) -> NdArrayTensor<E, D> {
NdArrayMathOps::sub(lhs, rhs)
}
fn sub_scalar<const D: usize>(lhs: NdArrayTensor<E, D>, rhs: E) -> NdArrayTensor<E, D> {
fn float_sub_scalar<const D: usize>(lhs: NdArrayTensor<E, D>, rhs: E) -> NdArrayTensor<E, D> {
NdArrayMathOps::sub_scalar(lhs, rhs)
}
fn mul<const D: usize>(
fn float_mul<const D: usize>(
lhs: NdArrayTensor<E, D>,
rhs: NdArrayTensor<E, D>,
) -> NdArrayTensor<E, D> {
NdArrayMathOps::mul(lhs, rhs)
}
fn mul_scalar<const D: usize>(lhs: NdArrayTensor<E, D>, rhs: E) -> NdArrayTensor<E, D> {
fn float_mul_scalar<const D: usize>(lhs: NdArrayTensor<E, D>, rhs: E) -> NdArrayTensor<E, D> {
NdArrayMathOps::mul_scalar(lhs, rhs)
}
fn div<const D: usize>(
fn float_div<const D: usize>(
lhs: NdArrayTensor<E, D>,
rhs: NdArrayTensor<E, D>,
) -> NdArrayTensor<E, D> {
NdArrayMathOps::div(lhs, rhs)
}
fn div_scalar<const D: usize>(lhs: NdArrayTensor<E, D>, rhs: E) -> NdArrayTensor<E, D> {
fn float_div_scalar<const D: usize>(lhs: NdArrayTensor<E, D>, rhs: E) -> NdArrayTensor<E, D> {
NdArrayMathOps::div_scalar(lhs, rhs)
}
fn matmul<const D: usize>(
fn float_matmul<const D: usize>(
lhs: NdArrayTensor<E, D>,
rhs: NdArrayTensor<E, D>,
) -> NdArrayTensor<E, D> {
matmul(lhs, rhs)
}
fn neg<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
Self::mul_scalar(tensor, (-1f32).elem::<E>())
fn float_neg<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
Self::float_mul_scalar(tensor, (-1f32).elem::<E>())
}
fn recip<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
fn float_recip<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
NdArrayMathOps::recip(tensor)
}
fn swap_dims<const D: usize>(
fn float_swap_dims<const D: usize>(
tensor: NdArrayTensor<E, D>,
dim1: usize,
dim2: usize,
@ -139,14 +142,14 @@ impl<E: FloatNdArrayElement> TensorOps<Self> for NdArray<E> {
NdArrayOps::swap_dims(tensor, dim1, dim2)
}
fn reshape<const D1: usize, const D2: usize>(
fn float_reshape<const D1: usize, const D2: usize>(
tensor: NdArrayTensor<E, D1>,
shape: Shape<D2>,
) -> NdArrayTensor<E, D2> {
NdArrayOps::reshape(tensor, shape)
}
fn gather<const D: usize>(
fn float_gather<const D: usize>(
dim: usize,
tensor: NdArrayTensor<E, D>,
indices: NdArrayTensor<i64, D>,
@ -154,7 +157,7 @@ impl<E: FloatNdArrayElement> TensorOps<Self> for NdArray<E> {
NdArrayMathOps::gather(dim, tensor, indices)
}
fn scatter<const D: usize>(
fn float_scatter<const D: usize>(
dim: usize,
tensor: NdArrayTensor<E, D>,
indices: NdArrayTensor<i64, D>,
@ -163,7 +166,7 @@ impl<E: FloatNdArrayElement> TensorOps<Self> for NdArray<E> {
NdArrayMathOps::scatter(dim, tensor, indices, value)
}
fn select<const D: usize>(
fn float_select<const D: usize>(
tensor: NdArrayTensor<E, D>,
dim: usize,
indices: NdArrayTensor<i64, 1>,
@ -171,7 +174,7 @@ impl<E: FloatNdArrayElement> TensorOps<Self> for NdArray<E> {
NdArrayMathOps::select(tensor, dim, indices)
}
fn select_assign<const D: usize>(
fn float_select_assign<const D: usize>(
tensor: NdArrayTensor<E, D>,
dim: usize,
indices: NdArrayTensor<i64, 1>,
@ -180,14 +183,14 @@ impl<E: FloatNdArrayElement> TensorOps<Self> for NdArray<E> {
NdArrayMathOps::select_assign(tensor, dim, indices, value)
}
fn slice<const D1: usize, const D2: usize>(
fn float_slice<const D1: usize, const D2: usize>(
tensor: NdArrayTensor<E, D1>,
ranges: [Range<usize>; D2],
) -> NdArrayTensor<E, D1> {
NdArrayOps::slice(tensor, ranges)
}
fn slice_assign<const D1: usize, const D2: usize>(
fn float_slice_assign<const D1: usize, const D2: usize>(
tensor: NdArrayTensor<E, D1>,
ranges: [Range<usize>; D2],
value: NdArrayTensor<E, D1>,
@ -195,7 +198,7 @@ impl<E: FloatNdArrayElement> TensorOps<Self> for NdArray<E> {
NdArrayOps::slice_assign(tensor, ranges, value)
}
fn mask_where<const D: usize>(
fn float_mask_where<const D: usize>(
tensor: NdArrayTensor<E, D>,
mask: NdArrayTensor<bool, D>,
value: NdArrayTensor<E, D>,
@ -203,7 +206,7 @@ impl<E: FloatNdArrayElement> TensorOps<Self> for NdArray<E> {
NdArrayMathOps::mask_where(tensor, mask, value)
}
fn mask_fill<const D: usize>(
fn float_mask_fill<const D: usize>(
tensor: NdArrayTensor<E, D>,
mask: NdArrayTensor<bool, D>,
value: E,
@ -211,47 +214,53 @@ impl<E: FloatNdArrayElement> TensorOps<Self> for NdArray<E> {
NdArrayMathOps::mask_fill(tensor, mask, value)
}
fn equal<const D: usize>(
fn float_equal<const D: usize>(
lhs: NdArrayTensor<E, D>,
rhs: NdArrayTensor<E, D>,
) -> NdArrayTensor<bool, D> {
let tensor = NdArray::<E>::sub(lhs, rhs);
let tensor = NdArray::<E>::float_sub(lhs, rhs);
let zero = 0.elem();
Self::equal_elem(tensor, zero)
Self::float_equal_elem(tensor, zero)
}
fn equal_elem<const D: usize>(lhs: NdArrayTensor<E, D>, rhs: E) -> NdArrayTensor<bool, D> {
fn float_equal_elem<const D: usize>(
lhs: NdArrayTensor<E, D>,
rhs: E,
) -> NdArrayTensor<bool, D> {
let array = lhs.array.mapv(|a| a == rhs).into_shared();
NdArrayTensor::new(array)
}
fn greater<const D: usize>(
fn float_greater<const D: usize>(
lhs: NdArrayTensor<E, D>,
rhs: NdArrayTensor<E, D>,
) -> NdArrayTensor<bool, D> {
let tensor = NdArray::<E>::sub(lhs, rhs);
let tensor = NdArray::<E>::float_sub(lhs, rhs);
let zero = 0.elem();
Self::greater_elem(tensor, zero)
Self::float_greater_elem(tensor, zero)
}
fn greater_elem<const D: usize>(lhs: NdArrayTensor<E, D>, rhs: E) -> NdArrayTensor<bool, D> {
fn float_greater_elem<const D: usize>(
lhs: NdArrayTensor<E, D>,
rhs: E,
) -> NdArrayTensor<bool, D> {
let array = lhs.array.mapv(|a| a > rhs).into_shared();
NdArrayTensor::new(array)
}
fn greater_equal<const D: usize>(
fn float_greater_equal<const D: usize>(
lhs: NdArrayTensor<E, D>,
rhs: NdArrayTensor<E, D>,
) -> NdArrayTensor<bool, D> {
let tensor = NdArray::<E>::sub(lhs, rhs);
let tensor = NdArray::<E>::float_sub(lhs, rhs);
let zero = 0.elem();
Self::greater_equal_elem(tensor, zero)
Self::float_greater_equal_elem(tensor, zero)
}
fn greater_equal_elem<const D: usize>(
fn float_greater_equal_elem<const D: usize>(
lhs: NdArrayTensor<E, D>,
rhs: E,
) -> NdArrayTensor<bool, D> {
@ -260,31 +269,34 @@ impl<E: FloatNdArrayElement> TensorOps<Self> for NdArray<E> {
NdArrayTensor::new(array)
}
fn lower<const D: usize>(
fn float_lower<const D: usize>(
lhs: NdArrayTensor<E, D>,
rhs: NdArrayTensor<E, D>,
) -> NdArrayTensor<bool, D> {
let tensor = NdArray::<E>::sub(lhs, rhs);
let tensor = NdArray::<E>::float_sub(lhs, rhs);
let zero = 0.elem();
Self::lower_elem(tensor, zero)
Self::float_lower_elem(tensor, zero)
}
fn lower_elem<const D: usize>(lhs: NdArrayTensor<E, D>, rhs: E) -> NdArrayTensor<bool, D> {
fn float_lower_elem<const D: usize>(
lhs: NdArrayTensor<E, D>,
rhs: E,
) -> NdArrayTensor<bool, D> {
let array = lhs.array.mapv(|a| a < rhs).into_shared();
NdArrayTensor::new(array)
}
fn lower_equal<const D: usize>(
fn float_lower_equal<const D: usize>(
lhs: NdArrayTensor<E, D>,
rhs: NdArrayTensor<E, D>,
) -> NdArrayTensor<bool, D> {
let tensor = NdArray::<E>::sub(lhs, rhs);
let tensor = NdArray::<E>::float_sub(lhs, rhs);
let zero = 0.elem();
Self::lower_equal_elem(tensor, zero)
Self::float_lower_equal_elem(tensor, zero)
}
fn lower_equal_elem<const D: usize>(
fn float_lower_equal_elem<const D: usize>(
lhs: NdArrayTensor<E, D>,
rhs: E,
) -> NdArrayTensor<bool, D> {
@ -293,65 +305,84 @@ impl<E: FloatNdArrayElement> TensorOps<Self> for NdArray<E> {
NdArrayTensor::new(array)
}
fn detach<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
fn float_detach<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
tensor
}
fn mean<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, 1> {
fn float_mean<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, 1> {
NdArrayMathOps::mean(tensor)
}
fn sum<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, 1> {
fn float_sum<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, 1> {
NdArrayMathOps::sum(tensor)
}
fn mean_dim<const D: usize>(tensor: NdArrayTensor<E, D>, dim: usize) -> NdArrayTensor<E, D> {
fn float_mean_dim<const D: usize>(
tensor: NdArrayTensor<E, D>,
dim: usize,
) -> NdArrayTensor<E, D> {
NdArrayMathOps::mean_dim(tensor, dim)
}
fn sum_dim<const D: usize>(tensor: NdArrayTensor<E, D>, dim: usize) -> NdArrayTensor<E, D> {
fn float_sum_dim<const D: usize>(
tensor: NdArrayTensor<E, D>,
dim: usize,
) -> NdArrayTensor<E, D> {
NdArrayMathOps::sum_dim(tensor, dim)
}
fn to_full_precision<const D: usize>(tensor: &NdArrayTensor<E, D>) -> NdArrayTensor<f32, D> {
fn float_to_full_precision<const D: usize>(
tensor: &NdArrayTensor<E, D>,
) -> NdArrayTensor<f32, D> {
let array = tensor.array.mapv(|a| a.elem()).into_shared();
NdArrayTensor::new(array)
}
fn from_full_precision<const D: usize>(tensor: NdArrayTensor<f32, D>) -> NdArrayTensor<E, D> {
fn float_from_full_precision<const D: usize>(
tensor: NdArrayTensor<f32, D>,
) -> NdArrayTensor<E, D> {
let array = tensor.array.mapv(|a| a.elem()).into_shared();
NdArrayTensor::new(array)
}
fn argmax<const D: usize>(tensor: NdArrayTensor<E, D>, dim: usize) -> NdArrayTensor<i64, D> {
fn float_argmax<const D: usize>(
tensor: NdArrayTensor<E, D>,
dim: usize,
) -> NdArrayTensor<i64, D> {
NdArrayMathOps::argmax(tensor, dim)
}
fn argmin<const D: usize>(tensor: NdArrayTensor<E, D>, dim: usize) -> NdArrayTensor<i64, D> {
fn float_argmin<const D: usize>(
tensor: NdArrayTensor<E, D>,
dim: usize,
) -> NdArrayTensor<i64, D> {
NdArrayMathOps::argmin(tensor, dim)
}
fn exp<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
fn float_exp<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
let array = tensor.array.mapv_into(|a| a.exp_elem()).into_shared();
NdArrayTensor::new(array)
}
fn log<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
fn float_log<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
let array = tensor.array.mapv_into(|a| a.log_elem()).into_shared();
NdArrayTensor::new(array)
}
fn log1p<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
fn float_log1p<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
let array = tensor.array.mapv_into(|a| a.log1p_elem()).into_shared();
NdArrayTensor::new(array)
}
fn powf_scalar<const D: usize>(tensor: NdArrayTensor<E, D>, value: f32) -> NdArrayTensor<E, D> {
fn float_powf_scalar<const D: usize>(
tensor: NdArrayTensor<E, D>,
value: f32,
) -> NdArrayTensor<E, D> {
let array = if value == 2.0 {
// Happens often and is faster.
tensor.array.mapv_into(|a| a * a).into_shared()
@ -369,19 +400,19 @@ impl<E: FloatNdArrayElement> TensorOps<Self> for NdArray<E> {
NdArrayTensor::new(array)
}
fn sqrt<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
fn float_sqrt<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
let array = tensor.array.mapv_into(|a| a.sqrt_elem()).into_shared();
NdArrayTensor::new(array)
}
fn abs<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
fn float_abs<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
let array = tensor.array.mapv_into(|a| a.abs_elem()).into_shared();
NdArrayTensor::new(array)
}
fn cos<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
fn float_cos<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
let array = tensor
.array
.mapv_into(|a| cos(a.to_f64().unwrap()).elem())
@ -390,7 +421,7 @@ impl<E: FloatNdArrayElement> TensorOps<Self> for NdArray<E> {
NdArrayTensor::new(array)
}
fn sin<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
fn float_sin<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
let array = tensor
.array
.mapv_into(|a| sin(a.to_f64().unwrap()).elem())
@ -399,7 +430,7 @@ impl<E: FloatNdArrayElement> TensorOps<Self> for NdArray<E> {
NdArrayTensor::new(array)
}
fn tanh<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
fn float_tanh<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
let array = tensor
.array
.mapv_into(|a| tanh(a.to_f64().unwrap()).elem())
@ -408,7 +439,7 @@ impl<E: FloatNdArrayElement> TensorOps<Self> for NdArray<E> {
NdArrayTensor::new(array)
}
fn erf<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
fn float_erf<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
let array = tensor
.array
.mapv_into(|a| erf(a.to_f64().unwrap()).elem())
@ -417,30 +448,37 @@ impl<E: FloatNdArrayElement> TensorOps<Self> for NdArray<E> {
NdArrayTensor::new(array)
}
fn cat<const D: usize>(tensors: Vec<NdArrayTensor<E, D>>, dim: usize) -> NdArrayTensor<E, D> {
fn float_cat<const D: usize>(
tensors: Vec<NdArrayTensor<E, D>>,
dim: usize,
) -> NdArrayTensor<E, D> {
NdArrayOps::cat(tensors, dim)
}
fn clamp_min<const D: usize>(tensor: NdArrayTensor<E, D>, min: E) -> NdArrayTensor<E, D> {
fn float_clamp_min<const D: usize>(tensor: NdArrayTensor<E, D>, min: E) -> NdArrayTensor<E, D> {
NdArrayMathOps::clamp_min(tensor, min)
}
fn clamp_max<const D: usize>(tensor: NdArrayTensor<E, D>, max: E) -> NdArrayTensor<E, D> {
fn float_clamp_max<const D: usize>(tensor: NdArrayTensor<E, D>, max: E) -> NdArrayTensor<E, D> {
NdArrayMathOps::clamp_max(tensor, max)
}
fn clamp<const D: usize>(tensor: NdArrayTensor<E, D>, min: E, max: E) -> NdArrayTensor<E, D> {
fn float_clamp<const D: usize>(
tensor: NdArrayTensor<E, D>,
min: E,
max: E,
) -> NdArrayTensor<E, D> {
NdArrayMathOps::clamp(tensor, min, max)
}
fn into_int<const D: usize>(
tensor: <NdArray<E> as Backend>::TensorPrimitive<D>,
fn float_into_int<const D: usize>(
tensor: <NdArray<E> as Backend>::FloatTensorPrimitive<D>,
) -> <NdArray<E> as Backend>::IntTensorPrimitive<D> {
let array = tensor.array.mapv(|a| a.elem()).into_shared();
NdArrayTensor { array }
}
fn powf<const D: usize>(
fn float_powf<const D: usize>(
lhs: NdArrayTensor<E, D>,
rhs: NdArrayTensor<E, D>,
) -> NdArrayTensor<E, D> {

View File

@ -78,7 +78,7 @@ impl<E: TchElement> Backend for LibTorch<E> {
type FullPrecisionElem = f32;
type FullPrecisionBackend = LibTorch<f32>;
type TensorPrimitive<const D: usize> = TchTensor<E, D>;
type FloatTensorPrimitive<const D: usize> = TchTensor<E, D>;
type FloatElem = E;
type IntTensorPrimitive<const D: usize> = TchTensor<i64, D>;

View File

@ -1,16 +1,19 @@
use super::TchOps;
use crate::{element::TchElement, LibTorch, LibTorchDevice, TchShape, TchTensor};
use burn_tensor::{
backend::Backend, ops::TensorOps, Data, Distribution, ElementConversion, Reader, Shape,
backend::Backend, ops::FloatTensorOps, Data, Distribution, ElementConversion, Reader, Shape,
};
use std::ops::Range;
impl<E: TchElement> TensorOps<Self> for LibTorch<E> {
fn from_data<const D: usize>(data: Data<E, D>, device: &LibTorchDevice) -> TchTensor<E, D> {
impl<E: TchElement> FloatTensorOps<Self> for LibTorch<E> {
fn float_from_data<const D: usize>(
data: Data<E, D>,
device: &LibTorchDevice,
) -> TchTensor<E, D> {
TchTensor::from_data(data, (*device).into())
}
fn random<const D: usize>(
fn float_random<const D: usize>(
shape: Shape<D>,
distribution: Distribution,
device: &LibTorchDevice,
@ -39,7 +42,7 @@ impl<E: TchElement> TensorOps<Self> for LibTorch<E> {
}
}
fn arange(range: Range<usize>, device: &LibTorchDevice) -> TchTensor<i64, 1> {
fn float_arange(range: Range<usize>, device: &LibTorchDevice) -> TchTensor<i64, 1> {
let device: tch::Device = (*device).into();
let mut tensor = tch::Tensor::arange(
range.end as i64 - range.start as i64,
@ -53,7 +56,7 @@ impl<E: TchElement> TensorOps<Self> for LibTorch<E> {
TchTensor::new(tensor)
}
fn repeat<const D: usize>(
fn float_repeat<const D: usize>(
tensor: TchTensor<E, D>,
dim: usize,
times: usize,
@ -61,59 +64,61 @@ impl<E: TchElement> TensorOps<Self> for LibTorch<E> {
TchOps::repeat(tensor, dim, times)
}
fn zeros<const D: usize>(shape: Shape<D>, device: &LibTorchDevice) -> TchTensor<E, D> {
fn float_zeros<const D: usize>(shape: Shape<D>, device: &LibTorchDevice) -> TchTensor<E, D> {
let shape = TchShape::from(shape);
let device: tch::Device = (*device).into();
TchTensor::new(tch::Tensor::zeros(shape.dims, (E::KIND, device)))
}
fn ones<const D: usize>(shape: Shape<D>, device: &LibTorchDevice) -> TchTensor<E, D> {
fn float_ones<const D: usize>(shape: Shape<D>, device: &LibTorchDevice) -> TchTensor<E, D> {
let shape = TchShape::from(shape);
let device: tch::Device = (*device).into();
TchTensor::new(tch::Tensor::ones(shape.dims, (E::KIND, device)))
}
fn shape<const D: usize>(tensor: &<LibTorch<E> as Backend>::TensorPrimitive<D>) -> Shape<D> {
fn float_shape<const D: usize>(
tensor: &<LibTorch<E> as Backend>::FloatTensorPrimitive<D>,
) -> Shape<D> {
tensor.shape()
}
fn into_data<const D: usize>(
tensor: <LibTorch<E> as Backend>::TensorPrimitive<D>,
fn float_into_data<const D: usize>(
tensor: <LibTorch<E> as Backend>::FloatTensorPrimitive<D>,
) -> Reader<Data<<LibTorch<E> as Backend>::FloatElem, D>> {
let shape = Self::shape(&tensor);
let tensor = Self::reshape(tensor.clone(), Shape::new([shape.num_elements()]));
let shape = Self::float_shape(&tensor);
let tensor = Self::float_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
let values: Result<Vec<E>, tch::TchError> = tensor.tensor.try_into();
Reader::Concrete(Data::new(values.unwrap(), shape))
}
fn device<const D: usize>(tensor: &TchTensor<E, D>) -> LibTorchDevice {
fn float_device<const D: usize>(tensor: &TchTensor<E, D>) -> LibTorchDevice {
tensor.tensor.device().into()
}
fn to_device<const D: usize>(
fn float_to_device<const D: usize>(
tensor: TchTensor<E, D>,
device: &LibTorchDevice,
) -> TchTensor<E, D> {
TchTensor::new(tensor.tensor.to((*device).into()))
}
fn empty<const D: usize>(
fn float_empty<const D: usize>(
shape: Shape<D>,
device: &<LibTorch<E> as Backend>::Device,
) -> <LibTorch<E> as Backend>::TensorPrimitive<D> {
) -> <LibTorch<E> as Backend>::FloatTensorPrimitive<D> {
let tensor = tch::Tensor::empty(shape.dims.map(|a| a as i64), (E::KIND, (*device).into()));
TchTensor::new(tensor)
}
fn add<const D: usize>(lhs: TchTensor<E, D>, rhs: TchTensor<E, D>) -> TchTensor<E, D> {
fn float_add<const D: usize>(lhs: TchTensor<E, D>, rhs: TchTensor<E, D>) -> TchTensor<E, D> {
TchOps::add(lhs, rhs)
}
fn add_scalar<const D: usize>(lhs: TchTensor<E, D>, rhs: E) -> TchTensor<E, D> {
fn float_add_scalar<const D: usize>(lhs: TchTensor<E, D>, rhs: E) -> TchTensor<E, D> {
let rhs: f64 = rhs.elem();
lhs.unary_ops(
@ -122,11 +127,11 @@ impl<E: TchElement> TensorOps<Self> for LibTorch<E> {
)
}
fn sub<const D: usize>(lhs: TchTensor<E, D>, rhs: TchTensor<E, D>) -> TchTensor<E, D> {
fn float_sub<const D: usize>(lhs: TchTensor<E, D>, rhs: TchTensor<E, D>) -> TchTensor<E, D> {
TchOps::sub(lhs, rhs)
}
fn sub_scalar<const D: usize>(lhs: TchTensor<E, D>, rhs: E) -> TchTensor<E, D> {
fn float_sub_scalar<const D: usize>(lhs: TchTensor<E, D>, rhs: E) -> TchTensor<E, D> {
let rhs: f64 = rhs.elem();
lhs.unary_ops(
@ -135,11 +140,11 @@ impl<E: TchElement> TensorOps<Self> for LibTorch<E> {
)
}
fn mul<const D: usize>(lhs: TchTensor<E, D>, rhs: TchTensor<E, D>) -> TchTensor<E, D> {
fn float_mul<const D: usize>(lhs: TchTensor<E, D>, rhs: TchTensor<E, D>) -> TchTensor<E, D> {
TchOps::mul(lhs, rhs)
}
fn mul_scalar<const D: usize>(lhs: TchTensor<E, D>, rhs: E) -> TchTensor<E, D> {
fn float_mul_scalar<const D: usize>(lhs: TchTensor<E, D>, rhs: E) -> TchTensor<E, D> {
let rhs: f64 = rhs.elem();
lhs.unary_ops(
@ -148,11 +153,11 @@ impl<E: TchElement> TensorOps<Self> for LibTorch<E> {
)
}
fn div<const D: usize>(lhs: TchTensor<E, D>, rhs: TchTensor<E, D>) -> TchTensor<E, D> {
fn float_div<const D: usize>(lhs: TchTensor<E, D>, rhs: TchTensor<E, D>) -> TchTensor<E, D> {
TchOps::div(lhs, rhs)
}
fn div_scalar<const D: usize>(lhs: TchTensor<E, D>, rhs: E) -> TchTensor<E, D> {
fn float_div_scalar<const D: usize>(lhs: TchTensor<E, D>, rhs: E) -> TchTensor<E, D> {
let rhs: f64 = rhs.elem();
lhs.unary_ops(
@ -161,20 +166,20 @@ impl<E: TchElement> TensorOps<Self> for LibTorch<E> {
)
}
fn matmul<const D: usize>(lhs: TchTensor<E, D>, rhs: TchTensor<E, D>) -> TchTensor<E, D> {
fn float_matmul<const D: usize>(lhs: TchTensor<E, D>, rhs: TchTensor<E, D>) -> TchTensor<E, D> {
let tensor = lhs.tensor.matmul(&rhs.tensor);
TchTensor::new(tensor)
}
fn neg<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
Self::mul_scalar(tensor, (-1f32).elem::<E>())
fn float_neg<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
Self::float_mul_scalar(tensor, (-1f32).elem::<E>())
}
fn recip<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
fn float_recip<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
TchTensor::new(tensor.tensor.reciprocal())
}
fn swap_dims<const D: usize>(
fn float_swap_dims<const D: usize>(
tensor: TchTensor<E, D>,
dim1: usize,
dim2: usize,
@ -182,14 +187,14 @@ impl<E: TchElement> TensorOps<Self> for LibTorch<E> {
TchOps::swap_dims(tensor, dim1, dim2)
}
fn reshape<const D1: usize, const D2: usize>(
fn float_reshape<const D1: usize, const D2: usize>(
tensor: TchTensor<E, D1>,
shape: Shape<D2>,
) -> TchTensor<E, D2> {
TchOps::reshape(tensor, shape)
}
fn gather<const D: usize>(
fn float_gather<const D: usize>(
dim: usize,
tensor: TchTensor<E, D>,
indices: TchTensor<i64, D>,
@ -197,7 +202,7 @@ impl<E: TchElement> TensorOps<Self> for LibTorch<E> {
TchOps::gather(dim, tensor, indices)
}
fn scatter<const D: usize>(
fn float_scatter<const D: usize>(
dim: usize,
tensor: TchTensor<E, D>,
indices: TchTensor<i64, D>,
@ -206,7 +211,7 @@ impl<E: TchElement> TensorOps<Self> for LibTorch<E> {
TchOps::scatter(dim, tensor, indices, value)
}
fn select<const D: usize>(
fn float_select<const D: usize>(
tensor: TchTensor<E, D>,
dim: usize,
indices: TchTensor<i64, 1>,
@ -214,7 +219,7 @@ impl<E: TchElement> TensorOps<Self> for LibTorch<E> {
TchOps::index_select_dim(tensor, dim, indices)
}
fn select_assign<const D: usize>(
fn float_select_assign<const D: usize>(
tensor: TchTensor<E, D>,
dim: usize,
indices: TchTensor<i64, 1>,
@ -223,22 +228,22 @@ impl<E: TchElement> TensorOps<Self> for LibTorch<E> {
TchOps::select_assign(tensor, dim, indices, value)
}
fn slice<const D1: usize, const D2: usize>(
fn float_slice<const D1: usize, const D2: usize>(
tensor: TchTensor<E, D1>,
ranges: [Range<usize>; D2],
) -> TchTensor<E, D1> {
TchOps::slice(tensor, ranges)
}
fn slice_assign<const D1: usize, const D2: usize>(
fn float_slice_assign<const D1: usize, const D2: usize>(
tensor: TchTensor<E, D1>,
ranges: [Range<usize>; D2],
value: TchTensor<E, D1>,
) -> <LibTorch<E> as Backend>::TensorPrimitive<D1> {
) -> <LibTorch<E> as Backend>::FloatTensorPrimitive<D1> {
TchOps::slice_assign(tensor, ranges, value)
}
fn mask_where<const D: usize>(
fn float_mask_where<const D: usize>(
tensor: TchTensor<E, D>,
mask: TchTensor<bool, D>,
value: TchTensor<E, D>,
@ -248,7 +253,7 @@ impl<E: TchElement> TensorOps<Self> for LibTorch<E> {
TchTensor::new(output)
}
fn mask_fill<const D: usize>(
fn float_mask_fill<const D: usize>(
tensor: TchTensor<E, D>,
mask: TchTensor<bool, D>,
value: E,
@ -261,187 +266,199 @@ impl<E: TchElement> TensorOps<Self> for LibTorch<E> {
)
}
fn equal<const D: usize>(lhs: TchTensor<E, D>, rhs: TchTensor<E, D>) -> TchTensor<bool, D> {
fn float_equal<const D: usize>(
lhs: TchTensor<E, D>,
rhs: TchTensor<E, D>,
) -> TchTensor<bool, D> {
TchOps::equal(lhs, rhs)
}
fn equal_elem<const D: usize>(lhs: TchTensor<E, D>, rhs: E) -> TchTensor<bool, D> {
fn float_equal_elem<const D: usize>(lhs: TchTensor<E, D>, rhs: E) -> TchTensor<bool, D> {
TchOps::equal_elem(lhs, rhs.elem::<f64>())
}
fn greater<const D: usize>(lhs: TchTensor<E, D>, rhs: TchTensor<E, D>) -> TchTensor<bool, D> {
fn float_greater<const D: usize>(
lhs: TchTensor<E, D>,
rhs: TchTensor<E, D>,
) -> TchTensor<bool, D> {
TchOps::greater(lhs, rhs)
}
fn greater_elem<const D: usize>(lhs: TchTensor<E, D>, rhs: E) -> TchTensor<bool, D> {
fn float_greater_elem<const D: usize>(lhs: TchTensor<E, D>, rhs: E) -> TchTensor<bool, D> {
TchOps::greater_elem(lhs, rhs.elem::<f64>())
}
fn greater_equal<const D: usize>(
fn float_greater_equal<const D: usize>(
lhs: TchTensor<E, D>,
rhs: TchTensor<E, D>,
) -> TchTensor<bool, D> {
TchOps::greater_equal(lhs, rhs)
}
fn greater_equal_elem<const D: usize>(lhs: TchTensor<E, D>, rhs: E) -> TchTensor<bool, D> {
fn float_greater_equal_elem<const D: usize>(
lhs: TchTensor<E, D>,
rhs: E,
) -> TchTensor<bool, D> {
TchOps::greater_equal_elem(lhs, rhs.elem::<f64>())
}
fn lower<const D: usize>(lhs: TchTensor<E, D>, rhs: TchTensor<E, D>) -> TchTensor<bool, D> {
fn float_lower<const D: usize>(
lhs: TchTensor<E, D>,
rhs: TchTensor<E, D>,
) -> TchTensor<bool, D> {
TchOps::lower(lhs, rhs)
}
fn lower_elem<const D: usize>(lhs: TchTensor<E, D>, rhs: E) -> TchTensor<bool, D> {
fn float_lower_elem<const D: usize>(lhs: TchTensor<E, D>, rhs: E) -> TchTensor<bool, D> {
TchOps::lower_elem(lhs, rhs.elem::<f64>())
}
fn lower_equal<const D: usize>(
fn float_lower_equal<const D: usize>(
lhs: TchTensor<E, D>,
rhs: TchTensor<E, D>,
) -> TchTensor<bool, D> {
TchOps::lower_equal(lhs, rhs)
}
fn lower_equal_elem<const D: usize>(lhs: TchTensor<E, D>, rhs: E) -> TchTensor<bool, D> {
fn float_lower_equal_elem<const D: usize>(lhs: TchTensor<E, D>, rhs: E) -> TchTensor<bool, D> {
TchOps::lower_equal_elem(lhs, rhs.elem::<f64>())
}
fn mean<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, 1> {
fn float_mean<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, 1> {
TchOps::mean(tensor)
}
fn sum<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, 1> {
fn float_sum<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, 1> {
TchOps::sum(tensor)
}
fn mean_dim<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
fn float_mean_dim<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
TchOps::mean_dim(tensor, dim)
}
fn sum_dim<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
fn float_sum_dim<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
TchOps::sum_dim(tensor, dim)
}
fn to_full_precision<const D: usize>(tensor: &TchTensor<E, D>) -> TchTensor<f32, D> {
fn float_to_full_precision<const D: usize>(tensor: &TchTensor<E, D>) -> TchTensor<f32, D> {
let storage = tensor.storage.clone();
let tensor = tensor.tensor.to_kind(tch::Kind::Float);
TchTensor::from_existing(tensor, storage)
}
fn from_full_precision<const D: usize>(tensor: TchTensor<f32, D>) -> TchTensor<E, D> {
fn float_from_full_precision<const D: usize>(tensor: TchTensor<f32, D>) -> TchTensor<E, D> {
let storage = tensor.storage.clone();
let tensor = tensor.tensor.to_kind(E::KIND);
TchTensor::from_existing(tensor, storage)
}
fn argmax<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<i64, D> {
fn float_argmax<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<i64, D> {
TchOps::argmax(tensor, dim)
}
fn argmin<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<i64, D> {
fn float_argmin<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<i64, D> {
TchOps::argmin(tensor, dim)
}
fn max_dim<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
fn float_max_dim<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
TchOps::max_dim(tensor, dim)
}
fn max_dim_with_indices<const D: usize>(
fn float_max_dim_with_indices<const D: usize>(
tensor: TchTensor<E, D>,
dim: usize,
) -> (TchTensor<E, D>, TchTensor<i64, D>) {
TchOps::max_dim_with_indices(tensor, dim)
}
fn min_dim<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
fn float_min_dim<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
TchOps::min_dim(tensor, dim)
}
fn min_dim_with_indices<const D: usize>(
fn float_min_dim_with_indices<const D: usize>(
tensor: TchTensor<E, D>,
dim: usize,
) -> (TchTensor<E, D>, TchTensor<i64, D>) {
TchOps::min_dim_with_indices(tensor, dim)
}
fn exp<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
fn float_exp<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
tensor.unary_ops(|mut tensor| tensor.exp_(), |tensor| tensor.exp())
}
fn log<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
fn float_log<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
tensor.unary_ops(|mut tensor| tensor.log_(), |tensor| tensor.log())
}
fn log1p<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
fn float_log1p<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
tensor.unary_ops(|mut tensor| tensor.log1p_(), |tensor| tensor.log1p())
}
fn powf_scalar<const D: usize>(tensor: TchTensor<E, D>, value: f32) -> TchTensor<E, D> {
fn float_powf_scalar<const D: usize>(tensor: TchTensor<E, D>, value: f32) -> TchTensor<E, D> {
tensor.unary_ops(
|mut tensor| tensor.f_pow_(value as f64).unwrap(),
|tensor| tensor.pow_tensor_scalar(value as f64),
)
}
fn sqrt<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
fn float_sqrt<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
tensor.unary_ops(|mut tensor| tensor.sqrt_(), |tensor| tensor.sqrt())
}
fn abs<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
fn float_abs<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
tensor.unary_ops(|mut tensor| tensor.abs_(), |tensor| tensor.abs())
}
fn cos<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
fn float_cos<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
tensor.unary_ops(|mut tensor| tensor.cos_(), |tensor| tensor.cos())
}
fn sin<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
fn float_sin<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
tensor.unary_ops(|mut tensor| tensor.sin_(), |tensor| tensor.sin())
}
fn tanh<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
fn float_tanh<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
tensor.unary_ops(|mut tensor| tensor.tanh_(), |tensor| tensor.tanh())
}
fn erf<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
fn float_erf<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
tensor.unary_ops(|mut tensor| tensor.erf_(), |tensor| tensor.erf())
}
fn cat<const D: usize>(tensors: Vec<TchTensor<E, D>>, dim: usize) -> TchTensor<E, D> {
fn float_cat<const D: usize>(tensors: Vec<TchTensor<E, D>>, dim: usize) -> TchTensor<E, D> {
TchOps::cat(tensors, dim)
}
fn clamp_min<const D: usize>(
fn float_clamp_min<const D: usize>(
tensor: TchTensor<E, D>,
min: E,
) -> <LibTorch<E> as Backend>::TensorPrimitive<D> {
) -> <LibTorch<E> as Backend>::FloatTensorPrimitive<D> {
TchOps::clamp_min(tensor, min.elem::<f64>())
}
fn clamp_max<const D: usize>(
tensor: <LibTorch<E> as Backend>::TensorPrimitive<D>,
fn float_clamp_max<const D: usize>(
tensor: <LibTorch<E> as Backend>::FloatTensorPrimitive<D>,
max: <LibTorch<E> as Backend>::FloatElem,
) -> <LibTorch<E> as Backend>::TensorPrimitive<D> {
) -> <LibTorch<E> as Backend>::FloatTensorPrimitive<D> {
TchOps::clamp_max(tensor, max.elem::<f64>())
}
fn clamp<const D: usize>(
tensor: <LibTorch<E> as Backend>::TensorPrimitive<D>,
fn float_clamp<const D: usize>(
tensor: <LibTorch<E> as Backend>::FloatTensorPrimitive<D>,
min: <LibTorch<E> as Backend>::FloatElem,
max: <LibTorch<E> as Backend>::FloatElem,
) -> <LibTorch<E> as Backend>::TensorPrimitive<D> {
) -> <LibTorch<E> as Backend>::FloatTensorPrimitive<D> {
TchOps::clamp(tensor, min.elem::<f64>(), max.elem::<f64>())
}
fn into_int<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<i64, D> {
fn float_into_int<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<i64, D> {
let tensor = tensor.tensor.to_kind(tch::Kind::Int64);
TchTensor::new(tensor)
}
fn narrow<const D: usize>(
fn float_narrow<const D: usize>(
tensor: TchTensor<E, D>,
dim: usize,
start: usize,
@ -450,7 +467,7 @@ impl<E: TchElement> TensorOps<Self> for LibTorch<E> {
TchOps::narrow(tensor, dim, start, length)
}
fn chunk<const D: usize>(
fn float_chunk<const D: usize>(
tensor: TchTensor<E, D>,
chunks: usize,
dim: usize,
@ -458,7 +475,7 @@ impl<E: TchElement> TensorOps<Self> for LibTorch<E> {
TchOps::chunk(tensor, chunks, dim)
}
fn powf<const D: usize>(
fn float_powf<const D: usize>(
lhs: burn_tensor::ops::FloatTensor<Self, D>,
rhs: burn_tensor::ops::FloatTensor<Self, D>,
) -> burn_tensor::ops::FloatTensor<Self, D> {

View File

@ -1,5 +1,5 @@
use crate::{element::TchElement, LibTorch, LibTorchDevice};
use burn_tensor::{ops::TensorOps, Data, Shape};
use burn_tensor::{ops::FloatTensorOps, Data, Shape};
use libc::c_void;
use std::{marker::PhantomData, sync::Arc};
@ -70,7 +70,7 @@ impl<E: TchElement, const D: usize> std::ops::Add for TchTensor<E, D> {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
LibTorch::add(self, rhs)
LibTorch::float_add(self, rhs)
}
}
@ -221,7 +221,7 @@ mod utils {
where
P: tch::kind::Element,
{
<LibTorch<P> as TensorOps<LibTorch<P>>>::into_data(self).read()
<LibTorch<P> as FloatTensorOps<LibTorch<P>>>::float_into_data(self).read()
}
}
}

View File

@ -1139,21 +1139,21 @@ impl<B: Backend> BasicOps<B> for Float {
type Elem = B::FloatElem;
fn empty<const D: usize>(shape: Shape<D>, device: &B::Device) -> Self::Primitive<D> {
B::empty(shape, device)
B::float_empty(shape, device)
}
fn shape<const D: usize>(tensor: &Self::Primitive<D>) -> Shape<D> {
B::shape(tensor)
B::float_shape(tensor)
}
fn reshape<const D1: usize, const D2: usize>(
tensor: Self::Primitive<D1>,
shape: Shape<D2>,
) -> Self::Primitive<D2> {
B::reshape(tensor, shape)
B::float_reshape(tensor, shape)
}
fn transpose<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<D> {
B::transpose(tensor)
B::float_transpose(tensor)
}
fn swap_dims<const D: usize>(
@ -1162,14 +1162,14 @@ impl<B: Backend> BasicOps<B> for Float {
dim2: usize,
) -> Self::Primitive<D> {
check!(TensorCheck::swap_dims::<D>(dim1, dim2));
B::swap_dims(tensor, dim1, dim2)
B::float_swap_dims(tensor, dim1, dim2)
}
fn slice<const D1: usize, const D2: usize>(
tensor: Self::Primitive<D1>,
ranges: [Range<usize>; D2],
) -> Self::Primitive<D1> {
B::slice(tensor, ranges)
B::float_slice(tensor, ranges)
}
fn slice_assign<const D1: usize, const D2: usize>(
@ -1177,29 +1177,29 @@ impl<B: Backend> BasicOps<B> for Float {
ranges: [Range<usize>; D2],
value: Self::Primitive<D1>,
) -> Self::Primitive<D1> {
B::slice_assign(tensor, ranges, value)
B::float_slice_assign(tensor, ranges, value)
}
fn device<const D: usize>(tensor: &Self::Primitive<D>) -> <B as Backend>::Device {
B::device(tensor)
B::float_device(tensor)
}
fn to_device<const D: usize>(
tensor: Self::Primitive<D>,
device: &<B as Backend>::Device,
) -> Self::Primitive<D> {
B::to_device(tensor, device)
B::float_to_device(tensor, device)
}
fn into_data<const D: usize>(tensor: Self::Primitive<D>) -> Reader<Data<Self::Elem, D>> {
B::into_data(tensor)
B::float_into_data(tensor)
}
fn from_data<const D: usize>(
data: Data<Self::Elem, D>,
device: &B::Device,
) -> Self::Primitive<D> {
B::from_data(data, device)
B::float_from_data(data, device)
}
fn repeat<const D: usize>(
@ -1207,18 +1207,18 @@ impl<B: Backend> BasicOps<B> for Float {
dim: usize,
times: usize,
) -> Self::Primitive<D> {
B::repeat(tensor, dim, times)
B::float_repeat(tensor, dim, times)
}
fn cat<const D: usize>(vectors: Vec<Self::Primitive<D>>, dim: usize) -> Self::Primitive<D> {
B::cat(vectors, dim)
B::float_cat(vectors, dim)
}
fn equal<const D: usize>(
lhs: Self::Primitive<D>,
rhs: Self::Primitive<D>,
) -> Tensor<B, D, Bool> {
Tensor::new(B::equal(lhs, rhs))
Tensor::new(B::float_equal(lhs, rhs))
}
}

View File

@ -35,53 +35,53 @@ where
///
/// `y = e^x`
pub fn exp(self) -> Self {
Self::new(B::exp(self.primitive))
Self::new(B::float_exp(self.primitive))
}
/// Applies element wise natural log operation *ln*.
///
/// `y = log(x)`
pub fn log(self) -> Self {
Self::new(B::log(self.primitive))
Self::new(B::float_log(self.primitive))
}
/// Applies the natural logarithm of one plus the input tensor, element-wise.
///
/// `y = log(x+1)`
pub fn log1p(self) -> Self {
Self::new(B::log1p(self.primitive))
Self::new(B::float_log1p(self.primitive))
}
/// Applies the [error function](https://en.wikipedia.org/wiki/Error_function) element wise.
///
/// `y = erf(x)`
pub fn erf(self) -> Self {
Self::new(B::erf(self.primitive))
Self::new(B::float_erf(self.primitive))
}
/// Applies element wise reciprocal operation.
pub fn recip(self) -> Self {
Self::new(B::recip(self.primitive))
Self::new(B::float_recip(self.primitive))
}
/// Applies element wise root square operation.
pub fn sqrt(self) -> Self {
Self::new(B::sqrt(self.primitive))
Self::new(B::float_sqrt(self.primitive))
}
/// Applies element wise cosine operation.
pub fn cos(self) -> Self {
Self::new(B::cos(self.primitive))
Self::new(B::float_cos(self.primitive))
}
/// Applies element wise sine operation.
pub fn sin(self) -> Self {
Self::new(B::sin(self.primitive))
Self::new(B::float_sin(self.primitive))
}
/// Applies element wise hyperbolic tangent operation.
pub fn tanh(self) -> Self {
Self::new(B::tanh(self.primitive))
Self::new(B::float_tanh(self.primitive))
}
/// Create a tensor from floats (f32) on a given device.
@ -118,23 +118,23 @@ where
/// }
/// ```
pub fn int(self) -> Tensor<B, D, Int> {
Tensor::new(B::into_int(self.primitive))
Tensor::new(B::float_into_int(self.primitive))
}
/// Returns a new tensor with the same shape and device as the current tensor filled with zeros.
pub fn zeros_like(&self) -> Self {
Tensor::new(B::zeros(self.shape(), &self.device()))
Tensor::new(B::float_zeros(self.shape(), &self.device()))
}
/// Returns a new tensor with the same shape and device as the current tensor filled with ones.
pub fn ones_like(&self) -> Self {
Tensor::new(B::ones(self.shape(), &self.device()))
Tensor::new(B::float_ones(self.shape(), &self.device()))
}
/// Returns a new tensor with the same shape and device as the current tensor filled random
/// values sampled from the given distribution.
pub fn random_like(&self, distribution: Distribution) -> Self {
Tensor::new(B::random(self.shape(), distribution, &self.device()))
Tensor::new(B::float_random(self.shape(), distribution, &self.device()))
}
/// Create a one hot tensor.
@ -175,7 +175,7 @@ where
/// If the two tensors dont' have a compatible shape.
pub fn matmul(self, other: Self) -> Self {
check!(TensorCheck::matmul(&self, &other));
Self::new(B::matmul(self.primitive, other.primitive))
Self::new(B::float_matmul(self.primitive, other.primitive))
}
/// Calculate the variance along the given dimension.
@ -209,17 +209,17 @@ where
distribution: Distribution,
device: &B::Device,
) -> Self {
let tensor = B::random(shape.into(), distribution, device);
let tensor = B::float_random(shape.into(), distribution, device);
Self::new(tensor)
}
/// Returns a tensor with full precision based on the selected backend.
pub fn to_full_precision(&self) -> Tensor<B::FullPrecisionBackend, D> {
Tensor::new(B::to_full_precision(&self.primitive))
Tensor::new(B::float_to_full_precision(&self.primitive))
}
/// Returns a tensor on the selected backend from a full precision tensor.
pub fn from_full_precision(tensor: Tensor<B::FullPrecisionBackend, D>) -> Self {
Self::new(B::from_full_precision(tensor.primitive))
Self::new(B::float_from_full_precision(tensor.primitive))
}
/// Detach the current tensor from the autodiff graph.
@ -228,7 +228,7 @@ where
/// This can be used in batchers or elsewhere to ensure that previous operations are not
/// considered in the autodiff graph.
pub fn detach(self) -> Self {
Self::new(B::detach(self.primitive))
Self::new(B::float_detach(self.primitive))
}
/// Mark the tensor to keep gradients during the backward pass.
@ -240,7 +240,7 @@ where
/// Returns true if the tensor requires gradients during the backward pass.
pub fn is_require_grad(&self) -> bool {
B::is_require_grad(&self.primitive)
B::float_is_require_grad(&self.primitive)
}
/// Mark the tensor as tracked or untracked depending on the require grad argument.
@ -248,7 +248,7 @@ where
///
/// This function does nothing when autodiff is not enabled.
pub fn set_require_grad(self, require_grad: bool) -> Self {
Self::new(B::set_require_grad(self.primitive, require_grad))
Self::new(B::float_set_require_grad(self.primitive, require_grad))
}
/// Applies the relu function to the tensor.

View File

@ -12,7 +12,7 @@ where
/// * `range` - The range of values to generate.
/// * `device` - The device to create the tensor on.
pub fn arange(range: Range<usize>, device: &B::Device) -> Self {
Tensor::new(B::arange(range, device))
Tensor::new(B::float_arange(range, device))
}
/// Returns a new integer tensor on the specified device.
@ -22,7 +22,7 @@ where
/// * `range` - The range of values to generate.
/// * `step` - The step between each value.
pub fn arange_step(range: Range<usize>, step: usize, device: &B::Device) -> Self {
Tensor::new(B::arange_step(range, step, device))
Tensor::new(B::float_arange_step(range, step, device))
}
}

View File

@ -22,7 +22,7 @@ pub trait TensorKind<B: Backend>: Clone + core::fmt::Debug {
}
impl<B: Backend> TensorKind<B> for Float {
type Primitive<const D: usize> = B::TensorPrimitive<D>;
type Primitive<const D: usize> = B::FloatTensorPrimitive<D>;
fn name() -> &'static str {
"Float"
}

View File

@ -1934,133 +1934,133 @@ impl<B: Backend> Numeric<B> for Float {
lhs: Self::Primitive<D>,
rhs: Self::Primitive<D>,
) -> <Float as TensorKind<B>>::Primitive<D> {
B::add(lhs, rhs)
B::float_add(lhs, rhs)
}
fn add_scalar<const D: usize, E: ElementConversion>(
lhs: Self::Primitive<D>,
rhs: E,
) -> Self::Primitive<D> {
B::add_scalar(lhs, rhs.elem())
B::float_add_scalar(lhs, rhs.elem())
}
fn sub<const D: usize>(
lhs: Self::Primitive<D>,
rhs: Self::Primitive<D>,
) -> <Float as TensorKind<B>>::Primitive<D> {
B::sub(lhs, rhs)
B::float_sub(lhs, rhs)
}
fn sub_scalar<const D: usize, E: ElementConversion>(
lhs: Self::Primitive<D>,
rhs: E,
) -> Self::Primitive<D> {
B::sub_scalar(lhs, rhs.elem())
B::float_sub_scalar(lhs, rhs.elem())
}
fn div<const D: usize>(
lhs: Self::Primitive<D>,
rhs: Self::Primitive<D>,
) -> <Float as TensorKind<B>>::Primitive<D> {
B::div(lhs, rhs)
B::float_div(lhs, rhs)
}
fn div_scalar<const D: usize, E: ElementConversion>(
lhs: Self::Primitive<D>,
rhs: E,
) -> Self::Primitive<D> {
B::div_scalar(lhs, rhs.elem())
B::float_div_scalar(lhs, rhs.elem())
}
fn mul<const D: usize>(
lhs: Self::Primitive<D>,
rhs: Self::Primitive<D>,
) -> <Float as TensorKind<B>>::Primitive<D> {
B::mul(lhs, rhs)
B::float_mul(lhs, rhs)
}
fn mul_scalar<const D: usize, E: ElementConversion>(
lhs: Self::Primitive<D>,
rhs: E,
) -> Self::Primitive<D> {
B::mul_scalar(lhs, rhs.elem())
B::float_mul_scalar(lhs, rhs.elem())
}
fn neg<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<D> {
B::neg(tensor)
B::float_neg(tensor)
}
fn zeros<const D: usize>(shape: Shape<D>, device: &B::Device) -> Self::Primitive<D> {
B::zeros(shape, device)
B::float_zeros(shape, device)
}
fn ones<const D: usize>(shape: Shape<D>, device: &B::Device) -> Self::Primitive<D> {
B::ones(shape, device)
B::float_ones(shape, device)
}
fn full<const D: usize, E: ElementConversion>(
shape: Shape<D>,
fill_value: E,
device: &B::Device,
) -> Self::Primitive<D> {
B::full(shape, fill_value.elem(), device)
B::float_full(shape, fill_value.elem(), device)
}
fn sum<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<1> {
B::sum(tensor)
B::float_sum(tensor)
}
fn sum_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D> {
B::sum_dim(tensor, dim)
B::float_sum_dim(tensor, dim)
}
fn mean<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<1> {
B::mean(tensor)
B::float_mean(tensor)
}
fn mean_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D> {
B::mean_dim(tensor, dim)
B::float_mean_dim(tensor, dim)
}
fn equal_elem<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Elem) -> Tensor<B, D, Bool> {
Tensor::new(B::equal_elem(lhs, rhs))
Tensor::new(B::float_equal_elem(lhs, rhs))
}
fn greater<const D: usize>(
lhs: Self::Primitive<D>,
rhs: Self::Primitive<D>,
) -> Tensor<B, D, Bool> {
Tensor::new(B::greater(lhs, rhs))
Tensor::new(B::float_greater(lhs, rhs))
}
fn greater_elem<const D: usize>(
lhs: Self::Primitive<D>,
rhs: Self::Elem,
) -> Tensor<B, D, Bool> {
Tensor::new(B::greater_elem(lhs, rhs))
Tensor::new(B::float_greater_elem(lhs, rhs))
}
fn greater_equal<const D: usize>(
lhs: Self::Primitive<D>,
rhs: Self::Primitive<D>,
) -> Tensor<B, D, Bool> {
Tensor::new(B::greater_equal(lhs, rhs))
Tensor::new(B::float_greater_equal(lhs, rhs))
}
fn greater_equal_elem<const D: usize>(
lhs: Self::Primitive<D>,
rhs: Self::Elem,
) -> Tensor<B, D, Bool> {
Tensor::new(B::greater_equal_elem(lhs, rhs))
Tensor::new(B::float_greater_equal_elem(lhs, rhs))
}
fn lower<const D: usize>(
lhs: Self::Primitive<D>,
rhs: Self::Primitive<D>,
) -> Tensor<B, D, Bool> {
Tensor::new(B::lower(lhs, rhs))
Tensor::new(B::float_lower(lhs, rhs))
}
fn lower_elem<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Elem) -> Tensor<B, D, Bool> {
Tensor::new(B::lower_elem(lhs, rhs))
Tensor::new(B::float_lower_elem(lhs, rhs))
}
fn lower_equal<const D: usize>(
lhs: Self::Primitive<D>,
rhs: Self::Primitive<D>,
) -> Tensor<B, D, Bool> {
Tensor::new(B::lower_equal(lhs, rhs))
Tensor::new(B::float_lower_equal(lhs, rhs))
}
fn lower_equal_elem<const D: usize>(
lhs: Self::Primitive<D>,
rhs: Self::Elem,
) -> Tensor<B, D, Bool> {
Tensor::new(B::lower_equal_elem(lhs, rhs))
Tensor::new(B::float_lower_equal_elem(lhs, rhs))
}
fn mask_where<const D: usize>(
@ -2068,7 +2068,7 @@ impl<B: Backend> Numeric<B> for Float {
mask: Tensor<B, D, Bool>,
source: Self::Primitive<D>,
) -> Self::Primitive<D> {
B::mask_where(tensor, mask.primitive, source)
B::float_mask_where(tensor, mask.primitive, source)
}
fn mask_fill<const D: usize>(
@ -2076,7 +2076,7 @@ impl<B: Backend> Numeric<B> for Float {
mask: Tensor<B, D, Bool>,
value: Self::Elem,
) -> Self::Primitive<D> {
B::mask_fill(tensor, mask.primitive, value)
B::float_mask_fill(tensor, mask.primitive, value)
}
fn select<const D: usize>(
@ -2084,7 +2084,7 @@ impl<B: Backend> Numeric<B> for Float {
dim: usize,
indices: Tensor<B, 1, Int>,
) -> Self::Primitive<D> {
B::select(tensor, dim, indices.primitive)
B::float_select(tensor, dim, indices.primitive)
}
fn select_assign<const D: usize>(
@ -2093,7 +2093,7 @@ impl<B: Backend> Numeric<B> for Float {
indices: Tensor<B, 1, Int>,
values: Self::Primitive<D>,
) -> Self::Primitive<D> {
B::select_assign(tensor, dim, indices.primitive, values)
B::float_select_assign(tensor, dim, indices.primitive, values)
}
fn gather<const D: usize>(
@ -2101,7 +2101,7 @@ impl<B: Backend> Numeric<B> for Float {
tensor: Self::Primitive<D>,
indices: Tensor<B, D, Int>,
) -> Self::Primitive<D> {
B::gather(dim, tensor, indices.primitive)
B::float_gather(dim, tensor, indices.primitive)
}
fn scatter<const D: usize>(
@ -2110,51 +2110,51 @@ impl<B: Backend> Numeric<B> for Float {
indices: Tensor<B, D, Int>,
values: Self::Primitive<D>,
) -> Self::Primitive<D> {
B::scatter(dim, tensor, indices.primitive, values)
B::float_scatter(dim, tensor, indices.primitive, values)
}
fn argmax<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,
) -> <B as Backend>::IntTensorPrimitive<D> {
B::argmax(tensor, dim)
B::float_argmax(tensor, dim)
}
fn argmin<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,
) -> <B as Backend>::IntTensorPrimitive<D> {
B::argmin(tensor, dim)
B::float_argmin(tensor, dim)
}
fn max<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<1> {
B::max(tensor)
B::float_max(tensor)
}
fn max_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D> {
B::max_dim(tensor, dim)
B::float_max_dim(tensor, dim)
}
fn max_dim_with_indices<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,
) -> (Self::Primitive<D>, <B as Backend>::IntTensorPrimitive<D>) {
B::max_dim_with_indices(tensor, dim)
B::float_max_dim_with_indices(tensor, dim)
}
fn min<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<1> {
B::min(tensor)
B::float_min(tensor)
}
fn min_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D> {
B::min_dim(tensor, dim)
B::float_min_dim(tensor, dim)
}
fn min_dim_with_indices<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,
) -> (Self::Primitive<D>, <B as Backend>::IntTensorPrimitive<D>) {
B::min_dim_with_indices(tensor, dim)
B::float_min_dim_with_indices(tensor, dim)
}
fn clamp<const D: usize>(
@ -2162,53 +2162,53 @@ impl<B: Backend> Numeric<B> for Float {
min: B::FloatElem,
max: B::FloatElem,
) -> Self::Primitive<D> {
B::clamp(tensor, min, max)
B::float_clamp(tensor, min, max)
}
fn clamp_min<const D: usize>(
tensor: Self::Primitive<D>,
min: B::FloatElem,
) -> Self::Primitive<D> {
B::clamp_min(tensor, min)
B::float_clamp_min(tensor, min)
}
fn clamp_max<const D: usize>(
tensor: Self::Primitive<D>,
max: B::FloatElem,
) -> Self::Primitive<D> {
B::clamp_max(tensor, max)
B::float_clamp_max(tensor, max)
}
fn abs<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<D> {
B::abs(tensor)
B::float_abs(tensor)
}
fn powf<const D: usize>(
lhs: Self::Primitive<D>,
rhs: Self::Primitive<D>,
) -> Self::Primitive<D> {
B::powf(lhs, rhs)
B::float_powf(lhs, rhs)
}
fn powf_scalar<const D: usize, E: ElementConversion>(
lhs: Self::Primitive<D>,
rhs: E,
) -> Self::Primitive<D> {
B::powf_scalar(lhs, rhs.elem())
B::float_powf_scalar(lhs, rhs.elem())
}
fn powi<const D: usize>(
lhs: Self::Primitive<D>,
rhs: Self::Primitive<D>,
) -> Self::Primitive<D> {
B::powf(lhs, rhs)
B::float_powf(lhs, rhs)
}
fn powi_scalar<const D: usize, E: ElementConversion>(
lhs: Self::Primitive<D>,
rhs: E,
) -> Self::Primitive<D> {
B::powf_scalar(lhs, rhs.elem())
B::float_powf_scalar(lhs, rhs.elem())
}
}

View File

@ -50,7 +50,7 @@ use crate::tensor::Element;
/// Most of the documentation for each function can be found on the user API [tensor struct](crate::Tensor).
/// For modules, public functions are often created, which can be used by `burn-core` modules.
pub trait Backend:
TensorOps<Self>
FloatTensorOps<Self>
+ BoolTensorOps<Self>
+ IntTensorOps<Self>
+ ModuleOps<Self>
@ -72,7 +72,7 @@ pub trait Backend:
type FullPrecisionElem: Element;
/// Tensor primitive to be used for all float operations.
type TensorPrimitive<const D: usize>: Clone + Send + Sync + 'static + core::fmt::Debug;
type FloatTensorPrimitive<const D: usize>: Clone + Send + Sync + 'static + core::fmt::Debug;
/// Float element type.
type FloatElem: Element;

View File

@ -24,7 +24,7 @@ where
}
}
type TensorPrimitive<B, const D: usize> = <B as Backend>::TensorPrimitive<D>;
type TensorPrimitive<B, const D: usize> = <B as Backend>::FloatTensorPrimitive<D>;
impl<ID> TensorContainer<ID>
where

View File

@ -1,4 +1,4 @@
use crate::tensor::ops::tensor::TensorOps;
use crate::tensor::ops::tensor::FloatTensorOps;
use crate::{backend::Backend, ElementConversion};
use core::f64::consts::SQRT_2;
@ -18,9 +18,9 @@ pub trait ActivationOps<B: Backend> {
///
/// The output tensor.
fn relu<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D> {
let mask = B::lower_equal_elem(tensor.clone(), 0.elem());
let mask = B::float_lower_equal_elem(tensor.clone(), 0.elem());
B::mask_fill(tensor, mask, 0.elem())
B::float_mask_fill(tensor, mask, 0.elem())
}
/// Applies the ReLU activation function backward.
@ -36,9 +36,9 @@ pub trait ActivationOps<B: Backend> {
output: FloatTensor<B, D>,
grad: FloatTensor<B, D>,
) -> FloatTensor<B, D> {
let mask = B::lower_equal_elem(output, 0.elem());
let mask = B::float_lower_equal_elem(output, 0.elem());
B::mask_fill(grad, mask, 0.elem())
B::float_mask_fill(grad, mask, 0.elem())
}
/// Applies the Gelu activation function.
@ -51,12 +51,12 @@ pub trait ActivationOps<B: Backend> {
///
/// The output tensor.
fn gelu<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D> {
let x = B::div_scalar(tensor.clone(), SQRT_2.elem());
let x = B::erf(x);
let x = B::add_scalar(x, 1i32.elem());
let x = B::mul(tensor, x);
let x = B::float_div_scalar(tensor.clone(), SQRT_2.elem());
let x = B::float_erf(x);
let x = B::float_add_scalar(x, 1i32.elem());
let x = B::float_mul(tensor, x);
B::div_scalar(x, 2i32.elem())
B::float_div_scalar(x, 2i32.elem())
}
/// Applies the Gelu activation function backward.
@ -80,28 +80,28 @@ pub trait ActivationOps<B: Backend> {
let constant_3 = 0.0535161;
let constant_4 = 0.398942;
let x3 = B::powf_scalar(x.clone(), 3.0);
let x3 = B::float_powf_scalar(x.clone(), 3.0);
let c1 = B::mul_scalar(x3.clone(), constant_1.elem());
let c2 = B::mul_scalar(x.clone(), constant_2.elem());
let c3 = B::mul_scalar(x3, constant_3.elem());
let c4 = B::mul_scalar(x, constant_4.elem());
let c1 = B::float_mul_scalar(x3.clone(), constant_1.elem());
let c2 = B::float_mul_scalar(x.clone(), constant_2.elem());
let c3 = B::float_mul_scalar(x3, constant_3.elem());
let c4 = B::float_mul_scalar(x, constant_4.elem());
let inner1 = B::add(c1, c2);
let inner2 = B::add(c3, c4);
let inner1 = B::float_add(c1, c2);
let inner2 = B::float_add(c3, c4);
let tanh = B::tanh(inner1);
let tanh = B::float_tanh(inner1);
let sech = B::powf_scalar(tanh.clone(), 2.0);
let sech = B::neg(sech);
let sech = B::add_scalar(sech, 1.elem());
let sech = B::float_powf_scalar(tanh.clone(), 2.0);
let sech = B::float_neg(sech);
let sech = B::float_add_scalar(sech, 1.elem());
let y1 = B::mul_scalar(tanh, 0.5.elem());
let y2 = B::mul(inner2, sech);
let y2 = B::add_scalar(y2, 0.5.elem());
let y = B::add(y1, y2);
let y1 = B::float_mul_scalar(tanh, 0.5.elem());
let y2 = B::float_mul(inner2, sech);
let y2 = B::float_add_scalar(y2, 0.5.elem());
let y = B::float_add(y1, y2);
B::mul(y, grad)
B::float_mul(y, grad)
}
/// Applies the Sigmoid activation function.
@ -114,15 +114,15 @@ pub trait ActivationOps<B: Backend> {
///
/// The output tensor.
fn sigmoid<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D> {
let tensor_full = B::to_full_precision(&tensor);
let tensor_tmp = B::FullPrecisionBackend::exp(B::FullPrecisionBackend::neg(
B::FullPrecisionBackend::log(B::FullPrecisionBackend::add_scalar(
B::FullPrecisionBackend::exp(B::FullPrecisionBackend::neg(tensor_full)),
let tensor_full = B::float_to_full_precision(&tensor);
let tensor_tmp = B::FullPrecisionBackend::float_exp(B::FullPrecisionBackend::float_neg(
B::FullPrecisionBackend::float_log(B::FullPrecisionBackend::float_add_scalar(
B::FullPrecisionBackend::float_exp(B::FullPrecisionBackend::float_neg(tensor_full)),
1.0.elem(),
)),
));
B::from_full_precision(tensor_tmp)
B::float_from_full_precision(tensor_tmp)
}
/// Applies the Sigmoid activation function backward.
@ -139,7 +139,10 @@ pub trait ActivationOps<B: Backend> {
output: FloatTensor<B, D>,
grad: FloatTensor<B, D>,
) -> FloatTensor<B, D> {
let value = B::mul(output.clone(), B::add_scalar(B::neg(output), 1.0.elem()));
B::mul(value, grad)
let value = B::float_mul(
output.clone(),
B::float_add_scalar(B::float_neg(output), 1.0.elem()),
);
B::float_mul(value, grad)
}
}

View File

@ -14,7 +14,7 @@ pub type IntElem<B> = <B as Backend>::IntElem;
pub type FullPrecisionBackend<B> = <B as Backend>::FullPrecisionBackend;
/// Float tensor primitive type used by the backend.
pub type FloatTensor<B, const D: usize> = <B as Backend>::TensorPrimitive<D>;
pub type FloatTensor<B, const D: usize> = <B as Backend>::FloatTensorPrimitive<D>;
/// Integer tensor primitive type used by the backend.
pub type IntTensor<B, const D: usize> = <B as Backend>::IntTensorPrimitive<D>;
/// Boolean tensor primitive type used by the backend.

View File

@ -470,7 +470,10 @@ pub trait IntTensorOps<B: Backend> {
///
/// The elements of `lhs` raised to the power of the elements of `rhs`.
fn int_powi<const D: usize>(lhs: IntTensor<B, D>, rhs: IntTensor<B, D>) -> IntTensor<B, D> {
B::into_int(B::powf(B::int_into_float(lhs), B::int_into_float(rhs)))
B::float_into_int(B::float_powf(
B::int_into_float(lhs),
B::int_into_float(rhs),
))
}
/// Elementwise power with a floatTensor.
@ -484,7 +487,7 @@ pub trait IntTensorOps<B: Backend> {
///
/// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
fn int_powf<const D: usize>(lhs: IntTensor<B, D>, rhs: FloatTensor<B, D>) -> IntTensor<B, D> {
B::into_int(B::powf(B::int_into_float(lhs), rhs))
B::float_into_int(B::float_powf(B::int_into_float(lhs), rhs))
}
/// Elementwise power with a scalar.
@ -498,7 +501,7 @@ pub trait IntTensorOps<B: Backend> {
///
/// The elements of `lhs` raised to the value of `rhs`.
fn int_powi_scalar<const D: usize>(lhs: IntTensor<B, D>, rhs: IntElem<B>) -> IntTensor<B, D> {
B::into_int(B::powf_scalar(
B::float_into_int(B::float_powf_scalar(
B::int_into_float(lhs),
rhs.to_f32().unwrap(),
))
@ -515,7 +518,7 @@ pub trait IntTensorOps<B: Backend> {
///
/// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
fn int_powf_scalar<const D: usize>(lhs: IntTensor<B, D>, rhs: f32) -> IntTensor<B, D> {
B::into_int(B::powf_scalar(B::int_into_float(lhs), rhs))
B::float_into_int(B::float_powf_scalar(B::int_into_float(lhs), rhs))
}
/// Clamps a tensor under a minimum value.

View File

@ -128,12 +128,12 @@ pub trait ModuleOps<B: Backend> {
/// The output tensor.
fn embedding(weights: FloatTensor<B, 2>, indices: IntTensor<B, 2>) -> FloatTensor<B, 3> {
let [batch_size, seq_length] = B::int_shape(&indices).dims;
let [_, d_model] = B::shape(&weights).dims;
let [_, d_model] = B::float_shape(&weights).dims;
let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
let output = B::select(weights, 0, indices);
let output = B::float_select(weights, 0, indices);
B::reshape(output, Shape::new([batch_size, seq_length, d_model]))
B::float_reshape(output, Shape::new([batch_size, seq_length, d_model]))
}
/// Embedding backward operation.
@ -153,14 +153,15 @@ pub trait ModuleOps<B: Backend> {
indices: IntTensor<B, 2>,
) -> FloatTensor<B, 2> {
let [batch_size, seq_length] = B::int_shape(&indices).dims;
let [n_embeddings, d_model] = B::shape(&weights).dims;
let device = B::device(&weights);
let [n_embeddings, d_model] = B::float_shape(&weights).dims;
let device = B::float_device(&weights);
let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
let output_grad = B::reshape(output_grad, Shape::new([batch_size * seq_length, d_model]));
let grad = B::zeros(Shape::new([n_embeddings, d_model]), &device);
let output_grad =
B::float_reshape(output_grad, Shape::new([batch_size * seq_length, d_model]));
let grad = B::float_zeros(Shape::new([n_embeddings, d_model]), &device);
B::select_assign(grad, 0, indices, output_grad)
B::float_select_assign(grad, 0, indices, output_grad)
}
/// One dimensional convolution.
///

View File

@ -63,11 +63,11 @@ pub(crate) fn conv1d_backward<B: Backend>(
output_grad: FloatTensor<B, 3>,
options: ConvOptions<1>,
) -> Conv1dBackward<B> {
let weight_shape = B::shape(&weight);
let weight_device = B::device(&weight);
let weight_shape = B::float_shape(&weight);
let weight_device = B::float_device(&weight);
let [batch_size, _, length_in] = B::shape(&x).dims;
let [_batch_size, channels_out, length_out] = B::shape(&output_grad).dims;
let [batch_size, _, length_in] = B::float_shape(&x).dims;
let [_batch_size, channels_out, length_out] = B::float_shape(&output_grad).dims;
let [_, _, kernel_size] = weight_shape.dims;
let padding_out = calculate_padding_out(
@ -96,7 +96,7 @@ pub(crate) fn conv1d_backward<B: Backend>(
true => conv1d_weight_grad_no_groups::<B>(x, output_grad.clone(), weight_shape, options),
false => conv1d_weight_grad_groups::<B>(
x,
B::zeros(weight_shape, &weight_device),
B::float_zeros(weight_shape, &weight_device),
output_grad.clone(),
options,
),
@ -106,11 +106,11 @@ pub(crate) fn conv1d_backward<B: Backend>(
x_grad,
weight_grad,
bias.map(|b| {
let grad = B::swap_dims(output_grad, 0, 1);
let grad = B::reshape(grad, Shape::new([channels_out, batch_size * length_out]));
let grad = B::sum_dim(grad, 1);
let grad = B::float_swap_dims(output_grad, 0, 1);
let grad = B::float_reshape(grad, Shape::new([channels_out, batch_size * length_out]));
let grad = B::float_sum_dim(grad, 1);
B::reshape(grad, B::shape(&b))
B::float_reshape(grad, B::float_shape(&b))
}),
)
}
@ -123,11 +123,11 @@ pub(crate) fn conv2d_backward<B: Backend>(
output_grad: FloatTensor<B, 4>,
options: ConvOptions<2>,
) -> Conv2dBackward<B> {
let weight_shape = B::shape(&weight);
let weight_device = B::device(&weight);
let weight_shape = B::float_shape(&weight);
let weight_device = B::float_device(&weight);
let [batch_size, _channels_in, height_in, width_in] = B::shape(&x).dims;
let [_, _, height_out, width_out] = B::shape(&output_grad).dims;
let [batch_size, _channels_in, height_in, width_in] = B::float_shape(&x).dims;
let [_, _, height_out, width_out] = B::float_shape(&output_grad).dims;
let [channels_out, _, kernel_size_1, kernel_size_2] = weight_shape.dims;
let padding_1_out = calculate_padding_out(
@ -164,7 +164,7 @@ pub(crate) fn conv2d_backward<B: Backend>(
true => conv2d_weight_grad_no_groups::<B>(x, output_grad.clone(), weight_shape, options),
false => conv2d_weight_grad_groups::<B>(
x,
B::zeros(weight_shape, &weight_device),
B::float_zeros(weight_shape, &weight_device),
output_grad.clone(),
options,
),
@ -174,14 +174,14 @@ pub(crate) fn conv2d_backward<B: Backend>(
x_grad,
weight_grad,
bias.map(|b| {
let grad = B::swap_dims(output_grad, 0, 1);
let grad = B::reshape(
let grad = B::float_swap_dims(output_grad, 0, 1);
let grad = B::float_reshape(
grad,
Shape::new([channels_out, batch_size * height_out * width_out]),
);
let grad = B::sum_dim(grad, 1);
let grad = B::float_sum_dim(grad, 1);
B::reshape(grad, B::shape(&b))
B::float_reshape(grad, B::float_shape(&b))
}),
)
}
@ -194,11 +194,11 @@ pub(crate) fn conv_transpose2d_backward<B: Backend>(
output_grad: FloatTensor<B, 4>,
options: ConvTransposeOptions<2>,
) -> Conv2dBackward<B> {
let weight_shape = B::shape(&weight);
let weight_device = B::device(&weight);
let weight_shape = B::float_shape(&weight);
let weight_device = B::float_device(&weight);
let [batch_size, _channels_in, _, _] = B::shape(&x).dims;
let [_, channels_out, height_out, width_out] = B::shape(&output_grad).dims;
let [batch_size, _channels_in, _, _] = B::float_shape(&x).dims;
let [_, channels_out, height_out, width_out] = B::float_shape(&output_grad).dims;
let x_grad = B::conv2d(
output_grad.clone(),
@ -221,7 +221,7 @@ pub(crate) fn conv_transpose2d_backward<B: Backend>(
),
false => conv_transpose2d_weight_grad_groups::<B>(
x,
B::zeros(weight_shape, &weight_device),
B::float_zeros(weight_shape, &weight_device),
output_grad.clone(),
options,
),
@ -231,14 +231,14 @@ pub(crate) fn conv_transpose2d_backward<B: Backend>(
x_grad,
weight_grad,
bias.map(|b| {
let grad = B::swap_dims(output_grad, 0, 1);
let grad = B::reshape(
let grad = B::float_swap_dims(output_grad, 0, 1);
let grad = B::float_reshape(
grad,
Shape::new([channels_out, batch_size * height_out * width_out]),
);
let grad = B::sum_dim(grad, 1);
let grad = B::float_sum_dim(grad, 1);
B::reshape(grad, B::shape(&b))
B::float_reshape(grad, B::float_shape(&b))
}),
)
}
@ -251,11 +251,11 @@ pub(crate) fn conv_transpose1d_backward<B: Backend>(
output_grad: FloatTensor<B, 3>,
options: ConvTransposeOptions<1>,
) -> Conv1dBackward<B> {
let weight_shape = B::shape(&weight);
let weight_device = B::device(&weight);
let weight_shape = B::float_shape(&weight);
let weight_device = B::float_device(&weight);
let [batch_size, _channels_in, _] = B::shape(&x).dims;
let [_, channels_out, length_out] = B::shape(&output_grad).dims;
let [batch_size, _channels_in, _] = B::float_shape(&x).dims;
let [_, channels_out, length_out] = B::float_shape(&output_grad).dims;
let x_grad = B::conv1d(
output_grad.clone(),
@ -278,7 +278,7 @@ pub(crate) fn conv_transpose1d_backward<B: Backend>(
),
false => conv_transpose1d_weight_grad_groups::<B>(
x,
B::zeros(weight_shape, &weight_device),
B::float_zeros(weight_shape, &weight_device),
output_grad.clone(),
options,
),
@ -288,11 +288,11 @@ pub(crate) fn conv_transpose1d_backward<B: Backend>(
x_grad,
weight_grad,
bias.map(|b| {
let grad = B::swap_dims(output_grad, 0, 1);
let grad = B::reshape(grad, Shape::new([channels_out, batch_size * length_out]));
let grad = B::sum_dim(grad, 1);
let grad = B::float_swap_dims(output_grad, 0, 1);
let grad = B::float_reshape(grad, Shape::new([channels_out, batch_size * length_out]));
let grad = B::float_sum_dim(grad, 1);
B::reshape(grad, B::shape(&b))
B::float_reshape(grad, B::float_shape(&b))
}),
)
}
@ -304,14 +304,14 @@ pub(crate) fn conv1d_from_conv2d<B: Backend>(
bias: Option<FloatTensor<B, 1>>,
options: ConvOptions<1>,
) -> FloatTensor<B, 3> {
let [channels_out, _channels_in, kernel_size] = B::shape(&weight).dims;
let [batch_size, channels_in, length_in] = B::shape(&x).dims;
let [channels_out, _channels_in, kernel_size] = B::float_shape(&weight).dims;
let [batch_size, channels_in, length_in] = B::float_shape(&x).dims;
let weight = B::reshape(
let weight = B::float_reshape(
weight,
Shape::new([channels_out, channels_in / options.groups, kernel_size, 1]),
);
let x = B::reshape(x, Shape::new([batch_size, channels_in, length_in, 1]));
let x = B::float_reshape(x, Shape::new([batch_size, channels_in, length_in, 1]));
let tensor = B::conv2d(
x,
@ -324,8 +324,8 @@ pub(crate) fn conv1d_from_conv2d<B: Backend>(
options.groups,
),
);
let [batch_size, channels_out, height_out, _weight_out] = B::shape(&tensor).dims;
B::reshape(tensor, Shape::from([batch_size, channels_out, height_out]))
let [batch_size, channels_out, height_out, _weight_out] = B::float_shape(&tensor).dims;
B::float_reshape(tensor, Shape::from([batch_size, channels_out, height_out]))
}
/// Execute a 1D transposed convolution using a 2D transposed convolution.
@ -335,14 +335,14 @@ pub(crate) fn conv_transpose1d_from_conv_transpose2d<B: Backend>(
bias: Option<FloatTensor<B, 1>>,
options: ConvTransposeOptions<1>,
) -> FloatTensor<B, 3> {
let [channels_in, channels_out, kernel_size] = B::shape(&weight).dims;
let [batch_size, _channels_in, length_in] = B::shape(&x).dims;
let [channels_in, channels_out, kernel_size] = B::float_shape(&weight).dims;
let [batch_size, _channels_in, length_in] = B::float_shape(&x).dims;
let weight = B::reshape(
let weight = B::float_reshape(
weight,
Shape::new([channels_in, channels_out, kernel_size, 1]),
);
let x = B::reshape(x, Shape::new([batch_size, channels_in, length_in, 1]));
let x = B::float_reshape(x, Shape::new([batch_size, channels_in, length_in, 1]));
let tensor = B::conv_transpose2d(
x,
@ -356,8 +356,8 @@ pub(crate) fn conv_transpose1d_from_conv_transpose2d<B: Backend>(
options.groups,
),
);
let [batch_size, channels_out, height_out, _weight_out] = B::shape(&tensor).dims;
B::reshape(tensor, Shape::from([batch_size, channels_out, height_out]))
let [batch_size, channels_out, height_out, _weight_out] = B::float_shape(&tensor).dims;
B::float_reshape(tensor, Shape::from([batch_size, channels_out, height_out]))
}
fn conv1d_weight_grad_groups<B: Backend>(
@ -366,11 +366,11 @@ fn conv1d_weight_grad_groups<B: Backend>(
output_grad: FloatTensor<B, 3>,
options: ConvOptions<1>,
) -> FloatTensor<B, 3> {
let [channels_out, increment_ci, kernel_size] = B::shape(&weight_grad).dims;
let [channels_out, increment_ci, kernel_size] = B::float_shape(&weight_grad).dims;
let increment_co = channels_out / options.groups;
let x_swapped = B::swap_dims(x, 0, 1);
let output_grad_swapped = B::swap_dims(output_grad, 0, 1);
let x_swapped = B::float_swap_dims(x, 0, 1);
let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);
for g in 0..options.groups {
let start_idx_ci = g * increment_ci;
@ -378,16 +378,16 @@ fn conv1d_weight_grad_groups<B: Backend>(
let start_idx_co = g * increment_co;
let end_idx_co = (g + 1) * increment_co;
let x = B::slice(x_swapped.clone(), [start_idx_ci..end_idx_ci]);
let grad = B::slice(output_grad_swapped.clone(), [start_idx_co..end_idx_co]);
let x = B::float_slice(x_swapped.clone(), [start_idx_ci..end_idx_ci]);
let grad = B::float_slice(output_grad_swapped.clone(), [start_idx_co..end_idx_co]);
let mut weight_grad_tmp = B::conv1d(
x,
grad,
None,
ConvOptions::new(options.dilation, options.padding, options.stride, 1),
);
weight_grad_tmp = B::swap_dims(weight_grad_tmp, 0, 1);
weight_grad = B::slice_assign(
weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1);
weight_grad = B::float_slice_assign(
weight_grad,
[start_idx_co..end_idx_co, 0..increment_ci, 0..kernel_size],
weight_grad_tmp,
@ -403,11 +403,12 @@ fn conv2d_weight_grad_groups<B: Backend>(
output_grad: FloatTensor<B, 4>,
options: ConvOptions<2>,
) -> FloatTensor<B, 4> {
let [channels_out, increment_ci, kernel_size_1, kernel_size_2] = B::shape(&weight_grad).dims;
let [channels_out, increment_ci, kernel_size_1, kernel_size_2] =
B::float_shape(&weight_grad).dims;
let increment_co = channels_out / options.groups;
let x_swapped = B::swap_dims(x, 0, 1);
let output_grad_swapped = B::swap_dims(output_grad, 0, 1);
let x_swapped = B::float_swap_dims(x, 0, 1);
let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);
for g in 0..options.groups {
let start_idx_ci = g * increment_ci;
@ -415,16 +416,16 @@ fn conv2d_weight_grad_groups<B: Backend>(
let start_idx_co = g * increment_co;
let end_idx_co = (g + 1) * increment_co;
let x = B::slice(x_swapped.clone(), [start_idx_ci..end_idx_ci]);
let grad = B::slice(output_grad_swapped.clone(), [start_idx_co..end_idx_co]);
let x = B::float_slice(x_swapped.clone(), [start_idx_ci..end_idx_ci]);
let grad = B::float_slice(output_grad_swapped.clone(), [start_idx_co..end_idx_co]);
let mut weight_grad_tmp = B::conv2d(
x,
grad,
None,
ConvOptions::new(options.dilation, options.padding, options.stride, 1),
);
weight_grad_tmp = B::swap_dims(weight_grad_tmp, 0, 1);
weight_grad = B::slice_assign(
weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1);
weight_grad = B::float_slice_assign(
weight_grad,
[
start_idx_co..end_idx_co,
@ -445,11 +446,12 @@ fn conv_transpose2d_weight_grad_groups<B: Backend>(
output_grad: FloatTensor<B, 4>,
options: ConvTransposeOptions<2>,
) -> FloatTensor<B, 4> {
let [channels_in, increment_co, kernel_size_1, kernel_size_2] = B::shape(&weight_grad).dims;
let [channels_in, increment_co, kernel_size_1, kernel_size_2] =
B::float_shape(&weight_grad).dims;
let increment_ci = channels_in / options.groups;
let x_swapped = B::swap_dims(x, 0, 1);
let output_grad_swapped = B::swap_dims(output_grad, 0, 1);
let x_swapped = B::float_swap_dims(x, 0, 1);
let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);
for g in 0..options.groups {
let start_idx_ci = g * increment_ci;
@ -457,19 +459,19 @@ fn conv_transpose2d_weight_grad_groups<B: Backend>(
let start_idx_co = g * increment_co;
let end_idx_co = (g + 1) * increment_co;
let x = B::slice(x_swapped.clone(), [start_idx_ci..end_idx_ci]);
let grad = B::slice(output_grad_swapped.clone(), [start_idx_co..end_idx_co]);
let x = B::float_slice(x_swapped.clone(), [start_idx_ci..end_idx_ci]);
let grad = B::float_slice(output_grad_swapped.clone(), [start_idx_co..end_idx_co]);
let mut weight_grad_tmp = B::conv2d(
grad,
x,
None,
ConvOptions::new(options.dilation, options.padding, options.stride, 1),
);
weight_grad_tmp = B::swap_dims(weight_grad_tmp, 0, 1);
let [_, _, kernel_size_1_tmp, kernel_size_2_tmp] = B::shape(&weight_grad_tmp).dims;
weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1);
let [_, _, kernel_size_1_tmp, kernel_size_2_tmp] = B::float_shape(&weight_grad_tmp).dims;
if kernel_size_1_tmp != kernel_size_1 || kernel_size_2_tmp != kernel_size_2 {
weight_grad_tmp = B::slice(
weight_grad_tmp = B::float_slice(
weight_grad_tmp,
[
0..increment_ci,
@ -480,7 +482,7 @@ fn conv_transpose2d_weight_grad_groups<B: Backend>(
);
}
weight_grad = B::slice_assign(
weight_grad = B::float_slice_assign(
weight_grad,
[
start_idx_ci..end_idx_ci,
@ -501,18 +503,18 @@ fn conv1d_weight_grad_no_groups<B: Backend>(
weight_shape: Shape<3>,
options: ConvOptions<1>,
) -> FloatTensor<B, 3> {
let x_swapped = B::swap_dims(x, 0, 1);
let output_grad_swapped = B::swap_dims(output_grad, 0, 1);
let x_swapped = B::float_swap_dims(x, 0, 1);
let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);
let weight_grad_swapped = B::conv1d(
x_swapped,
output_grad_swapped,
None,
ConvOptions::new(options.dilation, options.padding, options.stride, 1),
);
let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1);
let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1);
if B::shape(&weight_grad) != weight_shape {
weight_grad = B::slice(
if B::float_shape(&weight_grad) != weight_shape {
weight_grad = B::float_slice(
weight_grad,
[
0..weight_shape.dims[0],
@ -530,11 +532,11 @@ fn conv_transpose1d_weight_grad_groups<B: Backend>(
output_grad: FloatTensor<B, 3>,
options: ConvTransposeOptions<1>,
) -> FloatTensor<B, 3> {
let [channels_in, increment_co, kernel_size] = B::shape(&weight_grad).dims;
let [channels_in, increment_co, kernel_size] = B::float_shape(&weight_grad).dims;
let increment_ci = channels_in / options.groups;
let x_swapped = B::swap_dims(x, 0, 1);
let output_grad_swapped = B::swap_dims(output_grad, 0, 1);
let x_swapped = B::float_swap_dims(x, 0, 1);
let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);
for g in 0..options.groups {
let start_idx_ci = g * increment_ci;
@ -542,25 +544,25 @@ fn conv_transpose1d_weight_grad_groups<B: Backend>(
let start_idx_co = g * increment_co;
let end_idx_co = (g + 1) * increment_co;
let x = B::slice(x_swapped.clone(), [start_idx_ci..end_idx_ci]);
let grad = B::slice(output_grad_swapped.clone(), [start_idx_co..end_idx_co]);
let x = B::float_slice(x_swapped.clone(), [start_idx_ci..end_idx_ci]);
let grad = B::float_slice(output_grad_swapped.clone(), [start_idx_co..end_idx_co]);
let mut weight_grad_tmp = B::conv1d(
grad,
x,
None,
ConvOptions::new(options.dilation, options.padding, options.stride, 1),
);
weight_grad_tmp = B::swap_dims(weight_grad_tmp, 0, 1);
let [_, _, kernel_size_tmp] = B::shape(&weight_grad_tmp).dims;
weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1);
let [_, _, kernel_size_tmp] = B::float_shape(&weight_grad_tmp).dims;
if kernel_size_tmp != kernel_size {
weight_grad_tmp = B::slice(
weight_grad_tmp = B::float_slice(
weight_grad_tmp,
[0..increment_ci, 0..increment_co, 0..kernel_size],
);
}
weight_grad = B::slice_assign(
weight_grad = B::float_slice_assign(
weight_grad,
[start_idx_ci..end_idx_ci, 0..increment_co, 0..kernel_size],
weight_grad_tmp,
@ -576,18 +578,18 @@ fn conv2d_weight_grad_no_groups<B: Backend>(
weight_shape: Shape<4>,
options: ConvOptions<2>,
) -> FloatTensor<B, 4> {
let x_swapped = B::swap_dims(x, 0, 1);
let output_grad_swapped = B::swap_dims(output_grad, 0, 1);
let x_swapped = B::float_swap_dims(x, 0, 1);
let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);
let weight_grad_swapped = B::conv2d(
x_swapped,
output_grad_swapped,
None,
ConvOptions::new(options.dilation, options.padding, options.stride, 1),
);
let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1);
let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1);
if B::shape(&weight_grad) != weight_shape {
weight_grad = B::slice(
if B::float_shape(&weight_grad) != weight_shape {
weight_grad = B::float_slice(
weight_grad,
[
0..weight_shape.dims[0],
@ -606,20 +608,20 @@ fn conv_transpose1d_weight_grad_no_groups<B: Backend>(
weight_shape: Shape<3>,
options: ConvTransposeOptions<1>,
) -> FloatTensor<B, 3> {
let x_swapped = B::swap_dims(x, 0, 1);
let output_grad_swapped = B::swap_dims(output_grad, 0, 1);
let x_swapped = B::float_swap_dims(x, 0, 1);
let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);
let weight_grad_swapped = B::conv1d(
output_grad_swapped,
x_swapped,
None,
ConvOptions::new(options.dilation, options.padding, options.stride, 1),
);
let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1);
let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1);
let grad_shape = B::shape(&weight_grad);
let grad_shape = B::float_shape(&weight_grad);
if grad_shape != weight_shape {
weight_grad = B::slice(
weight_grad = B::float_slice(
weight_grad,
[
0..weight_shape.dims[0],
@ -637,20 +639,20 @@ fn conv_transpose2d_weight_grad_no_groups<B: Backend>(
weight_shape: Shape<4>,
options: ConvTransposeOptions<2>,
) -> FloatTensor<B, 4> {
let x_swapped = B::swap_dims(x, 0, 1);
let output_grad_swapped = B::swap_dims(output_grad, 0, 1);
let x_swapped = B::float_swap_dims(x, 0, 1);
let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);
let weight_grad_swapped = B::conv2d(
output_grad_swapped,
x_swapped,
None,
ConvOptions::new(options.dilation, options.padding, options.stride, 1),
);
let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1);
let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1);
let grad_shape = B::shape(&weight_grad);
let grad_shape = B::float_shape(&weight_grad);
if grad_shape != weight_shape {
weight_grad = B::slice(
weight_grad = B::float_slice(
weight_grad,
[
0..weight_shape.dims[0],

View File

@ -13,9 +13,9 @@ pub(crate) fn avg_pool1d_from_2d<B: Backend>(
padding: usize,
count_include_pad: bool,
) -> FloatTensor<B, 3> {
let [batch_size, channels, length] = B::shape(&x).dims;
let [batch_size, channels, length] = B::float_shape(&x).dims;
let x = B::reshape(x, Shape::from([batch_size, channels, length, 1]));
let x = B::float_reshape(x, Shape::from([batch_size, channels, length, 1]));
let x = B::avg_pool2d(
x,
[kernel_size, 1],
@ -24,9 +24,9 @@ pub(crate) fn avg_pool1d_from_2d<B: Backend>(
count_include_pad,
);
let [batch_size, channels, length, _] = B::shape(&x).dims;
let [batch_size, channels, length, _] = B::float_shape(&x).dims;
B::reshape(x, Shape::from([batch_size, channels, length]))
B::float_reshape(x, Shape::from([batch_size, channels, length]))
}
pub(crate) fn avg_pool1d_backward_from_2d<B: Backend>(
@ -37,11 +37,11 @@ pub(crate) fn avg_pool1d_backward_from_2d<B: Backend>(
padding: usize,
count_include_pad: bool,
) -> FloatTensor<B, 3> {
let [batch_size, channels, length_in] = B::shape(&x).dims;
let [_, _, length_out] = B::shape(&grad).dims;
let [batch_size, channels, length_in] = B::float_shape(&x).dims;
let [_, _, length_out] = B::float_shape(&grad).dims;
let x = B::reshape(x, Shape::from([batch_size, channels, length_in, 1]));
let grad_x = B::reshape(grad, Shape::from([batch_size, channels, length_out, 1]));
let x = B::float_reshape(x, Shape::from([batch_size, channels, length_in, 1]));
let grad_x = B::float_reshape(grad, Shape::from([batch_size, channels, length_out, 1]));
let grad_x = B::avg_pool2d_backward(
x,
@ -52,36 +52,36 @@ pub(crate) fn avg_pool1d_backward_from_2d<B: Backend>(
count_include_pad,
);
B::reshape(grad_x, Shape::from([batch_size, channels, length_in]))
B::float_reshape(grad_x, Shape::from([batch_size, channels, length_in]))
}
pub(crate) fn adaptive_avg_pool1d_from_2d<B: Backend>(
x: FloatTensor<B, 3>,
output_size: usize,
) -> FloatTensor<B, 3> {
let [batch_size, channels, length] = B::shape(&x).dims;
let [batch_size, channels, length] = B::float_shape(&x).dims;
let x = B::reshape(x, Shape::from([batch_size, channels, length, 1]));
let x = B::float_reshape(x, Shape::from([batch_size, channels, length, 1]));
let x = B::adaptive_avg_pool2d(x, [output_size, 1]);
let [batch_size, channels, length, _] = B::shape(&x).dims;
let [batch_size, channels, length, _] = B::float_shape(&x).dims;
B::reshape(x, Shape::from([batch_size, channels, length]))
B::float_reshape(x, Shape::from([batch_size, channels, length]))
}
pub(crate) fn adaptive_avg_pool1d_backward_from_2d<B: Backend>(
x: FloatTensor<B, 3>,
grad: FloatTensor<B, 3>,
) -> FloatTensor<B, 3> {
let [batch_size, channels, length_in] = B::shape(&x).dims;
let [_, _, length_out] = B::shape(&grad).dims;
let [batch_size, channels, length_in] = B::float_shape(&x).dims;
let [_, _, length_out] = B::float_shape(&grad).dims;
let x = B::reshape(x, Shape::from([batch_size, channels, length_in, 1]));
let grad_x = B::reshape(grad, Shape::from([batch_size, channels, length_out, 1]));
let x = B::float_reshape(x, Shape::from([batch_size, channels, length_in, 1]));
let grad_x = B::float_reshape(grad, Shape::from([batch_size, channels, length_out, 1]));
let grad_x = B::adaptive_avg_pool2d_backward(x, grad_x);
B::reshape(grad_x, Shape::from([batch_size, channels, length_in]))
B::float_reshape(grad_x, Shape::from([batch_size, channels, length_in]))
}
pub(crate) fn max_pool1d_from_2d<B: Backend>(
@ -91,9 +91,9 @@ pub(crate) fn max_pool1d_from_2d<B: Backend>(
padding: usize,
dilation: usize,
) -> FloatTensor<B, 3> {
let [batch_size, channels, length] = B::shape(&x).dims;
let [batch_size, channels, length] = B::float_shape(&x).dims;
let x = B::reshape(x, Shape::from([batch_size, channels, length, 1]));
let x = B::float_reshape(x, Shape::from([batch_size, channels, length, 1]));
let x = B::max_pool2d(
x,
[kernel_size, 1],
@ -102,9 +102,9 @@ pub(crate) fn max_pool1d_from_2d<B: Backend>(
[dilation, 1],
);
let [batch_size, channels, length, _] = B::shape(&x).dims;
let [batch_size, channels, length, _] = B::float_shape(&x).dims;
B::reshape(x, Shape::from([batch_size, channels, length]))
B::float_reshape(x, Shape::from([batch_size, channels, length]))
}
pub(crate) fn max_pool1d_with_indices_from_2d<B: Backend>(
@ -114,9 +114,9 @@ pub(crate) fn max_pool1d_with_indices_from_2d<B: Backend>(
padding: usize,
dilation: usize,
) -> MaxPool1dWithIndices<B> {
let [batch_size, channels, length] = B::shape(&x).dims;
let [batch_size, channels, length] = B::float_shape(&x).dims;
let x = B::reshape(x, Shape::from([batch_size, channels, 1, length]));
let x = B::float_reshape(x, Shape::from([batch_size, channels, 1, length]));
let x = B::max_pool2d_with_indices(
x,
[1, kernel_size],
@ -124,8 +124,8 @@ pub(crate) fn max_pool1d_with_indices_from_2d<B: Backend>(
[0, padding],
[1, dilation],
);
let [batch_size, channels, _, length] = B::shape(&x.output).dims;
let output = B::reshape(x.output, Shape::from([batch_size, channels, length]));
let [batch_size, channels, _, length] = B::float_shape(&x.output).dims;
let output = B::float_reshape(x.output, Shape::from([batch_size, channels, length]));
let indices = B::int_reshape(x.indices, Shape::from([batch_size, channels, length]));
MaxPool1dWithIndices::new(output, indices)
}
@ -139,11 +139,11 @@ pub(crate) fn max_pool1d_with_indices_backward_from_2d<B: Backend>(
output_grad: FloatTensor<B, 3>,
indices: IntTensor<B, 3>,
) -> MaxPool1dBackward<B> {
let [batch_size, channels, length_in] = B::shape(&x).dims;
let [_, _, length_out] = B::shape(&output_grad).dims;
let [batch_size, channels, length_in] = B::float_shape(&x).dims;
let [_, _, length_out] = B::float_shape(&output_grad).dims;
let x = B::reshape(x, Shape::from([batch_size, channels, length_in, 1]));
let grad_x = B::reshape(
let x = B::float_reshape(x, Shape::from([batch_size, channels, length_in, 1]));
let grad_x = B::float_reshape(
output_grad,
Shape::from([batch_size, channels, length_out, 1]),
);
@ -160,7 +160,7 @@ pub(crate) fn max_pool1d_with_indices_backward_from_2d<B: Backend>(
)
.x_grad;
MaxPool1dBackward::new(B::reshape(
MaxPool1dBackward::new(B::float_reshape(
grad_x,
Shape::from([batch_size, channels, length_in]),
))

View File

@ -54,7 +54,7 @@ pub(crate) fn create_unfolding_weight<B: Backend>(
}
}
B::from_data(Data::new(weight, shape), device)
B::float_from_data(Data::new(weight, shape), device)
}
/// Compute the unfold4d operation using the conv2d operations.
@ -63,8 +63,8 @@ pub(crate) fn unfold4d_using_conv2d<B: Backend>(
kernel_size: [usize; 2],
options: UnfoldOptions,
) -> FloatTensor<B, 3> {
let [_batch_size, in_channels, _in_height, _in_width] = B::shape(&x).dims;
let weight = create_unfolding_weight::<B>(in_channels, kernel_size, &B::device(&x));
let [_batch_size, in_channels, _in_height, _in_width] = B::float_shape(&x).dims;
let weight = create_unfolding_weight::<B>(in_channels, kernel_size, &B::float_device(&x));
let unfolded = B::conv2d(
x,
weight,
@ -77,9 +77,9 @@ pub(crate) fn unfold4d_using_conv2d<B: Backend>(
},
);
let [batch_size, channels_out, out_height, out_width] = B::shape(&unfolded).dims;
let [batch_size, channels_out, out_height, out_width] = B::float_shape(&unfolded).dims;
B::reshape(
B::float_reshape(
unfolded,
Shape::new([batch_size, channels_out, out_height * out_width]),
)

View File

@ -7,7 +7,7 @@ use core::ops::Range;
use num_traits::ToPrimitive;
/// Operations on float tensors.
pub trait TensorOps<B: Backend> {
pub trait FloatTensorOps<B: Backend> {
/// Creates a new tensor from the data structure.
///
/// # Arguments
@ -18,7 +18,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The tensor with the given data.
fn from_data<const D: usize>(
fn float_from_data<const D: usize>(
data: Data<FloatElem<B>, D>,
device: &Device<B>,
) -> FloatTensor<B, D>;
@ -34,7 +34,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The tensor with the given shape and random values.
fn random<const D: usize>(
fn float_random<const D: usize>(
shape: Shape<D>,
distribution: Distribution,
device: &Device<B>,
@ -50,8 +50,8 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The tensor with the given shape and zeros.
fn zeros<const D: usize>(shape: Shape<D>, device: &Device<B>) -> FloatTensor<B, D> {
Self::from_data(Data::zeros(shape), device)
fn float_zeros<const D: usize>(shape: Shape<D>, device: &Device<B>) -> FloatTensor<B, D> {
Self::float_from_data(Data::zeros(shape), device)
}
/// Creates a new tensor with ones.
@ -64,8 +64,8 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The tensor with the given shape and ones.
fn ones<const D: usize>(shape: Shape<D>, device: &Device<B>) -> FloatTensor<B, D> {
Self::from_data(Data::ones(shape), device)
fn float_ones<const D: usize>(shape: Shape<D>, device: &Device<B>) -> FloatTensor<B, D> {
Self::float_from_data(Data::ones(shape), device)
}
/// Creates a tensor filled with given value.
@ -79,12 +79,12 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The tensor filled with given value
fn full<const D: usize>(
fn float_full<const D: usize>(
shape: Shape<D>,
fill_value: FloatElem<B>,
device: &Device<B>,
) -> FloatTensor<B, D> {
Self::add_scalar(Self::zeros(shape, device), fill_value)
Self::float_add_scalar(Self::float_zeros(shape, device), fill_value)
}
/// Gets the shape of the tensor.
@ -96,7 +96,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The shape of the tensor.
fn shape<const D: usize>(tensor: &FloatTensor<B, D>) -> Shape<D>;
fn float_shape<const D: usize>(tensor: &FloatTensor<B, D>) -> Shape<D>;
/// Converts the tensor to a data structure.
///
@ -107,8 +107,8 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The data structure with the tensor's data.
fn to_data<const D: usize>(tensor: &FloatTensor<B, D>) -> Reader<Data<FloatElem<B>, D>> {
Self::into_data(tensor.clone())
fn float_to_data<const D: usize>(tensor: &FloatTensor<B, D>) -> Reader<Data<FloatElem<B>, D>> {
Self::float_into_data(tensor.clone())
}
/// Converts the tensor to a data structure.
@ -120,7 +120,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The data structure with the tensor's data.
fn into_data<const D: usize>(tensor: FloatTensor<B, D>) -> Reader<Data<FloatElem<B>, D>>;
fn float_into_data<const D: usize>(tensor: FloatTensor<B, D>) -> Reader<Data<FloatElem<B>, D>>;
/// Gets the device of the tensor.
///
@ -131,7 +131,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The device of the tensor.
fn device<const D: usize>(tensor: &FloatTensor<B, D>) -> Device<B>;
fn float_device<const D: usize>(tensor: &FloatTensor<B, D>) -> Device<B>;
/// Moves the tensor to the given device.
///
@ -143,7 +143,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The tensor on the given device.
fn to_device<const D: usize>(
fn float_to_device<const D: usize>(
tensor: FloatTensor<B, D>,
device: &Device<B>,
) -> FloatTensor<B, D>;
@ -162,8 +162,8 @@ pub trait TensorOps<B: Backend> {
/// # Remarks
///
/// Uses `arange_step` with a step size of 1 under the hood.
fn arange(range: Range<usize>, device: &Device<B>) -> IntTensor<B, 1> {
Self::arange_step(range, 1, device)
fn float_arange(range: Range<usize>, device: &Device<B>) -> IntTensor<B, 1> {
Self::float_arange_step(range, 1, device)
}
/// Converts float tensor to int tensor.
@ -175,7 +175,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The int tensor with the same data as the float tensor.
fn into_int<const D: usize>(tensor: FloatTensor<B, D>) -> IntTensor<B, D>;
fn float_into_int<const D: usize>(tensor: FloatTensor<B, D>) -> IntTensor<B, D>;
/// Creates a new tensor with values from the given range with the given step size.
///
@ -188,7 +188,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The tensor with the given values.
fn arange_step(range: Range<usize>, step: usize, device: &Device<B>) -> IntTensor<B, 1> {
fn float_arange_step(range: Range<usize>, step: usize, device: &Device<B>) -> IntTensor<B, 1> {
let value = range
.step_by(step)
.map(|i| (i as i64).elem())
@ -208,7 +208,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The empty tensor with the given shape.
fn empty<const D: usize>(shape: Shape<D>, device: &Device<B>) -> FloatTensor<B, D>;
fn float_empty<const D: usize>(shape: Shape<D>, device: &Device<B>) -> FloatTensor<B, D>;
/// Repeat the tensor along the given dimension.
///
@ -221,12 +221,12 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The tensor with the given dimension repeated.
fn repeat<const D: usize>(
fn float_repeat<const D: usize>(
tensor: FloatTensor<B, D>,
dim: usize,
times: usize,
) -> FloatTensor<B, D> {
let mut shape = B::shape(&tensor);
let mut shape = B::float_shape(&tensor);
if shape.dims[dim] != 1 {
panic!("Can only repeat dimension with dim=1");
}
@ -240,11 +240,11 @@ pub trait TensorOps<B: Backend> {
start..end
});
let mut tensor_output = B::empty(shape, &B::device(&tensor));
let mut tensor_output = B::float_empty(shape, &B::float_device(&tensor));
for i in 0..times {
let mut indices = indices_select_all.clone();
indices[dim] = i..i + 1;
tensor_output = B::slice_assign(tensor_output, indices, tensor.clone());
tensor_output = B::float_slice_assign(tensor_output, indices, tensor.clone());
}
tensor_output
@ -260,7 +260,10 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The result of adding the two tensors together.
fn add<const D: usize>(lhs: FloatTensor<B, D>, rhs: FloatTensor<B, D>) -> FloatTensor<B, D>;
fn float_add<const D: usize>(
lhs: FloatTensor<B, D>,
rhs: FloatTensor<B, D>,
) -> FloatTensor<B, D>;
/// Adds a scalar to a tensor.
///
@ -272,7 +275,10 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The result of adding the scalar to the tensor.
fn add_scalar<const D: usize>(lhs: FloatTensor<B, D>, rhs: FloatElem<B>) -> FloatTensor<B, D>;
fn float_add_scalar<const D: usize>(
lhs: FloatTensor<B, D>,
rhs: FloatElem<B>,
) -> FloatTensor<B, D>;
/// Clamps a tensor under a minimum value.
///
@ -284,13 +290,13 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The clamped tensor.
fn clamp_min<const D: usize>(
fn float_clamp_min<const D: usize>(
tensor: FloatTensor<B, D>,
min: FloatElem<B>,
) -> FloatTensor<B, D> {
// Default implementation
let mask = Self::lower_elem(tensor.clone(), min);
B::mask_fill(tensor, mask, min)
let mask = Self::float_lower_elem(tensor.clone(), min);
B::float_mask_fill(tensor, mask, min)
}
/// Clamps a tensor over a maximum value.
@ -303,13 +309,13 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The clamped tensor.
fn clamp_max<const D: usize>(
fn float_clamp_max<const D: usize>(
tensor: FloatTensor<B, D>,
max: FloatElem<B>,
) -> FloatTensor<B, D> {
// Default implementation
let mask = Self::greater_elem(tensor.clone(), max);
B::mask_fill(tensor, mask, max)
let mask = Self::float_greater_elem(tensor.clone(), max);
B::float_mask_fill(tensor, mask, max)
}
/// Clamps a tensor between a minimum and maximum value.
@ -323,13 +329,13 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The clamped tensor.
fn clamp<const D: usize>(
fn float_clamp<const D: usize>(
tensor: FloatTensor<B, D>,
min: FloatElem<B>,
max: FloatElem<B>,
) -> FloatTensor<B, D> {
// Default implementation
Self::clamp_min(Self::clamp_max(tensor, max), min)
Self::float_clamp_min(Self::float_clamp_max(tensor, max), min)
}
/// Subtracts two tensors.
@ -342,7 +348,10 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The result of subtracting the two tensors.
fn sub<const D: usize>(lhs: FloatTensor<B, D>, rhs: FloatTensor<B, D>) -> FloatTensor<B, D>;
fn float_sub<const D: usize>(
lhs: FloatTensor<B, D>,
rhs: FloatTensor<B, D>,
) -> FloatTensor<B, D>;
/// Subtracts a scalar from a tensor.
///
@ -354,10 +363,16 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The result of subtracting the scalar from the tensor.
fn sub_scalar<const D: usize>(lhs: FloatTensor<B, D>, rhs: FloatElem<B>) -> FloatTensor<B, D>;
fn float_sub_scalar<const D: usize>(
lhs: FloatTensor<B, D>,
rhs: FloatElem<B>,
) -> FloatTensor<B, D>;
/// Multiplies two tensors together element-wise.
fn mul<const D: usize>(lhs: FloatTensor<B, D>, rhs: FloatTensor<B, D>) -> FloatTensor<B, D>;
fn float_mul<const D: usize>(
lhs: FloatTensor<B, D>,
rhs: FloatTensor<B, D>,
) -> FloatTensor<B, D>;
/// Multiplies a tensor by a scalar.
///
@ -369,7 +384,10 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The result of multiplying the tensor by the scalar.
fn mul_scalar<const D: usize>(lhs: FloatTensor<B, D>, rhs: FloatElem<B>) -> FloatTensor<B, D>;
fn float_mul_scalar<const D: usize>(
lhs: FloatTensor<B, D>,
rhs: FloatElem<B>,
) -> FloatTensor<B, D>;
/// Divides two tensors element-wise.
///
@ -381,7 +399,10 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The result of dividing the two tensors.
fn div<const D: usize>(lhs: FloatTensor<B, D>, rhs: FloatTensor<B, D>) -> FloatTensor<B, D>;
fn float_div<const D: usize>(
lhs: FloatTensor<B, D>,
rhs: FloatTensor<B, D>,
) -> FloatTensor<B, D>;
/// Divides a tensor by a scalar.
///
@ -393,7 +414,10 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The result of dividing the tensor by the scalar.
fn div_scalar<const D: usize>(lhs: FloatTensor<B, D>, rhs: FloatElem<B>) -> FloatTensor<B, D>;
fn float_div_scalar<const D: usize>(
lhs: FloatTensor<B, D>,
rhs: FloatElem<B>,
) -> FloatTensor<B, D>;
/// Multiplies two tensors together using matrix multiplication.
///
@ -405,15 +429,18 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The result of multiplying the two tensors together using matrix multiplication.
fn matmul<const D: usize>(lhs: FloatTensor<B, D>, rhs: FloatTensor<B, D>) -> FloatTensor<B, D>;
fn float_matmul<const D: usize>(
lhs: FloatTensor<B, D>,
rhs: FloatTensor<B, D>,
) -> FloatTensor<B, D>;
/// Negates a tensor element-wise.
fn neg<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D> {
Self::mul_scalar(tensor, (-1.0_f32).elem::<FloatElem<B>>())
fn float_neg<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D> {
Self::float_mul_scalar(tensor, (-1.0_f32).elem::<FloatElem<B>>())
}
/// Calculates the reciprocals elementwise
fn recip<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D>;
fn float_recip<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D>;
/// Transposes a tensor.
///
@ -424,8 +451,8 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The transposed tensor.
fn transpose<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D> {
Self::swap_dims(tensor, D - 2, D - 1)
fn float_transpose<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D> {
Self::float_swap_dims(tensor, D - 2, D - 1)
}
/// Swaps two dimensions of a tensor.
@ -439,7 +466,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The tensor with the dimensions swapped.
fn swap_dims<const D: usize>(
fn float_swap_dims<const D: usize>(
tensor: FloatTensor<B, D>,
dim1: usize,
dim2: usize,
@ -455,7 +482,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The tensor with the new shape.
fn reshape<const D1: usize, const D2: usize>(
fn float_reshape<const D1: usize, const D2: usize>(
tensor: FloatTensor<B, D1>,
shape: Shape<D2>,
) -> FloatTensor<B, D2>;
@ -471,7 +498,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The gathered elements.
fn gather<const D: usize>(
fn float_gather<const D: usize>(
dim: usize,
tensor: FloatTensor<B, D>,
indices: IntTensor<B, D>,
@ -489,7 +516,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The tensor with the scattered elements.
fn scatter<const D: usize>(
fn float_scatter<const D: usize>(
dim: usize,
tensor: FloatTensor<B, D>,
indices: IntTensor<B, D>,
@ -507,7 +534,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The selected elements.
fn select<const D: usize>(
fn float_select<const D: usize>(
tensor: FloatTensor<B, D>,
dim: usize,
indices: IntTensor<B, 1>,
@ -526,7 +553,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The tensor with the selected elements assigned to the given value.
fn select_assign<const D: usize>(
fn float_select_assign<const D: usize>(
tensor: FloatTensor<B, D>,
dim: usize,
indices: IntTensor<B, 1>,
@ -543,7 +570,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The selected elements in a new tensor.
fn slice<const D1: usize, const D2: usize>(
fn float_slice<const D1: usize, const D2: usize>(
tensor: FloatTensor<B, D1>,
ranges: [Range<usize>; D2],
) -> FloatTensor<B, D1>;
@ -559,7 +586,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The tensor with the selected elements assigned to the given value.
fn slice_assign<const D1: usize, const D2: usize>(
fn float_slice_assign<const D1: usize, const D2: usize>(
tensor: FloatTensor<B, D1>,
ranges: [Range<usize>; D2],
value: FloatTensor<B, D1>,
@ -576,7 +603,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The tensor with the selected elements assigned to the given value.
fn mask_where<const D: usize>(
fn float_mask_where<const D: usize>(
tensor: FloatTensor<B, D>,
mask: BoolTensor<B, D>,
value: FloatTensor<B, D>,
@ -593,7 +620,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The tensor with the selected elements assigned to the given value.
fn mask_fill<const D: usize>(
fn float_mask_fill<const D: usize>(
tensor: FloatTensor<B, D>,
mask: BoolTensor<B, D>,
value: FloatElem<B>,
@ -609,7 +636,10 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// A boolean tensor with the result of the comparison.
fn equal<const D: usize>(lhs: FloatTensor<B, D>, rhs: FloatTensor<B, D>) -> BoolTensor<B, D>;
fn float_equal<const D: usize>(
lhs: FloatTensor<B, D>,
rhs: FloatTensor<B, D>,
) -> BoolTensor<B, D>;
/// Equal comparison of a tensor and a scalar.
///
@ -621,7 +651,10 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// A boolean tensor with the result of the comparison.
fn equal_elem<const D: usize>(lhs: FloatTensor<B, D>, rhs: FloatElem<B>) -> BoolTensor<B, D>;
fn float_equal_elem<const D: usize>(
lhs: FloatTensor<B, D>,
rhs: FloatElem<B>,
) -> BoolTensor<B, D>;
/// Greater than comparison of two tensors.
///
@ -633,7 +666,10 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// A boolean tensor with the result of the comparison.
fn greater<const D: usize>(lhs: FloatTensor<B, D>, rhs: FloatTensor<B, D>) -> BoolTensor<B, D>;
fn float_greater<const D: usize>(
lhs: FloatTensor<B, D>,
rhs: FloatTensor<B, D>,
) -> BoolTensor<B, D>;
/// Greater than comparison of a tensor and a scalar.
///
@ -645,7 +681,10 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// A boolean tensor with the result of the comparison.
fn greater_elem<const D: usize>(lhs: FloatTensor<B, D>, rhs: FloatElem<B>) -> BoolTensor<B, D>;
fn float_greater_elem<const D: usize>(
lhs: FloatTensor<B, D>,
rhs: FloatElem<B>,
) -> BoolTensor<B, D>;
/// Greater than or equal comparison of two tensors.
///
@ -657,7 +696,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// A boolean tensor with the result of the comparison.
fn greater_equal<const D: usize>(
fn float_greater_equal<const D: usize>(
lhs: FloatTensor<B, D>,
rhs: FloatTensor<B, D>,
) -> BoolTensor<B, D>;
@ -672,7 +711,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// A boolean tensor with the result of the comparison.
fn greater_equal_elem<const D: usize>(
fn float_greater_equal_elem<const D: usize>(
lhs: FloatTensor<B, D>,
rhs: FloatElem<B>,
) -> BoolTensor<B, D>;
@ -687,7 +726,10 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// A boolean tensor with the result of the comparison.
fn lower<const D: usize>(lhs: FloatTensor<B, D>, rhs: FloatTensor<B, D>) -> BoolTensor<B, D>;
fn float_lower<const D: usize>(
lhs: FloatTensor<B, D>,
rhs: FloatTensor<B, D>,
) -> BoolTensor<B, D>;
/// Less than comparison of a tensor and a scalar.
///
@ -699,7 +741,10 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// A boolean tensor with the result of the comparison.
fn lower_elem<const D: usize>(lhs: FloatTensor<B, D>, rhs: FloatElem<B>) -> BoolTensor<B, D>;
fn float_lower_elem<const D: usize>(
lhs: FloatTensor<B, D>,
rhs: FloatElem<B>,
) -> BoolTensor<B, D>;
/// Less than or equal comparison of two tensors.
///
@ -711,7 +756,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// A boolean tensor with the result of the comparison.
fn lower_equal<const D: usize>(
fn float_lower_equal<const D: usize>(
lhs: FloatTensor<B, D>,
rhs: FloatTensor<B, D>,
) -> BoolTensor<B, D>;
@ -726,19 +771,19 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// A boolean tensor with the result of the comparison.
fn lower_equal_elem<const D: usize>(
fn float_lower_equal_elem<const D: usize>(
lhs: FloatTensor<B, D>,
rhs: FloatElem<B>,
) -> BoolTensor<B, D>;
/// Detaches a tensor from the computation graph.
fn detach<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D> {
fn float_detach<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D> {
// Should only be overridden by autodiff backends.
tensor
}
/// Sets the `require_grad` flag of a tensor.
fn set_require_grad<const D: usize>(
fn float_set_require_grad<const D: usize>(
tensor: FloatTensor<B, D>,
_require_grad: bool,
) -> FloatTensor<B, D> {
@ -747,7 +792,7 @@ pub trait TensorOps<B: Backend> {
}
/// Returns the `require_grad` flag of a tensor.
fn is_require_grad<const D: usize>(_tensor: &FloatTensor<B, D>) -> bool {
fn float_is_require_grad<const D: usize>(_tensor: &FloatTensor<B, D>) -> bool {
// Should only be overridden by autodiff backends.
false
}
@ -761,7 +806,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// A scalar tensor with the sum of all elements in `tensor`.
fn sum<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, 1>;
fn float_sum<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, 1>;
/// Sum of all elements in a tensor along a dimension.
///
@ -773,7 +818,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// A tensor with the sum of all elements in `tensor` along `dim`.
fn sum_dim<const D: usize>(tensor: FloatTensor<B, D>, dim: usize) -> FloatTensor<B, D>;
fn float_sum_dim<const D: usize>(tensor: FloatTensor<B, D>, dim: usize) -> FloatTensor<B, D>;
/// Mean of all elements in a tensor.
///
@ -784,9 +829,9 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// A scalar tensor with the mean of all elements in `tensor`.
fn mean<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, 1> {
let num_elems = B::shape(&tensor).num_elements();
B::div_scalar(B::sum(tensor), (num_elems as i64).elem())
fn float_mean<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, 1> {
let num_elems = B::float_shape(&tensor).num_elements();
B::float_div_scalar(B::float_sum(tensor), (num_elems as i64).elem())
}
/// Mean of all elements in a tensor along a dimension.
@ -799,7 +844,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// A tensor with the mean of all elements in `tensor` along `dim`.
fn mean_dim<const D: usize>(tensor: FloatTensor<B, D>, dim: usize) -> FloatTensor<B, D>;
fn float_mean_dim<const D: usize>(tensor: FloatTensor<B, D>, dim: usize) -> FloatTensor<B, D>;
/// Converts a tensor to full precision.
///
@ -810,7 +855,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// A tensor with the same values as `tensor` but with full precision.
fn to_full_precision<const D: usize>(
fn float_to_full_precision<const D: usize>(
tensor: &FloatTensor<B, D>,
) -> FloatTensor<FullPrecisionBackend<B>, D>;
@ -823,7 +868,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// A tensor with the same values as `tensor` but with the precision of the backend.
fn from_full_precision<const D: usize>(
fn float_from_full_precision<const D: usize>(
tensor: FloatTensor<FullPrecisionBackend<B>, D>,
) -> FloatTensor<B, D>;
@ -836,7 +881,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// A tensor with the same shape as `tensor` with exponential values.
fn exp<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D>;
fn float_exp<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D>;
/// Returns a new tensor with natural logarithm values.
///
@ -847,7 +892,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// A tensor with the same shape as `tensor` with natural logarithm values.
fn log<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D>;
fn float_log<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D>;
/// Returns a new tensor with logarithm values of (1 + Xi).
///
@ -858,7 +903,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi).
fn log1p<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D>;
fn float_log1p<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D>;
/// Elementwise power with a FloatTensor.
///
@ -870,7 +915,10 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The elements of `lhs` raised to the power of the elements of `rhs`.
fn powf<const D: usize>(lhs: FloatTensor<B, D>, rhs: FloatTensor<B, D>) -> FloatTensor<B, D>;
fn float_powf<const D: usize>(
lhs: FloatTensor<B, D>,
rhs: FloatTensor<B, D>,
) -> FloatTensor<B, D>;
/// Elementwise power with an IntTensor.
///
@ -882,8 +930,11 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
fn powi<const D: usize>(lhs: FloatTensor<B, D>, rhs: IntTensor<B, D>) -> FloatTensor<B, D> {
Self::powf(lhs, B::int_into_float::<D>(rhs))
fn float_powi<const D: usize>(
lhs: FloatTensor<B, D>,
rhs: IntTensor<B, D>,
) -> FloatTensor<B, D> {
Self::float_powf(lhs, B::int_into_float::<D>(rhs))
}
/// raises a tensor to the power of a int scalar.
@ -896,8 +947,11 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The elements of `lhs` raised to the value of `rhs`.
fn powi_scalar<const D: usize>(lhs: FloatTensor<B, D>, rhs: IntElem<B>) -> FloatTensor<B, D> {
Self::powf_scalar(lhs, rhs.to_f32().unwrap())
fn float_powi_scalar<const D: usize>(
lhs: FloatTensor<B, D>,
rhs: IntElem<B>,
) -> FloatTensor<B, D> {
Self::float_powf_scalar(lhs, rhs.to_f32().unwrap())
}
/// Returns a new tensor with values raised to the power of float `value`.
@ -910,7 +964,10 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// A tensor with the same shape as `tensor` with values raised to the power of `value`.
fn powf_scalar<const D: usize>(tensor: FloatTensor<B, D>, value: f32) -> FloatTensor<B, D>;
fn float_powf_scalar<const D: usize>(
tensor: FloatTensor<B, D>,
value: f32,
) -> FloatTensor<B, D>;
/// Returns a new tensor with square root values.
///
@ -921,7 +978,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// A tensor with the same shape as `tensor` with square root values.
fn sqrt<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D>;
fn float_sqrt<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D>;
/// Returns a new tensor with absolute values.
///
@ -932,7 +989,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// A tensor with the same shape as `tensor` with absolute values.
fn abs<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D>;
fn float_abs<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D>;
/// Returns a new tensor with cosine values.
///
@ -943,7 +1000,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// A tensor with the same shape as `tensor` with cosine values.
fn cos<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D>;
fn float_cos<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D>;
/// Returns a new tensor with sine values.
///
@ -954,7 +1011,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// A tensor with the same shape as `tensor` with sine values.
fn sin<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D>;
fn float_sin<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D>;
/// Returns a new tensor with tangent values.
///
@ -965,7 +1022,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// A tensor with the same shape as `tensor` with tangent values.
fn tanh<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D>;
fn float_tanh<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D>;
/// Returns a new tensor with the error function values.
///
@ -976,7 +1033,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// A tensor with the same shape as `tensor` with error function values.
fn erf<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D>;
fn float_erf<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D>;
/// Catcatenates tensors along a dimension.
///
@ -988,7 +1045,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// A tensor with the catcatenated tensors along `dim`.
fn cat<const D: usize>(tensors: Vec<FloatTensor<B, D>>, dim: usize) -> FloatTensor<B, D>;
fn float_cat<const D: usize>(tensors: Vec<FloatTensor<B, D>>, dim: usize) -> FloatTensor<B, D>;
/// Gets the indices of the maximum elements of a tensor along an axis.
///
@ -1000,7 +1057,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// A tensor with the indices of the maximum elements of `tensor` along `dim`.
fn argmax<const D: usize>(tensor: FloatTensor<B, D>, dim: usize) -> IntTensor<B, D>;
fn float_argmax<const D: usize>(tensor: FloatTensor<B, D>, dim: usize) -> IntTensor<B, D>;
/// Gets the indices of the minimum elements of a tensor along an axis.
///
@ -1012,7 +1069,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// A tensor with the indices of the minimum elements of `tensor` along `dim`.
fn argmin<const D: usize>(tensor: FloatTensor<B, D>, dim: usize) -> IntTensor<B, D>;
fn float_argmin<const D: usize>(tensor: FloatTensor<B, D>, dim: usize) -> IntTensor<B, D>;
/// Gets the maximum element of a tensor.
///
@ -1023,11 +1080,11 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// A tensor with the maximum element of `tensor`.
fn max<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, 1> {
let shape = B::shape(&tensor);
let tensor = B::reshape(tensor, Shape::new([shape.num_elements()]));
fn float_max<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, 1> {
let shape = B::float_shape(&tensor);
let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
B::max_dim(tensor, 0)
B::float_max_dim(tensor, 0)
}
/// Gets the maximum elements of a tensor along an axis.
@ -1040,10 +1097,10 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// A tensor with the maximum elements of `tensor` along `dim`.
fn max_dim<const D: usize>(tensor: FloatTensor<B, D>, dim: usize) -> FloatTensor<B, D> {
let index = B::argmax(tensor.clone(), dim);
fn float_max_dim<const D: usize>(tensor: FloatTensor<B, D>, dim: usize) -> FloatTensor<B, D> {
let index = B::float_argmax(tensor.clone(), dim);
B::gather(dim, tensor, index)
B::float_gather(dim, tensor, index)
}
/// Gets the maximum elements of a tensor along an axis and their indices.
@ -1056,12 +1113,12 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// A tuple with the maximum elements of `tensor` along `dim` and their indices.
fn max_dim_with_indices<const D: usize>(
fn float_max_dim_with_indices<const D: usize>(
tensor: FloatTensor<B, D>,
dim: usize,
) -> (FloatTensor<B, D>, IntTensor<B, D>) {
let index = B::argmax(tensor.clone(), dim);
let values = B::gather(dim, tensor, index.clone());
let index = B::float_argmax(tensor.clone(), dim);
let values = B::float_gather(dim, tensor, index.clone());
(values, index)
}
@ -1075,11 +1132,11 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// A tensor with the minimum element of `tensor`.
fn min<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, 1> {
let shape = B::shape(&tensor);
let tensor = B::reshape(tensor, Shape::new([shape.num_elements()]));
fn float_min<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, 1> {
let shape = B::float_shape(&tensor);
let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
B::min_dim(tensor, 0)
B::float_min_dim(tensor, 0)
}
/// Gets the minimum elements of a tensor along an axis.
@ -1092,10 +1149,10 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// A tensor with the minimum elements of `tensor` along `dim`.
fn min_dim<const D: usize>(tensor: FloatTensor<B, D>, dim: usize) -> FloatTensor<B, D> {
let index = B::argmin(tensor.clone(), dim);
fn float_min_dim<const D: usize>(tensor: FloatTensor<B, D>, dim: usize) -> FloatTensor<B, D> {
let index = B::float_argmin(tensor.clone(), dim);
B::gather(dim, tensor, index)
B::float_gather(dim, tensor, index)
}
/// Gets the minimum elements of a tensor along an axis and their indices.
@ -1108,12 +1165,12 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// A tuple with the minimum elements of `tensor` along `dim` and their indices.
fn min_dim_with_indices<const D: usize>(
fn float_min_dim_with_indices<const D: usize>(
tensor: FloatTensor<B, D>,
dim: usize,
) -> (FloatTensor<B, D>, IntTensor<B, D>) {
let index = B::argmin(tensor.clone(), dim);
let values = B::gather(dim, tensor, index.clone());
let index = B::float_argmin(tensor.clone(), dim);
let values = B::float_gather(dim, tensor, index.clone());
(values, index)
}
@ -1133,7 +1190,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// A new tensor with the given dimension narrowed to the given range.
fn narrow<const D: usize>(
fn float_narrow<const D: usize>(
tensor: FloatTensor<B, D>,
dim: usize,
start: usize,
@ -1154,7 +1211,7 @@ pub trait TensorOps<B: Backend> {
///
/// A vectors of tensors
///
fn chunk<const D: usize>(
fn float_chunk<const D: usize>(
tensor: FloatTensor<B, D>,
chunks: usize,
dim: usize,

View File

@ -38,7 +38,7 @@ impl<G: GraphicsApi + 'static, F: FloatElement, I: IntElement> Backend for Wgpu<
type FloatElem = F;
type IntElem = I;
type TensorPrimitive<const D: usize> = WgpuTensor<F, D>;
type FloatTensorPrimitive<const D: usize> = WgpuTensor<F, D>;
type IntTensorPrimitive<const D: usize> = WgpuTensor<I, D>;
type BoolTensorPrimitive<const D: usize> = WgpuTensor<u32, D>;

View File

@ -91,7 +91,7 @@ where
fn float_tensor<const D: usize>(
handle: Self::Handle,
shape: Shape<D>,
) -> Self::TensorPrimitive<D> {
) -> Self::FloatTensorPrimitive<D> {
handle.into_tensor(shape)
}
@ -109,7 +109,7 @@ where
handle.into_tensor(shape)
}
fn float_tensor_handle<const D: usize>(tensor: Self::TensorPrimitive<D>) -> Self::Handle {
fn float_tensor_handle<const D: usize>(tensor: Self::FloatTensorPrimitive<D>) -> Self::Handle {
tensor.into()
}

View File

@ -15,24 +15,24 @@ use crate::{unary, FloatElement, GraphicsApi, IntElement, Wgpu};
use burn_tensor::ops::{
BoolTensor, Device, FloatElem, FloatTensor, FullPrecisionBackend, IntTensor,
};
use burn_tensor::{ops::TensorOps, Data, Distribution, Shape};
use burn_tensor::{ops::FloatTensorOps, Data, Distribution, Shape};
use burn_tensor::{ElementConversion, Reader};
use std::ops::Range;
impl<G, F, I> TensorOps<Wgpu<G, F, I>> for Wgpu<G, F, I>
impl<G, F, I> FloatTensorOps<Wgpu<G, F, I>> for Wgpu<G, F, I>
where
G: GraphicsApi + 'static,
F: FloatElement,
I: IntElement,
{
fn from_data<const D: usize>(
fn float_from_data<const D: usize>(
data: Data<FloatElem<Self>, D>,
device: &Device<Self>,
) -> FloatTensor<Self, D> {
super::from_data::<G, F, D>(data, device)
}
fn random<const D: usize>(
fn float_random<const D: usize>(
shape: Shape<D>,
distribution: Distribution,
device: &Device<Self>,
@ -51,48 +51,50 @@ where
}
}
fn shape<const D: usize>(tensor: &FloatTensor<Self, D>) -> Shape<D> {
fn float_shape<const D: usize>(tensor: &FloatTensor<Self, D>) -> Shape<D> {
tensor.shape.clone()
}
fn into_data<const D: usize>(tensor: FloatTensor<Self, D>) -> Reader<Data<FloatElem<Self>, D>> {
fn float_into_data<const D: usize>(
tensor: FloatTensor<Self, D>,
) -> Reader<Data<FloatElem<Self>, D>> {
super::into_data(tensor)
}
fn device<const D: usize>(tensor: &FloatTensor<Self, D>) -> Device<Self> {
fn float_device<const D: usize>(tensor: &FloatTensor<Self, D>) -> Device<Self> {
tensor.device.clone()
}
fn to_device<const D: usize>(
fn float_to_device<const D: usize>(
tensor: FloatTensor<Self, D>,
device: &Device<Self>,
) -> FloatTensor<Self, D> {
super::to_device::<G, F, D>(tensor, device)
}
fn empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> FloatTensor<Self, D> {
fn float_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> FloatTensor<Self, D> {
super::empty::<G, F, D>(shape, device)
}
fn add<const D: usize>(
fn float_add<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {
numeric::add(lhs, rhs)
}
fn add_scalar<const D: usize>(
fn float_add_scalar<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> FloatTensor<Self, D> {
numeric::add_scalar(lhs, rhs)
}
fn zeros<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> FloatTensor<Self, D> {
fn float_zeros<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> FloatTensor<Self, D> {
numeric::zeros::<G, F, D>(shape, device)
}
fn full<const D: usize>(
fn float_full<const D: usize>(
shape: Shape<D>,
fill_value: FloatElem<Self>,
device: &WgpuDevice,
@ -100,53 +102,53 @@ where
numeric::full::<G, F, D>(shape, device, fill_value)
}
fn ones<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> FloatTensor<Self, D> {
fn float_ones<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> FloatTensor<Self, D> {
numeric::ones::<G, F, D>(shape, device)
}
fn sub<const D: usize>(
fn float_sub<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {
numeric::sub(lhs, rhs)
}
fn sub_scalar<const D: usize>(
fn float_sub_scalar<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> FloatTensor<Self, D> {
numeric::sub_scalar(lhs, rhs)
}
fn mul<const D: usize>(
fn float_mul<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {
numeric::mul(lhs, rhs)
}
fn mul_scalar<const D: usize>(
fn float_mul_scalar<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> FloatTensor<Self, D> {
numeric::mul_scalar(lhs, rhs)
}
fn div<const D: usize>(
fn float_div<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {
numeric::div(lhs, rhs)
}
fn div_scalar<const D: usize>(
fn float_div_scalar<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> FloatTensor<Self, D> {
numeric::div_scalar(lhs, rhs)
}
fn matmul<const D: usize>(
fn float_matmul<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {
@ -162,7 +164,7 @@ where
}
}
fn swap_dims<const D: usize>(
fn float_swap_dims<const D: usize>(
tensor: FloatTensor<Self, D>,
dim1: usize,
dim2: usize,
@ -170,14 +172,14 @@ where
super::swap_dims(tensor, dim1, dim2)
}
fn reshape<const D1: usize, const D2: usize>(
fn float_reshape<const D1: usize, const D2: usize>(
tensor: FloatTensor<Self, D1>,
shape: Shape<D2>,
) -> FloatTensor<Self, D2> {
super::reshape(tensor, shape)
}
fn gather<const D: usize>(
fn float_gather<const D: usize>(
dim: usize,
tensor: FloatTensor<Self, D>,
indices: IntTensor<Self, D>,
@ -185,7 +187,7 @@ where
kernel::gather(dim, tensor, indices)
}
fn scatter<const D: usize>(
fn float_scatter<const D: usize>(
dim: usize,
tensor: FloatTensor<Self, D>,
indices: IntTensor<Self, D>,
@ -194,7 +196,7 @@ where
kernel::scatter(dim, tensor, indices, value)
}
fn select<const D: usize>(
fn float_select<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
indices: IntTensor<Self, 1>,
@ -202,7 +204,7 @@ where
kernel::select(tensor, dim, indices)
}
fn select_assign<const D: usize>(
fn float_select_assign<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
indices: IntTensor<Self, 1>,
@ -211,14 +213,14 @@ where
kernel::select_assign(tensor, dim, indices, value)
}
fn slice<const D1: usize, const D2: usize>(
fn float_slice<const D1: usize, const D2: usize>(
tensor: FloatTensor<Self, D1>,
ranges: [Range<usize>; D2],
) -> FloatTensor<Self, D1> {
kernel::slice(tensor, ranges)
}
fn slice_assign<const D1: usize, const D2: usize>(
fn float_slice_assign<const D1: usize, const D2: usize>(
tensor: FloatTensor<Self, D1>,
ranges: [Range<usize>; D2],
value: FloatTensor<Self, D1>,
@ -226,7 +228,7 @@ where
kernel::slice_assign(tensor, ranges, value)
}
fn mask_where<const D: usize>(
fn float_mask_where<const D: usize>(
tensor: FloatTensor<Self, D>,
mask: BoolTensor<Self, D>,
value: FloatTensor<Self, D>,
@ -234,7 +236,7 @@ where
kernel::mask_where(tensor, mask, value)
}
fn mask_fill<const D: usize>(
fn float_mask_fill<const D: usize>(
tensor: FloatTensor<Self, D>,
mask: BoolTensor<Self, D>,
value: FloatElem<Self>,
@ -242,81 +244,84 @@ where
kernel::mask_fill(tensor, mask, value)
}
fn equal<const D: usize>(
fn float_equal<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> BoolTensor<Self, D> {
kernel::equal(lhs, rhs)
}
fn equal_elem<const D: usize>(
fn float_equal_elem<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> BoolTensor<Self, D> {
kernel::equal_elem(lhs, rhs)
}
fn greater<const D: usize>(
fn float_greater<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> BoolTensor<Self, D> {
kernel::greater(lhs, rhs)
}
fn greater_elem<const D: usize>(
fn float_greater_elem<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> BoolTensor<Self, D> {
kernel::greater_elem(lhs, rhs)
}
fn greater_equal<const D: usize>(
fn float_greater_equal<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> BoolTensor<Self, D> {
kernel::greater_equal(lhs, rhs)
}
fn greater_equal_elem<const D: usize>(
fn float_greater_equal_elem<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> BoolTensor<Self, D> {
kernel::greater_equal_elem(lhs, rhs)
}
fn lower<const D: usize>(
fn float_lower<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> BoolTensor<Self, D> {
kernel::lower(lhs, rhs)
}
fn lower_elem<const D: usize>(
fn float_lower_elem<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> BoolTensor<Self, D> {
kernel::lower_elem(lhs, rhs)
}
fn lower_equal<const D: usize>(
fn float_lower_equal<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> BoolTensor<Self, D> {
kernel::lower_equal(lhs, rhs)
}
fn lower_equal_elem<const D: usize>(
fn float_lower_equal_elem<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> BoolTensor<Self, D> {
kernel::lower_equal_elem(lhs, rhs)
}
fn sum<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, 1> {
fn float_sum<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, 1> {
reduce::sum(tensor)
}
fn sum_dim<const D: usize>(tensor: FloatTensor<Self, D>, dim: usize) -> FloatTensor<Self, D> {
fn float_sum_dim<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
) -> FloatTensor<Self, D> {
#[cfg(feature = "autotune")]
{
reduce::sum_dim_autotune(tensor, dim)
@ -329,7 +334,10 @@ where
}
}
fn mean_dim<const D: usize>(tensor: FloatTensor<Self, D>, dim: usize) -> FloatTensor<Self, D> {
fn float_mean_dim<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
) -> FloatTensor<Self, D> {
#[cfg(feature = "autotune")]
{
reduce::mean_dim_autotune(tensor, dim)
@ -342,19 +350,19 @@ where
}
}
fn to_full_precision<const D: usize>(
fn float_to_full_precision<const D: usize>(
tensor: &FloatTensor<Self, D>,
) -> FloatTensor<FullPrecisionBackend<Self>, D> {
kernel::cast(tensor.clone())
}
fn from_full_precision<const D: usize>(
fn float_from_full_precision<const D: usize>(
tensor: FloatTensor<FullPrecisionBackend<Self>, D>,
) -> FloatTensor<Self, D> {
kernel::cast(tensor)
}
fn exp<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
fn float_exp<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(
operator: |elem: Elem| Operator::Exp {
input: Variable::Input(0, Item::Scalar(elem)),
@ -365,7 +373,7 @@ where
)
}
fn log<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
fn float_log<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(
operator: |elem: Elem| Operator::Log {
input: Variable::Input(0, Item::Scalar(elem)),
@ -376,7 +384,7 @@ where
)
}
fn log1p<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
fn float_log1p<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(
operator: |elem: Elem| Operator::Log1p {
input: Variable::Input(0, Item::Scalar(elem)),
@ -387,7 +395,10 @@ where
)
}
fn powf_scalar<const D: usize>(lhs: FloatTensor<Self, D>, rhs: f32) -> FloatTensor<Self, D> {
fn float_powf_scalar<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: f32,
) -> FloatTensor<Self, D> {
unary!(
operator: |elem: Elem| Operator::Powf {
lhs: Variable::Input(0, Item::Scalar(elem)),
@ -399,7 +410,7 @@ where
)
}
fn sqrt<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
fn float_sqrt<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(
operator: |elem: Elem| Operator::Sqrt {
input: Variable::Input(0, Item::Scalar(elem)),
@ -410,7 +421,7 @@ where
)
}
fn abs<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
fn float_abs<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(
operator: |elem: Elem| Operator::Abs {
input: Variable::Input(0, Item::Scalar(elem)),
@ -421,7 +432,7 @@ where
)
}
fn cos<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
fn float_cos<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(
operator: |elem: Elem| Operator::Cos {
input: Variable::Input(0, Item::Scalar(elem)),
@ -432,7 +443,7 @@ where
)
}
fn sin<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
fn float_sin<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(
operator: |elem: Elem| Operator::Sin {
input: Variable::Input(0, Item::Scalar(elem)),
@ -443,7 +454,7 @@ where
)
}
fn tanh<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
fn float_tanh<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(
operator: |elem: Elem| Operator::Tanh {
input: Variable::Input(0, Item::Scalar(elem)),
@ -454,7 +465,7 @@ where
)
}
fn erf<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
fn float_erf<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(
operator: |elem: Elem| Operator::Erf {
input: Variable::Input(0, Item::Scalar(elem)),
@ -465,23 +476,32 @@ where
)
}
fn cat<const D: usize>(tensors: Vec<FloatTensor<Self, D>>, dim: usize) -> FloatTensor<Self, D> {
fn float_cat<const D: usize>(
tensors: Vec<FloatTensor<Self, D>>,
dim: usize,
) -> FloatTensor<Self, D> {
kernel::cat(tensors, dim)
}
fn argmax<const D: usize>(tensor: FloatTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
fn float_argmax<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
) -> IntTensor<Self, D> {
reduce::argmax(tensor, dim)
}
fn argmin<const D: usize>(tensor: FloatTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
fn float_argmin<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
) -> IntTensor<Self, D> {
reduce::argmin(tensor, dim)
}
fn into_int<const D: usize>(tensor: FloatTensor<Self, D>) -> IntTensor<Self, D> {
fn float_into_int<const D: usize>(tensor: FloatTensor<Self, D>) -> IntTensor<Self, D> {
kernel::cast(tensor)
}
fn clamp<const D: usize>(
fn float_clamp<const D: usize>(
tensor: FloatTensor<Self, D>,
min: FloatElem<Self>,
max: FloatElem<Self>,
@ -489,7 +509,7 @@ where
kernel::clamp(tensor, min, max)
}
fn recip<const D: usize>(
fn float_recip<const D: usize>(
tensor: FloatTensor<Wgpu<G, F, I>, D>,
) -> FloatTensor<Wgpu<G, F, I>, D> {
unary!(
@ -502,7 +522,7 @@ where
)
}
fn repeat<const D: usize>(
fn float_repeat<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
times: usize,
@ -510,7 +530,7 @@ where
kernel::repeat(tensor, dim, times)
}
fn powf<const D: usize>(
fn float_powf<const D: usize>(
lhs: FloatTensor<Wgpu<G, F, I>, D>,
rhs: FloatTensor<Wgpu<G, F, I>, D>,
) -> FloatTensor<Wgpu<G, F, I>, D> {

View File

@ -51,8 +51,8 @@ impl<B: Backend> Backend for Autodiff<B> {
let (lhs, rhs, output, shape_bias) = ops.state;
// Fetch shapes of our tensor to support broadcasting.
let shape_lhs = B::shape(&lhs);
let shape_rhs = B::shape(&rhs);
let shape_lhs = B::float_shape(&lhs);
let shape_rhs = B::float_shape(&rhs);
// Compute the gradient of the output using the already existing `relu_backward`
// function in the basic Burn backend trait.
@ -61,13 +61,13 @@ impl<B: Backend> Backend for Autodiff<B> {
// Compute the lhs gradient, which is the derivative of matmul with support for
// broadcasting.
let grad_lhs = broadcast_shape::<B, D>(
B::matmul(grad_output.clone(), B::transpose(rhs)),
B::float_matmul(grad_output.clone(), B::float_transpose(rhs)),
&shape_lhs,
);
// Compute the rhs gradient, which is the derivative of matmul with support for
// broadcasting.
let grad_rhs = broadcast_shape::<B, D>(
B::matmul(B::transpose(lhs), grad_output.clone()),
B::float_matmul(B::float_transpose(lhs), grad_output.clone()),
&shape_rhs,
);
// The add derivative is only 1, so we just need to support broadcasting to
@ -101,7 +101,7 @@ impl<B: Backend> Backend for Autodiff<B> {
OpsKind::Tracked(prep) => {
// When at least one node is tracked, we should register our backward step.
// We compute the output and the state before finishing the preparation.
let bias_shape = B::shape(&bias.primitive);
let bias_shape = B::float_shape(&bias.primitive);
let output = B::fused_matmul_add_relu(
lhs.primitive.clone(),
rhs.primitive.clone(),

View File

@ -4,7 +4,8 @@ mod forward;
use burn::tensor::{activation, Tensor};
/// We use a type alias for better readability.
pub type FloatTensor<B, const D: usize> = <B as burn::tensor::backend::Backend>::TensorPrimitive<D>;
pub type FloatTensor<B, const D: usize> =
<B as burn::tensor::backend::Backend>::FloatTensorPrimitive<D>;
/// We create our own Backend trait that extends the Burn backend trait.
pub trait Backend: burn::tensor::backend::Backend {