Add tril_mask, triu_mask and diag_mask ops (#1479)

This commit is contained in:
Dilshod Tadjibaev 2024-03-18 10:15:40 -05:00 committed by GitHub
parent c729401fb2
commit 8a8300c1fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 241 additions and 50 deletions

View File

@ -279,13 +279,16 @@ Those operations are only available for `Int` tensors.
Those operations are only available for `Bool` tensors. Those operations are only available for `Bool` tensors.
| Burn API | PyTorch Equivalent | | Burn API | PyTorch Equivalent |
| ------------------- | ------------------------------- | | ----------------------------------- | ------------------------------- |
| `tensor.float()` | `tensor.to(torch.float)` | | `Tensor.diag_mask(shape, diagonal)` | N/A |
| `tensor.int()` | `tensor.to(torch.long)` | | `Tensor.tril_mask(shape, diagonal)` | N/A |
| `tensor.not()` | `tensor.logical_not()` | | `Tensor.triu_mask(shape, diagonal)` | N/A |
| `tensor.argwhere()` | `tensor.argwhere()` | | `tensor.argwhere()` | `tensor.argwhere()` |
| `tensor.nonzero()` | `tensor.nonzero(as_tuple=True)` | | `tensor.float()` | `tensor.to(torch.float)` |
| `tensor.int()` | `tensor.to(torch.long)` |
| `tensor.nonzero()` | `tensor.nonzero(as_tuple=True)` |
| `tensor.not()` | `tensor.logical_not()` |
## Activation Functions ## Activation Functions

View File

@ -10,6 +10,7 @@ pub fn generate_autoregressive_mask<B: Backend>(
seq_length: usize, seq_length: usize,
device: &B::Device, device: &B::Device,
) -> Tensor<B, 3, Bool> { ) -> Tensor<B, 3, Bool> {
// TODO replace with more efficient op of `triu_mask` and `expand`
let mut mask = Tensor::<B, 3, Int>::zeros([1, seq_length, seq_length], device); let mut mask = Tensor::<B, 3, Int>::zeros([1, seq_length, seq_length], device);
for i in 0..(seq_length - 1) { for i in 0..(seq_length - 1) {

View File

@ -1,8 +1,20 @@
use crate::{backend::Backend, Bool, Data, Int, Tensor}; use crate::{backend::Backend, Bool, Data, Int, Shape, Tensor};
use alloc::vec::Vec; use alloc::vec::Vec;
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
use crate::{argwhere, tensor::Shape}; use crate::argwhere;
/// The part of the tensor to keep when creating a triangular mask.
enum TriPart {
/// Upper triangular part.
Upper,
/// Lower triangular part.
Lower,
/// Diagonal part.
Diagonal,
}
impl<B, const D: usize> Tensor<B, D, Bool> impl<B, const D: usize> Tensor<B, D, Bool>
where where
@ -42,7 +54,7 @@ where
.collect() .collect()
} }
/// Compute the indices of the elements that are non-zero. /// Compute the indices of the elements that are true.
/// ///
/// # Returns /// # Returns
/// ///
@ -59,7 +71,7 @@ where
.collect() .collect()
} }
/// Compute the indices of the elements that are non-zero, grouped by element. /// Compute the indices of the elements that are true, grouped by element.
/// ///
/// # Returns /// # Returns
/// ///
@ -70,7 +82,7 @@ where
Tensor::new(B::bool_argwhere(self.primitive)) Tensor::new(B::bool_argwhere(self.primitive))
} }
/// Compute the indices of the elements that are non-zero, grouped by element. /// Compute the indices of the elements that are true, grouped by element.
/// ///
/// # Returns /// # Returns
/// ///
@ -80,4 +92,102 @@ where
pub async fn argwhere(self) -> Tensor<B, 2, Int> { pub async fn argwhere(self) -> Tensor<B, 2, Int> {
Tensor::new(argwhere::<B, D>(self.primitive).await) Tensor::new(argwhere::<B, D>(self.primitive).await)
} }
/// Creates a mask for the upper, lower triangle, or diagonal of a matrix, which can be used to
/// fill the specified area with a value.
fn tri_mask<S: Into<Shape<D>>>(
shape: S,
tri_part: TriPart,
offset: i64,
device: &B::Device,
) -> Self {
let shape = shape.into();
let height = shape.dims[D - 2];
let width = shape.dims[D - 1];
// Generate row and column index tensors.
let row_indices: Tensor<B, 1, Int> = Tensor::arange(0..height as i64, device);
let col_indices: Tensor<B, 1, Int> = Tensor::arange(0..width as i64, device);
// Prepare shapes for broadcasting.
let mut row_shape = [1; D];
row_shape[D - 2] = height;
let mut col_shape = [1; D];
col_shape[D - 1] = width;
// Reshape for broadcasting.
let row_broadcast = row_indices.reshape(Shape::new(row_shape));
let col_broadcast = col_indices.reshape(Shape::new(col_shape));
// Broadcasting trick to create a matrix that facilitates comparison for mask generation.
let matrix = row_broadcast.clone() - (col_broadcast.clone() - offset);
// Select the appropriate comparison function based on `tri_part`.
let compare = match tri_part {
TriPart::Upper => Tensor::greater_elem,
TriPart::Lower => Tensor::lower_elem,
TriPart::Diagonal => Tensor::not_equal_elem,
};
// Generate and return the mask by applying the comparison to the matrix.
compare(matrix, 0).unsqueeze()
}
/// Creates a mask for the upper triangle of a matrix, which can be used to fill the specified
/// area with a value.
///
/// This function generates a boolean tensor representing the mask of the upper triangle of a matrix.
///
/// # Arguments
///
/// * `shape`: The shape of the matrix.
/// * `offset`: The offset from the diagonal, where 0 means the diagonal, and positive values shift
/// towards the upper triangle.
/// * `device`: The device on which the tensor will be allocated.
///
/// # Returns
///
/// Returns a boolean tensor where `true` indicates the elements of the matrix that are part of the
/// upper triangle taking into account the specified `offset`.
pub fn triu_mask<S: Into<Shape<D>>>(shape: S, offset: i64, device: &B::Device) -> Self {
Self::tri_mask(shape, TriPart::Upper, offset, device)
}
/// Creates a mask for the lower triangle of a matrix, which can be used to fill the specified
/// area with a value.
///
/// This function generates a boolean tensor representing the mask of the lower triangle of a matrix.
///
/// # Arguments
///
/// * `shape`: The shape of the matrix.
/// * `offset`: The offset from the diagonal, where 0 means the diagonal, and negative values shift
/// towards the lower triangle.
/// * `device`: The device on which the tensor will be allocated.
///
/// # Returns
///
/// Returns a boolean tensor where `true` indicates the elements of the matrix that are part of the
/// lower triangle taking into account the specified `offset`.
pub fn tril_mask<S: Into<Shape<D>>>(shape: S, offset: i64, device: &B::Device) -> Self {
Self::tri_mask(shape, TriPart::Lower, offset, device)
}
/// Creates a mask for the diagonal of a matrix, which can be used to fill the specified
/// area with a value.
///
/// This function generates a boolean tensor representing the mask of the diagonal of a matrix.
///
/// # Arguments
///
/// * `shape`: The shape of the matrix.
/// * `device`: The device on which the tensor will be allocated.
///
/// # Returns
///
/// Returns a boolean tensor where `true` indicates the elements of the matrix that are part of the
/// diagonal.
pub fn diag_mask<S: Into<Shape<D>>>(shape: S, offset: i64, device: &B::Device) -> Self {
Self::tri_mask(shape, TriPart::Diagonal, offset, device)
}
} }

