Feat/add tensor checks (#283)

This commit is contained in:
Nathaniel Simard 2023-04-10 19:16:15 -04:00 committed by GitHub
parent 37b79bcc40
commit 2220965b5c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 631 additions and 35 deletions

View File

@ -12,7 +12,7 @@ pub fn generate_autoregressive_mask<B: Backend>(
) -> Tensor<B, 3, Bool> {
let mut mask = Tensor::<B, 3, Int>::zeros([1, seq_length, seq_length]);
for i in 0..seq_length {
for i in 0..(seq_length - 1) {
let values = Tensor::<B, 3, Int>::ones([1, 1, seq_length - (i + 1)]);
mask = mask.index_assign([0..1, i..i + 1, i + 1..seq_length], values);
}

View File

@ -4,7 +4,9 @@ use alloc::vec;
use alloc::vec::Vec;
use core::{fmt::Debug, ops::Range};
use crate::{backend::Backend, Bool, Data, Float, Int, Shape, TensorKind};
use crate::{
backend::Backend, check, check::TensorCheck, Bool, Data, Float, Int, Shape, TensorKind,
};
#[derive(new, Clone, Debug)]
pub struct Tensor<B, const D: usize, K = Float>
@ -48,7 +50,10 @@ where
///
/// If the tensor can not be reshape to the given shape.
pub fn reshape<const D2: usize, S: Into<Shape<D2>>>(self, shape: S) -> Tensor<B, D2, K> {
Tensor::new(K::reshape::<D, D2>(self.primitive, shape.into()))
let shape = shape.into();
check!(TensorCheck::reshape(&self.shape(), &shape));
Tensor::new(K::reshape::<D, D2>(self.primitive, shape))
}
/// Flatten the tensor along a given range of dimensions.
@ -88,17 +93,7 @@ where
///
/// ```
pub fn flatten<const D2: usize>(self, start_dim: usize, end_dim: usize) -> Tensor<B, D2, K> {
if start_dim > end_dim {
panic!("The start dim ({start_dim}) must be smaller than the end dim ({end_dim})")
}
if D2 > D {
panic!("Result dim ({D2}) must be smaller than ({D})")
}
if D < end_dim + 1 {
panic!("The end dim ({end_dim}) must be greater than the tensor dim ({D2})")
}
check!(TensorCheck::flatten::<D, D2>(start_dim, end_dim));
let current_dims = self.shape().dims;
let mut new_dims: [usize; D2] = [0; D2];
@ -135,9 +130,7 @@ where
/// }
/// ```
pub fn unsqueeze<const D2: usize>(self) -> Tensor<B, D2, K> {
if D2 < D {
panic!("Can't unsqueeze smaller tensor, got dim {D2}, expected > {D}")
}
check!(TensorCheck::unsqueeze::<D, D2>());
let mut dims = [1; D2];
let num_ones = D2 - D;
@ -169,6 +162,7 @@ where
/// }
/// ```
pub fn index<const D2: usize>(self, indexes: [core::ops::Range<usize>; D2]) -> Self {
check!(TensorCheck::index(&self.shape(), &indexes));
Self::new(K::index(self.primitive, indexes))
}
@ -199,6 +193,11 @@ where
indexes: [core::ops::Range<usize>; D2],
values: Self,
) -> Self {
check!(TensorCheck::index_assign(
&self.shape(),
&values.shape(),
&indexes
));
Self::new(K::index_assign(self.primitive, indexes, values.primitive))
}
@ -247,6 +246,7 @@ where
///
/// If the two tensors don't have the same shape.
pub fn equal(self, other: Self) -> Tensor<B, D, Bool> {
check!(TensorCheck::binary_ops_ew("Equal", &self, &other));
K::equal(self.primitive, other.primitive)
}
@ -262,6 +262,8 @@ where
///
/// If all tensors don't have the same shape.
pub fn cat(tensors: Vec<Self>, dim: usize) -> Self {
check!(TensorCheck::cat(&tensors, dim));
Self::new(K::cat(
tensors.into_iter().map(|vector| vector.primitive).collect(),
dim,

View File

@ -0,0 +1,566 @@
use crate::{backend::Backend, BasicOps, Shape, Tensor};
use alloc::format;
use alloc::string::{String, ToString};
use alloc::vec;
use alloc::vec::Vec;
use core::ops::Range;
/// The struct should always be used with the [check](crate::check) macro.
///
/// This is a simple public crate data structure that efficiently checks tensor operations and
/// formats clear error messages. It's crucial that the checks are really fast, but it doesn't matter
/// when a failed check is discovered since the program will panic.
///
/// # Notes
///
/// Failing tensor checks will always result in a panic.
/// As mentioned in [The Rust Programming Language book](https://doc.rust-lang.org/book/ch09-03-to-panic-or-not-to-panic.html),
/// when there is no way to recover, panic should be used instead of a result.
///
/// Most users will unwrap the results anyway, which will worsen the clarity of the code. Almost
/// all checks highlight programming errors, which means invalid programs that should be fixed.
/// Checks are not the ideal way to help users write correct programs, but they are still better
/// than backend errors. Other forms of compile-time validation could be developed, such as named
/// tensors, but we have to carefully evaluate the ease of use of the Tensor API. Adding overly
/// complex type validation checks might drastically worsen the API and result in harder-to-maintain
/// programs.
///
/// # Design
///
/// Maybe the Backend API should return a result for each operation, which would allow handling
/// all checks, even the ones that can't be efficiently checked before performing an operation,
/// such as the `index_select` operation. The downside of that approach is that all backend
/// implementation might re-implement the same checks, which may result in uncessary code
/// duplication. Maybe a combination of both strategies could help to cover all usecases.
pub enum TensorCheck {
Ok,
Failed(FailedTensorCheck),
}
impl TensorCheck {
/// Checks device and shape compatibility for element wise binary operations.
pub fn binary_ops_ew<B: Backend, const D: usize, K: BasicOps<B>>(
ops: &str,
lhs: &Tensor<B, D, K>,
rhs: &Tensor<B, D, K>,
) -> Self {
Self::Ok
.binary_ops_device(ops, &lhs.device(), &rhs.device())
.binary_ops_ew_shape(ops, &lhs.shape(), &rhs.shape())
}
pub fn reshape<const D1: usize, const D2: usize>(
original: &Shape<D1>,
target: &Shape<D2>,
) -> Self {
let mut check = Self::Ok;
if original.num_elements() != target.num_elements() {
check = check.register("Reshape", TensorError::new(
"The given shape doesn't have the same number of elements as the current tensor.",
)
.details(format!(
"Current shape: {:?}, target shape: {:?}.",
original.dims, target.dims
)));
}
check
}
pub fn flatten<const D1: usize, const D2: usize>(start_dim: usize, end_dim: usize) -> Self {
let mut check = Self::Ok;
if start_dim > end_dim {
check = check.register(
"Flatten",
TensorError::new(format!(
"The start dim ({start_dim}) must be smaller than the end dim ({end_dim})"
)),
);
}
if D2 > D1 {
check = check.register(
"Flatten",
TensorError::new(format!("Result dim ({D2}) must be smaller than ({D1})")),
);
}
if D1 < end_dim + 1 {
check = check.register(
"Flatten",
TensorError::new(format!(
"The end dim ({end_dim}) must be greater than the tensor dim ({D2})"
)),
);
}
check
}
pub fn unsqueeze<const D1: usize, const D2: usize>() -> Self {
let mut check = Self::Ok;
if D2 < D1 {
check = check.register(
"Unsqueeze",
TensorError::new(format!(
"Can't unsqueeze smaller tensor, got dim {D2}, expected > {D1}"
)),
);
}
check
}
pub fn swap_dims<const D: usize>(dim1: usize, dim2: usize) -> Self {
let mut check = Self::Ok;
if dim1 > D || dim2 > D {
check = check.register(
"Swap Dims",
TensorError::new("The swap dimensions must be smaller than the tensor dimension")
.details(format!(
"Swap dims ({dim1}, {dim2}) on tensor with ({D}) dimensions."
)),
);
}
check
}
pub fn matmul<B: Backend, const D: usize>(lhs: &Tensor<B, D>, rhs: &Tensor<B, D>) -> Self {
let mut check = Self::Ok;
check = check.binary_ops_device("Matmul", &lhs.device(), &rhs.device());
if D < 2 {
return check;
}
let shape_lhs = lhs.shape();
let shape_rhs = rhs.shape();
let dim_lhs = shape_lhs.dims[D - 1];
let dim_rhs = shape_rhs.dims[D - 2];
if dim_lhs != dim_rhs {
check = check.register(
"Matmul",
TensorError::new(format!(
"The inner dimension of matmul should be the same, but got {} and {}.",
dim_lhs, dim_rhs
))
.details(format!(
"Lhs shape {:?}, rhs shape {:?}.",
shape_lhs.dims, shape_rhs.dims
)),
);
}
check
}
pub fn cat<B: Backend, const D: usize, K: BasicOps<B>>(
tensors: &[Tensor<B, D, K>],
dim: usize,
) -> Self {
let mut check = Self::Ok;
if dim >= D {
check = check.register(
"Cat",
TensorError::new(
"Can't concatenate tensors on a dim that exceeds the tensors dimension",
)
.details(format!(
"Trying to concatenate tensors with {D} dimensions on axis {dim}."
)),
);
}
if tensors.is_empty() {
return check.register(
"Cat",
TensorError::new("Can't concatenate an empty list of tensors."),
);
}
let mut shape_reference = tensors.get(0).unwrap().shape();
shape_reference.dims[dim] = 1; // We want to check every dims except the one where the
// concatenation happens.
for tensor in tensors {
let mut shape = tensor.shape();
shape.dims[dim] = 1; // Ignore the concatenate dim.
if shape_reference != shape {
return check.register(
"Cat",
TensorError::new("Can't concatenate tensors with different shapes, except for the provided dimension").details(
format!(
"Provided dimension ({}), tensors shapes: {:?}",
dim,
tensors.iter().map(Tensor::shape).collect::<Vec<_>>()
),
),
);
}
}
check
}
pub fn index<const D1: usize, const D2: usize>(
shape: &Shape<D1>,
indexes: &[Range<usize>; D2],
) -> Self {
let mut check = Self::Ok;
let n_dims_tensor = D1;
let n_dims_indexes = D2;
if n_dims_tensor < n_dims_indexes {
check = check.register("Index",
TensorError::new ("The provided indexes array has a higher number of dimensions than the current tensor.")
.details(
format!(
"The indexes array must be smaller or equal to the tensor number of dimensions. \
Tensor number of dimensions: {n_dims_tensor}, indexes array lenght {n_dims_indexes}."
)));
}
for i in 0..usize::min(D1, D2) {
let d_tensor = shape.dims[i];
let index = indexes.get(i).unwrap();
if index.end > d_tensor {
check = check.register(
"Index",
TensorError::new("The provided indexes array has a range that exceeds the current tensor size.")
.details(format!(
"The range ({}..{}) exceeds the size of the tensor ({}) at dimension {}. \
Tensor shape {:?}, provided indexes {:?}.",
index.start,
index.end,
d_tensor,
i,
shape.dims,
indexes,
)));
}
if index.start >= index.end {
check = check.register(
"Index",
TensorError::new("The provided indexes array has a range where the start index is bigger or equal to its end.")
.details(format!(
"The range at dimension '{}' starts at '{}' and is greater or equal to its end '{}'. \
Tensor shape {:?}, provided indexes {:?}.",
i,
index.start,
index.end,
shape.dims,
indexes,
)));
}
}
check
}
pub fn index_assign<const D1: usize, const D2: usize>(
shape: &Shape<D1>,
shape_value: &Shape<D1>,
indexes: &[Range<usize>; D2],
) -> Self {
let mut check = Self::Ok;
if D1 < D2 {
check = check.register("Index Assign",
TensorError::new ("The provided indexes array has a higher number of dimensions than the current tensor.")
.details(
format!(
"The indexes array must be smaller or equal to the tensor number of dimensions. \
Tensor number of dimensions: {D1}, indexes array lenght {D2}."
)));
}
for i in 0..usize::min(D1, D2) {
let d_tensor = shape.dims[i];
let d_tensor_value = shape_value.dims[i];
let index = indexes.get(i).unwrap();
if index.end > d_tensor {
check = check.register(
"Index Assign",
TensorError::new("The provided indexes array has a range that exceeds the current tensor size.")
.details(format!(
"The range ({}..{}) exceeds the size of the tensor ({}) at dimension {}. \
Current tensor shape {:?}, value tensor shape {:?}, provided indexes {:?}.",
index.start,
index.end,
d_tensor,
i,
shape.dims,
shape_value.dims,
indexes,
)));
}
if index.end - index.start != d_tensor_value {
check = check.register(
"Index Assign",
TensorError::new("The value tensor must match the amount of elements selected with the indexes array")
.details(format!(
"The range ({}..{}) doesn't match the number of elements of the value tensor ({}) at dimension {}. \
Current tensor shape {:?}, value tensor shape {:?}, provided indexes {:?}.",
index.start,
index.end,
d_tensor_value,
i,
shape.dims,
shape_value.dims,
indexes,
)));
}
if index.start >= index.end {
check = check.register(
"Index Assign",
TensorError::new("The provided indexes array has a range where the start index is bigger or equal to its end.")
.details(format!(
"The range at dimension '{}' starts at '{}' and is greater or equal to its end '{}'. \
Current tensor shape {:?}, value tensor shape {:?}, provided indexes {:?}.",
i,
index.start,
index.end,
shape.dims,
shape_value.dims,
indexes,
)));
}
}
check
}
/// Checks aggregate dimension such as mean and sum.
pub fn aggregate_dim<const D: usize>(ops: &str, dim: usize) -> Self {
let mut check = Self::Ok;
if dim > D {
check = check.register(
ops,
TensorError::new(format!(
"Can't aggregate a tensor with ({D}) dimensions on axis ({dim})"
)),
);
}
check
}
/// The goal is to minimize the cost of checks when there are no error, but it's way less
/// important when an error occured, crafting a comprehensive error message is more important
/// than optimizing string manipulation.
fn register(self, ops: &str, error: TensorError) -> Self {
let errors = match self {
Self::Ok => vec![error],
Self::Failed(mut failed) => {
failed.errors.push(error);
failed.errors
}
};
Self::Failed(FailedTensorCheck {
ops: ops.to_string(),
errors,
})
}
/// Checks if shapes are compatible for element wise operations supporting broadcasting.
pub fn binary_ops_ew_shape<const D: usize>(
self,
ops: &str,
lhs: &Shape<D>,
rhs: &Shape<D>,
) -> Self {
let mut check = self;
for i in 0..D {
let d_lhs = lhs.dims[i];
let d_rhs = rhs.dims[i];
if d_lhs != d_rhs {
let is_broadcast = d_lhs == 1 || d_rhs == 1;
if is_broadcast {
continue;
}
check = check.register(ops,
TensorError::new("The provided tensors have incompatible shapes.")
.details(format!(
"Incompatible size at dimension '{}' => '{} != {}', which can't be broadcasted. \
Lhs tensor shape {:?}, Rhs tensor shape {:?}.",
i,
d_lhs,
d_rhs,
lhs.dims,
rhs.dims,
)));
}
}
check
}
/// Checks if tensor devices are equal.
fn binary_ops_device<Device: PartialEq + core::fmt::Debug>(
self,
ops: &str,
lhs: &Device,
rhs: &Device,
) -> Self {
match lhs != rhs {
true => self.register(
ops,
TensorError::new("The provided tensors are not on the same device.").details(
format!("Lhs tensor device {:?}, Rhs tensor device {:?}.", lhs, rhs,),
),
),
false => self,
}
}
}
pub struct FailedTensorCheck {
ops: String,
errors: Vec<TensorError>,
}
impl FailedTensorCheck {
/// Format all the checks into a single message ready to be printed by a [panic](core::panic).
pub fn format(self) -> String {
self.errors.into_iter().enumerate().fold(
format!(
"=== Tensor Operation Error ===\n Operation: '{}'\n Reason:",
self.ops
),
|accum, (number, error)| accum + error.format(number + 1).as_str(),
) + "\n"
}
}
struct TensorError {
description: String,
details: Option<String>,
}
impl TensorError {
pub fn new<S: Into<String>>(description: S) -> Self {
TensorError {
description: description.into(),
details: None,
}
}
pub fn details<S: Into<String>>(mut self, details: S) -> Self {
self.details = Some(details.into());
self
}
fn format(self, number: usize) -> String {
let mut message = format!("\n {number}. ");
message += self.description.as_str();
message += " ";
if let Some(details) = self.details {
message += details.as_str();
message += " ";
}
message
}
}
/// We use a macro for all checks, since the panic message file and line number will match the
/// function that does the check instead of a the generic error.rs crate private unreleated file
/// and line number.
#[macro_export(local_inner_macros)]
macro_rules! check {
($check:expr) => {
if let TensorCheck::Failed(check) = $check {
core::panic!("{}", check.format());
}
};
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[should_panic]
fn reshape_invalid_shape() {
check!(TensorCheck::reshape(
&Shape::new([2, 2]),
&Shape::new([1, 3])
));
}
#[test]
fn reshape_valid_shape() {
check!(TensorCheck::reshape(
&Shape::new([2, 2]),
&Shape::new([1, 4])
));
}
#[test]
#[should_panic]
fn index_range_exceed_dimension() {
check!(TensorCheck::index(
&Shape::new([3, 5, 7]),
&[0..2, 0..4, 1..8]
));
}
#[test]
#[should_panic]
fn index_range_exceed_number_of_dimensions() {
check!(TensorCheck::index(&Shape::new([3, 5]), &[0..1, 0..1, 0..1]));
}
#[test]
#[should_panic]
fn binary_ops_shapes_no_broadcast() {
check!(TensorCheck::binary_ops_ew_shape(
TensorCheck::Ok,
"TestOps",
&Shape::new([3, 5]),
&Shape::new([3, 6])
));
}
#[test]
fn binary_ops_shapes_with_broadcast() {
check!(TensorCheck::binary_ops_ew_shape(
TensorCheck::Ok,
"Test",
&Shape::new([3, 5]),
&Shape::new([1, 5])
));
}
#[test]
#[should_panic]
fn binary_ops_devices() {
check!(TensorCheck::binary_ops_device(
TensorCheck::Ok,
"Test",
&5, // We can pass anything that implements PartialEq as device
&8
));
}
}

View File

@ -3,6 +3,8 @@ use core::convert::TryInto;
use core::ops::Range;
use crate::backend::ADBackend;
use crate::check;
use crate::check::TensorCheck;
use crate::tensor::backend::Backend;
use crate::tensor::stats;
use crate::tensor::ElementConversion;
@ -196,6 +198,7 @@ where
///
/// If the dimensions exceed the shape of than the tensor.
pub fn swap_dims(self, dim1: usize, dim2: usize) -> Self {
check!(TensorCheck::swap_dims::<D>(dim1, dim2));
Self::new(B::swap_dims(self.primitive, dim1, dim2))
}
@ -207,6 +210,7 @@ where
///
/// If the two tensors dont' have a compatible shape.
pub fn matmul(self, other: Self) -> Self {
check!(TensorCheck::matmul(&self, &other));
Self::new(B::matmul(self.primitive, other.primitive))
}

View File

@ -1,3 +1,5 @@
pub(crate) mod check;
mod base;
mod bool;
mod float;

View File

@ -1,5 +1,6 @@
use crate::{
backend::Backend, Bool, Element, ElementConversion, Float, Int, Shape, Tensor, TensorKind,
backend::Backend, check, check::TensorCheck, BasicOps, Bool, Element, ElementConversion, Float,
Int, Shape, Tensor, TensorKind,
};
impl<B, const D: usize, K> Tensor<B, D, K>
@ -12,6 +13,7 @@ where
/// `y = x2 + x1`
#[allow(clippy::should_implement_trait)]
pub fn add(self, other: Self) -> Self {
check!(TensorCheck::binary_ops_ew("Add", &self, &other));
Self::new(K::add(self.primitive, other.primitive))
}
@ -27,6 +29,7 @@ where
/// `y = x2 - x1`
#[allow(clippy::should_implement_trait)]
pub fn sub(self, other: Self) -> Self {
check!(TensorCheck::binary_ops_ew("Sub", &self, &other));
Self::new(K::sub(self.primitive, other.primitive))
}
@ -42,6 +45,7 @@ where
/// `y = x2 / x1`
#[allow(clippy::should_implement_trait)]
pub fn div(self, other: Self) -> Self {
check!(TensorCheck::binary_ops_ew("Div", &self, &other));
Self::new(K::div(self.primitive, other.primitive))
}
@ -57,6 +61,7 @@ where
/// `y = x2 * x1`
#[allow(clippy::should_implement_trait)]
pub fn mul(self, other: Self) -> Self {
check!(TensorCheck::binary_ops_ew("Mul", &self, &other));
Self::new(K::mul(self.primitive, other.primitive))
}
@ -107,11 +112,13 @@ where
/// Aggregate all elements along the given *dimension* or *axis* in the tensor with the mean operation.
pub fn mean_dim(self, dim: usize) -> Self {
check!(TensorCheck::aggregate_dim::<D>("Mean", dim));
Self::new(K::mean_dim(self.primitive, dim))
}
/// Aggregate all elements along the given *dimension* or *axis* in the tensor with the sum operation.
pub fn sum_dim(self, dim: usize) -> Self {
check!(TensorCheck::aggregate_dim::<D>("Sum", dim));
Self::new(K::sum_dim(self.primitive, dim))
}
@ -121,6 +128,7 @@ where
///
/// If the two tensors don't have the same shape.
pub fn greater(self, other: Self) -> Tensor<B, D, Bool> {
check!(TensorCheck::binary_ops_ew("Greater", &self, &other));
K::greater(self.primitive, other.primitive)
}
@ -130,6 +138,7 @@ where
///
/// If the two tensors don't have the same shape.
pub fn greater_equal(self, other: Self) -> Tensor<B, D, Bool> {
check!(TensorCheck::binary_ops_ew("Greater_equal", &self, &other));
K::greater_equal(self.primitive, other.primitive)
}
@ -139,6 +148,7 @@ where
///
/// If the two tensors don't have the same shape.
pub fn lower(self, other: Self) -> Tensor<B, D, Bool> {
check!(TensorCheck::binary_ops_ew("Lower", &self, &other));
K::lower(self.primitive, other.primitive)
}
@ -148,6 +158,7 @@ where
///
/// If the two tensors don't have the same shape.
pub fn lower_equal(self, other: Self) -> Tensor<B, D, Bool> {
check!(TensorCheck::binary_ops_ew("Lower_equal", &self, &other));
K::lower_equal(self.primitive, other.primitive)
}
@ -223,8 +234,8 @@ where
/// # Warnings
///
/// This is an internal trait, use the public API provided by [tensor struct](Tensor).
pub trait Numeric<B: Backend>: TensorKind<B> {
type Elem: Element;
pub trait Numeric<B: Backend>: BasicOps<B> {
type NumElem: Element;
fn add<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Primitive<D>) -> Self::Primitive<D>;
fn add_scalar<const D: usize, E: ElementConversion>(
@ -257,28 +268,33 @@ pub trait Numeric<B: Backend>: TensorKind<B> {
lhs: Self::Primitive<D>,
rhs: Self::Primitive<D>,
) -> Tensor<B, D, Bool>;
fn greater_elem<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Elem)
-> Tensor<B, D, Bool>;
fn greater_elem<const D: usize>(
lhs: Self::Primitive<D>,
rhs: Self::NumElem,
) -> Tensor<B, D, Bool>;
fn greater_equal<const D: usize>(
lhs: Self::Primitive<D>,
rhs: Self::Primitive<D>,
) -> Tensor<B, D, Bool>;
fn greater_equal_elem<const D: usize>(
lhs: Self::Primitive<D>,
rhs: Self::Elem,
rhs: Self::NumElem,
) -> Tensor<B, D, Bool>;
fn lower<const D: usize>(
lhs: Self::Primitive<D>,
rhs: Self::Primitive<D>,
) -> Tensor<B, D, Bool>;
fn lower_elem<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Elem) -> Tensor<B, D, Bool>;
fn lower_elem<const D: usize>(
lhs: Self::Primitive<D>,
rhs: Self::NumElem,
) -> Tensor<B, D, Bool>;
fn lower_equal<const D: usize>(
lhs: Self::Primitive<D>,
rhs: Self::Primitive<D>,
) -> Tensor<B, D, Bool>;
fn lower_equal_elem<const D: usize>(
lhs: Self::Primitive<D>,
rhs: Self::Elem,
rhs: Self::NumElem,
) -> Tensor<B, D, Bool>;
fn index_select<const D: usize>(
tensor: Self::Primitive<D>,
@ -303,7 +319,7 @@ pub trait Numeric<B: Backend>: TensorKind<B> {
}
impl<B: Backend> Numeric<B> for Int {
type Elem = B::IntElem;
type NumElem = B::IntElem;
fn add<const D: usize>(
lhs: Self::Primitive<D>,
@ -384,7 +400,7 @@ impl<B: Backend> Numeric<B> for Int {
fn greater_elem<const D: usize>(
lhs: Self::Primitive<D>,
rhs: Self::Elem,
rhs: Self::NumElem,
) -> Tensor<B, D, Bool> {
Tensor::new(B::int_greater_elem(lhs, rhs))
}
@ -398,7 +414,7 @@ impl<B: Backend> Numeric<B> for Int {
fn greater_equal_elem<const D: usize>(
lhs: Self::Primitive<D>,
rhs: Self::Elem,
rhs: Self::NumElem,
) -> Tensor<B, D, Bool> {
Tensor::new(B::int_greater_equal_elem(lhs, rhs))
}
@ -410,7 +426,10 @@ impl<B: Backend> Numeric<B> for Int {
Tensor::new(B::int_lower(lhs, rhs))
}
fn lower_elem<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Elem) -> Tensor<B, D, Bool> {
fn lower_elem<const D: usize>(
lhs: Self::Primitive<D>,
rhs: Self::NumElem,
) -> Tensor<B, D, Bool> {
Tensor::new(B::int_lower_elem(lhs, rhs))
}
@ -423,7 +442,7 @@ impl<B: Backend> Numeric<B> for Int {
fn lower_equal_elem<const D: usize>(
lhs: Self::Primitive<D>,
rhs: Self::Elem,
rhs: Self::NumElem,
) -> Tensor<B, D, Bool> {
Tensor::new(B::int_lower_equal_elem(lhs, rhs))
}
@ -461,7 +480,7 @@ impl<B: Backend> Numeric<B> for Int {
}
impl<B: Backend> Numeric<B> for Float {
type Elem = B::FloatElem;
type NumElem = B::FloatElem;
fn add<const D: usize>(
lhs: Self::Primitive<D>,
@ -542,7 +561,7 @@ impl<B: Backend> Numeric<B> for Float {
fn greater_elem<const D: usize>(
lhs: Self::Primitive<D>,
rhs: Self::Elem,
rhs: Self::NumElem,
) -> Tensor<B, D, Bool> {
Tensor::new(B::greater_elem(lhs, rhs))
}
@ -556,7 +575,7 @@ impl<B: Backend> Numeric<B> for Float {
fn greater_equal_elem<const D: usize>(
lhs: Self::Primitive<D>,
rhs: Self::Elem,
rhs: Self::NumElem,
) -> Tensor<B, D, Bool> {
Tensor::new(B::greater_equal_elem(lhs, rhs))
}
@ -568,7 +587,10 @@ impl<B: Backend> Numeric<B> for Float {
Tensor::new(B::lower(lhs, rhs))
}
fn lower_elem<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Elem) -> Tensor<B, D, Bool> {
fn lower_elem<const D: usize>(
lhs: Self::Primitive<D>,
rhs: Self::NumElem,
) -> Tensor<B, D, Bool> {
Tensor::new(B::lower_elem(lhs, rhs))
}
@ -581,7 +603,7 @@ impl<B: Backend> Numeric<B> for Float {
fn lower_equal_elem<const D: usize>(
lhs: Self::Primitive<D>,
rhs: Self::Elem,
rhs: Self::NumElem,
) -> Tensor<B, D, Bool> {
Tensor::new(B::lower_equal_elem(lhs, rhs))
}