feat: added range onnx import (#1834)

* feat: added range onnx import

* fix: range input types
This commit is contained in:
jachym.putta 2024-05-31 23:40:54 +02:00 committed by GitHub
parent 36d4bcd705
commit 44f1053219
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 198 additions and 3 deletions

View File

@ -134,7 +134,7 @@ represent the corresponding Burn Op.
| [RandomNormalLike][127] | ❌ | ✅ |
| [RandomUniform][128] | ✅ | ✅ |
| [RandomUniformLike][129] | ❌ | ✅ |
| [Range][130] | | ✅ |
| [Range][130] | | ✅ |
| [Reciprocal][131] | ✅ | ✅ |
| [ReduceL][132] | ❌ | ❌ |
| [ReduceLogSum][133] | ❌ | ❌ |

View File

@ -75,6 +75,7 @@ fn main() {
.input("tests/squeeze/squeeze_opset13.onnx")
.input("tests/random_uniform/random_uniform.onnx")
.input("tests/random_normal/random_normal.onnx")
.input("tests/range/range.onnx")
.out_dir("model/")
.run_from_script();

View File

@ -57,6 +57,7 @@ include_models!(
less,
less_or_equal,
prelu,
range,
recip,
reduce_max,
reduce_mean,
@ -1070,6 +1071,21 @@ mod tests {
output.to_data().assert_approx_eq(&expected, 4);
}
#[test]
fn range() {
let device = Default::default();
let model: range::Model<Backend> = range::Model::new(&device);
// Run the model
let start = 0i64;
let limit = 10i64;
let delta = 2i64;
let output = model.forward(start, limit, delta);
let expected = Data::from([0, 2, 4, 6, 8]);
assert_eq!(output.to_data(), expected);
}
#[test]
fn recip() {
// Initialize the model

Binary file not shown.

View File

@ -0,0 +1,34 @@
#!/usr/bin/env python3
# used to generate model: onnx-tests/tests/range/range.onnx
import onnx
from onnx import helper, TensorProto
def main():
node = onnx.helper.make_node(
'Range',
name='range',
inputs=['start', 'end', 'step'],
outputs=['output']
)
graph_def = helper.make_graph(
nodes=[node],
name='RangeGraph',
inputs=[
helper.make_tensor_value_info('start', TensorProto.INT64, []),
helper.make_tensor_value_info('end', TensorProto.INT64, []),
helper.make_tensor_value_info('step', TensorProto.INT64, [])
],
outputs=[
helper.make_tensor_value_info('output', TensorProto.INT64, [5])
],
)
model_def = helper.make_model(graph_def, producer_name='range')
onnx.save(model_def, 'range.onnx')
if __name__ == '__main__':
main()

View File

@ -6,8 +6,8 @@ use super::{
global_avg_pool::GlobalAvgPoolNode, layer_norm::LayerNormNode, linear::LinearNode,
mask_where::WhereNode, matmul::MatmulNode, max_pool1d::MaxPool1dNode,
max_pool2d::MaxPool2dNode, prelu::PReluNode, random_normal::RandomNormalNode,
random_uniform::RandomUniformNode, reshape::ReshapeNode, squeeze::SqueezeNode,
unary::UnaryNode, unsqueeze::UnsqueezeNode,
random_uniform::RandomUniformNode, range::RangeNode, reshape::ReshapeNode,
squeeze::SqueezeNode, unary::UnaryNode, unsqueeze::UnsqueezeNode,
};
use crate::burn::{BurnImports, Scope, Type};
use burn::backend::NdArray;
@ -97,6 +97,7 @@ pub enum Node<PS: PrecisionSettings> {
Matmul(MatmulNode),
MaxPool1d(MaxPool1dNode),
MaxPool2d(MaxPool2dNode),
Range(RangeNode),
Reshape(ReshapeNode),
Squeeze(SqueezeNode),
Unary(UnaryNode),
@ -130,6 +131,7 @@ macro_rules! match_all {
Node::Matmul(node) => $func(node),
Node::MaxPool1d(node) => $func(node),
Node::MaxPool2d(node) => $func(node),
Node::Range(node) => $func(node),
Node::Reshape(node) => $func(node),
Node::Squeeze(node) => $func(node),
Node::Unary(node) => $func(node),
@ -173,6 +175,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::Matmul(_) => "matmul",
Node::MaxPool1d(_) => "max_pool1d",
Node::MaxPool2d(_) => "max_pool2d",
Node::Range(_) => "range",
Node::Reshape(_) => "reshape",
Node::Squeeze(_) => "squeeze",
Node::Unary(unary) => unary.kind.as_str(),

View File

@ -23,6 +23,7 @@ pub(crate) mod max_pool2d;
pub(crate) mod prelu;
pub(crate) mod random_normal;
pub(crate) mod random_uniform;
pub(crate) mod range;
pub(crate) mod reshape;
pub(crate) mod squeeze;
pub(crate) mod unary;

View File

@ -0,0 +1,102 @@
use super::{Node, NodeCodegen};
use crate::burn::{ScalarType, Scope, TensorType, Type};
use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::quote;
#[derive(Debug, Clone, new)]
pub struct RangeNode {
pub start: ScalarType,
pub end: ScalarType,
pub step: ScalarType,
pub output: TensorType,
}
impl<PS: PrecisionSettings> NodeCodegen<PS> for RangeNode {
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
}
fn input_types(&self) -> Vec<Type> {
vec![
Type::Scalar(self.start.clone()),
Type::Scalar(self.end.clone()),
Type::Scalar(self.step.clone()),
]
}
fn forward(&self, _scope: &mut Scope, _node_position: usize) -> TokenStream {
let output = &self.output.name;
let start = &self.start.name;
let end = &self.end.name;
let step = &self.step.name;
quote! {
let #output = Tensor::arange_step(#start..#end, #step as usize, &*self.device);
}
}
fn into_node(self) -> Node<PS> {
Node::Range(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::burn::graph::BurnGraph;
use crate::burn::node::test::assert_tokens;
use crate::burn::{ScalarKind, ScalarType};
use burn::record::FullPrecisionSettings;
#[test]
fn codegen_nodes_range() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
graph.register(
RangeNode::new(
ScalarType::new("start", ScalarKind::Int64),
ScalarType::new("end", ScalarKind::Int64),
ScalarType::new("step", ScalarKind::Int64),
TensorType::new_int("output", 1),
)
.into_node(),
);
graph.register_input_output(
vec!["start".to_string(), "end".to_string(), "step".to_string()],
vec!["output".to_string()],
);
let expected = quote! {
use burn::tensor::Int;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, start: i64, end: i64, step: i64) -> Tensor<B, 1, Int> {
let output = Tensor::arange_step(start..end, step as usize, &*self.device);
output
}
}
};
assert_tokens(graph.codegen(), expected);
}
}

View File

@ -51,6 +51,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
NodeType::GreaterOrEqual => greater_or_equal_update_outputs(node),
NodeType::Less => less_update_outputs(node),
NodeType::LessOrEqual => less_or_equal_update_outputs(node),
NodeType::Range => range_update_outputs(node),
NodeType::Reciprocal => same_as_input(node),
NodeType::ReduceMax => reduce_max_update_outputs(node),
NodeType::ReduceMean => reduce_mean_update_outputs(node),
@ -587,6 +588,18 @@ fn matmul_update_outputs(node: &mut Node) {
}
}
fn range_update_outputs(node: &mut Node) {
if node.inputs.len() != 3 {
panic!("Range: expected 3 inputs, found {}", node.inputs.len());
}
node.outputs[0].ty = ArgType::Tensor(TensorType {
elem_type: ElementType::Int64,
dim: 1,
shape: None,
});
}
/// Infers the shape of a ReduceMax node and replaces the shape of the output tensor.
fn reduce_max_update_outputs(node: &mut Node) {
if node.inputs.len() != 1 {

View File

@ -38,6 +38,7 @@ use crate::{
prelu::PReluNode,
random_normal::RandomNormalNode,
random_uniform::RandomUniformNode,
range::RangeNode,
reshape::ReshapeNode,
squeeze::SqueezeNode,
unary::UnaryNode,
@ -279,6 +280,7 @@ impl OnnxGraph {
NodeType::Tanh => graph.register(Self::tanh_conversion(node)),
NodeType::Constant => graph.register(Self::constant_conversion::<PS>(node)),
NodeType::Min => graph.register(Self::min_conversion(node)),
NodeType::Range => graph.register(Self::range_conversion(node)),
NodeType::ReduceMax => graph.register(Self::reduce_max_conversion(node)),
NodeType::ReduceMean => graph.register(Self::reduce_mean_conversion(node)),
NodeType::ReduceSum => graph.register(Self::reduce_sum_conversion(node)),
@ -575,6 +577,29 @@ impl OnnxGraph {
BinaryNode::min_pair(lhs, rhs, output)
}
fn range_conversion(node: Node) -> RangeNode {
fn convert_arg_to_scalar(arg: &Argument) -> ScalarType {
match &arg.ty {
ArgType::Scalar(scalar) => {
ScalarType::new(arg.name.clone(), ScalarKind::from(scalar))
}
ArgType::Tensor(tensor) => {
if tensor.dim != 0 {
panic!("Range node requires scalar inputs");
}
ScalarType::new(arg.name.clone(), ScalarKind::from(&tensor.elem_type))
}
_ => panic!("Range node requires scalar inputs"),
}
}
let output = node.outputs.first().unwrap().to_tensor_type();
let start = convert_arg_to_scalar(node.inputs.first().unwrap());
let end = convert_arg_to_scalar(node.inputs.get(1).unwrap());
let step = convert_arg_to_scalar(node.inputs.get(2).unwrap());
RangeNode::new(start, end, step, output)
}
fn reduce_max_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();