Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug(onnx): model created using tf2onnx panics with non-valid Ident #2878

Open
martinjrobins opened this issue Mar 8, 2025 · 5 comments
Open
Labels
bug Something isn't working onnx

Comments

@martinjrobins
Copy link

martinjrobins commented Mar 8, 2025

Describe the bug
I have a equinox jax model I want to import into burn. I used jax2tf to get a tensorflow model, then tf2onnx to obtain the onnx file. When I try to import this into burn it panics with the message:

ERROR burn_import::logger: PANIC => panicked at /home/mrobins/.cargo/registry/src/index.crates.io-1949cf8c6b5b557f/burn-import-0.16.0/src/burn/ty.rs:123:19:
  "jax2tf_rhs_/pjit_silu_/Const_2:0" is not a valid Ident   

I'd imagine it is the ":" or "/" characters in the ident, which seem to be used to identify blocks and outputs in the model. Is this panic expected or a bug?

To Reproduce

You can get the onnx file here: https://github.com/martinjrobins/diffsol/raw/refs/heads/workspace/examples/neural-ode-weather-prediction/rhs.onnx

Then I read it in as per the onnx example: https://burn.dev/burn-book/import/onnx-model.html

Expected behavior

The model to import without panic

Desktop (please complete the following information):

  • OS: Ubuntu
  • Browser chrome
  • Version 22.04

Additional context
Using burn v0.16.0

@martinjrobins
Copy link
Author

Following on from this, I removed all the special characters from the onnx graph nodes manually, then had an issue that the Split node was not supported in burn v0.16.0, so I upgraded to v0.17.0 using the main branch of this repo.

However, the model still fails to import with the following error:

DEBUG burn_import::burn::graph: Building the scope nodes len => '26'    
  ERROR burn_import::logger: PANIC => panicked at /home/mrobins/git/burn/crates/burn-import/src/formatter.rs:8:31:
  Valid token tree: BadSourceCode("error: expected type, found `{`\n --> <stdin>:1:1443\n  |\n1 | ... > { # [allow (unused_variables)] pub fn new (device : & B :: Device) -> Self { Self { phantom : core :: marker :: PhantomData , device : burn :: module :: Ignored (device . clone ()) , } } _blank_ ! () ; # [allow (clippy :: let_and_return , clippy :: approx_constant)] pub fn forward (& self , input1 : Tensor < B , 1 > , input2 : Tensor < B , 1 > ,) -> { let mut split_tensors = input1 . split_with_sizes ([128 , 64 , 2048 , 32 , 64 , 2 ,] , 0) ; let [split1_out1 , split1_out2 , split1_out3 , split1_out4 , split1_out5 , split1_out6] = split_tensors . try_into () . unwrap () ; let unsqueeze1_out1 : Tensor < B , 2 > = input2 . unsqueeze_dims (& [1 ,]) ; let reshape1_out1 = split1_out5 . reshape ([2 , 32 ,]) ; let reshape2_out1 = split1_out1 . reshape ([64 , 2 ,]) ; let matmul1_out1 = reshape2_out1 . matmul (unsqueeze1_out1) ; let squeeze1_out1 = matmul1_out1 . squeeze_dims (& []) ; let add1_out1 = squeeze1_out1 . add (split1_out2) ; let neg1_out1 = add1_out1 . clone () . neg () ; let exp1_out1 = neg1_out1 . exp () ; let add2_out1 = exp1_out1 . add_scalar (jax2tf_rhs__pjit_silu__Const_0) ; let div1_out1 = jax2tf_rhs__pjit_silu__Const_0 / add2_out1 ; let mul1_out1 = add1_out1 . mul_scalar (div1_out1) ; let unsqueeze2_out1 : Tensor < B , 2 > = mul1_out1 . unsqueeze_dims (& [1 ,]) ; let reshape3_out1 = split1_out3 . reshape ([32 , 64 ,]) ; let matmul2_out1 = reshape3_out1 . matmul (unsqueeze2_out1) ; let squeeze2_out1 = matmul2_out1 . squeeze_dims (& []) ; let add3_out1 = squeeze2_out1 . add (split1_out4) ; let neg2_out1 = add3_out1 . clone () . neg () ; let exp2_out1 = neg2_out1 . exp () ; let add4_out1 = exp2_out1 . add_scalar (jax2tf_rhs__pjit_silu__Const_0) ; let div2_out1 = jax2tf_rhs__pjit_silu__Const_0 / add4_out1 ; let mul2_out1 = add3_out1 . mul_scalar (div2_out1) ; let unsqueeze3_out1 : Tensor < B , 2 > = mul2_out1 . unsqueeze_dims (& [1 ,]) ; let matmul3_out1 = reshape1_out1 . matmul (unsqueeze3_out1) ; let squeeze3_out1 = matmul3_out1 . squeeze_dims (& []) ; let add5_out1 = squeeze3_out1 . add (split1_out6) ; } }\n  |       - while parsing this item list starting here                                                                                                                                                                                                                                                                                                                    ^ expected type                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 - the item list ends here\n\n")    

@martinjrobins
Copy link
Author

ok, after a little more digging I've found that:

  • the original model (linked above) is parsed and run with no issues using candle-onnx, so I believe it to be a correct model
  • my manually edited model I mentioned above is, however, incorrect. It seems like the ":0" syntax is necessary to distinguish different outputs from a node, I have a Split node that has a few different outputs. However if ":0" or ":1" is left in the onnx file then burn panics with the original error

I note that the Split node was only added to burn 3 weeks ago, so perhaps it was not neccessary to support multiple outputs when importing models previously?

