mirror of https://github.com/tracel-ai/burn.git
Fix ONNX where op for scalar inputs (#2218)
* Fix ONNX where op dim_inference for scalar inputs * Rewrite ONNX Where codegen to support scalars * ONNX Where: Add tests for all_scalar inputs --------- Co-authored-by: Guillaume Lagrange <lagrange.guillaume.1@gmail.com>
This commit is contained in:
parent
59d41bd4b2
commit
6b51b73a5f
|
@ -56,6 +56,10 @@ fn main() {
|
|||
.input("tests/log/log.onnx")
|
||||
.input("tests/log_softmax/log_softmax.onnx")
|
||||
.input("tests/mask_where/mask_where.onnx")
|
||||
.input("tests/mask_where/mask_where_broadcast.onnx")
|
||||
.input("tests/mask_where/mask_where_scalar_x.onnx")
|
||||
.input("tests/mask_where/mask_where_scalar_y.onnx")
|
||||
.input("tests/mask_where/mask_where_all_scalar.onnx")
|
||||
.input("tests/matmul/matmul.onnx")
|
||||
.input("tests/max/max.onnx")
|
||||
.input("tests/maxpool1d/maxpool1d.onnx")
|
||||
|
|
|
@ -1,12 +1,8 @@
|
|||
pytorch2.1.2:×
|
||||
pytorch2.3.0:Å
|
||||
?
|
||||
onnx::Where_0
|
||||
onnx::Where_1
|
||||
onnx::Where_25/Where"Where
|
||||
A
|
||||
onnx::Where_0
|
||||
onnx::Where_3
|
||||
onnx::Where_46/Where_1"Where
|
||||
onnx::Where_23/Where"Where
|
||||
main_graphZ
|
||||
onnx::Where_0
|
||||
|
||||
|
@ -19,20 +15,8 @@ main_graphZ
|
|||
onnx::Where_2
|
||||
|
||||
|
||||
Z
|
||||
onnx::Where_3
|
||||
|
||||
|
||||
Z
|
||||
onnx::Where_4
|
||||
|
||||
|
||||
b
|
||||
5
|
||||
|
||||
|
||||
b
|
||||
6
|
||||
3
|
||||
|
||||
|
||||
B
|
|
@ -1,6 +1,10 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# used to generate model: onnx-tests/tests/mask_where/mask_where.onnx
|
||||
# used to generate models:
|
||||
# mask_where.onnx
|
||||
# mask_where_broadcast.onnx
|
||||
# mask_where_scalar_x.onnx
|
||||
# mask_where_scalar_y.onnx
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -10,23 +14,17 @@ class Model(nn.Module):
|
|||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
|
||||
def forward(self, condition, x1, y1, x2, y2):
|
||||
return torch.where(condition, x1, y1), torch.where(condition, x2, y2)
|
||||
def forward(self, condition, x, y):
|
||||
return torch.where(condition, x, y)
|
||||
|
||||
|
||||
def main():
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(0)
|
||||
|
||||
def create_model(name: str, device: torch.device, mask: torch.Tensor, x: torch.Tensor, y: torch.Tensor):
|
||||
print(f"--- {name} ---")
|
||||
# Export to onnx
|
||||
model = Model()
|
||||
model.eval()
|
||||
device = torch.device("cpu")
|
||||
onnx_name = "mask_where.onnx"
|
||||
x = torch.ones(2, 2, device=device)
|
||||
y = torch.zeros(2, 2, device=device)
|
||||
mask = torch.tensor([[True, False], [False, True]], device=device)
|
||||
test_input = (mask, x, y, x[0], y[0])
|
||||
onnx_name = f"{name}.onnx"
|
||||
test_input = (mask, x, y)
|
||||
|
||||
torch.onnx.export(model, (test_input), onnx_name, verbose=False, opset_version=16)
|
||||
|
||||
|
@ -37,6 +35,24 @@ def main():
|
|||
output = model.forward(*test_input)
|
||||
print(f"Test output data: {output}")
|
||||
|
||||
def main():
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(0)
|
||||
device = torch.device("cpu")
|
||||
|
||||
mask = torch.tensor([[True, False], [False, True]], device=device)
|
||||
x = torch.ones(2, 2, device=device)
|
||||
y = torch.zeros(2, 2, device=device)
|
||||
mask_scalar = torch.tensor(True, device=device)
|
||||
x_scalar = torch.tensor(1., device=device)
|
||||
y_scalar = torch.tensor(0., device=device)
|
||||
create_model("mask_where", device, mask, x, y)
|
||||
create_model("mask_where_broadcast", device, mask, x[0], y[0])
|
||||
create_model("mask_where_scalar_x", device, mask, x_scalar, y)
|
||||
create_model("mask_where_scalar_y", device, mask, x, y_scalar)
|
||||
create_model("mask_where_all_scalar", device, mask_scalar, x_scalar, y_scalar)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Binary file not shown.
|
@ -0,0 +1,22 @@
|
|||
pytorch2.3.0:½
|
||||
?
|
||||
onnx::Where_0
|
||||
onnx::Where_1
|
||||
onnx::Where_23/Where"Where
|
||||
main_graphZ
|
||||
onnx::Where_0
|
||||
|
||||
|
||||
Z
|
||||
onnx::Where_1
|
||||
|
||||
|
||||
Z
|
||||
onnx::Where_2
|
||||
|
||||
|
||||
b
|
||||
3
|
||||
|
||||
|
||||
B
|
Binary file not shown.
Binary file not shown.
|
@ -65,6 +65,10 @@ include_models!(
|
|||
log,
|
||||
log_softmax,
|
||||
mask_where,
|
||||
mask_where_broadcast,
|
||||
mask_where_scalar_x,
|
||||
mask_where_scalar_y,
|
||||
mask_where_all_scalar,
|
||||
matmul,
|
||||
max,
|
||||
maxpool1d,
|
||||
|
@ -1955,17 +1959,75 @@ mod tests {
|
|||
let device = Default::default();
|
||||
let model: mask_where::Model<Backend> = mask_where::Model::new(&device);
|
||||
|
||||
let x1 = Tensor::ones([2, 2], &device);
|
||||
let y1 = Tensor::zeros([2, 2], &device);
|
||||
let x2 = Tensor::ones([2], &device);
|
||||
let y2 = Tensor::zeros([2], &device);
|
||||
let x = Tensor::ones([2, 2], &device);
|
||||
let y = Tensor::zeros([2, 2], &device);
|
||||
let mask = Tensor::from_bool([[true, false], [false, true]].into(), &device);
|
||||
|
||||
let (output, output_broadcasted) = model.forward(mask, x1, y1, x2, y2);
|
||||
let output = model.forward(mask, x, y);
|
||||
let expected = TensorData::from([[1f32, 0.0], [0.0, 1.0]]);
|
||||
|
||||
output.to_data().assert_eq(&expected, true);
|
||||
output_broadcasted.to_data().assert_eq(&expected, true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mask_where_broadcast() {
|
||||
let device = Default::default();
|
||||
let model: mask_where_broadcast::Model<Backend> = mask_where_broadcast::Model::new(&device);
|
||||
|
||||
let x = Tensor::ones([2], &device);
|
||||
let y = Tensor::zeros([2], &device);
|
||||
let mask = Tensor::from_bool([[true, false], [false, true]].into(), &device);
|
||||
|
||||
let output = model.forward(mask, x, y);
|
||||
let expected = TensorData::from([[1f32, 0.0], [0.0, 1.0]]);
|
||||
|
||||
output.to_data().assert_eq(&expected, true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mask_where_scalar_x() {
|
||||
let device = Default::default();
|
||||
let model: mask_where_scalar_x::Model<Backend> = mask_where_scalar_x::Model::new(&device);
|
||||
|
||||
let x = 1.0f32;
|
||||
let y = Tensor::zeros([2, 2], &device);
|
||||
let mask = Tensor::from_bool([[true, false], [false, true]].into(), &device);
|
||||
|
||||
let output = model.forward(mask, x, y);
|
||||
let expected = TensorData::from([[1f32, 0.0], [0.0, 1.0]]);
|
||||
|
||||
output.to_data().assert_eq(&expected, true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mask_where_scalar_y() {
|
||||
let device = Default::default();
|
||||
let model: mask_where_scalar_y::Model<Backend> = mask_where_scalar_y::Model::new(&device);
|
||||
|
||||
let x = Tensor::ones([2, 2], &device);
|
||||
let y = 0.0f32;
|
||||
let mask = Tensor::from_bool([[true, false], [false, true]].into(), &device);
|
||||
|
||||
let output = model.forward(mask, x, y);
|
||||
let expected = TensorData::from([[1f32, 0.0], [0.0, 1.0]]);
|
||||
|
||||
output.to_data().assert_eq(&expected, true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mask_where_all_scalar() {
|
||||
let device = Default::default();
|
||||
let model: mask_where_all_scalar::Model<Backend> =
|
||||
mask_where_all_scalar::Model::new(&device);
|
||||
|
||||
let x = 1.0f32;
|
||||
let y = 0.0f32;
|
||||
let mask = true;
|
||||
|
||||
let output = model.forward(mask, x, y);
|
||||
let expected = 1.0f32;
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
@ -35,13 +35,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for GatherNode {
|
|||
|
||||
let input = match &self.input {
|
||||
Type::Tensor(in_tensor) => scope.tensor_use_owned(in_tensor, node_position),
|
||||
Type::Shape(in_shape) => {
|
||||
let in_shape_name = &in_shape.name;
|
||||
// To copy just the values from the shape value without moving it
|
||||
// (which could lead to ownership problems if the same Shape is used multiple times)
|
||||
// borrow the array as a slice and use that to create the Tensor:
|
||||
quote! { Tensor::<B, 1, Int>::from_data(&#in_shape_name as &[_], &*self.device) }
|
||||
}
|
||||
Type::Shape(in_shape) => in_shape.to_tensor(),
|
||||
_ => panic!("Gather needs Scalar or Shape input, got {:?}!", self.input),
|
||||
};
|
||||
|
||||
|
@ -281,7 +275,7 @@ mod tests {
|
|||
let indices = tensor1;
|
||||
|
||||
let tensor2 = Tensor::select(
|
||||
Tensor::<B, 1, Int>::from_data(&shape1 as &[_], &*self.device),
|
||||
Tensor::<B, 1, burn::tensor::Int>::from_data(&shape1 as &[_], &*self.device),
|
||||
0,
|
||||
indices,
|
||||
);
|
||||
|
|
|
@ -1,68 +1,64 @@
|
|||
use core::cmp::max;
|
||||
|
||||
use super::{Node, NodeCodegen};
|
||||
use crate::burn::{BurnImports, TensorType, ToTokens, Type};
|
||||
use crate::burn::{BurnImports, ScalarType, ToTokens, Type};
|
||||
|
||||
use burn::record::PrecisionSettings;
|
||||
use proc_macro2::TokenStream;
|
||||
use quote::quote;
|
||||
|
||||
#[derive(Debug, Clone, new)]
|
||||
pub struct WhereNode {
|
||||
/// Bool tensor. When True (nonzero), yield X, otherwise yield Y.
|
||||
pub condition: TensorType,
|
||||
pub condition: Type,
|
||||
/// Values selected at indices where condition is True.
|
||||
pub x: TensorType,
|
||||
pub x: Type,
|
||||
/// Values selected at indices where condition is False.
|
||||
pub y: TensorType,
|
||||
pub output: TensorType,
|
||||
pub y: Type,
|
||||
pub output: Type,
|
||||
}
|
||||
|
||||
impl<PS: PrecisionSettings> NodeCodegen<PS> for WhereNode {
|
||||
fn output_types(&self) -> Vec<Type> {
|
||||
vec![Type::Tensor(self.output.clone())]
|
||||
vec![self.output.clone()]
|
||||
}
|
||||
|
||||
fn input_types(&self) -> Vec<crate::burn::Type> {
|
||||
vec![
|
||||
Type::Tensor(self.condition.clone()),
|
||||
Type::Tensor(self.x.clone()),
|
||||
Type::Tensor(self.y.clone()),
|
||||
]
|
||||
vec![self.condition.clone(), self.x.clone(), self.y.clone()]
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
scope: &mut crate::burn::Scope,
|
||||
node_position: usize,
|
||||
) -> proc_macro2::TokenStream {
|
||||
let mut mask = scope.tensor_use_owned(&self.condition, node_position);
|
||||
let mut x = scope.tensor_use_owned(&self.x, node_position);
|
||||
let mut y = scope.tensor_use_owned(&self.y, node_position);
|
||||
let output = &self.output.name;
|
||||
fn forward(&self, scope: &mut crate::burn::Scope, node_position: usize) -> TokenStream {
|
||||
match &self.output {
|
||||
Type::Tensor(out) => {
|
||||
let cond = Self::input_as_tensor(&self.condition, out.dim, scope, node_position);
|
||||
let y = Self::input_as_tensor(&self.y, out.dim, scope, node_position);
|
||||
let out_id = &out.name;
|
||||
|
||||
// x, y and condition need to be broadcastable
|
||||
let broadcasted_dim = max(max(self.x.dim, self.y.dim), self.condition.dim);
|
||||
let unsqueeze_dims = broadcasted_dim.to_tokens();
|
||||
|
||||
if self.condition.dim < broadcasted_dim {
|
||||
mask = quote! { #mask.unsqueeze::<#unsqueeze_dims>()};
|
||||
}
|
||||
|
||||
if self.x.dim < broadcasted_dim {
|
||||
x = quote! { #x.unsqueeze::<#unsqueeze_dims>()};
|
||||
}
|
||||
|
||||
if self.y.dim < broadcasted_dim {
|
||||
y = quote! { #y.unsqueeze::<#unsqueeze_dims>()};
|
||||
}
|
||||
|
||||
quote! {
|
||||
let #output = #y.mask_where(#mask, #x);
|
||||
if let Type::Scalar(x) = &self.x {
|
||||
let x = &x.name;
|
||||
quote! {
|
||||
let #out_id = #y.mask_fill(#cond, #x);
|
||||
}
|
||||
} else {
|
||||
let x = Self::input_as_tensor(&self.x, out.dim, scope, node_position);
|
||||
quote! {
|
||||
let #out_id = #y.mask_where(#cond, #x);
|
||||
}
|
||||
}
|
||||
}
|
||||
Type::Scalar(out) => {
|
||||
// Scalar out means all inputs are scalars as well:
|
||||
let cond = self.condition.as_scalar();
|
||||
let x = self.x.as_scalar();
|
||||
let y = self.y.as_scalar();
|
||||
Self::forward_scalar(out, cond, x, y)
|
||||
}
|
||||
other => panic!("Where cannot handle {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn register_imports(&self, imports: &mut BurnImports) {
|
||||
imports.register("burn::tensor::Bool");
|
||||
if matches!(&self.output, Type::Tensor(_)) {
|
||||
imports.register("burn::tensor::Bool");
|
||||
}
|
||||
}
|
||||
|
||||
fn into_node(self) -> super::Node<PS> {
|
||||
|
@ -70,6 +66,50 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for WhereNode {
|
|||
}
|
||||
}
|
||||
|
||||
impl WhereNode {
|
||||
fn forward_scalar(
|
||||
out: &ScalarType,
|
||||
cond: &ScalarType,
|
||||
x: &ScalarType,
|
||||
y: &ScalarType,
|
||||
) -> TokenStream {
|
||||
let out_name = &out.name;
|
||||
let out_type = out.ty();
|
||||
let cond_name = &cond.name;
|
||||
let x_name = &x.name;
|
||||
let y_name = &y.name;
|
||||
|
||||
quote! {
|
||||
let #out_name : #out_type = if #cond_name {
|
||||
#x_name
|
||||
}
|
||||
else {
|
||||
#y_name
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
fn input_as_tensor(
|
||||
input: &Type,
|
||||
broadcast_rank: usize,
|
||||
scope: &mut crate::burn::Scope,
|
||||
node_position: usize,
|
||||
) -> TokenStream {
|
||||
let (tensor, rank) = match input {
|
||||
Type::Tensor(t) => (scope.tensor_use_owned(t, node_position), t.dim),
|
||||
Type::Scalar(s) => (s.to_full_tensor(&vec![1; broadcast_rank]), broadcast_rank),
|
||||
Type::Shape(s) => (s.to_tensor(), 1),
|
||||
_ => panic!("Where op: {input:?} input not implemented"),
|
||||
};
|
||||
if rank < broadcast_rank {
|
||||
let broadcast_rank_tokens = broadcast_rank.to_tokens();
|
||||
quote! { #tensor.unsqueeze::<#broadcast_rank_tokens>()}
|
||||
} else {
|
||||
tensor
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
|
@ -79,7 +119,7 @@ mod tests {
|
|||
use crate::burn::{
|
||||
graph::BurnGraph,
|
||||
node::{mask_where::WhereNode, test::assert_tokens},
|
||||
TensorType,
|
||||
ScalarKind, TensorType,
|
||||
};
|
||||
|
||||
#[test]
|
||||
|
@ -87,10 +127,10 @@ mod tests {
|
|||
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
|
||||
|
||||
graph.register(WhereNode::new(
|
||||
TensorType::new_bool("tensor1", 2),
|
||||
TensorType::new_float("tensor2", 2),
|
||||
TensorType::new_float("tensor3", 2),
|
||||
TensorType::new_float("tensor4", 2),
|
||||
Type::Tensor(TensorType::new_bool("tensor1", 2)),
|
||||
Type::Tensor(TensorType::new_float("tensor2", 2)),
|
||||
Type::Tensor(TensorType::new_float("tensor3", 2)),
|
||||
Type::Tensor(TensorType::new_float("tensor4", 2)),
|
||||
));
|
||||
|
||||
graph.register_input_output(
|
||||
|
@ -146,10 +186,10 @@ mod tests {
|
|||
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
|
||||
|
||||
graph.register(WhereNode::new(
|
||||
TensorType::new_bool("tensor1", 4),
|
||||
TensorType::new_float("tensor2", 2),
|
||||
TensorType::new_float("tensor3", 3),
|
||||
TensorType::new_float("tensor4", 4),
|
||||
Type::Tensor(TensorType::new_bool("tensor1", 4)),
|
||||
Type::Tensor(TensorType::new_float("tensor2", 2)),
|
||||
Type::Tensor(TensorType::new_float("tensor3", 3)),
|
||||
Type::Tensor(TensorType::new_float("tensor4", 4)),
|
||||
));
|
||||
|
||||
graph.register_input_output(
|
||||
|
@ -201,4 +241,181 @@ mod tests {
|
|||
|
||||
assert_tokens(graph.codegen(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_codegen_where_scalar_x() {
|
||||
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
|
||||
|
||||
graph.register(WhereNode::new(
|
||||
Type::Tensor(TensorType::new_bool("tensor1", 2)),
|
||||
Type::Scalar(ScalarType::new("scalar2", ScalarKind::Float64)),
|
||||
Type::Tensor(TensorType::new_float("tensor3", 2)),
|
||||
Type::Tensor(TensorType::new_float("tensor4", 2)),
|
||||
));
|
||||
|
||||
graph.register_input_output(
|
||||
vec![
|
||||
"tensor1".to_string(),
|
||||
"scalar2".to_string(),
|
||||
"tensor3".to_string(),
|
||||
],
|
||||
vec!["tensor4".to_string()],
|
||||
);
|
||||
|
||||
let expected = quote! {
|
||||
use burn::tensor::Bool;
|
||||
use burn::{
|
||||
module::Module,
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
#[allow(unused_variables)]
|
||||
pub fn new(device: &B::Device) -> Self {
|
||||
Self {
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
pub fn forward(
|
||||
&self,
|
||||
tensor1: Tensor<B, 2, Bool>,
|
||||
scalar2: f64,
|
||||
tensor3: Tensor<B, 2>
|
||||
) -> Tensor<B, 2> {
|
||||
let tensor4 = tensor3.mask_fill(tensor1, scalar2);
|
||||
|
||||
tensor4
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
assert_tokens(graph.codegen(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_codegen_where_scalar_y() {
|
||||
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
|
||||
|
||||
graph.register(WhereNode::new(
|
||||
Type::Tensor(TensorType::new_bool("tensor1", 2)),
|
||||
Type::Tensor(TensorType::new_float("tensor2", 2)),
|
||||
Type::Scalar(ScalarType::new("scalar3", ScalarKind::Float64)),
|
||||
Type::Tensor(TensorType::new_float("tensor4", 2)),
|
||||
));
|
||||
|
||||
graph.register_input_output(
|
||||
vec![
|
||||
"tensor1".to_string(),
|
||||
"tensor2".to_string(),
|
||||
"scalar3".to_string(),
|
||||
],
|
||||
vec!["tensor4".to_string()],
|
||||
);
|
||||
|
||||
let expected = quote! {
|
||||
use burn::tensor::Bool;
|
||||
use burn::{
|
||||
module::Module,
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
#[allow(unused_variables)]
|
||||
pub fn new(device: &B::Device) -> Self {
|
||||
Self {
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
pub fn forward(
|
||||
&self,
|
||||
tensor1: Tensor<B, 2, Bool>,
|
||||
tensor2: Tensor<B, 2>,
|
||||
scalar3: f64
|
||||
) -> Tensor<B, 2> {
|
||||
let tensor4 = Tensor::<B, 2, burn::tensor::Float>::full([1, 1], scalar3, &*self.device)
|
||||
.mask_where(tensor1, tensor2);
|
||||
|
||||
tensor4
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
assert_tokens(graph.codegen(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_codegen_where_all_scalar() {
|
||||
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
|
||||
|
||||
graph.register(WhereNode::new(
|
||||
Type::Scalar(ScalarType::new("scalar1", ScalarKind::Bool)),
|
||||
Type::Scalar(ScalarType::new("scalar2", ScalarKind::Float64)),
|
||||
Type::Scalar(ScalarType::new("scalar3", ScalarKind::Float64)),
|
||||
Type::Scalar(ScalarType::new("scalar4", ScalarKind::Float64)),
|
||||
));
|
||||
|
||||
graph.register_input_output(
|
||||
vec![
|
||||
"scalar1".to_string(),
|
||||
"scalar2".to_string(),
|
||||
"scalar3".to_string(),
|
||||
],
|
||||
vec!["scalar4".to_string()],
|
||||
);
|
||||
|
||||
let expected = quote! {
|
||||
use burn::{
|
||||
module::Module,
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
#[allow(unused_variables)]
|
||||
pub fn new(device: &B::Device) -> Self {
|
||||
Self {
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
pub fn forward(
|
||||
&self,
|
||||
scalar1: bool,
|
||||
scalar2: f64,
|
||||
scalar3: f64
|
||||
) -> f64 {
|
||||
let scalar4: f64 = if scalar1 { scalar2 } else { scalar3 };
|
||||
|
||||
scalar4
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
assert_tokens(graph.codegen(), expected);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -79,6 +79,27 @@ impl Type {
|
|||
Type::Other(other) => other.ty(),
|
||||
}
|
||||
}
|
||||
pub fn as_tensor(&self) -> &TensorType {
|
||||
if let Self::Tensor(t) = self {
|
||||
t
|
||||
} else {
|
||||
panic!("Called Type::as_tensor on {self:?}!");
|
||||
}
|
||||
}
|
||||
pub fn as_scalar(&self) -> &ScalarType {
|
||||
if let Self::Scalar(s) = self {
|
||||
s
|
||||
} else {
|
||||
panic!("Called Type::as_scalar on {self:?}!");
|
||||
}
|
||||
}
|
||||
pub fn as_shape(&self) -> &ShapeType {
|
||||
if let Self::Shape(s) = self {
|
||||
s
|
||||
} else {
|
||||
panic!("Called Type::as_shape on {self:?}!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarType {
|
||||
|
@ -100,6 +121,28 @@ impl ScalarType {
|
|||
ScalarKind::Bool => quote! { bool },
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper for Ops that need to process a Scalar as a tensor on device
|
||||
///
|
||||
/// Uploads the Scalar to the device as a full tensor using the given shape definition
|
||||
pub fn to_full_tensor(&self, shape: &[usize]) -> TokenStream {
|
||||
let name = &self.name;
|
||||
let shape_tokens = shape
|
||||
.iter()
|
||||
.map(ToTokens::to_tokens)
|
||||
.map(|s| quote! {#s, })
|
||||
.collect::<TokenStream>();
|
||||
let rank = shape.len();
|
||||
let rank_tokens = rank.to_tokens();
|
||||
let tensor_kind = match self.kind {
|
||||
ScalarKind::Int32 | ScalarKind::Int64 => quote! { burn::tensor::Int },
|
||||
ScalarKind::Float32 | ScalarKind::Float64 => quote! { burn::tensor::Float },
|
||||
ScalarKind::Bool => quote! { burn::tensor::Bool },
|
||||
};
|
||||
quote! {
|
||||
Tensor::<B, #rank_tokens, #tensor_kind>::full([#shape_tokens], #name, &*self.device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeType {
|
||||
|
@ -116,6 +159,17 @@ impl ShapeType {
|
|||
let dim = self.dim.to_tokens();
|
||||
quote! { [usize; #dim] }
|
||||
}
|
||||
|
||||
/// Helper for Ops that need to process a shape as a tensor on device
|
||||
///
|
||||
/// Uploads the Shape to the device as a rank 1 Int tensor
|
||||
pub fn to_tensor(&self) -> TokenStream {
|
||||
let shape_name = &self.name;
|
||||
// To copy just the values from the shape value without moving it
|
||||
// (which could lead to ownership problems if the same Shape is used multiple times)
|
||||
// borrow the array as a slice and use that to create the Tensor:
|
||||
quote! { Tensor::<B, 1, burn::tensor::Int>::from_data(&#shape_name as &[_], &*self.device) }
|
||||
}
|
||||
}
|
||||
|
||||
impl TensorType {
|
||||
|
|
|
@ -747,10 +747,10 @@ impl ParsedOnnxGraph {
|
|||
}
|
||||
|
||||
fn where_conversion(node: Node) -> WhereNode {
|
||||
let condition = TensorType::from(node.inputs.first().unwrap());
|
||||
let x = TensorType::from(node.inputs.get(1).unwrap());
|
||||
let y = TensorType::from(node.inputs.get(2).unwrap());
|
||||
let output = TensorType::from(node.outputs.first().unwrap());
|
||||
let condition = Type::from(node.inputs.first().unwrap());
|
||||
let x = Type::from(node.inputs.get(1).unwrap());
|
||||
let y = Type::from(node.inputs.get(2).unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
|
||||
WhereNode::new(condition, x, y, output)
|
||||
}
|
||||
|
|
|
@ -766,17 +766,31 @@ fn reduce_sum_update_outputs(node: &mut Node) {
|
|||
}
|
||||
|
||||
fn where_update_outputs(node: &mut Node) {
|
||||
match (&node.inputs[0].ty, &node.inputs[1].ty, &node.inputs[2].ty) {
|
||||
(ArgType::Tensor(condition), ArgType::Tensor(x), ArgType::Tensor(y)) => {
|
||||
// With broadcasting support, output dim has to be computed based on the inputs
|
||||
node.outputs[0].ty = ArgType::Tensor(TensorType {
|
||||
elem_type: x.elem_type.clone(),
|
||||
dim: max(condition.dim, max(x.dim, y.dim)),
|
||||
..Default::default()
|
||||
});
|
||||
set_broadcasting_output_shape(node);
|
||||
}
|
||||
_ => panic!("Only tensor input is valid"),
|
||||
let condition = &node.inputs[0].ty;
|
||||
let x = &node.inputs[1].ty;
|
||||
let y = &node.inputs[2].ty;
|
||||
let elem_type = x.elem_type().clone();
|
||||
assert_eq!(
|
||||
*condition.elem_type(),
|
||||
ElementType::Bool,
|
||||
"Where condition must be boolean!"
|
||||
);
|
||||
assert_eq!(
|
||||
elem_type,
|
||||
*y.elem_type(),
|
||||
"Where x and y have different element types!"
|
||||
);
|
||||
|
||||
let output_rank = max(condition.rank(), max(x.rank(), y.rank()));
|
||||
if output_rank == 0 {
|
||||
node.outputs[0].ty = ArgType::Scalar(elem_type);
|
||||
} else {
|
||||
node.outputs[0].ty = ArgType::Tensor(TensorType {
|
||||
elem_type,
|
||||
dim: output_rank,
|
||||
..Default::default()
|
||||
});
|
||||
set_broadcasting_output_shape(node);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -92,7 +92,7 @@ pub enum AttributeValue {
|
|||
pub type Attributes = HashMap<String, AttributeValue>;
|
||||
|
||||
/// The type of an element.
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum ElementType {
|
||||
Float32,
|
||||
Float64,
|
||||
|
@ -135,6 +135,24 @@ impl ArgType {
|
|||
pub fn is_tensor(&self) -> bool {
|
||||
matches!(self, Self::Tensor(_))
|
||||
}
|
||||
|
||||
/// returns the rank (dimension) of the Arg
|
||||
pub fn rank(&self) -> usize {
|
||||
match self {
|
||||
ArgType::Scalar(_) => 0,
|
||||
ArgType::Shape(_) => 1,
|
||||
ArgType::Tensor(t) => t.dim,
|
||||
}
|
||||
}
|
||||
|
||||
/// returns the element type of the Arg
|
||||
pub fn elem_type(&self) -> &ElementType {
|
||||
match self {
|
||||
ArgType::Scalar(s) => s,
|
||||
ArgType::Shape(_) => panic!("ArgType::Shape has no ElementType"),
|
||||
ArgType::Tensor(t) => &t.elem_type,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Argument {
|
||||
|
|
Loading…
Reference in New Issue