View File

@ -1,3 +1,5 @@
use crate::alloc::borrow::ToOwned;
use crate::{ use crate::{
backend::Backend, check, check::TensorCheck, BasicOps, Bool, Distribution, Element, backend::Backend, check, check::TensorCheck, BasicOps, Bool, Distribution, Element,
ElementConversion, Float, Int, Shape, Tensor, TensorKind, ElementConversion, Float, Int, Shape, Tensor, TensorKind,
@ -482,42 +484,6 @@ where
Self::new(K::abs(self.primitive)) Self::new(K::abs(self.primitive))
} }
/// Returns the triangular part of a matrix (2-D tensor) or batch of matrices,
/// based on the specified comparison method, zeroing out the other elements.
///
/// # Parameters
///
/// - `diagonal`: The diagonal from which the triangular part is computed.
/// - `compare`: A comparison function determining which part of the triangle to zero out.
/// Use `Tensor::<B, D, Int>::greater_elem` for upper triangular
/// and `Tensor::<B, D, Int>::lower_elem` for lower triangular.
///
pub(crate) fn tri_compare<F>(self, diagonal: i64, compare: F) -> Self
where
F: FnOnce(Tensor<B, D, Int>, i64) -> Tensor<B, D, Bool>,
{
check!(TensorCheck::tri::<{ D }>());
let shape = self.shape();
let height = shape.dims[D - 2];
let width = shape.dims[D - 1];
let row_indices: Tensor<B, 1, Int> = Tensor::arange(0..height as i64, &self.device());
let col_indices: Tensor<B, 1, Int> = Tensor::arange(0..width as i64, &self.device());
let mut row_shape = [1; D];
row_shape[D - 2] = height;
let mut col_shape = [1; D];
col_shape[D - 1] = width;
let row_broadcast = row_indices.reshape(Shape::new(row_shape));
let col_broadcast = col_indices.reshape(Shape::new(col_shape));
let mask = compare(row_broadcast - (col_broadcast - diagonal), 0).unsqueeze();
self.mask_fill(mask, 0)
}
/// Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input, /// Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input,
/// the other elements of the result tensor out are set to 0. /// the other elements of the result tensor out are set to 0.
/// ///
@ -546,7 +512,13 @@ where
/// } /// }
/// ``` /// ```
pub fn triu(self, diagonal: i64) -> Self { pub fn triu(self, diagonal: i64) -> Self {
self.tri_compare(diagonal, Tensor::greater_elem) check!(TensorCheck::tri::<{ D }>());
// last two dimensions
let shape = &self.shape().dims[D - 2..].to_owned();
let mask = Tensor::<B, 2, Bool>::triu_mask(shape, diagonal, &self.device()).unsqueeze();
self.mask_fill(mask, 0)
} }
/// Returns the lower triangular part of a matrix (2-D tensor) or batch of matrices input, /// Returns the lower triangular part of a matrix (2-D tensor) or batch of matrices input,
@ -578,7 +550,13 @@ where
/// } /// }
/// ``` /// ```
pub fn tril(self, diagonal: i64) -> Self { pub fn tril(self, diagonal: i64) -> Self {
self.tri_compare(diagonal, Tensor::lower_elem) check!(TensorCheck::tri::<{ D }>());
// last two dimensions
let shape = &self.shape().dims[D - 2..].to_owned();
let mask = Tensor::<B, 2, Bool>::tril_mask(shape, diagonal, &self.device()).unsqueeze();
self.mask_fill(mask, 0)
} }
/// Applies element wise power operation with a float Tensor /// Applies element wise power operation with a float Tensor

