Full support for ONNX scalar operators and Constants (#578)

This commit is contained in:
Dilshod Tadjibaev 2023-08-04 15:51:51 -05:00 committed by GitHub
parent ca9a8808d9
commit 1554a3c898
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
49 changed files with 1462 additions and 561 deletions

View File

@ -11,6 +11,7 @@ members = [
"burn-dataset",
"burn-derive",
"burn-import",
"burn-import/onnx-tests",
"burn-ndarray",
"burn-no-std-tests",
"burn-tch",
@ -29,6 +30,7 @@ dashmap = "5.4.0"
dirs = "5.0.1"
fake = "2.6.1"
flate2 = "1.0.26"
float-cmp = "0.9.0"
gix-tempfile = {version = "7.0.0", features = ["signals"]}
hashbrown = "0.14.0"
indicatif = "0.17.5"

View File

@ -23,6 +23,12 @@ impl<T: Clone> Param<T> {
pub fn val(&self) -> T {
self.value.clone()
}
/// Execute the given function on the inner value.
pub fn map<F: FnOnce(T) -> T>(mut self, func: F) -> Self {
self.value = func(self.value);
self
}
}
impl<T> core::ops::Deref for Param<T> {

View File

@ -1,3 +1,5 @@
use core::marker::PhantomData;
use crate::{
self as burn,
module::{ADModule, Module, ModuleMapper, ModuleVisitor},
@ -135,12 +137,10 @@ impl<const D: usize, B: Backend> Module<B> for Tensor<B, D> {
}
fn into_record(self) -> Self::Record {
// Treat as a constant and do not record
ConstantRecord::new()
ConstantRecord
}
fn load_record(self, _record: Self::Record) -> Self {
// Treat as a constant and do not load
self
}
}
@ -153,15 +153,49 @@ impl<const D: usize, B: ADBackend> ADModule<B> for Tensor<B, D> {
}
}
impl<B: Backend> Module<B> for PhantomData<B> {
type Record = ConstantRecord;
fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {
// Nothing to do
}
fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
self
}
fn load_record(self, _record: Self::Record) -> Self {
self
}
fn into_record(self) -> Self::Record {
ConstantRecord::new()
}
}
impl<B: ADBackend> ADModule<B> for PhantomData<B> {
type InnerModule = PhantomData<B::InnerBackend>;
fn valid(&self) -> Self::InnerModule {
PhantomData
}
}
#[cfg(all(test, feature = "std"))]
mod tests {
use core::marker::PhantomData;
use burn_tensor::backend::Backend;
use burn_tensor::Tensor;
use crate::module::Module;
use crate::TestBackend;
use crate::{
record::{BinBytesRecorder, FullPrecisionSettings, Recorder},
TestADBackend,
};
use burn::module::Module;
use crate as burn;
#[test]
fn tensor_load_record_setting() {
@ -185,4 +219,16 @@ mod tests {
assert!(!no_grad_is_require_grad);
assert!(!with_default_is_require_grad);
}
#[test]
fn empty_module_with_phantom() {
#[derive(Module, Debug, new)]
struct EmptyModule<B: Backend> {
_phantom: PhantomData<B>,
}
let _module = EmptyModule::<TestBackend>::new();
assert_eq!(core::mem::size_of::<EmptyModule<TestBackend>>(), 0);
}
}

View File

@ -15,3 +15,5 @@ pub use settings::*;
mod file;
#[cfg(feature = "std")]
pub use file::*;
pub use primitive::ParamSerde;

View File

@ -32,7 +32,7 @@ List taken from [here](https://github.com/onnx/onnx/blob/main/docs/Operators.md)
- [ ] BitwiseOr
- [ ] BitwiseXor
- [ ] BlackmanWindow
- [ ] Cast
- [x] Cast
- [ ] CastLike
- [ ] Ceil
- [ ] Celu
@ -40,9 +40,9 @@ List taken from [here](https://github.com/onnx/onnx/blob/main/docs/Operators.md)
- [ ] Clip
- [ ] Col
- [ ] Compress
- [ ] Concat
- [x] Concat
- [ ] ConcatFromSequence
- [ ] Constant
- [x] Constant
- [ ] ConstantOfShape
- [ ] Conv
- [ ] Conv1d

View File

@ -0,0 +1,13 @@
[package]
name = "onnx-tests"
version = "0.9.0"
edition = "2021"
[dev-dependencies]
burn = { path = "../../burn" }
burn-ndarray = { path = "../../burn-ndarray" }
serde = { workspace = true }
float-cmp = { workspace = true }
[build-dependencies]
burn-import = { path = "../" }

View File

@ -0,0 +1,32 @@
# ONNX Tests
This crate contains ONNX models that are utilized in testing the conversion of ONNX to Burn source
code through the `burn-import` crate. The tests are designed as end-to-end tests, ensuring that ONNX
models are accurately converted into Burn source code. Of utmost importance is verifying that the
converted Burn source code compiles without errors and produces the same output as the original ONNX
model.
Here is the directory structure of this crate:
- `tests/<model>`: This directory contains the ONNX model and the Python script to generate it.
- `tests/<model>/<model>.onnx`: The ONNX model is generated by the script.
- `tests/<model>/<model>.py`: This is the Python script responsible for generating the ONNX model
using PyTorch.
- `tests/onnx_tests.rs`: This is the main test file, where all the tests are contained.
- `build.rs`: This build script generates the ONNX models and is executed by `cargo test` before
running the actual tests.
## Adding new tests
Here are the steps to add a new test:
1. Add your Python script to the `tests/<model>` directory. Refer to existing scripts for examples.
2. Run your Python script to generate the ONNX model and inspect the output of the model with the
test data. Use the inputs and outputs in your test.
3. Make sure the ONNX output contains the desired operators by verifying with the
[Netron](https://github.com/lutzroeder/netron) app. Sometimes PyTorch will optimize the model and
remove operators that are not necessary for the model to run. If this happens, you can disable
optimization by setting `torch.onnx.export(..., do_constant_folding=False)`.
4. Add an entry to the `build.rs` file to account for the generation of the new ONNX model.
5. Include a test in `tests/onnx_tests.rs` to test the new ONNX model.
6. Run `cargo test` to ensure your test passes.

View File

@ -0,0 +1,19 @@
use burn_import::onnx::ModelGen;
fn main() {
// Re-run this build script if the onnx-tests directory changes.
println!("cargo:rerun-if-changed=tests");
// Add onnx models.
ModelGen::new()
.input("tests/add/add.onnx")
.input("tests/sub/sub.onnx")
.input("tests/mul/mul.onnx")
.input("tests/div/div.onnx")
.input("tests/concat/concat.onnx")
.input("tests/conv2d/conv2d.onnx")
.out_dir("model/")
.run_from_script();
// panic!("Purposefully failing build to output logs.");
}

View File

@ -0,0 +1 @@

Binary file not shown.

View File

@ -0,0 +1,57 @@
#!/usr/bin/env python3
# used to generate model: onnx-tests/tests/add/add.onnx
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
# Declare a constant float tensor with ones
self.a = torch.ones(1, 1, 1, 4)
# Declare a scalar
self.b = 5.0
super(Model, self).__init__()
def forward(self, x, k):
# Add a tensor input and a constant tensor
x = x + self.a
# Add a scalar constant and a scalar input
d = self.b + k
# Add a tensor and a scalar
x = x + d
return x
def main():
# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "add.onnx"
dummy_input = torch.randn(1, 2, 3, 4, device=device)
scalar = 2.0
torch.onnx.export(model, (dummy_input, scalar), onnx_name,
verbose=False, opset_version=16)
print("Finished exporting model to {}".format(onnx_name))
# Output some test data for use in the test
test_input = torch.tensor([[[[1.0, 2.0, 3.0, 4.0]]]])
print("Test input data: {}, {}".format(test_input, scalar))
output = model.forward(test_input, scalar)
print("Test output data: {}".format(output))
if __name__ == '__main__':
main()

View File

@ -1,4 +1,4 @@
pytorch2.0.1:£
pytorch2.0.1:¡
P
onnx::Concat_0
onnx::Concat_0/Concat_output_0/Concat"Concat*
@ -9,16 +9,16 @@ P
/Concat_output_0
/Concat_output_0
/Concat_output_02 /Concat_1"Concat*
axis  torch_jitZ)
onnx::Concat_0

axis  torch_jitZ(
onnx::Concat_0




 b
2



b
2




 B


B

View File

@ -1,9 +1,9 @@
# used to generate model: burn-import/tests/data/conv2d/conv2d.onnx
#!/usr/bin/env python3
# used to generate model: onnx-tests/tests/concat/concat.onnx
import torch
import torch.nn as nn
import onnx
from onnxoptimizer import optimize
class Model(nn.Module):
def __init__(self):
@ -24,9 +24,20 @@ def main():
model.eval()
device = torch.device("cpu")
onnx_name = "concat.onnx"
dummy_input = torch.randn(1,256,13,13, device=device)
dummy_input = torch.randn(1,2,3,5, device=device)
torch.onnx.export(model, dummy_input, onnx_name,
verbose=False, opset_version=16)
print("Finished exporting model to {}".format(onnx_name))
# Output some test data for use in the test
test_input = torch.randn(1,2,3,5, device=device)
print("Test input data shape: {}".format(test_input.shape))
output = model.forward(test_input)
print("Test output data shape: {}".format(output.shape))
if __name__ == '__main__':
main()

Binary file not shown.

View File

@ -0,0 +1,43 @@
#!/usr/bin/env python3
# used to generate model: onnx-tests/tests/conv2d/conv2d.onnx
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(4, 6, (3, 5), groups = 2, stride=(2, 1), padding=(4, 2), dilation=(3, 1))
def forward(self, x):
x = self.conv1(x)
return x
def main():
# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
file_name = "conv2d.onnx"
test_input = torch.ones(2, 4, 10, 15, device=device)
torch.onnx.export(model, test_input, file_name,
verbose=False, opset_version=16)
print("Finished exporting model to {}".format(file_name))
# Output some test data for use in the test
print("Test input data shape of ones: {}".format(test_input.shape))
output = model.forward(test_input)
print("Test output data shape: {}".format(output.shape))
sum = output.sum().item()
print("Test output sum: {}".format(sum))
if __name__ == '__main__':
main()

Binary file not shown.

View File

@ -0,0 +1,47 @@
#!/usr/bin/env python3
# used to generate model: onnx-tests/tests/add/add.onnx
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x, k, m):
a = k / m
x = x / a
return x
def main():
# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "div.onnx"
dummy_input = torch.randn(1, 2, 3, 4, device=device)
scalar1, scalar2 = 9.0, 3.0
torch.onnx.export(model, (dummy_input, scalar1, scalar2), onnx_name,
verbose=False, opset_version=16)
print("Finished exporting model to {}".format(onnx_name))
# Output some test data for use in the test
test_input = torch.tensor([[[[3.0, 6.0, 6.0, 9.0]]]])
print("Test input data: {}, {}, {}".format(test_input, scalar1, scalar2))
output = model.forward(test_input, scalar1, scalar2)
print("Test output data: {}".format(output))
if __name__ == '__main__':
main()

Binary file not shown.

View File

@ -0,0 +1,57 @@
#!/usr/bin/env python3
# used to generate model: onnx-tests/tests/add/add.onnx
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
# Declare a constant float tensor
self.a = torch.full((1, 1, 1, 4), 3.0)
# Declare a scalar
self.b = 7.0
super(Model, self).__init__()
def forward(self, x, k):
# Multiply the input by the constant tensor
x = x * self.a
# Multiply the input scalar by the constant scalar
d = k * self.b
# Multiply the result of the previous multiplications
x = x * d
return x
def main():
# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "mul.onnx"
dummy_input = torch.randn(1, 2, 3, 4, device=device)
scalar = 6.0
torch.onnx.export(model, (dummy_input, scalar), onnx_name,
verbose=False, opset_version=16)
print("Finished exporting model to {}".format(onnx_name))
# Output some test data for use in the test
test_input = torch.tensor([[[[1.0, 2.0, 3.0, 4.0]]]])
print("Test input data: {}, {}".format(test_input, scalar))
output = model.forward(test_input, scalar)
print("Test output data: {}".format(output))
if __name__ == '__main__':
main()

View File

@ -0,0 +1,128 @@
pub mod add {
include!(concat!(env!("OUT_DIR"), "/model/add.rs"));
}
pub mod sub {
include!(concat!(env!("OUT_DIR"), "/model/sub.rs"));
}
pub mod mul {
include!(concat!(env!("OUT_DIR"), "/model/mul.rs"));
}
pub mod div {
include!(concat!(env!("OUT_DIR"), "/model/div.rs"));
}
pub mod concat {
include!(concat!(env!("OUT_DIR"), "/model/concat.rs"));
}
pub mod conv2d {
include!(concat!(env!("OUT_DIR"), "/model/conv2d.rs"));
}
#[cfg(test)]
mod tests {
use super::*;
use burn::tensor::{Data, Shape, Tensor};
use float_cmp::ApproxEq;
type Backend = burn_ndarray::NdArrayBackend<f32>;
#[test]
fn add_scalar_to_tensor_and_tensor_to_tensor() {
// Initialize the model with weights (loaded from the exported file)
let model: add::Model<Backend> = add::Model::default();
// Run the model
let input = Tensor::<Backend, 4>::from_floats([[[[1., 2., 3., 4.]]]]);
let scalar = 2f64;
let output = model.forward(input, scalar);
let expected = Data::from([[[[9., 10., 11., 12.]]]]);
assert_eq!(output.to_data(), expected);
}
#[test]
fn sub_scalar_from_tensor_and_tensor_from_tensor() {
// Initialize the model with weights (loaded from the exported file)
let model: sub::Model<Backend> = sub::Model::default();
// Run the model
let input = Tensor::<Backend, 4>::from_floats([[[[1., 2., 3., 4.]]]]);
let scalar = 3.0f64;
let output = model.forward(input, scalar);
let expected = Data::from([[[[6., 7., 8., 9.]]]]);
assert_eq!(output.to_data(), expected);
}
#[test]
fn mul_scalar_with_tensor_and_tensor_with_tensor() {
// Initialize the model with weights (loaded from the exported file)
let model: mul::Model<Backend> = mul::Model::default();
// Run the model
let input = Tensor::<Backend, 4>::from_floats([[[[1., 2., 3., 4.]]]]);
let scalar = 6.0f64;
let output = model.forward(input, scalar);
let expected = Data::from([[[[126., 252., 378., 504.]]]]);
assert_eq!(output.to_data(), expected);
}
#[test]
fn div_tensor_by_scalar_and_tensor_by_tensor() {
// Initialize the model without weights (because the exported file does not contain them)
let model: div::Model<Backend> = div::Model::new();
// Run the model
let input = Tensor::<Backend, 4>::from_floats([[[[3., 6., 6., 9.]]]]);
let scalar1 = 9.0f64;
let scalar2 = 3.0f64;
let output = model.forward(input, scalar1, scalar2);
let expected = Data::from([[[[1., 2., 2., 3.]]]]);
assert_eq!(output.to_data(), expected);
}
#[test]
fn concat_tensors() {
// Initialize the model
let model: concat::Model<Backend> = concat::Model::new();
// Run the model
let input = Tensor::<Backend, 4>::zeros([1, 2, 3, 5]);
let output = model.forward(input);
let expected = Shape::from([1, 18, 3, 5]);
assert_eq!(output.shape(), expected);
}
#[test]
fn conv2d() {
// Initialize the model with weights (loaded from the exported file)
let model: conv2d::Model<Backend> = conv2d::Model::default();
// Run the model with ones as input for easier testing
let input = Tensor::<Backend, 4>::ones([2, 4, 10, 15]);
let output = model.forward(input);
let expected_shape = Shape::from([2, 6, 6, 15]);
assert_eq!(output.shape().clone(), expected_shape);
// We are using the sum of the output tensor to test the correctness of the conv2d node
// because the output tensor is too large to compare with the expected tensor.
let output_sum = output.sum().into_scalar();
let expected_sum = 24.004_995; // from pytorch
assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2)));
}
}

Binary file not shown.

View File

@ -0,0 +1,57 @@
#!/usr/bin/env python3
# used to generate model: onnx-tests/tests/add/add.onnx
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
# Declare a constant float tensor with ones
self.a = torch.ones(1, 1, 1, 4)
# Declare a scalar
self.b = 9.0
super(Model, self).__init__()
def forward(self, x, k):
# Subtract a constant tensor from a tensor input
x = x - self.a
# Subtract a scalar constant from a scalar input
d = k - self.b
# Sutract a scalar from a tensor
x = x - d
return x
def main():
# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "sub.onnx"
dummy_input = torch.randn(1, 2, 3, 4, device=device)
scalar = 3.0
torch.onnx.export(model, (dummy_input, scalar), onnx_name,
verbose=False, opset_version=16)
print("Finished exporting model to {}".format(onnx_name))
# Output some test data for use in the test
test_input = torch.tensor([[[[1.0, 2.0, 3.0, 4.0]]]])
print("Test input data: {}, {}".format(test_input, scalar))
output = model.forward(test_input, scalar)
print("Test output data: {}".format(output))
if __name__ == '__main__':
main()

View File

@ -9,7 +9,7 @@ use burn::record::{
use proc_macro2::TokenStream;
use quote::quote;
use serde::{ser::SerializeMap, Serialize};
use std::path::PathBuf;
use std::{collections::HashMap, path::PathBuf};
/// Burn graph intermediate representation of modules and tensor operations.
#[derive(Default, Debug)]
@ -21,6 +21,8 @@ pub struct BurnGraph<PS: PrecisionSettings> {
default: Option<TokenStream>,
blank_spaces: bool,
gen_new_fn: bool,
graph_input_types: Vec<Type>,
graph_output_types: Vec<Type>,
}
impl<PS: PrecisionSettings> BurnGraph<PS> {
@ -163,20 +165,22 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
fn build_scope(&mut self) {
log::debug!("Building the scope nodes len => '{}'", self.nodes.len());
let input = self.nodes.first().unwrap();
fn to_tensor(ty: Type<'_>) -> Option<&TensorType> {
fn to_tensor(ty: Type) -> Option<TensorType> {
match ty {
Type::Tensor(tensor) => Some(tensor),
Type::Tensor(tensor) => Some(tensor.clone()),
Type::Scalar(_) => None,
Type::Other(_) => None,
}
}
input
.input_types()
// Register graph tensor input with 0 as node position
self.graph_input_types
.clone()
.into_iter()
.flat_map(to_tensor)
.for_each(|tensor| self.scope.tensor_register_variable(tensor, 0));
.for_each(|tensor| {
self.scope.tensor_register_variable(&tensor, 0);
});
self.nodes
.iter()
@ -187,7 +191,7 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
.flat_map(to_tensor)
.for_each(|tensor| {
self.scope
.tensor_register_variable(tensor, node_position + 1)
.tensor_register_variable(&tensor, node_position + 1)
})
});
@ -198,7 +202,10 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
node.input_types()
.into_iter()
.flat_map(to_tensor)
.for_each(|tensor| self.scope.tensor_register_future_use(tensor, node_position))
.for_each(|tensor| {
self.scope
.tensor_register_future_use(&tensor, node_position)
})
});
}
@ -240,16 +247,33 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
let name = field.name();
let ty = field.ty();
quote! {
#name: #ty,
if matches!(&field, Type::Tensor(_)) {
quote! {
#name: burn::module::Param<#ty>,
}
} else {
quote! {
#name: #ty,
}
}
})
.for_each(|code| body.extend(code));
quote! {
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
#body
// Add dummy field if no field is present to avoid empty struct
// and make sure we can derive Module trait and use it in a model.
if body.is_empty() {
quote! {
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
_phantom: core::marker::PhantomData<B>,
}
}
} else {
quote! {
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
#body
}
}
}
}
@ -269,13 +293,24 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
.map(|field| field.name().clone())
.collect::<Vec<_>>();
quote! {
#[allow(dead_code)]
pub fn new() -> Self {
#body
if fields.is_empty() {
quote! {
#[allow(dead_code)]
pub fn new() -> Self {
Self {
_phantom: core::marker::PhantomData,
}
}
}
} else {
quote! {
#[allow(dead_code)]
pub fn new() -> Self {
#body
Self {
#(#fields,)*
Self {
#(#fields,)*
}
}
}
}
@ -295,12 +330,22 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
.map(|field| field.name().clone())
.collect::<Vec<_>>();
quote! {
pub fn new_with(record: ModelRecord<B>) -> Self {
#body
if fields.is_empty() {
quote! {
pub fn new_with(_record: ModelRecord<B>) -> Self {
Self {
_phantom: core::marker::PhantomData,
}
}
}
} else {
quote! {
pub fn new_with(record: ModelRecord<B>) -> Self {
#body
Self {
#(#fields,)*
Self {
#(#fields,)*
}
}
}
}
@ -311,26 +356,19 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
let mut output_type_def = quote! {};
let mut output_return_def = quote! {};
self.nodes
.first()
.unwrap()
.input_types()
.into_iter()
.for_each(|input| {
let name = input.name();
let ty = input.ty();
self.graph_input_types.iter().for_each(|input| {
let name = input.name().clone();
let ty = input.ty().clone();
input_def.extend(quote! {
#name: #ty,
input_def.extend(quote! {
#name: #ty,
})
});
})
});
let output_types = self.nodes.last().unwrap().output_types();
let multiple_output = self.graph_output_types.len() > 1;
let multiple_output = output_types.len() > 1;
output_types.into_iter().for_each(|output| {
self.graph_output_types.iter().for_each(|output| {
let name = output.name();
let ty = output.ty();
@ -379,6 +417,48 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
}
}
}
/// Register the input and output types of the graph using the passed in names.
/// The names must be unique and match the names of the inputs and outputs of the nodes.
/// The order will be preserved.
///
/// # Arguments
///
/// * `input_names` - The names of the inputs of the graph.
/// * `output_names` - The names of the outputs of the graph.
///
/// # Panics
///
/// Panics if the graph is empty.
pub fn register_input_output(&mut self, input_names: Vec<String>, output_names: Vec<String>) {
assert!(
!self.nodes.is_empty(),
"Cannot register input and output types for an empty graph."
);
// Get the unique names of each input of the nodes
let mut inputs = HashMap::new();
let mut outputs = HashMap::new();
for node in self.nodes.iter() {
for input in node.input_types() {
inputs.insert(input.name().to_string(), input);
}
for output in node.output_types() {
outputs.insert(output.name().to_string(), output);
}
}
// Get the input and output types of the graph using passed in names
input_names.iter().for_each(|input| {
self.graph_input_types
.push(inputs.get(input).unwrap().clone());
});
output_names.iter().for_each(|output| {
self.graph_output_types
.push(outputs.get(output).unwrap().clone());
});
}
}
#[derive(new)]

View File

@ -77,7 +77,7 @@ pub enum Node<PS: PrecisionSettings> {
MaxPool2d(MaxPool2dNode),
Linear(LinearNode<PS>),
BatchNorm(BatchNormNode<PS>),
Constant(ConstantNode),
Constant(ConstantNode<PS>),
Unary(UnaryNode),
Reshape(ReshapeNode),
Concat(ConcatNode),
@ -174,7 +174,6 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for Node<PS> {
#[cfg(test)]
pub(crate) mod tests {
use crate::burn::{
codegen::ToTokens,
graph::BurnGraph,
node::{conv2d::Conv2dNode, matmul::MatmulNode, test::assert_tokens, NodeCodegen},
TensorType,
@ -185,14 +184,18 @@ pub(crate) mod tests {
use proc_macro2::TokenStream;
use quote::quote;
fn one_node_graph<T: NodeCodegen<FullPrecisionSettings> + 'static>(
pub(crate) fn one_node_graph<T: NodeCodegen<FullPrecisionSettings> + 'static>(
node_gen: T,
forward: TokenStream,
input_names: Vec<String>,
output_names: Vec<String>,
) {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
graph.register(node_gen);
graph.register_input_output(input_names, output_names);
let expected = quote! {
use burn::{
module::Module,
@ -200,11 +203,15 @@ pub(crate) mod tests {
};
#[derive(Module, Debug)]
pub struct Model <B: Backend>{}
pub struct Model<B: Backend> {
_phantom: core::marker::PhantomData<B>,
}
impl<B: Backend> Model <B> {
pub fn new_with(record: ModelRecord<B>) -> Self {
Self { }
pub fn new_with(_record: ModelRecord<B>) -> Self {
Self {
_phantom: core::marker::PhantomData,
}
}
#[allow(clippy::let_and_return)]
@ -215,42 +222,6 @@ pub(crate) mod tests {
assert_tokens(graph.codegen(), expected);
}
pub(crate) fn codegen_unary_operator<
const N: usize,
T: NodeCodegen<FullPrecisionSettings> + 'static,
>(
node_gen: T,
function: TokenStream,
) {
let forward = |function, tensor_dim| {
quote! {
pub fn forward(&self, tensor1: Tensor<B, #tensor_dim>) -> Tensor<B, #tensor_dim> {
#function
}
}
};
one_node_graph(node_gen, forward(function, N.to_tokens()));
}
pub(crate) fn codegen_binary_operator<
const N: usize,
T: NodeCodegen<FullPrecisionSettings> + 'static,
>(
node_gen: T,
function: TokenStream,
) {
let forward = |function, tensor_dim| {
quote! {
pub fn forward(&self, tensor1: Tensor<B, #tensor_dim>, tensor2: Tensor<B, #tensor_dim>) -> Tensor<B, #tensor_dim> {
#function
}
}
};
one_node_graph(node_gen, forward(function, N.to_tokens()));
}
#[test]
fn test_codegen_two_nodes() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
@ -269,6 +240,11 @@ pub(crate) mod tests {
Conv2dConfig::new([3, 3], [3, 3]).with_padding(PaddingConfig2d::Valid),
));
graph.register_input_output(
vec!["tensor1".to_string(), "tensor2".to_string()],
vec!["tensor4".to_string()],
);
let expected = quote! {
use burn::{
module::Module,
@ -333,6 +309,11 @@ pub(crate) mod tests {
TensorType::new_float("output", 4),
));
graph.register_input_output(
vec!["tensor1".to_string(), "tensor2".to_string()],
vec!["output".to_string()],
);
let expected = quote! {
use burn::{
module::Module,

View File

@ -102,13 +102,13 @@ macro_rules! batch_norm_serialize {
impl<PS: PrecisionSettings> NodeCodegen<PS> for BatchNormNode<PS> {
fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(&self.input)]
vec![Type::Tensor(self.input.clone())]
}
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(&self.output)]
vec![Type::Tensor(self.output.clone())]
}
fn field_type(&self) -> Option<Type> {
Some(Type::Other(&self.field))
Some(Type::Other(self.field.clone()))
}
fn field_init(&self, with_record: bool) -> Option<TokenStream> {
@ -181,6 +181,8 @@ mod tests {
BatchNormConfig::new(128),
));
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);
let expected = quote! {
use burn::{
module::Module,

View File

@ -1,5 +1,5 @@
use super::{Node, NodeCodegen};
use crate::burn::{Scope, TensorType, Type};
use crate::burn::{Scope, Type};
use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::quote;
@ -32,9 +32,9 @@ type FnPointer = Arc<dyn Fn(TokenStream, TokenStream) -> TokenStream>;
/// Node for all binary operators.
#[derive(Clone, new)]
pub struct BinaryNode {
pub lhs: TensorType,
pub rhs: TensorType,
pub output: TensorType,
pub lhs: Type,
pub rhs: Type,
pub output: Type,
pub binary_type: BinaryType,
function: FnPointer,
}
@ -56,17 +56,35 @@ impl std::fmt::Debug for BinaryNode {
impl<PS: PrecisionSettings> NodeCodegen<PS> for BinaryNode {
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(&self.output)]
vec![self.output.clone()]
}
fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(&self.lhs), Type::Tensor(&self.rhs)]
vec![self.lhs.clone(), self.rhs.clone()]
}
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let lhs = scope.tensor_use_owned(&self.lhs, node_position);
let rhs = scope.tensor_use_owned(&self.rhs, node_position);
let output = &self.output.name;
// Get the lhs name in the form of token stream.
let lhs = match &self.lhs {
Type::Tensor(tensor) => scope.tensor_use_owned(tensor, node_position),
Type::Scalar(scalar) => {
let name = scalar.name.clone();
quote! { #name }
}
_ => panic!("lhs must be a tensor or scalar"),
};
// Get the rhs name in the form of token stream
let rhs = match &self.rhs {
Type::Tensor(tensor) => scope.tensor_use_owned(tensor, node_position),
Type::Scalar(scalar) => {
let name = scalar.name.clone();
quote! { #name }
}
_ => panic!("rhs must be a tensor or scalar"),
};
let output = &self.output.name();
let function = (self.function)(lhs, rhs);
quote! {
@ -80,27 +98,53 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for BinaryNode {
}
impl BinaryNode {
pub(crate) fn add(lhs: TensorType, rhs: TensorType, output: TensorType) -> Self {
let function = move |lhs, rhs| quote! { #lhs.add(#rhs) };
pub(crate) fn add(lhs: Type, rhs: Type, output: Type) -> Self {
let function = match (&lhs, &rhs) {
(Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.add(#rhs) },
(Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.add_scalar(#rhs) },
(Type::Scalar(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #rhs.add_scalar(#lhs) },
(Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs + #rhs },
_ => panic!("Addition is supported for tensor and scalar only"),
};
Self::new(lhs, rhs, output, BinaryType::Add, Arc::new(function))
}
pub(crate) fn sub(lhs: TensorType, rhs: TensorType, output: TensorType) -> Self {
let function = move |lhs, rhs| quote! { #lhs.sub(#rhs) };
pub(crate) fn sub(lhs: Type, rhs: Type, output: Type) -> Self {
let function = match (&lhs, &rhs) {
(Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.sub(#rhs) },
(Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.sub_scalar(#rhs) },
(Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs - #rhs },
_ => panic!("Subtraction is supported for tensor and scalar only"),
};
Self::new(lhs, rhs, output, BinaryType::Sub, Arc::new(function))
}
pub(crate) fn mul(lhs: TensorType, rhs: TensorType, output: TensorType) -> Self {
let function = move |lhs, rhs| quote! { #lhs.mul(#rhs) };
pub(crate) fn mul(lhs: Type, rhs: Type, output: Type) -> Self {
let function = match (&lhs, &rhs) {
(Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.mul(#rhs) },
(Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.mul_scalar(#rhs) },
(Type::Scalar(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #rhs.mul_scalar(#lhs) },
(Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs * #rhs },
_ => panic!("Multiplication is supported for tensor and scalar only"),
};
Self::new(lhs, rhs, output, BinaryType::Mul, Arc::new(function))
}
pub(crate) fn div(lhs: TensorType, rhs: TensorType, output: TensorType) -> Self {
let function = move |lhs, rhs| quote! { #lhs.div(#rhs) };
pub(crate) fn div(lhs: Type, rhs: Type, output: Type) -> Self {
let function = match (&lhs, &rhs) {
(Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.div(#rhs) },
(Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.div_scalar(#rhs) },
(Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs / #rhs },
_ => panic!("Division is supported for tensor and scalar only"),
};
Self::new(lhs, rhs, output, BinaryType::Div, Arc::new(function))
}
pub(crate) fn equal(lhs: TensorType, rhs: TensorType, output: TensorType) -> Self {
pub(crate) fn equal(lhs: Type, rhs: Type, output: Type) -> Self {
let function = move |lhs, rhs| quote! { #lhs.equal(#rhs) };
Self::new(lhs, rhs, output, BinaryType::Equal, Arc::new(function))
}
@ -110,48 +154,93 @@ impl BinaryNode {
mod tests {
use super::*;
use crate::burn::node::tests::codegen_binary_operator;
use crate::burn::TensorType;
use crate::burn::node::tests::one_node_graph;
use crate::burn::{ScalarKind, ScalarType, TensorType};
macro_rules! test_binary_operator {
macro_rules! test_binary_operator_on_tensors {
($operator:ident) => {{
codegen_binary_operator::<4, _>(
one_node_graph(
BinaryNode::$operator(
TensorType::new_float("tensor1", 4),
TensorType::new_float("tensor2", 4),
TensorType::new_float("tensor3", 4),
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_float("tensor2", 4)),
Type::Tensor(TensorType::new_float("tensor3", 4)),
),
quote! {
let tensor3 = tensor1.$operator(tensor2);
pub fn forward(&self, tensor1: Tensor<B, 4>, tensor2: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor3 = tensor1.$operator(tensor2);
tensor3
tensor3
}
},
vec!["tensor1".to_string(), "tensor2".to_string()],
vec!["tensor3".to_string()],
);
}};
}
macro_rules! test_binary_operator_on_tensor_and_scalar {
($operator:ident, $burn_operator:ident) => {{
one_node_graph(
BinaryNode::$operator(
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Scalar(ScalarType::new("scalar1", ScalarKind::Float32)),
Type::Tensor(TensorType::new_float("tensor3", 4)),
),
quote! {
pub fn forward(&self, scalar1: f32, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor3 = tensor1.$burn_operator(scalar1);
tensor3
}
},
vec!["scalar1".to_string(), "tensor1".to_string()],
vec!["tensor3".to_string()],
);
}};
}
#[test]
fn test_binary_codegen_add() {
test_binary_operator!(add);
test_binary_operator_on_tensors!(add);
}
#[test]
fn test_binary_codegen_add_scalar() {
test_binary_operator_on_tensor_and_scalar!(add, add_scalar);
}
#[test]
fn test_binary_codegen_sub() {
test_binary_operator!(sub);
test_binary_operator_on_tensors!(sub);
}
#[test]
fn test_binary_codegen_sub_scalar() {
test_binary_operator_on_tensor_and_scalar!(sub, sub_scalar);
}
#[test]
fn test_binary_codegen_mul() {
test_binary_operator!(mul);
test_binary_operator_on_tensors!(mul);
}
#[test]
fn test_binary_codegen_mul_scalar() {
test_binary_operator_on_tensor_and_scalar!(mul, mul_scalar);
}
#[test]
fn test_binary_codegen_div() {
test_binary_operator!(div);
test_binary_operator_on_tensors!(div);
}
#[test]
fn test_binary_codegen_div_scalar() {
test_binary_operator_on_tensor_and_scalar!(div, div_scalar);
}
#[test]
fn test_binary_codegen_equal() {
test_binary_operator!(equal);
test_binary_operator_on_tensors!(equal);
}
}

View File

@ -14,11 +14,14 @@ pub struct ConcatNode {
impl<PS: PrecisionSettings> NodeCodegen<PS> for ConcatNode {
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(&self.output)]
vec![Type::Tensor(self.output.clone())]
}
fn input_types(&self) -> Vec<Type> {
self.inputs.iter().map(Type::Tensor).collect()
self.inputs
.iter()
.map(|t| Type::Tensor(t.clone()))
.collect()
}
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
@ -65,6 +68,11 @@ mod tests {
1,
));
graph.register_input_output(
vec!["tensor1".to_string(), "tensor2".to_string()],
vec!["tensor3".to_string()],
);
let expected = quote! {
use burn::{
module::Module,
@ -72,12 +80,17 @@ mod tests {
};
#[derive(Module, Debug)]
pub struct Model <B: Backend>{}
pub struct Model<B: Backend> {
_phantom: core::marker::PhantomData<B>,
}
impl<B: Backend> Model <B> {
pub fn new_with(record: ModelRecord<B>) -> Self {
Self { }
pub fn new_with(_record: ModelRecord<B>) -> Self {
Self {
_phantom: core::marker::PhantomData,
}
}
#[allow(clippy::let_and_return)]
pub fn forward(&self, tensor1: Tensor<B, 4>, tensor2: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor3 = burn::tensor::Tensor::cat(vec![tensor1, tensor2], 1);

View File

@ -1,72 +1,177 @@
use super::{Node, NodeCodegen};
use crate::burn::{OtherType, Scope, Type};
use burn::record::PrecisionSettings;
use crate::burn::{ScalarKind, ScalarType, Scope, TensorType, ToTokens, Type};
use burn::{
module::ParamId,
record::{ParamSerde, PrecisionSettings},
tensor::DataSerialize,
};
use proc_macro2::{Ident, Span, TokenStream};
use quote::quote;
use serde::Serialize;
#[derive(Debug, Clone)]
pub struct ConstantNode {
pub struct ConstantNode<PS: PrecisionSettings> {
pub name: String,
pub value: ConstantValue,
output_ty: OtherType,
pub value: ConstantValue<PS>,
pub output: Type,
}
#[derive(Debug, Clone)]
pub enum TensorValue<PS: PrecisionSettings> {
Float(DataSerialize<PS::FloatElem>),
Int(DataSerialize<PS::IntElem>),
}
#[derive(Debug, Clone, new)]
pub enum ConstantValue {
Int(i32),
Float(f32),
Bool(bool),
pub enum ConstantValue<PS: PrecisionSettings> {
/// Float constant.
Float32(f32),
Float64(f64),
/// Integer constant.
Int32(i32),
Int64(i64),
/// Tensor constant.
Tensor(TensorType, TensorValue<PS>),
}
impl ConstantValue {
impl<PS: PrecisionSettings> ConstantValue<PS> {
pub fn ty_tokens(&self) -> TokenStream {
match self {
ConstantValue::Int(_) => quote! { i32 },
ConstantValue::Float(_) => quote! { f32 },
ConstantValue::Bool(_) => quote! { bool },
ConstantValue::Float32(_) => quote! { f32 },
ConstantValue::Float64(_) => quote! { f64 },
ConstantValue::Int32(_) => quote! { i32 },
ConstantValue::Int64(_) => quote! { i64 },
ConstantValue::Tensor(tensor_type, _) => {
let ty = tensor_type.ty();
quote! { burn::module::Param<#ty>}
}
}
}
pub fn val_tokens(&self) -> TokenStream {
match self {
ConstantValue::Int(val) => quote! { #val },
ConstantValue::Float(val) => quote! { #val },
ConstantValue::Bool(val) => quote! { #val },
ConstantValue::Float32(val) => quote! { #val },
ConstantValue::Float64(val) => quote! { #val },
ConstantValue::Int32(val) => quote! { #val },
ConstantValue::Int64(val) => quote! { #val },
ConstantValue::Tensor(_, _) => {
panic!("Tensor constant is not assignable.")
}
}
}
}
impl ConstantNode {
pub fn new(name: String, value: ConstantValue) -> Self {
let output_ty = OtherType::new(name.clone(), value.ty_tokens());
impl<PS: PrecisionSettings> ConstantNode<PS> {
pub fn new(name: String, value: ConstantValue<PS>, output: Type) -> Self {
Self {
name,
value,
output_ty,
output,
}
}
pub fn constant_value_into_type(&self) -> Type {
let name = Ident::new(self.name.as_str(), Span::call_site());
match &self.value {
ConstantValue::Float32(_) => Type::Scalar(ScalarType {
name,
kind: ScalarKind::Float32,
}),
ConstantValue::Float64(_) => Type::Scalar(ScalarType {
name,
kind: ScalarKind::Float64,
}),
ConstantValue::Int32(_) => Type::Scalar(ScalarType {
name,
kind: ScalarKind::Int32,
}),
ConstantValue::Int64(_) => Type::Scalar(ScalarType {
name,
kind: ScalarKind::Int64,
}),
ConstantValue::Tensor(tensor_type, _) => Type::Tensor(tensor_type.clone()),
}
}
}
impl<PS: PrecisionSettings> NodeCodegen<PS> for ConstantNode {
impl<PS: PrecisionSettings> NodeCodegen<PS> for ConstantNode<PS> {
fn output_types(&self) -> Vec<Type> {
vec![Type::Other(&self.output_ty)]
vec![self.output.clone()]
}
fn input_types(&self) -> Vec<Type> {
vec![]
}
fn field_type(&self) -> Option<Type> {
match &self.value {
ConstantValue::Tensor(tensor_type, _) => Some(Type::Tensor(tensor_type.clone())),
_ => None,
}
}
fn field_init(&self, with_record: bool) -> Option<TokenStream> {
match &self.value {
ConstantValue::Tensor(tensor_type, _) => {
let ty = tensor_type.ty();
let name = Ident::new(self.name.as_ref(), Span::call_site());
let shape = tensor_type.clone().shape.unwrap().to_tokens();
let dim = tensor_type.clone().dim.to_tokens();
if with_record {
Some(quote! {
let #name = record.#name.map(|tensor| tensor.set_require_grad(false));
})
} else {
Some(quote! {
let #name: burn::module::Param<#ty> = burn::module::Param::new(
burn::module::ParamId::new(),
Tensor::<B, #dim>::zeros(#shape).set_require_grad(false),
);
})
}
}
_ => None,
}
}
fn forward(&self, _scope: &mut Scope, _node_position: usize) -> TokenStream {
let name = Ident::new(self.name.as_ref(), Span::call_site());
let val = self.value.val_tokens();
let ty = self.value.ty_tokens();
let output = self.output.name();
quote! {
let #name: #ty = #val;
match &self.value {
ConstantValue::Tensor(_, _) => {
quote! {
let #output = self.#name.val();
}
}
_ => {
let val = self.value.val_tokens();
let ty = self.value.ty_tokens();
quote! {
let #output: #ty = #val;
}
}
}
}
fn into_node(self) -> Node<PS> {
Node::Constant(self)
}
fn field_serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
if let ConstantValue::Tensor(_, ds) = &self.value {
let data: DataSerialize<PS::FloatElem> = match ds {
TensorValue::Float(data) => data.clone().convert(),
TensorValue::Int(data) => data.clone().convert(),
};
let data = ParamSerde::new(ParamId::new().into_string(), data);
return data.serialize(serializer);
}
S::serialize_none(serializer)
}
}
// TODO add test missing for constant node (@antimora 8/2/2023)

View File

@ -47,13 +47,13 @@ impl<PS: PrecisionSettings> Conv2dNode<PS> {
impl<PS: PrecisionSettings> NodeCodegen<PS> for Conv2dNode<PS> {
fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(&self.input)]
vec![Type::Tensor(self.input.clone())]
}
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(&self.output)]
vec![Type::Tensor(self.output.clone())]
}
fn field_type(&self) -> Option<Type> {
Some(Type::Other(&self.field))
Some(Type::Other(self.field.clone()))
}
fn field_init(&self, with_record: bool) -> Option<TokenStream> {
@ -154,6 +154,8 @@ mod tests {
Conv2dConfig::new([3, 3], [3, 3]).with_padding(PaddingConfig2d::Valid),
));
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);
let expected = quote! {
use burn::{
module::Module,

View File

@ -47,14 +47,14 @@ impl<PS: PrecisionSettings> LinearNode<PS> {
impl<PS: PrecisionSettings> NodeCodegen<PS> for LinearNode<PS> {
fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(&self.input)]
vec![Type::Tensor(self.input.clone())]
}
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(&self.output)]
vec![Type::Tensor(self.output.clone())]
}
fn field_type(&self) -> Option<Type> {
Some(Type::Other(&self.field))
Some(Type::Other(self.field.clone()))
}
fn field_init(&self, with_record: bool) -> Option<TokenStream> {
@ -136,6 +136,8 @@ mod tests {
LinearConfig::new(128, 128),
));
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);
let expected = quote! {
use burn::{
module::Module,

View File

@ -13,11 +13,14 @@ pub struct MatmulNode {
impl<PS: PrecisionSettings> NodeCodegen<PS> for MatmulNode {
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(&self.output)]
vec![Type::Tensor(self.output.clone())]
}
fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(&self.lhs), Type::Tensor(&self.rhs)]
vec![
Type::Tensor(self.lhs.clone()),
Type::Tensor(self.rhs.clone()),
]
}
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
@ -57,6 +60,11 @@ mod tests {
TensorType::new_float("tensor3", 4),
));
graph.register_input_output(
vec!["tensor1".to_string(), "tensor2".to_string()],
vec!["tensor3".to_string()],
);
let expected = quote! {
use burn::{
module::Module,
@ -64,12 +72,17 @@ mod tests {
};
#[derive(Module, Debug)]
pub struct Model <B: Backend>{}
pub struct Model<B: Backend> {
_phantom: core::marker::PhantomData<B>,
}
impl<B: Backend> Model <B> {
pub fn new_with(record: ModelRecord<B>) -> Self {
Self { }
pub fn new_with(_record: ModelRecord<B>) -> Self {
Self {
_phantom: core::marker::PhantomData,
}
}
#[allow(clippy::let_and_return)]
pub fn forward(&self, tensor1: Tensor<B, 4>, tensor2: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor3 = tensor1.matmul(tensor2);

View File

@ -37,13 +37,13 @@ impl MaxPool2dNode {
impl<PS: PrecisionSettings> NodeCodegen<PS> for MaxPool2dNode {
fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(&self.input)]
vec![Type::Tensor(self.input.clone())]
}
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(&self.output)]
vec![Type::Tensor(self.output.clone())]
}
fn field_type(&self) -> Option<Type> {
Some(Type::Other(&self.field))
Some(Type::Other(self.field.clone()))
}
fn field_init(&self, _with_record: bool) -> Option<TokenStream> {
@ -109,6 +109,8 @@ mod tests {
.with_padding(PaddingConfig2d::Valid),
));
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);
let expected = quote! {
use burn::{
module::Module,

View File

@ -13,11 +13,11 @@ pub struct ReshapeNode {
impl<PS: PrecisionSettings> NodeCodegen<PS> for ReshapeNode {
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(&self.output)]
vec![Type::Tensor(self.output.clone())]
}
fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(&self.input)]
vec![Type::Tensor(self.input.clone())]
}
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
@ -56,6 +56,8 @@ mod tests {
[4, 4, 4, 4].into(),
));
graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]);
let expected = quote! {
use burn::{
module::Module,
@ -63,11 +65,15 @@ mod tests {
};
#[derive(Module, Debug)]
pub struct Model <B: Backend>{}
pub struct Model<B: Backend> {
_phantom: core::marker::PhantomData<B>,
}
impl<B: Backend> Model <B> {
pub fn new_with(record: ModelRecord<B>) -> Self {
Self { }
pub fn new_with(_record: ModelRecord<B>) -> Self {
Self {
_phantom: core::marker::PhantomData,
}
}
#[allow(clippy::let_and_return)]
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {

View File

@ -1,5 +1,5 @@
use super::{Node, NodeCodegen};
use crate::burn::{Scope, TensorType, ToTokens, Type};
use crate::burn::{Scope, ToTokens, Type};
use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::quote;
@ -11,8 +11,8 @@ type FnPointer = Arc<dyn Fn(TokenStream) -> TokenStream>;
/// Node for all unary operators.
#[derive(Clone, new)]
pub struct UnaryNode {
pub input: TensorType,
pub output: TensorType,
pub input: Type,
pub output: Type,
pub kind: UnaryNodeKind,
function: FnPointer,
}
@ -20,20 +20,22 @@ pub struct UnaryNode {
/// Type of unary node.
#[derive(Clone)]
pub enum UnaryNodeKind {
Cast,
Flatten,
LogSoftmax,
Relu,
Sigmoid,
LogSoftmax,
Transpose,
}
impl UnaryNodeKind {
pub fn as_str(&self) -> &str {
match self {
Self::Cast => "cast",
Self::Flatten => "flatten",
Self::LogSoftmax => "log_softmax",
Self::Relu => "relu",
Self::Sigmoid => "sigmoid",
Self::LogSoftmax => "log_softmax",
Self::Transpose => "transpose",
}
}
@ -55,16 +57,26 @@ impl std::fmt::Debug for UnaryNode {
impl<PS: PrecisionSettings> NodeCodegen<PS> for UnaryNode {
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(&self.output)]
vec![self.output.clone()]
}
fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(&self.input)]
vec![self.input.clone()]
}
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let input = scope.tensor_use_owned(&self.input, node_position);
let output = &self.output.name;
// Get the lhs name in the form of token stream.
let input = match &self.input {
Type::Tensor(tensor) => scope.tensor_use_owned(tensor, node_position),
Type::Scalar(scalar) => {
let name = scalar.name.clone();
quote! { #name }
}
_ => panic!("lhs must be a tensor or scalar"),
};
// let input = scope.tensor_use_owned(&self.input, node_position);
let output = &self.output.name();
let function = (self.function)(input);
quote! {
@ -78,12 +90,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for UnaryNode {
}
impl UnaryNode {
pub(crate) fn flatten(
input: TensorType,
output: TensorType,
start_dim: usize,
end_dim: usize,
) -> Self {
pub(crate) fn flatten(input: Type, output: Type, start_dim: usize, end_dim: usize) -> Self {
let start_dim = start_dim.to_tokens();
let end_dim = end_dim.to_tokens();
let function = move |input| quote! { #input.flatten(#start_dim, #end_dim) };
@ -91,109 +98,193 @@ impl UnaryNode {
Self::new(input, output, UnaryNodeKind::Flatten, Arc::new(function))
}
pub(crate) fn relu(input: TensorType, output: TensorType) -> Self {
pub(crate) fn relu(input: Type, output: Type) -> Self {
let function = move |input| quote! { burn::tensor::activation::relu(#input) };
Self::new(input, output, UnaryNodeKind::Relu, Arc::new(function))
}
pub(crate) fn sigmoid(input: TensorType, output: TensorType) -> Self {
pub(crate) fn sigmoid(input: Type, output: Type) -> Self {
let function = move |input| quote! { burn::tensor::activation::sigmoid(#input) };
Self::new(input, output, UnaryNodeKind::Sigmoid, Arc::new(function))
}
pub(crate) fn log_softmax(input: TensorType, output: TensorType, dim: usize) -> Self {
pub(crate) fn log_softmax(input: Type, output: Type, dim: usize) -> Self {
let dim = dim.to_tokens();
let function = move |input| quote! { burn::tensor::activation::log_softmax(#input, #dim) };
Self::new(input, output, UnaryNodeKind::LogSoftmax, Arc::new(function))
}
pub(crate) fn transpose(input: TensorType, output: TensorType) -> Self {
pub(crate) fn transpose(input: Type, output: Type) -> Self {
let function = move |input| quote! { #input.transpose() };
Self::new(input, output, UnaryNodeKind::Transpose, Arc::new(function))
}
/// Casts the input to the output type.
///
/// Currently this function only supports the following conversions:
/// 1) scalar -> scalar
///
/// TODO: Implement the following conversions:
/// 2) tensor int -> tensor float
/// 3) tensor float -> tensor int
/// 4) tensor -> scalar
/// 5) scalar -> tensor
pub(crate) fn cast(input: Type, output: Type) -> Self {
let function = match output.clone() {
Type::Scalar(scalar) => {
let ty = scalar.ty();
move |input| quote! { #input as #ty }
}
Type::Tensor(_tensor) => {
// TODO: Implement this after tensor Int is implemented (@antimora 8/2/2023)
// TODO: If the input is scalar and the output type is a tensor,
// we should generate another code block. (@antimora 8/4/2023)
// Tensor::from_data(Data::from([#input]).convert()).unsqueeze();
todo!()
}
_ => panic!("output must be a tensor"),
};
Self::new(input, output, UnaryNodeKind::Cast, Arc::new(function))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::burn::node::tests::codegen_unary_operator;
use crate::burn::TensorType;
use crate::burn::node::tests::one_node_graph;
use crate::burn::{ScalarKind, ScalarType, TensorType};
#[test]
fn test_unary_codegen_flatten() {
codegen_unary_operator::<4, _>(
one_node_graph(
UnaryNode::flatten(
TensorType::new_float("tensor1", 4),
TensorType::new_float("tensor2", 4),
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_float("tensor2", 4)),
1,
2,
),
quote! {
let tensor2 = tensor1.flatten(1, 2);
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor2 = tensor1.flatten(1, 2);
tensor2
tensor2
}
},
vec!["tensor1".to_string()],
vec!["tensor2".to_string()],
);
}
#[test]
fn test_unary_codegen_relu() {
codegen_unary_operator::<4, _>(
one_node_graph(
UnaryNode::relu(
TensorType::new_float("tensor1", 4),
TensorType::new_float("tensor2", 4),
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_float("tensor2", 4)),
),
quote! {
let tensor2 = burn::tensor::activation::relu(tensor1);
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor2 = burn::tensor::activation::relu(tensor1);
tensor2
tensor2
}
},
vec!["tensor1".to_string()],
vec!["tensor2".to_string()],
);
}
#[test]
fn test_unary_codegen_sigmoid() {
codegen_unary_operator::<4, _>(
one_node_graph(
UnaryNode::sigmoid(
TensorType::new_float("tensor1", 4),
TensorType::new_float("tensor2", 4),
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_float("tensor2", 4)),
),
quote! {
let tensor2 = burn::tensor::activation::sigmoid(tensor1);
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor2 = burn::tensor::activation::sigmoid(tensor1);
tensor2
tensor2
}
},
vec!["tensor1".to_string()],
vec!["tensor2".to_string()],
);
}
#[test]
fn test_unary_codegen_log_softmax() {
codegen_unary_operator::<4, _>(
one_node_graph(
UnaryNode::log_softmax(
TensorType::new_float("tensor1", 4),
TensorType::new_float("tensor2", 4),
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_float("tensor2", 4)),
1,
),
quote! {
let tensor2 = burn::tensor::activation::log_softmax(tensor1, 1);
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor2 = burn::tensor::activation::log_softmax(tensor1, 1);
tensor2
tensor2
}
},
vec!["tensor1".to_string()],
vec!["tensor2".to_string()],
);
}
#[test]
fn test_unary_codegen_transpose() {
codegen_unary_operator::<4, _>(
one_node_graph(
UnaryNode::transpose(
TensorType::new_float("tensor1", 4),
TensorType::new_float("tensor2", 4),
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_float("tensor2", 4)),
),
quote! {
let tensor2 = tensor1.transpose();
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor2 = tensor1.transpose();
tensor2
tensor2
}
},
vec!["tensor1".to_string()],
vec!["tensor2".to_string()],
);
}
#[test]
fn test_unary_codegen_cast() {
one_node_graph(
UnaryNode::cast(
Type::Scalar(ScalarType::new("scalar1", ScalarKind::Float64)),
Type::Scalar(ScalarType::new("scalar2", ScalarKind::Float32)),
),
quote! {
pub fn forward(&self, scalar1: f64) -> f32 {
let scalar2 = scalar1 as f32;
scalar2
}
},
vec!["scalar1".to_string()],
vec!["scalar2".to_string()],
);
one_node_graph(
UnaryNode::cast(
Type::Scalar(ScalarType::new("scalar1", ScalarKind::Float32)),
Type::Scalar(ScalarType::new("scalar2", ScalarKind::Float64)),
),
quote! {
pub fn forward(&self, scalar1: f32) -> f64 {
let scalar2 = scalar1 as f64;
scalar2
}
},
vec!["scalar1".to_string()],
vec!["scalar2".to_string()],
);
}
}

View File

@ -10,64 +10,114 @@ pub struct TensorType {
pub name: Ident,
pub dim: usize,
pub kind: TensorKind,
pub shape: Option<Vec<usize>>,
}
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Copy)]
pub enum TensorKind {
Int,
Float,
Bool,
}
#[derive(Debug, Clone)]
pub enum ScalarKind {
Int32,
Int64,
Float32,
Float64,
Bool,
}
#[derive(Debug, Clone)]
pub struct ScalarType {
pub name: Ident,
pub kind: ScalarKind,
}
#[derive(Debug, Clone)]
pub struct OtherType {
pub name: Ident,
pub ty: TokenStream,
}
pub enum Type<'a> {
Tensor(&'a TensorType),
Other(&'a OtherType),
#[derive(Debug, Clone)]
pub enum Type {
/// Tensor type.
Tensor(TensorType),
/// Scalar type.
Scalar(ScalarType),
// Other type (more flexible type).
Other(OtherType),
}
impl<'a> Type<'a> {
impl Type {
pub fn name(&self) -> &Ident {
match self {
Type::Tensor(tensor) => &tensor.name,
Type::Scalar(scalar) => &scalar.name,
Type::Other(other) => &other.name,
}
}
pub fn ty(&self) -> TokenStream {
match self {
Type::Tensor(tensor) => tensor.ty(),
Type::Scalar(scalar) => scalar.ty(),
Type::Other(other) => other.ty(),
}
}
}
impl ScalarType {
pub fn new<S: AsRef<str>>(name: S, kind: ScalarKind) -> Self {
Self {
name: Ident::new(name.as_ref(), Span::call_site()),
kind,
}
}
pub fn ty(&self) -> TokenStream {
match self.kind {
ScalarKind::Int32 => quote! { i32 },
ScalarKind::Int64 => quote! { i64 },
ScalarKind::Float32 => quote! { f32 },
ScalarKind::Float64 => quote! { f64 },
ScalarKind::Bool => quote! { bool },
}
}
}
impl TensorType {
pub fn new<S: AsRef<str>>(name: S, dim: usize, kind: TensorKind) -> Self {
pub fn new<S: AsRef<str>>(
name: S,
dim: usize,
kind: TensorKind,
shape: Option<Vec<usize>>,
) -> Self {
Self {
name: Ident::new(name.as_ref(), Span::call_site()),
dim,
kind,
shape,
}
}
pub fn new_float<S: AsRef<str>>(name: S, dim: usize) -> Self {
Self::new(name, dim, TensorKind::Float)
Self::new(name, dim, TensorKind::Float, None)
}
pub fn new_int<S: AsRef<str>>(name: S, dim: usize) -> Self {
Self::new(name, dim, TensorKind::Int)
Self::new(name, dim, TensorKind::Int, None)
}
pub fn new_bool<S: AsRef<str>>(name: S, dim: usize) -> Self {
Self::new(name, dim, TensorKind::Bool)
Self::new(name, dim, TensorKind::Bool, None)
}
pub fn ty(&self) -> TokenStream {
let dim = self.dim.to_tokens();
// TODO use passed elem kind and do not assume float (@antimora 8/1/2023)
quote! {
Tensor<B, #dim>
}

View File

@ -1,8 +1,11 @@
use std::collections::HashMap;
use protobuf::Enum;
use super::{
ir::{ArgType, Argument, AttributeValue, Node, NodeType, TensorArg},
ir::{ArgType, Argument, AttributeValue, ElementType, Node, NodeType, TensorArg},
op_configuration::flatten_config,
protos::tensor_proto::DataType,
};
struct TensorDimUpdater {
@ -70,15 +73,13 @@ pub fn dim_inference(
NodeType::Sub => same_as_input(node),
NodeType::Pow => same_as_input(node),
NodeType::Mul => same_as_input(node),
NodeType::Cast => same_as_input(node),
NodeType::Cast => cast_update_outputs(node),
NodeType::Div => same_as_input(node),
NodeType::Sqrt => same_as_input(node),
NodeType::Softmax => same_as_input(node),
NodeType::Erf => same_as_input(node),
NodeType::ReduceMean => mean_update_outputs(node),
NodeType::Constant => {
node.outputs[0].ty = ArgType::Constant;
}
NodeType::Constant => constant_update_outputs(node),
NodeType::Equal => same_as_input(node),
NodeType::Shape => shape_update_outputs(node),
NodeType::Unsqueeze => unsqueeze_update_outputs(node),
@ -89,7 +90,9 @@ pub fn dim_inference(
NodeType::Concat => concat_update_outputs(node),
NodeType::Reshape => reshape_update_outputs(node),
NodeType::Dropout => same_as_input(node),
NodeType::GlobalAveragePool => same_as_input(node), //FIXME use correct output
//FIXME use correct output for GAP (@antimora 8/1/2023)
NodeType::GlobalAveragePool => same_as_input(node),
_ => todo!(
"shape inference for {:?} is not implemented",
node.node_type
@ -102,6 +105,20 @@ pub fn dim_inference(
updater.update_arguments(graph_outputs);
}
fn constant_update_outputs(node: &mut Node) {
// Fix the tensor dimension of the output when the value is tensor
let output = &mut node.outputs[0];
match node.attrs.get("value") {
Some(value) => match &value {
AttributeValue::Tensor(tensor) => {
output.ty = ArgType::Tensor(TensorArg { dim: tensor.dim });
}
_ => {}
},
None => panic!("Constant node must have a value attribute"),
};
}
/// Infer the shape of the output tensor of a Conv2d node
fn linear_update_outputs(node: &mut Node) {
if node.inputs.len() != 1 {
@ -119,6 +136,49 @@ fn linear_update_outputs(node: &mut Node) {
}
}
/// Update the output type using "to" attribute
fn cast_update_outputs(node: &mut Node) {
if node.inputs.len() != 1 {
panic!("Cast: multiple inputs are not supported");
}
let output = &mut node.outputs[0];
// Extract cast type and update the output tensor
let elem_type = match node.attrs.get("to") {
Some(value) => match &value {
AttributeValue::Int64(type_id) => match DataType::from_i32(*type_id as i32).unwrap() {
DataType::FLOAT => ElementType::Float32,
DataType::INT32 => ElementType::Int32,
DataType::INT64 => ElementType::Int64,
DataType::DOUBLE => ElementType::Float64,
_ => panic!("Cast: unsupported type"),
},
_ => panic!("'to' attribute must be an Int64"),
},
None => panic!("Constant node must have a value attribute"),
};
match output.ty.clone() {
ArgType::Tensor(tensor) => {
if tensor.dim == 0 {
// treat 0-dim tensor as scalar
output.ty = ArgType::Scalar(elem_type);
} else {
todo!("Cast: update tensor type");
// TODO track the type of the tensor elements (@antimora 8/1/2023)
// output.ty = ArgType::Tensor(TensorArg {
// dim: tensor.dim,
// elem_type,
// });
}
}
ArgType::Scalar(_scalar) => {
output.ty = ArgType::Scalar(elem_type);
}
_ => panic!("Only tensor input is valid"),
}
}
fn concat_update_outputs(node: &mut Node) {
let tensor = node
.inputs
@ -184,7 +244,7 @@ fn unsqueeze_update_outputs(node: &mut Node) {
let dim = match node_input.clone().ty {
ArgType::Tensor(tensor) => tensor.dim,
ArgType::Shape(dim) => dim,
ArgType::Constant => panic!("Needs shape or tensor"),
ArgType::Scalar(_) => panic!("Needs shape or tensor"),
};
node.outputs[0].ty = ArgType::Tensor(TensorArg { dim: dim + 1 });

View File

@ -36,6 +36,16 @@ pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph {
let onnx_model: ModelProto =
Message::parse_from_reader(&mut file).expect("Unable to parse ONNX file");
log::debug!("Number of nodes: {:?}", onnx_model.graph.node.len());
log::debug!("Number of inputs: {:?}", onnx_model.graph.input.len());
log::debug!(
"Number of initializers: {:?}",
onnx_model.graph.initializer.len()
);
log::debug!("Number of outputs: {:?}", onnx_model.graph.output.len());
// Convert the nodes
let mut nodes: Vec<Node> = vec![];
for onnx_node in onnx_model.graph.node.iter() {
@ -46,10 +56,7 @@ pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph {
move_inputs_to_state(&mut nodes, &onnx_model.graph.initializer);
// Get the topological sort of the nodes and the top nodes
let (ts, top_nodes) = get_top_nodes(&nodes);
// Sort the nodes
top_sort_nodes(&mut nodes, ts);
top_sort_nodes(&mut nodes);
// Collect inputs, outputs and initializers
let check_if_initializer: HashSet<String> = onnx_model
@ -58,7 +65,8 @@ pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph {
.iter()
.map(|x| x.name.clone())
.collect();
let mut inputs = collect_inputs(&onnx_model, &check_if_initializer, top_nodes);
let mut inputs = collect_inputs(&onnx_model, &check_if_initializer);
let mut outputs = collect_outputs(&onnx_model, check_if_initializer);
let states = collect_states(onnx_model);
@ -90,10 +98,7 @@ fn collect_states(onnx_model: ModelProto) -> Vec<State> {
for initializer in onnx_model.graph.initializer.iter() {
let tensor_proto = initializer.clone();
let name = tensor_proto.name.clone();
// FIXME data conversion for the tensor is incorrect
let tensor: Tensor = tensor_proto.try_into().unwrap();
let ty = StateType::Tensor(tensor);
let arg = State { name, ty };
@ -108,7 +113,6 @@ fn collect_outputs(
onnx_model: &ModelProto,
check_if_initializer: HashSet<String>,
) -> Vec<Argument> {
// TODO: filter out the outputs that are not used in the graph
let outputs: Vec<Argument> = onnx_model
.graph
.output
@ -123,42 +127,30 @@ fn collect_outputs(
fn collect_inputs(
onnx_model: &ModelProto,
check_if_initializer: &HashSet<String>,
top_nodes: HashSet<String>,
) -> Vec<Argument> {
// Get the unique inputs
let inputs: Vec<Argument> = onnx_model
.graph
.input
.iter()
.filter(|x| !check_if_initializer.contains(x.name.as_str()))
.filter(|x| top_nodes.contains(&x.name))
// .filter(|x| top_nodes.contains(&x.name))
.map(|x| Argument::try_from(x.clone()).unwrap())
.collect();
inputs
// Convert to a vector and return
inputs.into_iter().collect()
}
/// Sort the nodes in topological order
fn top_sort_nodes(nodes: &mut Vec<Node>, mut ts: TopologicalSort<Node>) {
fn top_sort_nodes(nodes: &mut Vec<Node>) {
let mut ts = topsort(nodes);
*nodes = vec![];
while let Some(node) = ts.pop() {
nodes.push(node);
}
}
/// Get the top nodes in the graph
fn get_top_nodes(nodes: &Vec<Node>) -> (TopologicalSort<Node>, HashSet<String>) {
// Get the names of the top nodes (first nodes in the graph to receive the input)
// Sometimes onnx will pass inputs to be used as weights and biases but they are not truly inputs
let ts = topsort(nodes);
let mut top_nodes: HashSet<String> = HashSet::new();
for node in ts.peek_all() {
for input in node.inputs.iter() {
top_nodes.insert(input.name.clone());
}
}
(ts, top_nodes)
}
fn to_string(bytes: Vec<u8>) -> String {
from_utf8(bytes.as_slice()).unwrap().to_string()
}

View File

@ -14,9 +14,9 @@ pub struct Argument {
#[derive(Debug, Clone)]
pub enum ArgType {
Tensor(TensorArg),
Scalar(ElementType),
Shape(usize),
Constant,
Tensor(TensorArg),
}
#[derive(new, Default, Debug, Clone)]
@ -142,6 +142,15 @@ impl core::hash::Hash for Argument {
}
}
impl Eq for Argument {}
// Required by HashSet
impl PartialEq for Argument {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
}
}
/// The list of supported node types (ONNX operators and some extra ones to map easily to Burn's ops)
#[derive(Debug, Hash, Eq, PartialEq, EnumString, Clone, Display)]
pub enum NodeType {

View File

@ -16,7 +16,7 @@ use crate::{
batch_norm::BatchNormNode,
binary::BinaryNode,
concat::ConcatNode,
constant::{ConstantNode, ConstantValue},
constant::{ConstantNode, ConstantValue, TensorValue},
conv2d::Conv2dNode,
linear::LinearNode,
matmul::MatmulNode,
@ -24,7 +24,7 @@ use crate::{
reshape::ReshapeNode,
unary::UnaryNode,
},
TensorType,
ScalarKind, ScalarType, TensorKind, TensorType, Type,
},
format_tokens,
logger::init_log,
@ -39,7 +39,7 @@ use crate::{
use super::{
from_onnx::parse_onnx,
ir::{ArgType, Argument, ONNXGraph, State, StateType, Tensor, TensorData},
ir::{ArgType, Argument, ElementType, ONNXGraph, State, StateType, Tensor, TensorData},
op_configuration::concat_config,
};
@ -98,6 +98,8 @@ impl ModelGen {
fn run(&self, is_build_script: bool) {
log::info!("Starting to convert ONNX to Burn");
log::info!("Starting to convert ONNX to Burn");
// prepend the out_dir to the cargo_out_dir if this is a build script
let out_dir = if is_build_script {
let cargo_out_dir = env::var("OUT_DIR").expect("OUT_DIR env is not set");
@ -112,6 +114,8 @@ impl ModelGen {
log::debug!("Output directory: {:?}", out_dir);
log::debug!("Output directory: {:?}", out_dir);
create_dir_all(&out_dir).unwrap();
for input in self.inputs.iter() {
@ -122,10 +126,16 @@ impl ModelGen {
log::debug!("Input file name: {:?}", file_name);
log::debug!("Output file: {:?}", out_file);
log::info!("Converting {:?}", input);
log::debug!("Input file name: {:?}", file_name);
log::debug!("Output file: {:?}", out_file);
Self::generate_model(self.development, input, out_file);
}
log::info!("Finished converting ONNX to Burn");
log::info!("Finished converting ONNX to Burn");
}
/// Generate model source code and model state.
@ -134,6 +144,10 @@ impl ModelGen {
log::debug!("Development mode: {:?}", development);
log::debug!("Output file: {:?}", out_file);
log::info!("Generating model from {:?}", input);
log::debug!("Development mode: {:?}", development);
log::debug!("Output file: {:?}", out_file);
let graph = parse_onnx(input.as_ref());
if development {
@ -161,6 +175,8 @@ impl ModelGen {
fs::write(out_file.with_extension("rs"), code_str).unwrap();
log::info!("Model generated");
log::info!("Model generated");
}
}
@ -186,64 +202,112 @@ impl ONNXGraph {
NodeType::Relu => graph.register(Self::relu_conversion(node)),
NodeType::Flatten => graph.register(Self::flatten_conversion(node)),
NodeType::LogSoftmax => graph.register(Self::log_softmax_conversion(node)),
NodeType::Constant => graph.register(Self::constant_conversion(node)),
NodeType::Constant => graph.register(Self::constant_conversion::<PS>(node)),
NodeType::Reshape => graph.register(Self::reshape_conversion(node)),
NodeType::Sigmoid => graph.register(Self::sigmoid_conversion(node)),
NodeType::Transpose => graph.register(Self::transpose_conversion(node)),
NodeType::Concat => graph.register(Self::concat_conversion(node)),
NodeType::Cast => graph.register(Self::cast_conversion(node)),
_ => panic!("Unsupported node conversion {}", node.node_type),
}
}
// Get input and output names
let input_names = self
.inputs
.iter()
.map(|input| input.name.clone())
.collect::<Vec<_>>();
let output_names = self
.outputs
.iter()
.map(|output| output.name.clone())
.collect::<Vec<_>>();
// Register inputs and outputs with the graph
graph.register_input_output(input_names, output_names);
graph
}
fn constant_conversion(mut node: Node) -> ConstantNode {
fn constant_conversion<PS: PrecisionSettings>(mut node: Node) -> ConstantNode<PS> {
let output = node.outputs.get(0).unwrap();
let value = node.attrs.remove("value").unwrap();
let value = match value {
AttributeValue::Float32(val) => ConstantValue::Float(val),
AttributeValue::Int64(val) => ConstantValue::Int(val as i32),
AttributeValue::Float32s(val) => ConstantValue::Float(val[0]),
AttributeValue::Int64s(val) => ConstantValue::Int(val[0] as i32),
_ => panic!("Unsupported constant node: {:?}", node),
AttributeValue::Float32(val) => ConstantValue::Float32(val),
AttributeValue::Int64(val) => ConstantValue::Int64(val),
AttributeValue::Tensor(tensor) => {
if tensor.dim == 0 {
// Treat zero dim tensor as scalar value by extracting the first element
// because PyTorch/ONNX uses zero dim tensor for scalar values
match tensor.data.unwrap() {
TensorData::Float32(val) => ConstantValue::Float32(val[0]),
TensorData::Float64(val) => ConstantValue::Float64(val[0]),
TensorData::Int32(val) => ConstantValue::Int32(val[0]),
TensorData::Int64(val) => ConstantValue::Int64(val[0]),
_ => panic!(
"Unsupported zero dim constant tensor type: {:?} ",
tensor.elem_type
),
}
} else {
let ds = match tensor.elem_type {
ElementType::Float32 | ElementType::Float64 => TensorValue::Float(
tensor.clone().into_data_serialize::<PS::FloatElem>(),
),
ElementType::Int32 | ElementType::Int64 => {
TensorValue::Int(tensor.clone().into_data_serialize::<PS::IntElem>())
}
_ => panic!("Unsupported constant tensor type: {:?} ", tensor.elem_type),
};
ConstantValue::<PS>::Tensor(
TensorType::new(
node.name.clone(),
tensor.dim,
tensor.elem_type.into(),
tensor.shape,
),
ds,
)
}
}
_ => panic!("Unsupported constant value: {:?} ", value),
};
ConstantNode::new(output.name.clone(), value)
ConstantNode::new(node.name.clone(), value, output.to_type())
}
fn add_conversion(node: Node) -> BinaryNode {
// FIXME scalar vs tensor
let lhs = node.inputs.get(0).unwrap().to_tensor_type();
let rhs = node.inputs.get(1).unwrap().to_tensor_type();
let output = node.outputs.get(0).unwrap().to_tensor_type();
let lhs = node.inputs.get(0).unwrap().to_type();
let rhs = node.inputs.get(1).unwrap().to_type();
let output = node.outputs.get(0).unwrap().to_type();
BinaryNode::add(lhs, rhs, output)
BinaryNode::add(lhs.clone(), rhs.clone(), output.clone())
}
fn sub_conversion(node: Node) -> BinaryNode {
// FIXME scalar vs tensor
let lhs = node.inputs.get(0).unwrap().to_tensor_type();
let rhs = node.inputs.get(1).unwrap().to_tensor_type();
let output = node.outputs.get(0).unwrap().to_tensor_type();
let lhs = node.inputs.get(0).unwrap().to_type();
let rhs = node.inputs.get(1).unwrap().to_type();
let output = node.outputs.get(0).unwrap().to_type();
BinaryNode::sub(lhs, rhs, output)
}
fn mul_conversion(node: Node) -> BinaryNode {
// FIXME scalar vs tensor
let lhs = node.inputs.get(0).unwrap().to_tensor_type();
let rhs = node.inputs.get(1).unwrap().to_tensor_type();
let output = node.outputs.get(0).unwrap().to_tensor_type();
let lhs = node.inputs.get(0).unwrap().to_type();
let rhs = node.inputs.get(1).unwrap().to_type();
let output = node.outputs.get(0).unwrap().to_type();
BinaryNode::mul(lhs, rhs, output)
}
fn div_conversion(node: Node) -> BinaryNode {
let lhs = node.inputs.get(0).unwrap().to_tensor_type();
let rhs = node.inputs.get(1).unwrap().to_tensor_type();
let output = node.outputs.get(0).unwrap().to_tensor_type();
let lhs = node.inputs.get(0).unwrap().to_type();
let rhs = node.inputs.get(1).unwrap().to_type();
let output = node.outputs.get(0).unwrap().to_type();
BinaryNode::div(lhs, rhs, output)
}
@ -257,35 +321,42 @@ impl ONNXGraph {
}
fn equal_conversion(node: Node) -> BinaryNode {
let lhs = node.inputs.get(0).unwrap().to_tensor_type();
let rhs = node.inputs.get(1).unwrap().to_tensor_type();
let output = node.outputs.get(0).unwrap().to_tensor_type();
let lhs = node.inputs.get(0).unwrap().to_type();
let rhs = node.inputs.get(1).unwrap().to_type();
let output = node.outputs.get(0).unwrap().to_type();
BinaryNode::equal(lhs, rhs, output)
}
fn relu_conversion(node: Node) -> UnaryNode {
let input = node.inputs.get(0).unwrap().to_tensor_type();
let output = node.outputs.get(0).unwrap().to_tensor_type();
let input = node.inputs.get(0).unwrap().to_type();
let output = node.outputs.get(0).unwrap().to_type();
UnaryNode::relu(input, output)
}
fn flatten_conversion(node: Node) -> UnaryNode {
let input = node.inputs.get(0).unwrap().to_tensor_type();
let output = node.outputs.get(0).unwrap().to_tensor_type();
let input = node.inputs.get(0).unwrap().to_type();
let output = node.outputs.get(0).unwrap().to_type();
let (start_dim, end_dim) = flatten_config(&node);
UnaryNode::flatten(input, output, start_dim, end_dim)
}
fn transpose_conversion(node: Node) -> UnaryNode {
let input = node.inputs.get(0).unwrap().to_tensor_type();
let output = node.outputs.get(0).unwrap().to_tensor_type();
let input = node.inputs.get(0).unwrap().to_type();
let output = node.outputs.get(0).unwrap().to_type();
UnaryNode::transpose(input, output)
}
fn cast_conversion(node: Node) -> UnaryNode {
let input = node.inputs.get(0).unwrap().to_type();
let output = node.outputs.get(0).unwrap().to_type();
UnaryNode::cast(input, output)
}
fn reshape_conversion(mut node: Node) -> ReshapeNode {
let input = node.inputs.get(0).unwrap().to_tensor_type();
let output = node.outputs.get(0).unwrap().to_tensor_type();
@ -299,15 +370,15 @@ impl ONNXGraph {
}
fn sigmoid_conversion(node: Node) -> UnaryNode {
let input = node.inputs.get(0).unwrap().to_tensor_type();
let output = node.outputs.get(0).unwrap().to_tensor_type();
let input = node.inputs.get(0).unwrap().to_type();
let output = node.outputs.get(0).unwrap().to_type();
UnaryNode::sigmoid(input, output)
}
fn log_softmax_conversion(node: Node) -> UnaryNode {
let input = node.inputs.get(0).unwrap().to_tensor_type();
let output = node.outputs.get(0).unwrap().to_tensor_type();
let input = node.inputs.get(0).unwrap().to_type();
let output = node.outputs.get(0).unwrap().to_type();
let dim = log_softmax_config(&node);
UnaryNode::log_softmax(input, output, dim)
@ -419,8 +490,54 @@ impl Argument {
pub fn to_tensor_type(&self) -> TensorType {
match &self.ty {
ArgType::Tensor(tensor) => TensorType::new_float(self.name.clone(), tensor.dim),
_ => panic!("Can't transform to tensor."),
}
}
pub fn to_type(&self) -> Type {
match &self.ty {
ArgType::Tensor(tensor) => {
// Treat tensor with dim 0 as scalar
if tensor.dim == 0 {
// FIXME Convert to correct scalar type (@antimora 8/1/2023)
// Currently it's not dangerous because we don't use specific scalar type
Type::Scalar(ScalarType::new(self.name.clone(), ScalarKind::Float64))
} else {
Type::Tensor(TensorType::new_float(self.name.clone(), tensor.dim))
}
}
ArgType::Scalar(elem_type) => {
Type::Scalar(ScalarType::new(self.name.clone(), elem_type.into()))
}
ArgType::Shape(_shape) => panic!("Can't transform shape to tensor."),
ArgType::Constant => panic!("Can't transform constant to tensor."),
}
}
}
impl From<&ElementType> for ScalarKind {
fn from(elem_type: &ElementType) -> Self {
match elem_type {
ElementType::Float32 => ScalarKind::Float32,
ElementType::Float64 => ScalarKind::Float64,
ElementType::Int32 => ScalarKind::Int32,
ElementType::Int64 => ScalarKind::Int64,
ElementType::Bool => ScalarKind::Bool,
ElementType::String => panic!("String tensor unsupported"),
ElementType::Float16 => panic!("Float16 tensor unsupported"),
}
}
}
impl From<ElementType> for TensorKind {
fn from(elem_type: ElementType) -> Self {
match elem_type {
ElementType::Float32 => TensorKind::Float,
ElementType::Float64 => TensorKind::Float,
ElementType::Int32 => TensorKind::Int,
ElementType::Int64 => TensorKind::Int,
ElementType::Bool => TensorKind::Bool,
_ => panic!("Unsupported tensor type"),
}
}
}

View File

@ -1,31 +0,0 @@
// Generated by integration tests
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct Model<B: Backend> {}
impl<B: Backend> Model<B> {
pub fn new_with(record: ModelRecord<B>) -> Self {
Self {}
}
#[allow(clippy::let_and_return)]
pub fn forward(&self, input1: Tensor<B, 4>, input1: Tensor<B, 4>) -> Tensor<B, 4> {
let concat1_out1 = burn::tensor::Tensor::cat(vec![input1.clone(), input1.clone()], 1);
let concat2_out1 = burn::tensor::Tensor::cat(
vec![
input1.clone(),
concat1_out1.clone(),
concat1_out1.clone(),
concat1_out1.clone(),
concat1_out1,
],
1,
);
concat2_out1
}
}

View File

@ -1,36 +0,0 @@
# used to generate model: burn-import/tests/data/conv2d/conv2d.onnx
import torch
import torch.nn as nn
import onnx
from onnxoptimizer import optimize
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(16, 36, (3, 5), groups = 2, stride=(2, 1), padding=(4, 2), dilation=(3, 1))
def forward(self, x):
x = self.conv1(x)
return x
def main():
# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
dummy_input = torch.randn(20, 16, 50, 100, device=device)
torch.onnx.export(model, dummy_input, "conv2d.onnx",
verbose=False, opset_version=16)
# Apply the optimization pass to simplify the model
onnx_model = onnx.load("conv2d.onnx")
optimized_model = optimize(onnx_model)
# Save the optimized model
onnx.save(optimized_model, "conv2d.onnx")
if __name__ == '__main__':
main()

View File

@ -1,17 +0,0 @@
# Model1 test data files
This directory contains the test data for the model1 test. The test data is generated by running the
following command:
```bash
python3 model1.py
cargo run model1.onnx ./
```
The following files are generated:
- `model1.onnx`: The ONNX model
- `model1.rs`: The generated Rust code for the model (the path in the comment needs to be fixed for
the test)
- `model1.json`: The data of the model
- `model1.graph.txt`: The IR of the model

View File

@ -1,46 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import onnx
from onnxoptimizer import optimize
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(1, 8, 3)
self.norm1 = nn.BatchNorm2d(8)
self.fc1 = nn.Linear(8*6*6, 10)
self.norm2 = nn.BatchNorm1d(10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.norm1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = self.norm2(x)
output = F.log_softmax(x, dim=1)
return output
def main():
# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
dummy_input = torch.randn(1, 1, 8, 8, device=device)
torch.onnx.export(model, dummy_input, "model1.onnx",
verbose=False, opset_version=16)
# Apply the optimization pass to simplify the model
onnx_model = onnx.load("model1.onnx")
optimized_model = optimize(onnx_model)
# Save the optimized model
onnx.save(optimized_model, "model1.onnx")
if __name__ == '__main__':
main()

View File

@ -1,62 +0,0 @@
// Generated by integration tests
use burn::nn::conv::Conv2d;
use burn::nn::conv::Conv2dConfig;
use burn::nn::BatchNorm;
use burn::nn::BatchNormConfig;
use burn::nn::Linear;
use burn::nn::LinearConfig;
use burn::nn::PaddingConfig2d;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
conv2d1: Conv2d<B>,
batchnormalization1: BatchNorm<B, 2>,
linear1: Linear<B>,
batchnormalization2: BatchNorm<B, 0>,
}
impl<B: Backend> Model<B> {
pub fn new_with(record: ModelRecord<B>) -> Self {
let conv2d1 = Conv2dConfig::new([1, 8], [3, 3])
.with_stride([1, 1])
.with_padding(PaddingConfig2d::Valid)
.with_dilation([1, 1])
.with_groups(1)
.with_bias(true)
.init_with(record.conv2d1);
let batchnormalization1 = BatchNormConfig::new(8)
.with_epsilon(0.000009999999747378752f64)
.with_momentum(0.8999999761581421f64)
.init_with(record.batchnormalization1);
let linear1 = LinearConfig::new(288, 10)
.with_bias(true)
.init_with(record.linear1);
let batchnormalization2 = BatchNormConfig::new(10)
.with_epsilon(0.000009999999747378752f64)
.with_momentum(0.8999999761581421f64)
.init_with(record.batchnormalization2);
Self {
conv2d1,
batchnormalization1,
linear1,
batchnormalization2,
}
}
#[allow(clippy::let_and_return)]
pub fn forward(&self, input1: Tensor<B, 4>) -> Tensor<B, 2> {
let conv2d1_out1 = self.conv2d1.forward(input1);
let relu1_out1 = burn::tensor::activation::relu(conv2d1_out1);
let batchnormalization1_out1 = self.batchnormalization1.forward(relu1_out1);
let flatten1_out1 = batchnormalization1_out1.flatten(1, 3);
let linear1_out1 = self.linear1.forward(flatten1_out1);
let batchnormalization2_out1 = self.batchnormalization2.forward(linear1_out1);
let logsoftmax1_out1 = burn::tensor::activation::log_softmax(batchnormalization2_out1, 1);
logsoftmax1_out1
}
}

View File

@ -1,42 +0,0 @@
#[cfg(test)]
#[cfg(feature = "onnx")]
mod tests {
use std::fs::read_to_string;
use std::path::Path;
use burn::record::FullPrecisionSettings;
use pretty_assertions::assert_eq;
use rstest::*;
fn code<P: AsRef<Path>>(onnx_path: P) -> String {
let graph = burn_import::onnx::parse_onnx(onnx_path.as_ref());
let graph = graph
.into_burn::<FullPrecisionSettings>()
.with_blank_space(true)
.with_top_comment(Some("Generated by integration tests".into()));
burn_import::format_tokens(graph.codegen())
}
#[rstest]
#[case::mixed("model1")]
#[case::conv2d("conv2d")]
#[case::concat("concat")]
// #[case::description_here("model2")] <- Add more models here
fn test_codegen(#[case] model_name: &str) {
let input_file = format!("tests/data/{model_name}/{model_name}.onnx");
let source_file = format!("tests/data/{model_name}/{model_name}.rs");
let source_expected: String =
read_to_string(source_file).expect("Expected source file is missing");
let generated_code = code(input_file);
// Uncomment this to update the expected code
// println!("Generated code:\n{}", generated_code);
assert_eq!(
source_expected, generated_code,
"Expected code is left, actual code is right"
);
}
}