Fix ONNX where op for scalar inputs (#2218)

* Fix ONNX where op dim_inference for scalar inputs

* Rewrite ONNX Where codegen to support scalars

* ONNX Where: Add tests for all_scalar inputs

---------

Co-authored-by: Guillaume Lagrange <lagrange.guillaume.1@gmail.com>
This commit is contained in:
Adrian Müller 2024-09-03 17:17:18 +02:00 committed by GitHub
parent 59d41bd4b2
commit 6b51b73a5f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 497 additions and 112 deletions

View File

@ -56,6 +56,10 @@ fn main() {
.input("tests/log/log.onnx")
.input("tests/log_softmax/log_softmax.onnx")
.input("tests/mask_where/mask_where.onnx")
.input("tests/mask_where/mask_where_broadcast.onnx")
.input("tests/mask_where/mask_where_scalar_x.onnx")
.input("tests/mask_where/mask_where_scalar_y.onnx")
.input("tests/mask_where/mask_where_all_scalar.onnx")
.input("tests/matmul/matmul.onnx")
.input("tests/max/max.onnx")
.input("tests/maxpool1d/maxpool1d.onnx")

View File

@ -1,12 +1,8 @@
pytorch2.1.2:×
pytorch2.3.0:Å
?
onnx::Where_0
onnx::Where_1
onnx::Where_25/Where"Where
A
onnx::Where_0
onnx::Where_3
onnx::Where_46/Where_1"Where
onnx::Where_23/Where"Where
main_graphZ
onnx::Where_0
 
@ -19,20 +15,8 @@ main_graphZ
onnx::Where_2


Z
onnx::Where_3

Z
onnx::Where_4

b
5


b
6
3


B

View File

@ -1,6 +1,10 @@
#!/usr/bin/env python3
# used to generate model: onnx-tests/tests/mask_where/mask_where.onnx
# used to generate models:
# mask_where.onnx
# mask_where_broadcast.onnx
# mask_where_scalar_x.onnx
# mask_where_scalar_y.onnx
import torch
import torch.nn as nn
@ -10,23 +14,17 @@ class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, condition, x1, y1, x2, y2):
return torch.where(condition, x1, y1), torch.where(condition, x2, y2)
def forward(self, condition, x, y):
return torch.where(condition, x, y)
def main():
# Set random seed for reproducibility
torch.manual_seed(0)
def create_model(name: str, device: torch.device, mask: torch.Tensor, x: torch.Tensor, y: torch.Tensor):
print(f"--- {name} ---")
# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "mask_where.onnx"
x = torch.ones(2, 2, device=device)
y = torch.zeros(2, 2, device=device)
mask = torch.tensor([[True, False], [False, True]], device=device)
test_input = (mask, x, y, x[0], y[0])
onnx_name = f"{name}.onnx"
test_input = (mask, x, y)
torch.onnx.export(model, (test_input), onnx_name, verbose=False, opset_version=16)
@ -37,6 +35,24 @@ def main():
output = model.forward(*test_input)
print(f"Test output data: {output}")
def main():
# Set random seed for reproducibility
torch.manual_seed(0)
device = torch.device("cpu")
mask = torch.tensor([[True, False], [False, True]], device=device)
x = torch.ones(2, 2, device=device)
y = torch.zeros(2, 2, device=device)
mask_scalar = torch.tensor(True, device=device)
x_scalar = torch.tensor(1., device=device)
y_scalar = torch.tensor(0., device=device)
create_model("mask_where", device, mask, x, y)
create_model("mask_where_broadcast", device, mask, x[0], y[0])
create_model("mask_where_scalar_x", device, mask, x_scalar, y)
create_model("mask_where_scalar_y", device, mask, x, y_scalar)
create_model("mask_where_all_scalar", device, mask_scalar, x_scalar, y_scalar)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,22 @@
pytorch2.3.0:½
?
onnx::Where_0
onnx::Where_1
onnx::Where_23/Where"Where
main_graphZ
onnx::Where_0
 

Z
onnx::Where_1

Z
onnx::Where_2

b
3


B

View File

@ -65,6 +65,10 @@ include_models!(
log,
log_softmax,
mask_where,
mask_where_broadcast,
mask_where_scalar_x,
mask_where_scalar_y,
mask_where_all_scalar,
matmul,
max,
maxpool1d,
@ -1955,17 +1959,75 @@ mod tests {
let device = Default::default();
let model: mask_where::Model<Backend> = mask_where::Model::new(&device);
let x1 = Tensor::ones([2, 2], &device);
let y1 = Tensor::zeros([2, 2], &device);
let x2 = Tensor::ones([2], &device);
let y2 = Tensor::zeros([2], &device);
let x = Tensor::ones([2, 2], &device);
let y = Tensor::zeros([2, 2], &device);
let mask = Tensor::from_bool([[true, false], [false, true]].into(), &device);
let (output, output_broadcasted) = model.forward(mask, x1, y1, x2, y2);
let output = model.forward(mask, x, y);
let expected = TensorData::from([[1f32, 0.0], [0.0, 1.0]]);
output.to_data().assert_eq(&expected, true);
output_broadcasted.to_data().assert_eq(&expected, true);
}
#[test]
fn mask_where_broadcast() {
let device = Default::default();
let model: mask_where_broadcast::Model<Backend> = mask_where_broadcast::Model::new(&device);
let x = Tensor::ones([2], &device);
let y = Tensor::zeros([2], &device);
let mask = Tensor::from_bool([[true, false], [false, true]].into(), &device);
let output = model.forward(mask, x, y);
let expected = TensorData::from([[1f32, 0.0], [0.0, 1.0]]);
output.to_data().assert_eq(&expected, true);
}
#[test]
fn mask_where_scalar_x() {
let device = Default::default();
let model: mask_where_scalar_x::Model<Backend> = mask_where_scalar_x::Model::new(&device);
let x = 1.0f32;
let y = Tensor::zeros([2, 2], &device);
let mask = Tensor::from_bool([[true, false], [false, true]].into(), &device);
let output = model.forward(mask, x, y);
let expected = TensorData::from([[1f32, 0.0], [0.0, 1.0]]);
output.to_data().assert_eq(&expected, true);
}
#[test]
fn mask_where_scalar_y() {
let device = Default::default();
let model: mask_where_scalar_y::Model<Backend> = mask_where_scalar_y::Model::new(&device);
let x = Tensor::ones([2, 2], &device);
let y = 0.0f32;
let mask = Tensor::from_bool([[true, false], [false, true]].into(), &device);
let output = model.forward(mask, x, y);
let expected = TensorData::from([[1f32, 0.0], [0.0, 1.0]]);
output.to_data().assert_eq(&expected, true);
}
#[test]
fn mask_where_all_scalar() {
let device = Default::default();
let model: mask_where_all_scalar::Model<Backend> =
mask_where_all_scalar::Model::new(&device);
let x = 1.0f32;
let y = 0.0f32;
let mask = true;
let output = model.forward(mask, x, y);
let expected = 1.0f32;
assert_eq!(output, expected);
}
#[test]

View File

@ -35,13 +35,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for GatherNode {
let input = match &self.input {
Type::Tensor(in_tensor) => scope.tensor_use_owned(in_tensor, node_position),
Type::Shape(in_shape) => {
let in_shape_name = &in_shape.name;
// To copy just the values from the shape value without moving it
// (which could lead to ownership problems if the same Shape is used multiple times)
// borrow the array as a slice and use that to create the Tensor:
quote! { Tensor::<B, 1, Int>::from_data(&#in_shape_name as &[_], &*self.device) }
}
Type::Shape(in_shape) => in_shape.to_tensor(),
_ => panic!("Gather needs Scalar or Shape input, got {:?}!", self.input),
};
@ -281,7 +275,7 @@ mod tests {
let indices = tensor1;
let tensor2 = Tensor::select(
Tensor::<B, 1, Int>::from_data(&shape1 as &[_], &*self.device),
Tensor::<B, 1, burn::tensor::Int>::from_data(&shape1 as &[_], &*self.device),
0,
indices,
);

View File

@ -1,68 +1,64 @@
use core::cmp::max;
use super::{Node, NodeCodegen};
use crate::burn::{BurnImports, TensorType, ToTokens, Type};
use crate::burn::{BurnImports, ScalarType, ToTokens, Type};
use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::quote;
#[derive(Debug, Clone, new)]
pub struct WhereNode {
/// Bool tensor. When True (nonzero), yield X, otherwise yield Y.
pub condition: TensorType,
pub condition: Type,
/// Values selected at indices where condition is True.
pub x: TensorType,
pub x: Type,
/// Values selected at indices where condition is False.
pub y: TensorType,
pub output: TensorType,
pub y: Type,
pub output: Type,
}
impl<PS: PrecisionSettings> NodeCodegen<PS> for WhereNode {
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
vec![self.output.clone()]
}
fn input_types(&self) -> Vec<crate::burn::Type> {
vec![
Type::Tensor(self.condition.clone()),
Type::Tensor(self.x.clone()),
Type::Tensor(self.y.clone()),
]
vec![self.condition.clone(), self.x.clone(), self.y.clone()]
}
fn forward(
&self,
scope: &mut crate::burn::Scope,
node_position: usize,
) -> proc_macro2::TokenStream {
let mut mask = scope.tensor_use_owned(&self.condition, node_position);
let mut x = scope.tensor_use_owned(&self.x, node_position);
let mut y = scope.tensor_use_owned(&self.y, node_position);
let output = &self.output.name;
fn forward(&self, scope: &mut crate::burn::Scope, node_position: usize) -> TokenStream {
match &self.output {
Type::Tensor(out) => {
let cond = Self::input_as_tensor(&self.condition, out.dim, scope, node_position);
let y = Self::input_as_tensor(&self.y, out.dim, scope, node_position);
let out_id = &out.name;
// x, y and condition need to be broadcastable
let broadcasted_dim = max(max(self.x.dim, self.y.dim), self.condition.dim);
let unsqueeze_dims = broadcasted_dim.to_tokens();
if self.condition.dim < broadcasted_dim {
mask = quote! { #mask.unsqueeze::<#unsqueeze_dims>()};
}
if self.x.dim < broadcasted_dim {
x = quote! { #x.unsqueeze::<#unsqueeze_dims>()};
}
if self.y.dim < broadcasted_dim {
y = quote! { #y.unsqueeze::<#unsqueeze_dims>()};
}
quote! {
let #output = #y.mask_where(#mask, #x);
if let Type::Scalar(x) = &self.x {
let x = &x.name;
quote! {
let #out_id = #y.mask_fill(#cond, #x);
}
} else {
let x = Self::input_as_tensor(&self.x, out.dim, scope, node_position);
quote! {
let #out_id = #y.mask_where(#cond, #x);
}
}
}
Type::Scalar(out) => {
// Scalar out means all inputs are scalars as well:
let cond = self.condition.as_scalar();
let x = self.x.as_scalar();
let y = self.y.as_scalar();
Self::forward_scalar(out, cond, x, y)
}
other => panic!("Where cannot handle {other:?}"),
}
}
fn register_imports(&self, imports: &mut BurnImports) {
imports.register("burn::tensor::Bool");
if matches!(&self.output, Type::Tensor(_)) {
imports.register("burn::tensor::Bool");
}
}
fn into_node(self) -> super::Node<PS> {
@ -70,6 +66,50 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for WhereNode {
}
}
impl WhereNode {
fn forward_scalar(
out: &ScalarType,
cond: &ScalarType,
x: &ScalarType,
y: &ScalarType,
) -> TokenStream {
let out_name = &out.name;
let out_type = out.ty();
let cond_name = &cond.name;
let x_name = &x.name;
let y_name = &y.name;
quote! {
let #out_name : #out_type = if #cond_name {
#x_name
}
else {
#y_name
};
}
}
fn input_as_tensor(
input: &Type,
broadcast_rank: usize,
scope: &mut crate::burn::Scope,
node_position: usize,
) -> TokenStream {
let (tensor, rank) = match input {
Type::Tensor(t) => (scope.tensor_use_owned(t, node_position), t.dim),
Type::Scalar(s) => (s.to_full_tensor(&vec![1; broadcast_rank]), broadcast_rank),
Type::Shape(s) => (s.to_tensor(), 1),
_ => panic!("Where op: {input:?} input not implemented"),
};
if rank < broadcast_rank {
let broadcast_rank_tokens = broadcast_rank.to_tokens();
quote! { #tensor.unsqueeze::<#broadcast_rank_tokens>()}
} else {
tensor
}
}
}
#[cfg(test)]
mod tests {
@ -79,7 +119,7 @@ mod tests {
use crate::burn::{
graph::BurnGraph,
node::{mask_where::WhereNode, test::assert_tokens},
TensorType,
ScalarKind, TensorType,
};
#[test]
@ -87,10 +127,10 @@ mod tests {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
graph.register(WhereNode::new(
TensorType::new_bool("tensor1", 2),
TensorType::new_float("tensor2", 2),
TensorType::new_float("tensor3", 2),
TensorType::new_float("tensor4", 2),
Type::Tensor(TensorType::new_bool("tensor1", 2)),
Type::Tensor(TensorType::new_float("tensor2", 2)),
Type::Tensor(TensorType::new_float("tensor3", 2)),
Type::Tensor(TensorType::new_float("tensor4", 2)),
));
graph.register_input_output(
@ -146,10 +186,10 @@ mod tests {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
graph.register(WhereNode::new(
TensorType::new_bool("tensor1", 4),
TensorType::new_float("tensor2", 2),
TensorType::new_float("tensor3", 3),
TensorType::new_float("tensor4", 4),
Type::Tensor(TensorType::new_bool("tensor1", 4)),
Type::Tensor(TensorType::new_float("tensor2", 2)),
Type::Tensor(TensorType::new_float("tensor3", 3)),
Type::Tensor(TensorType::new_float("tensor4", 4)),
));
graph.register_input_output(
@ -201,4 +241,181 @@ mod tests {
assert_tokens(graph.codegen(), expected);
}
#[test]
fn test_codegen_where_scalar_x() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
graph.register(WhereNode::new(
Type::Tensor(TensorType::new_bool("tensor1", 2)),
Type::Scalar(ScalarType::new("scalar2", ScalarKind::Float64)),
Type::Tensor(TensorType::new_float("tensor3", 2)),
Type::Tensor(TensorType::new_float("tensor4", 2)),
));
graph.register_input_output(
vec![
"tensor1".to_string(),
"scalar2".to_string(),
"tensor3".to_string(),
],
vec!["tensor4".to_string()],
);
let expected = quote! {
use burn::tensor::Bool;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(
&self,
tensor1: Tensor<B, 2, Bool>,
scalar2: f64,
tensor3: Tensor<B, 2>
) -> Tensor<B, 2> {
let tensor4 = tensor3.mask_fill(tensor1, scalar2);
tensor4
}
}
};
assert_tokens(graph.codegen(), expected);
}
#[test]
fn test_codegen_where_scalar_y() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
graph.register(WhereNode::new(
Type::Tensor(TensorType::new_bool("tensor1", 2)),
Type::Tensor(TensorType::new_float("tensor2", 2)),
Type::Scalar(ScalarType::new("scalar3", ScalarKind::Float64)),
Type::Tensor(TensorType::new_float("tensor4", 2)),
));
graph.register_input_output(
vec![
"tensor1".to_string(),
"tensor2".to_string(),
"scalar3".to_string(),
],
vec!["tensor4".to_string()],
);
let expected = quote! {
use burn::tensor::Bool;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(
&self,
tensor1: Tensor<B, 2, Bool>,
tensor2: Tensor<B, 2>,
scalar3: f64
) -> Tensor<B, 2> {
let tensor4 = Tensor::<B, 2, burn::tensor::Float>::full([1, 1], scalar3, &*self.device)
.mask_where(tensor1, tensor2);
tensor4
}
}
};
assert_tokens(graph.codegen(), expected);
}
#[test]
fn test_codegen_where_all_scalar() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
graph.register(WhereNode::new(
Type::Scalar(ScalarType::new("scalar1", ScalarKind::Bool)),
Type::Scalar(ScalarType::new("scalar2", ScalarKind::Float64)),
Type::Scalar(ScalarType::new("scalar3", ScalarKind::Float64)),
Type::Scalar(ScalarType::new("scalar4", ScalarKind::Float64)),
));
graph.register_input_output(
vec![
"scalar1".to_string(),
"scalar2".to_string(),
"scalar3".to_string(),
],
vec!["scalar4".to_string()],
);
let expected = quote! {
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(
&self,
scalar1: bool,
scalar2: f64,
scalar3: f64
) -> f64 {
let scalar4: f64 = if scalar1 { scalar2 } else { scalar3 };
scalar4
}
}
};
assert_tokens(graph.codegen(), expected);
}
}

View File

@ -79,6 +79,27 @@ impl Type {
Type::Other(other) => other.ty(),
}
}
pub fn as_tensor(&self) -> &TensorType {
if let Self::Tensor(t) = self {
t
} else {
panic!("Called Type::as_tensor on {self:?}!");
}
}
pub fn as_scalar(&self) -> &ScalarType {
if let Self::Scalar(s) = self {
s
} else {
panic!("Called Type::as_scalar on {self:?}!");
}
}
pub fn as_shape(&self) -> &ShapeType {
if let Self::Shape(s) = self {
s
} else {
panic!("Called Type::as_shape on {self:?}!");
}
}
}
impl ScalarType {
@ -100,6 +121,28 @@ impl ScalarType {
ScalarKind::Bool => quote! { bool },
}
}
/// Helper for Ops that need to process a Scalar as a tensor on device
///
/// Uploads the Scalar to the device as a full tensor using the given shape definition
pub fn to_full_tensor(&self, shape: &[usize]) -> TokenStream {
let name = &self.name;
let shape_tokens = shape
.iter()
.map(ToTokens::to_tokens)
.map(|s| quote! {#s, })
.collect::<TokenStream>();
let rank = shape.len();
let rank_tokens = rank.to_tokens();
let tensor_kind = match self.kind {
ScalarKind::Int32 | ScalarKind::Int64 => quote! { burn::tensor::Int },
ScalarKind::Float32 | ScalarKind::Float64 => quote! { burn::tensor::Float },
ScalarKind::Bool => quote! { burn::tensor::Bool },
};
quote! {
Tensor::<B, #rank_tokens, #tensor_kind>::full([#shape_tokens], #name, &*self.device)
}
}
}
impl ShapeType {
@ -116,6 +159,17 @@ impl ShapeType {
let dim = self.dim.to_tokens();
quote! { [usize; #dim] }
}
/// Helper for Ops that need to process a shape as a tensor on device
///
/// Uploads the Shape to the device as a rank 1 Int tensor
pub fn to_tensor(&self) -> TokenStream {
let shape_name = &self.name;
// To copy just the values from the shape value without moving it
// (which could lead to ownership problems if the same Shape is used multiple times)
// borrow the array as a slice and use that to create the Tensor:
quote! { Tensor::<B, 1, burn::tensor::Int>::from_data(&#shape_name as &[_], &*self.device) }
}
}
impl TensorType {

View File

@ -747,10 +747,10 @@ impl ParsedOnnxGraph {
}
fn where_conversion(node: Node) -> WhereNode {
let condition = TensorType::from(node.inputs.first().unwrap());
let x = TensorType::from(node.inputs.get(1).unwrap());
let y = TensorType::from(node.inputs.get(2).unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
let condition = Type::from(node.inputs.first().unwrap());
let x = Type::from(node.inputs.get(1).unwrap());
let y = Type::from(node.inputs.get(2).unwrap());
let output = Type::from(node.outputs.first().unwrap());
WhereNode::new(condition, x, y, output)
}

View File

@ -766,17 +766,31 @@ fn reduce_sum_update_outputs(node: &mut Node) {
}
fn where_update_outputs(node: &mut Node) {
match (&node.inputs[0].ty, &node.inputs[1].ty, &node.inputs[2].ty) {
(ArgType::Tensor(condition), ArgType::Tensor(x), ArgType::Tensor(y)) => {
// With broadcasting support, output dim has to be computed based on the inputs
node.outputs[0].ty = ArgType::Tensor(TensorType {
elem_type: x.elem_type.clone(),
dim: max(condition.dim, max(x.dim, y.dim)),
..Default::default()
});
set_broadcasting_output_shape(node);
}
_ => panic!("Only tensor input is valid"),
let condition = &node.inputs[0].ty;
let x = &node.inputs[1].ty;
let y = &node.inputs[2].ty;
let elem_type = x.elem_type().clone();
assert_eq!(
*condition.elem_type(),
ElementType::Bool,
"Where condition must be boolean!"
);
assert_eq!(
elem_type,
*y.elem_type(),
"Where x and y have different element types!"
);
let output_rank = max(condition.rank(), max(x.rank(), y.rank()));
if output_rank == 0 {
node.outputs[0].ty = ArgType::Scalar(elem_type);
} else {
node.outputs[0].ty = ArgType::Tensor(TensorType {
elem_type,
dim: output_rank,
..Default::default()
});
set_broadcasting_output_shape(node);
}
}

View File

@ -92,7 +92,7 @@ pub enum AttributeValue {
pub type Attributes = HashMap<String, AttributeValue>;
/// The type of an element.
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq)]
pub enum ElementType {
Float32,
Float64,
@ -135,6 +135,24 @@ impl ArgType {
pub fn is_tensor(&self) -> bool {
matches!(self, Self::Tensor(_))
}
/// returns the rank (dimension) of the Arg
pub fn rank(&self) -> usize {
match self {
ArgType::Scalar(_) => 0,
ArgType::Shape(_) => 1,
ArgType::Tensor(t) => t.dim,
}
}
/// returns the element type of the Arg
pub fn elem_type(&self) -> &ElementType {
match self {
ArgType::Scalar(s) => s,
ArgType::Shape(_) => panic!("ArgType::Shape has no ElementType"),
ArgType::Tensor(t) => &t.elem_type,
}
}
}
impl Argument {