View File

@ -89,6 +89,7 @@ macro_rules! testgen_all {
burn_tensor::testgen_bool!(); burn_tensor::testgen_bool!();
burn_tensor::testgen_argwhere_nonzero!(); burn_tensor::testgen_argwhere_nonzero!();
burn_tensor::testgen_sign!(); burn_tensor::testgen_sign!();
burn_tensor::testgen_tri_mask!();
// test stats // test stats
burn_tensor::testgen_var!(); burn_tensor::testgen_var!();

View File

@ -51,3 +51,4 @@ mod sub;
mod tanh; mod tanh;
mod transpose; mod transpose;
mod tri; mod tri;
mod tri_mask;

View File

@ -0,0 +1,97 @@
#[burn_tensor_testgen::testgen(tri_mask)]
mod tests {
use super::*;
use burn_tensor::{Bool, Data, Tensor};
#[test]
fn square_diag() {
let device = Default::default();
let data_expected = Data::from([
[false, true, true],
[true, false, true],
[true, true, false],
]);
let tensor = Tensor::<TestBackend, 2, Bool>::diag_mask([3, 3], 0, &device);
assert_eq!(data_expected, tensor.into_data());
}
#[test]
fn square_diag_offset() {
let device = Default::default();
let data_expected =
Data::from([[true, false, true], [true, true, false], [true, true, true]]);
let tensor = Tensor::<TestBackend, 2, Bool>::diag_mask([3, 3], 1, &device);
assert_eq!(data_expected, tensor.into_data());
}
#[test]
fn square_tri_upper() {
let device = Default::default();
let data_expected = Data::from([
[false, false, false],
[true, false, false],
[true, true, false],
]);
let tensor = Tensor::<TestBackend, 2, Bool>::triu_mask([3, 3], 0, &device);
assert_eq!(data_expected, tensor.into_data());
}
#[test]
fn square_tri_upper_offset() {
let device = Default::default();
let data_expected = Data::from([
[true, false, false],
[true, true, false],
[true, true, true],
]);
let tensor = Tensor::<TestBackend, 2, Bool>::triu_mask([3, 3], 1, &device);
assert_eq!(data_expected, tensor.into_data());
}
#[test]
fn square_tri_lower() {
let device = Default::default();
let data_expected = Data::from([
[false, true, true],
[false, false, true],
[false, false, false],
]);
let tensor = Tensor::<TestBackend, 2, Bool>::tril_mask([3, 3], 0, &device);
assert_eq!(data_expected, tensor.into_data());
}
#[test]
fn square_tri_lower_offset() {
let device = Default::default();
let data_expected = Data::from([
[true, true, true],
[false, true, true],
[false, false, true],
]);
let tensor = Tensor::<TestBackend, 2, Bool>::tril_mask([3, 3], -1, &device);
assert_eq!(data_expected, tensor.into_data());
}
#[test]
fn rect_diag() {
let device = Default::default();
let data_expected = Data::from([
[false, true, true, true],
[true, false, true, true],
[true, true, false, true],
]);
let tensor = Tensor::<TestBackend, 2, Bool>::diag_mask([3, 4], 0, &device);
assert_eq!(data_expected, tensor.into_data());
let data_expected = Data::from([
[false, true, true],
[true, false, true],
[true, true, false],
[true, true, true],
]);
let tensor = Tensor::<TestBackend, 2, Bool>::diag_mask([4, 3], 0, &device);
assert_eq!(data_expected, tensor.into_data());
}
}