@nathanielsimard nathanielsimard added the bug Something isn't working label Mar 10, 2025
@laggui laggui added the onnx label Mar 11, 2025
@laggui
Copy link
Member

laggui commented Mar 11, 2025

Sorry for the delayed response!

You can get the onnx file here: https://github.com/martinjrobins/diffsol/raw/refs/heads/workspace/examples/neural-ode-weather-prediction/rhs.onnx

The link is invalid, but there seems to be a script to generate it here?

Could you share the ONNX model in this issue? Might have to zip it for github to accept the upload.

I'd imagine it is the ":" or "/" characters in the ident, which seem to be used to identify blocks and outputs in the model. Is this panic expected or a bug?

If this is valid ONNX, then this is a bug. During code generation we parse the names and format them for variables but looks like not all symbols are handled.

I note that the Split node was only added to burn 3 weeks ago, so perhaps it was not neccessary to support multiple outputs when importing models previously?

As you noted, support for the split node was just added recently. Perhaps the implementation does not handle the whole specification. If you include the model I could take a look!

@martinjrobins
Copy link
Author

Yes, that script will generate the models, but I'll also attach it below. Thanks for looking into this :)

rhs_model.zip

@laggui
Copy link
Member

laggui commented Mar 11, 2025

Ok so my initial hypothesis was correct.

If this is valid ONNX, then this is a bug. During code generation we parse the names and format them for variables but looks like not all symbols are handled.

The name formatting really only handles alphanumeric values.

pub fn format_name(name: &str) -> String {
let name_is_number = name.bytes().all(|digit| digit.is_ascii_digit());
if name_is_number {
format!("_{}", name)
} else {
name.to_string()
}
}

If we change the last line to replace the invalid ident values in your model, it correctly parses the model.

name.to_string().replace(":", "_").replace("/", "_") // replace ":" -> "_", "/" -> "_"

But I get another error:

ERROR burn_import::logger: PANIC => panicked at /home/laggui/workspace/burn/crates/burn-import/src/burn/node/binary.rs:158:18:
  Division is supported for tensor and scalar only

In your case, the model actually has a lhs scalar and rhs tensor, which is not handled.

A quick fix

pub(crate) fn div(lhs: Type, rhs: Type, output: Type) -> Self {
        let function = match (&lhs, &rhs) {
            (Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.div(#rhs) },
            (Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.div_scalar(#rhs) },
            (Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs / #rhs },
            (Type::Scalar(_), Type::Tensor(_)) => {
                move |lhs, rhs| quote! { #rhs.recip().mul_scalar(#lhs) }
            }
            _ => panic!("Division is supported for tensor and scalar only"),
        };

        Self::new(lhs, rhs, output, BinaryType::Div, Arc::new(function))
    }

And finally, the codegen worked.. but it's not correct 😓

error[E0425]: cannot find value `jax2tf_rhs__pjit_silu__Const_1_0` in this scope
  --> /home/laggui/workspace/my_burn_app/target/debug/build/my_burn_app-42b3bff62a6673e2/out/model/rhs.rs:55:46
   |
55 | ...scalar(jax2tf_rhs__pjit_silu__Const_1_0);
   |           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ not found in this scope

error[E0425]: cannot find value `jax2tf_rhs__pjit_silu__Const_1_0` in this scope
  --> /home/laggui/workspace/my_burn_app/target/debug/build/my_burn_app-42b3bff62a6673e2/out/model/rhs.rs:58:25
   |
58 |             .mul_scalar(jax2tf_rhs__pjit_silu__Const_1_0);
   |                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ not found in this scope

error[E0425]: cannot find value `jax2tf_rhs__pjit_silu__Const_1_0` in this scope
  --> /home/laggui/workspace/my_burn_app/target/debug/build/my_burn_app-42b3bff62a6673e2/out/model/rhs.rs:66:46
   |
66 | ...scalar(jax2tf_rhs__pjit_silu__Const_1_0);
   |           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ not found in this scope

error[E0425]: cannot find value `jax2tf_rhs__pjit_silu__Const_1_0` in this scope
  --> /home/laggui/workspace/my_burn_app/target/debug/build/my_burn_app-42b3bff62a6673e2/out/model/rhs.rs:69:25
   |
69 |             .mul_scalar(jax2tf_rhs__pjit_silu__Const_1_0);
   |                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ not found in this scope

error[E0308]: mismatched types
    --> /home/laggui/workspace/my_burn_app/target/debug/build/my_burn_app-42b3bff62a6673e2/out/model/rhs.rs:44:57
     |
44   | ...= input1.split_with_sizes([256, 64, 2048, 32, 128, 4], 0);
     |             ---------------- ^^^^^^^^^^^^^^^^^^^^^^^^^^^- help: try using a conversion method: `.to_vec()`
     |             |                |
     |             |                expected `Vec<usize>`, found `[{integer}; 6]`
     |             arguments to this method are incorrect
     |
     = note: expected struct `Vec<usize>`
                 found array `[{integer}; 6]`
note: method defined here
    --> /home/laggui/workspace/burn/crates/burn-tensor/src/tensor/api/base.rs:1360:12
     |
1360 |     pub fn split_with_sizes(self, split_sizes: Vec<usize>, dim: usize) ...
     |            ^^^^^^^^^^^^^^^^

The SplitNode codegen seems to be incorrect here (fix should be easy as it currently passes an array instead of a vec), but then we will get stuck at the constant values as per #1882.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working onnx
Projects
None yet
Development

No branches or pull requests

3 participants