feat: added slice onnx import (#1856)

* feat: added slice onnx import

* fix: axes, steps handling
This commit is contained in:
jachym 2024-06-11 13:50:03 +02:00 committed by GitHub
parent dd60446946
commit 671ec8c679
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 318 additions and 3 deletions

View File

@ -171,7 +171,7 @@ represent the corresponding Burn Op.
| [Sin][164] | ✅ | ✅ | | [Sin][164] | ✅ | ✅ |
| [Sinh][165] | ❌ | ❌ | | [Sinh][165] | ❌ | ❌ |
| [Size][166] | ❌ | ❌ | | [Size][166] | ❌ | ❌ |
| [Slice][167] | | ✅ | | [Slice][167] | | ✅ |
| [Softmax][168] | ✅ | ✅ | | [Softmax][168] | ✅ | ✅ |
| [SoftmaxCrossEntropyLoss][169] | ❌ | ❌ | | [SoftmaxCrossEntropyLoss][169] | ❌ | ❌ |
| [Softplus][170] | ❌ | ❌ | | [Softplus][170] | ❌ | ❌ |

View File

@ -69,6 +69,7 @@ fn main() {
.input("tests/conv_transpose2d/conv_transpose2d.onnx") .input("tests/conv_transpose2d/conv_transpose2d.onnx")
.input("tests/pow/pow.onnx") .input("tests/pow/pow.onnx")
.input("tests/pow/pow_int.onnx") .input("tests/pow/pow_int.onnx")
.input("tests/slice/slice.onnx")
.input("tests/sum/sum.onnx") .input("tests/sum/sum.onnx")
.input("tests/sum/sum_int.onnx") .input("tests/sum/sum_int.onnx")
.input("tests/unsqueeze/unsqueeze.onnx") .input("tests/unsqueeze/unsqueeze.onnx")

View File

@ -71,6 +71,7 @@ include_models!(
sigmoid, sigmoid,
sign, sign,
sin, sin,
slice,
softmax, softmax,
sqrt, sqrt,
sub_int, sub_int,
@ -459,6 +460,24 @@ mod tests {
assert!(expected_sum_2d.approx_eq(output_sum_2d, (1.0e-4, 2))); assert!(expected_sum_2d.approx_eq(output_sum_2d, (1.0e-4, 2)));
} }
#[test]
fn slice() {
let model: slice::Model<Backend> = slice::Model::default();
let device = Default::default();
let input = Tensor::<Backend, 2>::from_floats(
[
[1., 2., 3., 4., 5., 6., 7., 8., 9., 10.],
[11., 12., 13., 14., 15., 16., 17., 18., 19., 20.],
],
&device,
);
let output = model.forward(input);
let expected = Data::from([[1., 2., 3., 4., 5.]]);
assert_eq!(output.to_data(), expected);
}
#[test] #[test]
fn softmax() { fn softmax() {
// Initialize the model without weights (because the exported file does not contain them) // Initialize the model without weights (because the exported file does not contain them)

Binary file not shown.

View File

@ -0,0 +1,101 @@
#!/usr/bin/env python3
# used to generate model: onnx-tests/tests/slice/slice.onnx
import onnx
from onnx import helper, TensorProto
def main() -> None:
# Starts
starts_val = [0,0] # Example shape value
starts_tensor = helper.make_tensor(
name="starts",
data_type=TensorProto.INT64,
dims=[len(starts_val)],
vals=starts_val,
)
starts_node = helper.make_node(
"Constant",
name="starts_constant",
inputs=[],
outputs=["starts"],
value=starts_tensor,
)
# Ends
ends_val = [1,5] # Example shape value
ends_tensor = helper.make_tensor(
name="ends",
data_type=TensorProto.INT64,
dims=[len(ends_val)],
vals=ends_val,
)
ends_node = helper.make_node(
"Constant",
name="ends_constant",
inputs=[],
outputs=["ends"],
value=ends_tensor,
)
# Axes
axes_val = [0,1] # Example shape value
axes_tensor = helper.make_tensor(
name="axes",
data_type=TensorProto.INT64,
dims=[len(axes_val)],
vals=axes_val,
)
axes_node = helper.make_node(
"Constant",
name="axes_constant",
inputs=[],
outputs=["axes"],
value=axes_tensor,
)
# Steps
steps_val = [1, 1] # Example shape value
steps_tensor = helper.make_tensor(
name="steps",
data_type=TensorProto.INT64,
dims=[len(steps_val)],
vals=steps_val,
)
steps_node = helper.make_node(
"Constant",
name="steps_constant",
inputs=[],
outputs=["steps"],
value=steps_tensor,
)
# Define the Slice node that uses the outputs from the constant nodes
slice_node = helper.make_node(
"Slice",
name="slice_node",
inputs=["input_tensor", "starts", "ends", "axes", "steps"],
outputs=["output"],
)
# Create the graph
graph_def = helper.make_graph(
nodes=[starts_node, ends_node, axes_node, steps_node, slice_node],
name="SliceGraph",
inputs=[
helper.make_tensor_value_info("input_tensor", TensorProto.FLOAT, [2, 10]),
],
outputs=[
helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 5])
],
)
# Create the model
model_def = helper.make_model(graph_def, producer_name="slice")
# Save the model to a file
onnx.save(model_def, "slice.onnx")
if __name__ == "__main__":
main()

View File

@ -7,7 +7,7 @@ use super::{
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode, layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode, max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode,
random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode, random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode,
reshape::ReshapeNode, squeeze::SqueezeNode, sum::SumNode, unary::UnaryNode, reshape::ReshapeNode, slice::SliceNode, squeeze::SqueezeNode, sum::SumNode, unary::UnaryNode,
unsqueeze::UnsqueezeNode, unsqueeze::UnsqueezeNode,
}; };
use crate::burn::{BurnImports, Scope, Type}; use crate::burn::{BurnImports, Scope, Type};
@ -102,6 +102,7 @@ pub enum Node<PS: PrecisionSettings> {
MaxPool2d(MaxPool2dNode), MaxPool2d(MaxPool2dNode),
Range(RangeNode), Range(RangeNode),
Reshape(ReshapeNode), Reshape(ReshapeNode),
Slice(SliceNode),
Squeeze(SqueezeNode), Squeeze(SqueezeNode),
Sum(SumNode), Sum(SumNode),
Unary(UnaryNode), Unary(UnaryNode),
@ -139,6 +140,7 @@ macro_rules! match_all {
Node::MaxPool2d(node) => $func(node), Node::MaxPool2d(node) => $func(node),
Node::Range(node) => $func(node), Node::Range(node) => $func(node),
Node::Reshape(node) => $func(node), Node::Reshape(node) => $func(node),
Node::Slice(node) => $func(node),
Node::Squeeze(node) => $func(node), Node::Squeeze(node) => $func(node),
Node::Sum(node) => $func(node), Node::Sum(node) => $func(node),
Node::Unary(node) => $func(node), Node::Unary(node) => $func(node),
@ -186,6 +188,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::MaxPool2d(_) => "max_pool2d", Node::MaxPool2d(_) => "max_pool2d",
Node::Range(_) => "range", Node::Range(_) => "range",
Node::Reshape(_) => "reshape", Node::Reshape(_) => "reshape",
Node::Slice(_) => "slice",
Node::Squeeze(_) => "squeeze", Node::Squeeze(_) => "squeeze",
Node::Sum(_) => "add", Node::Sum(_) => "add",
Node::Unary(unary) => unary.kind.as_str(), Node::Unary(unary) => unary.kind.as_str(),

View File

@ -27,6 +27,7 @@ pub(crate) mod random_normal;
pub(crate) mod random_uniform; pub(crate) mod random_uniform;
pub(crate) mod range; pub(crate) mod range;
pub(crate) mod reshape; pub(crate) mod reshape;
pub(crate) mod slice;
pub(crate) mod squeeze; pub(crate) mod squeeze;
pub(crate) mod sum; pub(crate) mod sum;
pub(crate) mod unary; pub(crate) mod unary;

View File

@ -0,0 +1,90 @@
use super::{Node, NodeCodegen};
use crate::burn::{Scope, TensorType, Type};
use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::quote;
#[derive(Debug, Clone, new)]
pub struct SliceNode {
pub input: TensorType,
pub output: TensorType,
pub starts: Vec<usize>,
pub ends: Vec<usize>,
}
impl<PS: PrecisionSettings> NodeCodegen<PS> for SliceNode {
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
}
fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.input.clone())]
}
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let input = scope.tensor_use_owned(&self.input, node_position);
let output = &self.output.name;
let starts = &self.starts;
let ends = &self.ends;
quote! {
let #output = #input.slice([#(#starts..#ends),*]);
}
}
fn into_node(self) -> Node<PS> {
Node::Slice(self)
}
}
#[cfg(test)]
mod tests {
use burn::record::FullPrecisionSettings;
use super::*;
use crate::burn::{
graph::BurnGraph,
node::{slice::SliceNode, test::assert_tokens},
TensorType,
};
#[test]
fn test_codegen_slice() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
graph.register(SliceNode::new(
TensorType::new_float("tensor1", 4),
TensorType::new_float("tensor2", 4),
vec![0, 0, 0, 0],
vec![1, 1, 1, 1],
));
graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]);
let expected = quote! {
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor2 = tensor1.slice([0usize..1usize,0usize..1usize,0usize..1usize,0usize..1usize]);
tensor2
}
}
};
assert_tokens(graph.codegen(), expected);
}
}

