mirror of https://github.com/tracel-ai/burn.git
feat: added range onnx import (#1834)
* feat: added range onnx import * fix: range input types
This commit is contained in:
parent
36d4bcd705
commit
44f1053219
|
@ -134,7 +134,7 @@ represent the corresponding Burn Op.
|
|||
| [RandomNormalLike][127] | ❌ | ✅ |
|
||||
| [RandomUniform][128] | ✅ | ✅ |
|
||||
| [RandomUniformLike][129] | ❌ | ✅ |
|
||||
| [Range][130] | ❌ | ✅ |
|
||||
| [Range][130] | ✅ | ✅ |
|
||||
| [Reciprocal][131] | ✅ | ✅ |
|
||||
| [ReduceL][132] | ❌ | ❌ |
|
||||
| [ReduceLogSum][133] | ❌ | ❌ |
|
||||
|
|
|
@ -75,6 +75,7 @@ fn main() {
|
|||
.input("tests/squeeze/squeeze_opset13.onnx")
|
||||
.input("tests/random_uniform/random_uniform.onnx")
|
||||
.input("tests/random_normal/random_normal.onnx")
|
||||
.input("tests/range/range.onnx")
|
||||
.out_dir("model/")
|
||||
.run_from_script();
|
||||
|
||||
|
|
|
@ -57,6 +57,7 @@ include_models!(
|
|||
less,
|
||||
less_or_equal,
|
||||
prelu,
|
||||
range,
|
||||
recip,
|
||||
reduce_max,
|
||||
reduce_mean,
|
||||
|
@ -1070,6 +1071,21 @@ mod tests {
|
|||
output.to_data().assert_approx_eq(&expected, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn range() {
|
||||
let device = Default::default();
|
||||
let model: range::Model<Backend> = range::Model::new(&device);
|
||||
|
||||
// Run the model
|
||||
let start = 0i64;
|
||||
let limit = 10i64;
|
||||
let delta = 2i64;
|
||||
let output = model.forward(start, limit, delta);
|
||||
|
||||
let expected = Data::from([0, 2, 4, 6, 8]);
|
||||
assert_eq!(output.to_data(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn recip() {
|
||||
// Initialize the model
|
||||
|
|
Binary file not shown.
|
@ -0,0 +1,34 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# used to generate model: onnx-tests/tests/range/range.onnx
|
||||
|
||||
import onnx
|
||||
from onnx import helper, TensorProto
|
||||
|
||||
def main():
|
||||
node = onnx.helper.make_node(
|
||||
'Range',
|
||||
name='range',
|
||||
inputs=['start', 'end', 'step'],
|
||||
outputs=['output']
|
||||
)
|
||||
|
||||
graph_def = helper.make_graph(
|
||||
nodes=[node],
|
||||
name='RangeGraph',
|
||||
inputs=[
|
||||
helper.make_tensor_value_info('start', TensorProto.INT64, []),
|
||||
helper.make_tensor_value_info('end', TensorProto.INT64, []),
|
||||
helper.make_tensor_value_info('step', TensorProto.INT64, [])
|
||||
],
|
||||
outputs=[
|
||||
helper.make_tensor_value_info('output', TensorProto.INT64, [5])
|
||||
],
|
||||
)
|
||||
|
||||
model_def = helper.make_model(graph_def, producer_name='range')
|
||||
|
||||
onnx.save(model_def, 'range.onnx')
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -6,8 +6,8 @@ use super::{
|
|||
global_avg_pool::GlobalAvgPoolNode, layer_norm::LayerNormNode, linear::LinearNode,
|
||||
mask_where::WhereNode, matmul::MatmulNode, max_pool1d::MaxPool1dNode,
|
||||
max_pool2d::MaxPool2dNode, prelu::PReluNode, random_normal::RandomNormalNode,
|
||||
random_uniform::RandomUniformNode, reshape::ReshapeNode, squeeze::SqueezeNode,
|
||||
unary::UnaryNode, unsqueeze::UnsqueezeNode,
|
||||
random_uniform::RandomUniformNode, range::RangeNode, reshape::ReshapeNode,
|
||||
squeeze::SqueezeNode, unary::UnaryNode, unsqueeze::UnsqueezeNode,
|
||||
};
|
||||
use crate::burn::{BurnImports, Scope, Type};
|
||||
use burn::backend::NdArray;
|
||||
|
@ -97,6 +97,7 @@ pub enum Node<PS: PrecisionSettings> {
|
|||
Matmul(MatmulNode),
|
||||
MaxPool1d(MaxPool1dNode),
|
||||
MaxPool2d(MaxPool2dNode),
|
||||
Range(RangeNode),
|
||||
Reshape(ReshapeNode),
|
||||
Squeeze(SqueezeNode),
|
||||
Unary(UnaryNode),
|
||||
|
@ -130,6 +131,7 @@ macro_rules! match_all {
|
|||
Node::Matmul(node) => $func(node),
|
||||
Node::MaxPool1d(node) => $func(node),
|
||||
Node::MaxPool2d(node) => $func(node),
|
||||
Node::Range(node) => $func(node),
|
||||
Node::Reshape(node) => $func(node),
|
||||
Node::Squeeze(node) => $func(node),
|
||||
Node::Unary(node) => $func(node),
|
||||
|
@ -173,6 +175,7 @@ impl<PS: PrecisionSettings> Node<PS> {
|
|||
Node::Matmul(_) => "matmul",
|
||||
Node::MaxPool1d(_) => "max_pool1d",
|
||||
Node::MaxPool2d(_) => "max_pool2d",
|
||||
Node::Range(_) => "range",
|
||||
Node::Reshape(_) => "reshape",
|
||||
Node::Squeeze(_) => "squeeze",
|
||||
Node::Unary(unary) => unary.kind.as_str(),
|
||||
|
|
|
@ -23,6 +23,7 @@ pub(crate) mod max_pool2d;
|
|||
pub(crate) mod prelu;
|
||||
pub(crate) mod random_normal;
|
||||
pub(crate) mod random_uniform;
|
||||
pub(crate) mod range;
|
||||
pub(crate) mod reshape;
|
||||
pub(crate) mod squeeze;
|
||||
pub(crate) mod unary;
|
||||
|
|
|
@ -0,0 +1,102 @@
|
|||
use super::{Node, NodeCodegen};
|
||||
use crate::burn::{ScalarType, Scope, TensorType, Type};
|
||||
use burn::record::PrecisionSettings;
|
||||
use proc_macro2::TokenStream;
|
||||
use quote::quote;
|
||||
|
||||
#[derive(Debug, Clone, new)]
|
||||
pub struct RangeNode {
|
||||
pub start: ScalarType,
|
||||
pub end: ScalarType,
|
||||
pub step: ScalarType,
|
||||
pub output: TensorType,
|
||||
}
|
||||
|
||||
impl<PS: PrecisionSettings> NodeCodegen<PS> for RangeNode {
|
||||
fn output_types(&self) -> Vec<Type> {
|
||||
vec![Type::Tensor(self.output.clone())]
|
||||
}
|
||||
|
||||
fn input_types(&self) -> Vec<Type> {
|
||||
vec![
|
||||
Type::Scalar(self.start.clone()),
|
||||
Type::Scalar(self.end.clone()),
|
||||
Type::Scalar(self.step.clone()),
|
||||
]
|
||||
}
|
||||
|
||||
fn forward(&self, _scope: &mut Scope, _node_position: usize) -> TokenStream {
|
||||
let output = &self.output.name;
|
||||
|
||||
let start = &self.start.name;
|
||||
let end = &self.end.name;
|
||||
let step = &self.step.name;
|
||||
|
||||
quote! {
|
||||
let #output = Tensor::arange_step(#start..#end, #step as usize, &*self.device);
|
||||
}
|
||||
}
|
||||
fn into_node(self) -> Node<PS> {
|
||||
Node::Range(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::burn::graph::BurnGraph;
|
||||
use crate::burn::node::test::assert_tokens;
|
||||
use crate::burn::{ScalarKind, ScalarType};
|
||||
use burn::record::FullPrecisionSettings;
|
||||
|
||||
#[test]
|
||||
fn codegen_nodes_range() {
|
||||
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
|
||||
|
||||
graph.register(
|
||||
RangeNode::new(
|
||||
ScalarType::new("start", ScalarKind::Int64),
|
||||
ScalarType::new("end", ScalarKind::Int64),
|
||||
ScalarType::new("step", ScalarKind::Int64),
|
||||
TensorType::new_int("output", 1),
|
||||
)
|
||||
.into_node(),
|
||||
);
|
||||
graph.register_input_output(
|
||||
vec!["start".to_string(), "end".to_string(), "step".to_string()],
|
||||
vec!["output".to_string()],
|
||||
);
|
||||
|
||||
let expected = quote! {
|
||||
use burn::tensor::Int;
|
||||
use burn::{
|
||||
module::Module,
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
#[allow(unused_variables)]
|
||||
pub fn new(device: &B::Device) -> Self {
|
||||
Self {
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
pub fn forward(&self, start: i64, end: i64, step: i64) -> Tensor<B, 1, Int> {
|
||||
let output = Tensor::arange_step(start..end, step as usize, &*self.device);
|
||||
|
||||
output
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
assert_tokens(graph.codegen(), expected);
|
||||
}
|
||||
}
|
|
@ -51,6 +51,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
|
|||
NodeType::GreaterOrEqual => greater_or_equal_update_outputs(node),
|
||||
NodeType::Less => less_update_outputs(node),
|
||||
NodeType::LessOrEqual => less_or_equal_update_outputs(node),
|
||||
NodeType::Range => range_update_outputs(node),
|
||||
NodeType::Reciprocal => same_as_input(node),
|
||||
NodeType::ReduceMax => reduce_max_update_outputs(node),
|
||||
NodeType::ReduceMean => reduce_mean_update_outputs(node),
|
||||
|
@ -587,6 +588,18 @@ fn matmul_update_outputs(node: &mut Node) {
|
|||
}
|
||||
}
|
||||
|
||||
fn range_update_outputs(node: &mut Node) {
|
||||
if node.inputs.len() != 3 {
|
||||
panic!("Range: expected 3 inputs, found {}", node.inputs.len());
|
||||
}
|
||||
|
||||
node.outputs[0].ty = ArgType::Tensor(TensorType {
|
||||
elem_type: ElementType::Int64,
|
||||
dim: 1,
|
||||
shape: None,
|
||||
});
|
||||
}
|
||||
|
||||
/// Infers the shape of a ReduceMax node and replaces the shape of the output tensor.
|
||||
fn reduce_max_update_outputs(node: &mut Node) {
|
||||
if node.inputs.len() != 1 {
|
||||
|
|
|
@ -38,6 +38,7 @@ use crate::{
|
|||
prelu::PReluNode,
|
||||
random_normal::RandomNormalNode,
|
||||
random_uniform::RandomUniformNode,
|
||||
range::RangeNode,
|
||||
reshape::ReshapeNode,
|
||||
squeeze::SqueezeNode,
|
||||
unary::UnaryNode,
|
||||
|
@ -279,6 +280,7 @@ impl OnnxGraph {
|
|||
NodeType::Tanh => graph.register(Self::tanh_conversion(node)),
|
||||
NodeType::Constant => graph.register(Self::constant_conversion::<PS>(node)),
|
||||
NodeType::Min => graph.register(Self::min_conversion(node)),
|
||||
NodeType::Range => graph.register(Self::range_conversion(node)),
|
||||
NodeType::ReduceMax => graph.register(Self::reduce_max_conversion(node)),
|
||||
NodeType::ReduceMean => graph.register(Self::reduce_mean_conversion(node)),
|
||||
NodeType::ReduceSum => graph.register(Self::reduce_sum_conversion(node)),
|
||||
|
@ -575,6 +577,29 @@ impl OnnxGraph {
|
|||
BinaryNode::min_pair(lhs, rhs, output)
|
||||
}
|
||||
|
||||
fn range_conversion(node: Node) -> RangeNode {
|
||||
fn convert_arg_to_scalar(arg: &Argument) -> ScalarType {
|
||||
match &arg.ty {
|
||||
ArgType::Scalar(scalar) => {
|
||||
ScalarType::new(arg.name.clone(), ScalarKind::from(scalar))
|
||||
}
|
||||
ArgType::Tensor(tensor) => {
|
||||
if tensor.dim != 0 {
|
||||
panic!("Range node requires scalar inputs");
|
||||
}
|
||||
ScalarType::new(arg.name.clone(), ScalarKind::from(&tensor.elem_type))
|
||||
}
|
||||
_ => panic!("Range node requires scalar inputs"),
|
||||
}
|
||||
}
|
||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
||||
let start = convert_arg_to_scalar(node.inputs.first().unwrap());
|
||||
let end = convert_arg_to_scalar(node.inputs.get(1).unwrap());
|
||||
let step = convert_arg_to_scalar(node.inputs.get(2).unwrap());
|
||||
|
||||
RangeNode::new(start, end, step, output)
|
||||
}
|
||||
|
||||
fn reduce_max_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.first().unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
|
|
Loading…
Reference in New Issue