"""Convert DFT operations in an ONNX model to equivalent MatMul operations.""" import onnxscript import onnx_ir as ir import numpy as np class ReplaceDftWithMatMulRule(onnxscript.rewriter.RewriteRuleClassBase): def pattern(self, op, x, dft_length): x = op.Reshape(x, _allow_other_inputs=True) dft = op.DFT(x, dft_length, _outputs=["dft_output"]) return dft def rewrite(self, op, x: ir.Value, dft_length: ir.Value, dft_output: ir.Value): # Get the DFT node attributes dft_node = dft_output.producer() assert dft_node is not None dft_size = ir.convenience.get_const_tensor(dft_length).numpy().item() # Create one-sided DFT matrices (real and imaginary parts, DC to Nyquist) # Real part: Re(DFT[k]) = sum(x[n] * cos(2*pi*k*n/N)) # Imaginary part: Im(DFT[k]) = sum(x[n] * -sin(2*pi*k*n/N)) # For one-sided DFT, we only need frequencies from 0 to Nyquist (dft_size//2 + 1) num_freqs = dft_size // 2 + 1 # Vectorized creation of DFT matrices n = np.arange(dft_size, dtype=np.float32)[:, np.newaxis] # Shape: (dft_size, 1) k = np.arange(num_freqs, dtype=np.float32)[ np.newaxis, : ] # Shape: (1, num_freqs) # Real part (cosine) dft_matrix_real = np.cos( 2 * np.pi * k * n / dft_size ) # Shape: (dft_size, num_freqs) # Imaginary part (negative sine) dft_matrix_imag = -np.sin( 2 * np.pi * k * n / dft_size ) # Shape: (dft_size, num_freqs) # Stack real and imaginary parts: shape (dft_size, num_freqs * 2) # Interleave real and imaginary: [real_0, imag_0, real_1, imag_1, ...] dft_matrix = np.stack([dft_matrix_real, dft_matrix_imag], axis=-1).reshape( dft_size, num_freqs * 2 ) # Create constant node for the combined DFT matrix dft_matrix = op.initializer(ir.tensor(dft_matrix), name=f"{x.name}_dft_matrix") # Single matrix multiplication matmul_result = op.MatMul(x, dft_matrix) new_shape = op.initializer( ir.tensor([-1, 500, 513, 2], name=f"{x.name}_dft_reshaped_shape") ) result = op.Reshape(matmul_result, new_shape) return result model = ir.load("perch_v2.onnx") onnxscript.rewriter.rewrite( model, [ReplaceDftWithMatMulRule().rule()], ) onnxscript.optimizer.optimize(model) ir.save(model, "perch_v2_no_dft.onnx")