View File

@ -63,6 +63,7 @@ pub fn dim_inference(node: &mut Node) {
NodeType::Sigmoid => same_as_input(node), NodeType::Sigmoid => same_as_input(node),
NodeType::Sign => same_as_input(node), NodeType::Sign => same_as_input(node),
NodeType::Sin => same_as_input(node), NodeType::Sin => same_as_input(node),
NodeType::Slice => slice_update_outputs(node),
NodeType::Softmax => same_as_input(node), NodeType::Softmax => same_as_input(node),
NodeType::Sqrt => same_as_input(node), NodeType::Sqrt => same_as_input(node),
NodeType::Sub => same_as_input(node), NodeType::Sub => same_as_input(node),
@ -423,6 +424,33 @@ fn squeeze_update_output(node: &mut Node) {
}); });
} }
fn slice_update_outputs(node: &mut Node) {
let shape = match &node.inputs[1].value {
Some(value) => match value {
Data::Int64s(shape) => Some(shape.clone()),
_ => panic!("Slice: invalid input types"),
},
None => None,
};
if shape.is_none() {
panic!("Slice: invalid shape");
}
let output = match &node.outputs[0].ty {
ArgType::Tensor(tensor) => tensor.clone(),
_ => panic!("Slice: invalid output types"),
};
if let Some(shape) = shape {
node.outputs[0].ty = ArgType::Tensor(TensorType {
dim: shape.len(),
shape: None, // shape is calculated at runtime
..output
});
}
}
/// Update the output tensor dimension based on the "axes" attribute or the second input /// Update the output tensor dimension based on the "axes" attribute or the second input
fn unsqueeze_update_output(node: &mut Node) { fn unsqueeze_update_output(node: &mut Node) {
let axes = if node.inputs.len() == 2 { let axes = if node.inputs.len() == 2 {

View File

@ -18,7 +18,7 @@ use super::ir::{ArgType, Argument, Node, NodeType};
use protobuf::Message; use protobuf::Message;
const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 10] = [ const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 11] = [
NodeType::BatchNormalization, NodeType::BatchNormalization,
NodeType::Clip, NodeType::Clip,
NodeType::Conv1d, NodeType::Conv1d,
@ -28,6 +28,7 @@ const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 10] = [
NodeType::Reshape, NodeType::Reshape,
NodeType::Unsqueeze, NodeType::Unsqueeze,
NodeType::ReduceSum, NodeType::ReduceSum,
NodeType::Slice,
NodeType::Squeeze, NodeType::Squeeze,
]; ];

View File

@ -1015,6 +1015,67 @@ pub fn shape_config(curr: &Node) -> (usize, usize) {
(start_dim as usize, end_dim as usize) (start_dim as usize, end_dim as usize)
} }
pub fn slice_config(node: &Node) -> (Vec<usize>, Vec<usize>) {
let start_value = &node.inputs[1].value;
let end_value = &node.inputs[2].value;
let starts = match &node.inputs[1].ty {
ArgType::Tensor(tensor) => {
assert_eq!(tensor.dim, 1, "Slice: ends tensor must be 1D");
if let Some(Data::Int64s(shape)) = start_value.as_ref() {
shape
.iter()
.map(|x| {
assert!(*x >= 0, "Slice: start must be positive");
*x as usize
})
.collect()
} else {
panic!("Tensor data type must be int64")
}
}
_ => panic!("Only tensor input is valid for shape"),
};
let ends = match &node.inputs[2].ty {
ArgType::Tensor(tensor) => {
assert_eq!(tensor.dim, 1, "Slice: ends tensor must be 1D");
if let Some(Data::Int64s(shape)) = end_value.as_ref() {
shape
.iter()
.map(|x| {
assert!(*x >= 0, "Slice: end must be positive");
*x as usize
})
.collect()
} else {
panic!("Tensor data type must be int64")
}
}
_ => panic!("Only tensor input is valid for shape"),
};
for (key, value) in node.attrs.iter() {
match key.as_str() {
"axes" => {
let mut i = 0;
value.clone().into_i64s().iter().for_each(|x| {
assert_eq!(*x, i, "Slice: axes must be consecutive");
i += 1;
})
}
"steps" => value.clone().into_i64s().into_iter().for_each(|x| {
if x != 1 {
panic!("Slice: steps other than 1 are not supported");
}
}),
_ => {}
}
}
(starts, ends)
}
pub fn transpose_config(curr: &Node) -> Vec<i64> { pub fn transpose_config(curr: &Node) -> Vec<i64> {
if curr.inputs.len() != 1 { if curr.inputs.len() != 1 {
panic!( panic!(

View File

@ -42,6 +42,7 @@ use crate::{
random_uniform::RandomUniformNode, random_uniform::RandomUniformNode,
range::RangeNode, range::RangeNode,
reshape::ReshapeNode, reshape::ReshapeNode,
slice::SliceNode,
squeeze::SqueezeNode, squeeze::SqueezeNode,
sum::SumNode, sum::SumNode,
unary::UnaryNode, unary::UnaryNode,
@ -294,6 +295,7 @@ impl OnnxGraph {
NodeType::Shape => graph.register(Self::shape_conversion(node)), NodeType::Shape => graph.register(Self::shape_conversion(node)),
NodeType::Sigmoid => graph.register(Self::sigmoid_conversion(node)), NodeType::Sigmoid => graph.register(Self::sigmoid_conversion(node)),
NodeType::Sin => graph.register(Self::sin_conversion(node)), NodeType::Sin => graph.register(Self::sin_conversion(node)),
NodeType::Slice => graph.register(Self::slice_conversion(node)),
NodeType::Sum => graph.register(Self::sum_conversion(node)), NodeType::Sum => graph.register(Self::sum_conversion(node)),
NodeType::Transpose => graph.register(Self::transpose_conversion(node)), NodeType::Transpose => graph.register(Self::transpose_conversion(node)),
NodeType::Concat => graph.register(Self::concat_conversion(node)), NodeType::Concat => graph.register(Self::concat_conversion(node)),
@ -686,6 +688,14 @@ impl OnnxGraph {
UnaryNode::sin(input, output) UnaryNode::sin(input, output)
} }
fn slice_conversion(node: Node) -> SliceNode {
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let (starts, ends) = slice_config(&node);
SliceNode::new(input, output, starts, ends)
}
fn sum_conversion(node: Node) -> SumNode { fn sum_conversion(node: Node) -> SumNode {
let inputs = node let inputs = node
.inputs .inputs