"""Cleanup and optimize perch_v2_slim.onnx model. This script can be applied after completing these steps: 1. Use `tf2onnx` to convert the tflite model to onnx 2. Apply onnxslim and onnxscript.optimize.optimizer on the model 3. Manually edit the model to remove the first DFT node (no-op) and fuse the nodes that effectively takes the magnitude of the DFT output with ReduceL2. """ import onnx_ir as ir import onnx_ir.passes.common import onnxscript import numpy as np m = ir.load("perch_v2_slim.onnx") for node in m.graph: if node.op_type == "MatMul": print("Simplify MatMul + Reshape:", node.name) if node.inputs[0].producer().op_type == "Reshape": # Skip the reshape input = node.inputs[0].producer().inputs[0] node.replace_input_with(0, input) for usage in node.outputs[0].uses(): if usage.node.op_type == "Reshape": reshape_usages = list(usage.node.outputs[0].uses()) # Keep the last Reshape if reshape_usages[0].node.op_type == "ReduceMax": shape = ir.val( "reshape_shape", const_value=ir.tensor([-1, 16, 4, 14795, 4]) ) m.graph.initializers.add(shape) usage.node.replace_input_with(1, shape) continue reshape_node = usage.node output = reshape_node.outputs[0] output.replace_all_uses_with(node.outputs[0]) # Remove Expand if node.op_type == "Expand": print("Remove Expand:", node.name) input = node.inputs[0] output = node.outputs[0] output.replace_all_uses_with(input) # Clean up any unused nodes onnx_ir.passes.common.RemoveUnusedNodesPass()(m) # Do some const folding onnxscript.optimizer.optimize( m, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024 ) one_1d = ir.val("1d_one", const_value=ir.tensor([1], dtype=ir.DataType.INT64)) m.graph.initializers.add(one_1d) # Simplify Unsqueeze + Reshape for node in m.graph: if node.op_type == "Reshape": print("Simplify Unsqueeze + Reshape:", node.name) if ( node.inputs[0].producer() and node.inputs[0].producer().op_type == "Unsqueeze" ): unsqueeze_node = node.inputs[0].producer() unsqueeze_node.replace_input_with(1, one_1d) node.outputs[0].replace_all_uses_with(unsqueeze_node.outputs[0]) unsqueeze_node.outputs[0].shape = ir.Shape(["batch", 160000, 1]) first_reshape_shape = ir.val( "first_reshape_shape", const_value=ir.tensor([-1, 1, 160000, 1]) ) m.graph.initializers.add(first_reshape_shape) # Simplify first Reshape + Unsqueeze for node in m.graph: if node.op_type == "Unsqueeze": print("Simplify Reshape + Unsqueeze:", node.name) if node.inputs[0].producer() and node.inputs[0].producer().op_type == "Reshape": reshape_node = node.inputs[0].producer() reshape_node.replace_input_with(1, first_reshape_shape) node.outputs[0].replace_all_uses_with(reshape_node.outputs[0]) reshape_node.outputs[0].shape = ir.Shape(["batch", 1, 160000, 1]) break # Fuse Conv + Sub into Conv for node in m.graph: if node.op_type == "Conv": print("Check Conv for fusion:", node.name) conv_node = node assert len(conv_node.outputs[0].uses()) == 1 for usage in conv_node.outputs[0].uses(): if usage.node.op_type == "Sub": sub_node = usage.node print(" Fuse Sub into Conv:", sub_node.name) sub_value = sub_node.inputs[1] new_bias = (np.negative(sub_value.const_value.numpy())).reshape((-1,)) new_bias_val = ir.val( f"{sub_value.name}_neg", const_value=ir.tensor(new_bias), ) m.graph.initializers.add(new_bias_val) if len(conv_node.inputs) == 2: # Bad access of private field conv_node._inputs = conv_node._inputs + (None,) conv_node.replace_input_with(2, new_bias_val) sub_node.outputs[0].replace_all_uses_with(conv_node.outputs[0]) # Clean up any unused nodes onnx_ir.passes.common.RemoveUnusedNodesPass()(m) # Clear all intermediate shapes and re-infer shapes for node in m.graph: for output in node.outputs: if output.is_graph_output(): continue output.shape = None m.graph.inputs[0].shape = ir.Shape(["batch", *m.graph.inputs[0].shape[1:]]) for output in m.graph.outputs: output.shape = ir.Shape(["batch", *output.shape[1:]]) onnxscript.optimizer.optimize( m, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024 ) onnx_ir.passes.common.ClearMetadataAndDocStringPass()(m) # Replace None dim with "batch" for node in m.graph: for output in node.outputs: if output.shape is None: continue shape = ir.Shape(output.shape) for i in range(len(shape)): dim = shape[i] if isinstance(dim, ir.SymbolicDim) and dim.value is None: shape[i] = ir.SymbolicDim("batch") output.shape = shape # Rename IO and match the tflite model m.graph.inputs[0].name = "inputs" m.graph.outputs[0].name = "spatial_embedding" m.graph.outputs[1].name = "embedding" m.graph.outputs[2].name = "spectrogram" m.graph.outputs[3].name = "label" out_0 = m.graph.outputs[0] out_1 = m.graph.outputs[1] m.graph.outputs[1] = out_0 m.graph.outputs[0] = out_1 m.producer_name = "onnx-ir" m.producer_version = None m.ir_version = 10 ir.save(m, "perch_v2.onnx")