mirror of https://github.com/tracel-ai/burn.git
Add tril_mask, triu_mask and diag_mask ops (#1479)
This commit is contained in:
parent
c729401fb2
commit
8a8300c1fb
|
@ -279,13 +279,16 @@ Those operations are only available for `Int` tensors.
|
||||||
|
|
||||||
Those operations are only available for `Bool` tensors.
|
Those operations are only available for `Bool` tensors.
|
||||||
|
|
||||||
| Burn API | PyTorch Equivalent |
|
| Burn API | PyTorch Equivalent |
|
||||||
| ------------------- | ------------------------------- |
|
| ----------------------------------- | ------------------------------- |
|
||||||
| `tensor.float()` | `tensor.to(torch.float)` |
|
| `Tensor.diag_mask(shape, diagonal)` | N/A |
|
||||||
| `tensor.int()` | `tensor.to(torch.long)` |
|
| `Tensor.tril_mask(shape, diagonal)` | N/A |
|
||||||
| `tensor.not()` | `tensor.logical_not()` |
|
| `Tensor.triu_mask(shape, diagonal)` | N/A |
|
||||||
| `tensor.argwhere()` | `tensor.argwhere()` |
|
| `tensor.argwhere()` | `tensor.argwhere()` |
|
||||||
| `tensor.nonzero()` | `tensor.nonzero(as_tuple=True)` |
|
| `tensor.float()` | `tensor.to(torch.float)` |
|
||||||
|
| `tensor.int()` | `tensor.to(torch.long)` |
|
||||||
|
| `tensor.nonzero()` | `tensor.nonzero(as_tuple=True)` |
|
||||||
|
| `tensor.not()` | `tensor.logical_not()` |
|
||||||
|
|
||||||
## Activation Functions
|
## Activation Functions
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,7 @@ pub fn generate_autoregressive_mask<B: Backend>(
|
||||||
seq_length: usize,
|
seq_length: usize,
|
||||||
device: &B::Device,
|
device: &B::Device,
|
||||||
) -> Tensor<B, 3, Bool> {
|
) -> Tensor<B, 3, Bool> {
|
||||||
|
// TODO replace with more efficient op of `triu_mask` and `expand`
|
||||||
let mut mask = Tensor::<B, 3, Int>::zeros([1, seq_length, seq_length], device);
|
let mut mask = Tensor::<B, 3, Int>::zeros([1, seq_length, seq_length], device);
|
||||||
|
|
||||||
for i in 0..(seq_length - 1) {
|
for i in 0..(seq_length - 1) {
|
||||||
|
|
|
@ -1,8 +1,20 @@
|
||||||
use crate::{backend::Backend, Bool, Data, Int, Tensor};
|
use crate::{backend::Backend, Bool, Data, Int, Shape, Tensor};
|
||||||
use alloc::vec::Vec;
|
use alloc::vec::Vec;
|
||||||
|
|
||||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||||
use crate::{argwhere, tensor::Shape};
|
use crate::argwhere;
|
||||||
|
|
||||||
|
/// The part of the tensor to keep when creating a triangular mask.
|
||||||
|
enum TriPart {
|
||||||
|
/// Upper triangular part.
|
||||||
|
Upper,
|
||||||
|
|
||||||
|
/// Lower triangular part.
|
||||||
|
Lower,
|
||||||
|
|
||||||
|
/// Diagonal part.
|
||||||
|
Diagonal,
|
||||||
|
}
|
||||||
|
|
||||||
impl<B, const D: usize> Tensor<B, D, Bool>
|
impl<B, const D: usize> Tensor<B, D, Bool>
|
||||||
where
|
where
|
||||||
|
@ -42,7 +54,7 @@ where
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Compute the indices of the elements that are non-zero.
|
/// Compute the indices of the elements that are true.
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
///
|
///
|
||||||
|
@ -59,7 +71,7 @@ where
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Compute the indices of the elements that are non-zero, grouped by element.
|
/// Compute the indices of the elements that are true, grouped by element.
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
///
|
///
|
||||||
|
@ -70,7 +82,7 @@ where
|
||||||
Tensor::new(B::bool_argwhere(self.primitive))
|
Tensor::new(B::bool_argwhere(self.primitive))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Compute the indices of the elements that are non-zero, grouped by element.
|
/// Compute the indices of the elements that are true, grouped by element.
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
///
|
///
|
||||||
|
@ -80,4 +92,102 @@ where
|
||||||
pub async fn argwhere(self) -> Tensor<B, 2, Int> {
|
pub async fn argwhere(self) -> Tensor<B, 2, Int> {
|
||||||
Tensor::new(argwhere::<B, D>(self.primitive).await)
|
Tensor::new(argwhere::<B, D>(self.primitive).await)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Creates a mask for the upper, lower triangle, or diagonal of a matrix, which can be used to
|
||||||
|
/// fill the specified area with a value.
|
||||||
|
fn tri_mask<S: Into<Shape<D>>>(
|
||||||
|
shape: S,
|
||||||
|
tri_part: TriPart,
|
||||||
|
offset: i64,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Self {
|
||||||
|
let shape = shape.into();
|
||||||
|
let height = shape.dims[D - 2];
|
||||||
|
let width = shape.dims[D - 1];
|
||||||
|
|
||||||
|
// Generate row and column index tensors.
|
||||||
|
let row_indices: Tensor<B, 1, Int> = Tensor::arange(0..height as i64, device);
|
||||||
|
let col_indices: Tensor<B, 1, Int> = Tensor::arange(0..width as i64, device);
|
||||||
|
|
||||||
|
// Prepare shapes for broadcasting.
|
||||||
|
let mut row_shape = [1; D];
|
||||||
|
row_shape[D - 2] = height;
|
||||||
|
let mut col_shape = [1; D];
|
||||||
|
col_shape[D - 1] = width;
|
||||||
|
|
||||||
|
// Reshape for broadcasting.
|
||||||
|
let row_broadcast = row_indices.reshape(Shape::new(row_shape));
|
||||||
|
let col_broadcast = col_indices.reshape(Shape::new(col_shape));
|
||||||
|
|
||||||
|
// Broadcasting trick to create a matrix that facilitates comparison for mask generation.
|
||||||
|
let matrix = row_broadcast.clone() - (col_broadcast.clone() - offset);
|
||||||
|
|
||||||
|
// Select the appropriate comparison function based on `tri_part`.
|
||||||
|
let compare = match tri_part {
|
||||||
|
TriPart::Upper => Tensor::greater_elem,
|
||||||
|
TriPart::Lower => Tensor::lower_elem,
|
||||||
|
TriPart::Diagonal => Tensor::not_equal_elem,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Generate and return the mask by applying the comparison to the matrix.
|
||||||
|
compare(matrix, 0).unsqueeze()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a mask for the upper triangle of a matrix, which can be used to fill the specified
|
||||||
|
/// area with a value.
|
||||||
|
///
|
||||||
|
/// This function generates a boolean tensor representing the mask of the upper triangle of a matrix.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `shape`: The shape of the matrix.
|
||||||
|
/// * `offset`: The offset from the diagonal, where 0 means the diagonal, and positive values shift
|
||||||
|
/// towards the upper triangle.
|
||||||
|
/// * `device`: The device on which the tensor will be allocated.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// Returns a boolean tensor where `true` indicates the elements of the matrix that are part of the
|
||||||
|
/// upper triangle taking into account the specified `offset`.
|
||||||
|
pub fn triu_mask<S: Into<Shape<D>>>(shape: S, offset: i64, device: &B::Device) -> Self {
|
||||||
|
Self::tri_mask(shape, TriPart::Upper, offset, device)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a mask for the lower triangle of a matrix, which can be used to fill the specified
|
||||||
|
/// area with a value.
|
||||||
|
///
|
||||||
|
/// This function generates a boolean tensor representing the mask of the lower triangle of a matrix.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `shape`: The shape of the matrix.
|
||||||
|
/// * `offset`: The offset from the diagonal, where 0 means the diagonal, and negative values shift
|
||||||
|
/// towards the lower triangle.
|
||||||
|
/// * `device`: The device on which the tensor will be allocated.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// Returns a boolean tensor where `true` indicates the elements of the matrix that are part of the
|
||||||
|
/// lower triangle taking into account the specified `offset`.
|
||||||
|
pub fn tril_mask<S: Into<Shape<D>>>(shape: S, offset: i64, device: &B::Device) -> Self {
|
||||||
|
Self::tri_mask(shape, TriPart::Lower, offset, device)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a mask for the diagonal of a matrix, which can be used to fill the specified
|
||||||
|
/// area with a value.
|
||||||
|
///
|
||||||
|
/// This function generates a boolean tensor representing the mask of the diagonal of a matrix.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `shape`: The shape of the matrix.
|
||||||
|
/// * `device`: The device on which the tensor will be allocated.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// Returns a boolean tensor where `true` indicates the elements of the matrix that are part of the
|
||||||
|
/// diagonal.
|
||||||
|
pub fn diag_mask<S: Into<Shape<D>>>(shape: S, offset: i64, device: &B::Device) -> Self {
|
||||||
|
Self::tri_mask(shape, TriPart::Diagonal, offset, device)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
use crate::alloc::borrow::ToOwned;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
backend::Backend, check, check::TensorCheck, BasicOps, Bool, Distribution, Element,
|
backend::Backend, check, check::TensorCheck, BasicOps, Bool, Distribution, Element,
|
||||||
ElementConversion, Float, Int, Shape, Tensor, TensorKind,
|
ElementConversion, Float, Int, Shape, Tensor, TensorKind,
|
||||||
|
@ -482,42 +484,6 @@ where
|
||||||
Self::new(K::abs(self.primitive))
|
Self::new(K::abs(self.primitive))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the triangular part of a matrix (2-D tensor) or batch of matrices,
|
|
||||||
/// based on the specified comparison method, zeroing out the other elements.
|
|
||||||
///
|
|
||||||
/// # Parameters
|
|
||||||
///
|
|
||||||
/// - `diagonal`: The diagonal from which the triangular part is computed.
|
|
||||||
/// - `compare`: A comparison function determining which part of the triangle to zero out.
|
|
||||||
/// Use `Tensor::<B, D, Int>::greater_elem` for upper triangular
|
|
||||||
/// and `Tensor::<B, D, Int>::lower_elem` for lower triangular.
|
|
||||||
///
|
|
||||||
pub(crate) fn tri_compare<F>(self, diagonal: i64, compare: F) -> Self
|
|
||||||
where
|
|
||||||
F: FnOnce(Tensor<B, D, Int>, i64) -> Tensor<B, D, Bool>,
|
|
||||||
{
|
|
||||||
check!(TensorCheck::tri::<{ D }>());
|
|
||||||
|
|
||||||
let shape = self.shape();
|
|
||||||
let height = shape.dims[D - 2];
|
|
||||||
let width = shape.dims[D - 1];
|
|
||||||
|
|
||||||
let row_indices: Tensor<B, 1, Int> = Tensor::arange(0..height as i64, &self.device());
|
|
||||||
let col_indices: Tensor<B, 1, Int> = Tensor::arange(0..width as i64, &self.device());
|
|
||||||
|
|
||||||
let mut row_shape = [1; D];
|
|
||||||
row_shape[D - 2] = height;
|
|
||||||
let mut col_shape = [1; D];
|
|
||||||
col_shape[D - 1] = width;
|
|
||||||
|
|
||||||
let row_broadcast = row_indices.reshape(Shape::new(row_shape));
|
|
||||||
let col_broadcast = col_indices.reshape(Shape::new(col_shape));
|
|
||||||
|
|
||||||
let mask = compare(row_broadcast - (col_broadcast - diagonal), 0).unsqueeze();
|
|
||||||
|
|
||||||
self.mask_fill(mask, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input,
|
/// Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input,
|
||||||
/// the other elements of the result tensor out are set to 0.
|
/// the other elements of the result tensor out are set to 0.
|
||||||
///
|
///
|
||||||
|
@ -546,7 +512,13 @@ where
|
||||||
/// }
|
/// }
|
||||||
/// ```
|
/// ```
|
||||||
pub fn triu(self, diagonal: i64) -> Self {
|
pub fn triu(self, diagonal: i64) -> Self {
|
||||||
self.tri_compare(diagonal, Tensor::greater_elem)
|
check!(TensorCheck::tri::<{ D }>());
|
||||||
|
|
||||||
|
// last two dimensions
|
||||||
|
let shape = &self.shape().dims[D - 2..].to_owned();
|
||||||
|
|
||||||
|
let mask = Tensor::<B, 2, Bool>::triu_mask(shape, diagonal, &self.device()).unsqueeze();
|
||||||
|
self.mask_fill(mask, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the lower triangular part of a matrix (2-D tensor) or batch of matrices input,
|
/// Returns the lower triangular part of a matrix (2-D tensor) or batch of matrices input,
|
||||||
|
@ -578,7 +550,13 @@ where
|
||||||
/// }
|
/// }
|
||||||
/// ```
|
/// ```
|
||||||
pub fn tril(self, diagonal: i64) -> Self {
|
pub fn tril(self, diagonal: i64) -> Self {
|
||||||
self.tri_compare(diagonal, Tensor::lower_elem)
|
check!(TensorCheck::tri::<{ D }>());
|
||||||
|
|
||||||
|
// last two dimensions
|
||||||
|
let shape = &self.shape().dims[D - 2..].to_owned();
|
||||||
|
|
||||||
|
let mask = Tensor::<B, 2, Bool>::tril_mask(shape, diagonal, &self.device()).unsqueeze();
|
||||||
|
self.mask_fill(mask, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Applies element wise power operation with a float Tensor
|
/// Applies element wise power operation with a float Tensor
|
||||||
|
|
|
@ -89,6 +89,7 @@ macro_rules! testgen_all {
|
||||||
burn_tensor::testgen_bool!();
|
burn_tensor::testgen_bool!();
|
||||||
burn_tensor::testgen_argwhere_nonzero!();
|
burn_tensor::testgen_argwhere_nonzero!();
|
||||||
burn_tensor::testgen_sign!();
|
burn_tensor::testgen_sign!();
|
||||||
|
burn_tensor::testgen_tri_mask!();
|
||||||
|
|
||||||
// test stats
|
// test stats
|
||||||
burn_tensor::testgen_var!();
|
burn_tensor::testgen_var!();
|
||||||
|
|
|
@ -51,3 +51,4 @@ mod sub;
|
||||||
mod tanh;
|
mod tanh;
|
||||||
mod transpose;
|
mod transpose;
|
||||||
mod tri;
|
mod tri;
|
||||||
|
mod tri_mask;
|
||||||
|
|
|
@ -0,0 +1,97 @@
|
||||||
|
#[burn_tensor_testgen::testgen(tri_mask)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use burn_tensor::{Bool, Data, Tensor};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn square_diag() {
|
||||||
|
let device = Default::default();
|
||||||
|
let data_expected = Data::from([
|
||||||
|
[false, true, true],
|
||||||
|
[true, false, true],
|
||||||
|
[true, true, false],
|
||||||
|
]);
|
||||||
|
let tensor = Tensor::<TestBackend, 2, Bool>::diag_mask([3, 3], 0, &device);
|
||||||
|
assert_eq!(data_expected, tensor.into_data());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn square_diag_offset() {
|
||||||
|
let device = Default::default();
|
||||||
|
let data_expected =
|
||||||
|
Data::from([[true, false, true], [true, true, false], [true, true, true]]);
|
||||||
|
let tensor = Tensor::<TestBackend, 2, Bool>::diag_mask([3, 3], 1, &device);
|
||||||
|
assert_eq!(data_expected, tensor.into_data());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn square_tri_upper() {
|
||||||
|
let device = Default::default();
|
||||||
|
let data_expected = Data::from([
|
||||||
|
[false, false, false],
|
||||||
|
[true, false, false],
|
||||||
|
[true, true, false],
|
||||||
|
]);
|
||||||
|
let tensor = Tensor::<TestBackend, 2, Bool>::triu_mask([3, 3], 0, &device);
|
||||||
|
assert_eq!(data_expected, tensor.into_data());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn square_tri_upper_offset() {
|
||||||
|
let device = Default::default();
|
||||||
|
let data_expected = Data::from([
|
||||||
|
[true, false, false],
|
||||||
|
[true, true, false],
|
||||||
|
[true, true, true],
|
||||||
|
]);
|
||||||
|
let tensor = Tensor::<TestBackend, 2, Bool>::triu_mask([3, 3], 1, &device);
|
||||||
|
assert_eq!(data_expected, tensor.into_data());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn square_tri_lower() {
|
||||||
|
let device = Default::default();
|
||||||
|
|
||||||
|
let data_expected = Data::from([
|
||||||
|
[false, true, true],
|
||||||
|
[false, false, true],
|
||||||
|
[false, false, false],
|
||||||
|
]);
|
||||||
|
let tensor = Tensor::<TestBackend, 2, Bool>::tril_mask([3, 3], 0, &device);
|
||||||
|
assert_eq!(data_expected, tensor.into_data());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn square_tri_lower_offset() {
|
||||||
|
let device = Default::default();
|
||||||
|
|
||||||
|
let data_expected = Data::from([
|
||||||
|
[true, true, true],
|
||||||
|
[false, true, true],
|
||||||
|
[false, false, true],
|
||||||
|
]);
|
||||||
|
let tensor = Tensor::<TestBackend, 2, Bool>::tril_mask([3, 3], -1, &device);
|
||||||
|
assert_eq!(data_expected, tensor.into_data());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rect_diag() {
|
||||||
|
let device = Default::default();
|
||||||
|
let data_expected = Data::from([
|
||||||
|
[false, true, true, true],
|
||||||
|
[true, false, true, true],
|
||||||
|
[true, true, false, true],
|
||||||
|
]);
|
||||||
|
let tensor = Tensor::<TestBackend, 2, Bool>::diag_mask([3, 4], 0, &device);
|
||||||
|
assert_eq!(data_expected, tensor.into_data());
|
||||||
|
|
||||||
|
let data_expected = Data::from([
|
||||||
|
[false, true, true],
|
||||||
|
[true, false, true],
|
||||||
|
[true, true, false],
|
||||||
|
[true, true, true],
|
||||||
|
]);
|
||||||
|
let tensor = Tensor::<TestBackend, 2, Bool>::diag_mask([4, 3], 0, &device);
|
||||||
|
assert_eq!(data_expected, tensor.into_data());
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue