File size: 28,036 Bytes
0558aa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import os
import shutil
import tarfile
import tempfile
from pathlib import Path
from time import time
from typing import List

import tensorrt as trt
import torch
import yaml
from omegaconf import OmegaConf
from PIL import Image
from tensorrt_llm._common import check_max_num_tokens
from tensorrt_llm.builder import BuildConfig, Builder
from tensorrt_llm.commands.build import build as build_trtllm
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models import MLLaMAForCausalLM
from tensorrt_llm.plugin import PluginConfig
from transformers import AutoModel, AutoProcessor, MllamaForConditionalGeneration

from nemo.collections.multimodal.speech_llm.modules.perception_modules import AudioPerceptionModule
from nemo.core.classes.common import typecheck
from nemo.export.tensorrt_llm import TensorRTLLM
from nemo.export.trt_llm.nemo_ckpt_loader.nemo_file import load_nemo_model

from .converter import convert_mllama_nemo_to_hf

logger = trt.Logger(trt.Logger.INFO)


def build_trtllm_engine(
    model_dir: str,
    visual_checkpoint_path: str,
    llm_checkpoint_path: str = None,
    model_type: str = "neva",
    llm_model_type: str = "llama",
    tensor_parallelism_size: int = 1,
    max_input_len: int = 256,
    max_output_len: int = 256,
    max_batch_size: int = 1,
    max_multimodal_len: int = 1024,
    dtype: str = "bfloat16",
    use_lora_plugin: str = None,
    lora_target_modules: List[str] = None,
    max_lora_rank: int = 64,
    lora_ckpt_list: List[str] = None,
):
    """Build TRTLLM engine by nemo export"""
    trt_llm_exporter = TensorRTLLM(model_dir=model_dir, lora_ckpt_list=lora_ckpt_list, load_model=False)
    trt_llm_exporter.export(
        nemo_checkpoint_path=visual_checkpoint_path if llm_checkpoint_path is None else llm_checkpoint_path,
        model_type=llm_model_type,
        tensor_parallelism_size=tensor_parallelism_size,
        max_input_len=max_input_len,
        max_output_len=max_output_len,
        max_seq_len=max_input_len + max_output_len,
        max_batch_size=max_batch_size,
        max_prompt_embedding_table_size=max_multimodal_len,
        dtype=dtype,
        load_model=False,
        use_lora_plugin=use_lora_plugin,
        lora_target_modules=lora_target_modules,
        max_lora_rank=max_lora_rank,
        use_mcore_path=False,
    )


def build_mllama_trtllm_engine(
    model_dir: str,
    hf_model_path: str,
    tensor_parallelism_size: int = 1,
    max_input_len: int = 256,
    max_output_len: int = 256,
    max_batch_size: int = 1,
    max_multimodal_len: int = 1024,
    dtype: str = "bfloat16",
    use_lora_plugin: str = None,
    lora_target_modules: List[str] = None,
    max_lora_rank: int = 64,
    lora_ckpt_list: List[str] = None,
):
    """Build mllama TRTLLM engine from HF"""
    if max_batch_size < 4:
        print(
            "TensorRT LLM may hit a runtime issue with batch size is smaller than 4 on some models." " Force set to 4"
        )
        max_batch_size = 4

    plugin_config = PluginConfig()
    plugin_config.gpt_attention_plugin = "auto"
    plugin_config.gemm_plugin = "auto"
    plugin_config.enable_paged_kv_cache(tokens_per_block=128)
    plugin_config.remove_input_padding = True
    plugin_config.use_paged_context_fmha = True

    max_seq_len = max_input_len + max_output_len
    max_num_tokens, opt_num_tokens = check_max_num_tokens(
        max_num_tokens=None,
        opt_num_tokens=None,
        max_seq_len=max_seq_len,
        max_batch_size=max_batch_size,
        max_input_len=max_input_len,
        max_beam_width=1,
        remove_input_padding=True,
        enable_context_fmha=plugin_config.context_fmha,
        tokens_per_block=128,
        multiple_profiles=False,
    )

    build_dict = {
        'max_input_len': max_input_len,
        'max_output_len': max_output_len,
        'max_encoder_input_len': max_multimodal_len,
        'max_batch_size': max_batch_size,
        'max_beam_width': 1,
        'max_seq_len': max_seq_len,
        'max_num_tokens': max_num_tokens,
        'opt_num_tokens': opt_num_tokens,
        'strongly_typed': True,
        'builder_opt': None,
    }
    build_config = BuildConfig.from_dict(build_dict, plugin_config=plugin_config)

    for rank in range(tensor_parallelism_size):
        mapping = Mapping(world_size=tensor_parallelism_size, rank=rank, tp_size=tensor_parallelism_size)
        model = MLLaMAForCausalLM.from_hugging_face(
            hf_model_path,
            dtype,
            mapping=mapping,
        )

        engine = build_trtllm(model, build_config)
        engine.save(model_dir)


def export_visual_wrapper_onnx(
    visual_wrapper, input, output_dir, input_names=['input'], dynamic_axes={'input': {0: 'batch'}}
):
    """Export visual wrapper to ONNX"""
    logger.log(trt.Logger.INFO, "Exporting onnx")
    os.makedirs(f'{output_dir}/onnx', exist_ok=True)
    torch.onnx.export(
        visual_wrapper,
        input,
        f'{output_dir}/onnx/visual_encoder.onnx',
        opset_version=17,
        input_names=input_names,
        output_names=['output'],
        dynamic_axes=dynamic_axes,
    )


def export_perception_wrapper_onnx(
    perception_wrapper,
    input,
    output_dir,
    input_names=['processed_signal', 'processed_signal_length'],
    output_names=['encoded', 'encoded_length'],
    dynamic_axes={
        'processed_signal': {0: 'batch', 2: 'time'},
        'processed_signal_length': {0: 'batch'},
        'encoded': {0: 'batch', 1: 'time'},
        'encoded_length': {0: 'batch'},
    },
):
    """Export perception wrapper to ONNX"""
    logger.log(trt.Logger.INFO, "Exporting onnx")
    os.makedirs(f'{output_dir}/onnx', exist_ok=True)
    torch.onnx.export(
        perception_wrapper,
        input,
        f'{output_dir}/onnx/perception_encoder.onnx',
        opset_version=17,
        input_names=input_names,
        output_names=output_names,
        dynamic_axes=dynamic_axes,
    )


def build_trt_engine(
    model_type,
    input_sizes,
    output_dir,
    vision_max_batch_size,
    dtype=torch.bfloat16,
    image_size=None,
    num_frames=None,
    nemo_config=None,
    part_name='visual_encoder',
):
    """Build TRT engine from onnx"""
    onnx_file = '%s/onnx/%s.onnx' % (output_dir, part_name)
    engine_file = '%s/%s.engine' % (output_dir, part_name)
    config_file = '%s/%s' % (output_dir, "config.json")
    nemo_config_file = '%s/%s' % (output_dir, "nemo_config.yaml")

    with open(nemo_config_file, 'w') as f:
        yaml.dump(nemo_config, f)

    logger.log(trt.Logger.INFO, "Building TRT engine for %s" % part_name)

    builder = trt.Builder(logger)
    network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    profile = builder.create_optimization_profile()

    config_args = {"precision": str(dtype).split('.')[-1], "model_type": model_type}
    if image_size is not None:
        config_args["image_size"] = image_size
    if num_frames is not None:
        config_args["num_frames"] = num_frames

    config_wrapper = Builder().create_builder_config(**config_args)
    config = config_wrapper.trt_builder_config

    parser = trt.OnnxParser(network, logger)

    with open(onnx_file, 'rb') as model:
        if not parser.parse(model.read(), os.path.abspath(onnx_file)):
            logger.log(trt.Logger.ERROR, "Failed parsing %s" % onnx_file)
            for error in range(parser.num_errors):
                logger.log(trt.Logger.ERROR, parser.get_error(error))
        logger.log(trt.Logger.INFO, "Succeeded parsing %s" % onnx_file)

    # Delete onnx files since we don't need them now
    shutil.rmtree(f'{output_dir}/onnx')

    nBS = -1
    nMinBS = 1
    nOptBS = max(nMinBS, int(vision_max_batch_size / 2))
    nMaxBS = vision_max_batch_size

    inputT = network.get_input(0)

    # input sizes can be a list of ints (e.g., [3, H, W]) when inputs are images,
    # or a list of three int lists (e.g., [[1, 1, 2700], [1, 500, 2700], [1, 4096, 2700]]).
    # or a list of three list of lists
    # (e.g., [{input1: min_shape, input2: min_shape, }, \
    #     {input1: opt_shape, input2: opt_shape}, \
    # {input1: max_shape, input2: max_shape}] )
    assert isinstance(input_sizes, list), "input_sizes must be a list"
    if isinstance(input_sizes[0], int):
        logger.log(trt.Logger.INFO, f"Processed input sizes {input_sizes}")
        inputT.shape = [nBS, *input_sizes]
        min_size = opt_size = max_size = input_sizes
    elif len(input_sizes) == 3 and isinstance(input_sizes[0], list):
        min_size, opt_size, max_size = input_sizes
        logger.log(trt.Logger.INFO, f"Processed min/opt/max input sizes {min_size}/{opt_size}/{max_size}")
    elif len(input_sizes) == 3 and isinstance(input_sizes[0], dict):
        logger.log(trt.Logger.INFO, f"Processed min/opt/max input sizes {input_sizes}")
    else:
        raise ValueError(f"invalid input sizes: {input_sizes}")

    if isinstance(input_sizes[0], dict):
        for i in range(network.num_inputs):
            inputT = network.get_input(i)
            input_name = inputT.name
            min_size = input_sizes[0][input_name]
            opt_size = input_sizes[1][input_name]
            max_size = input_sizes[2][input_name]
            logger.log(trt.Logger.INFO, f"{input_name} min/opt/max input sizes {min_size}/{opt_size}/{max_size}")
            profile.set_shape(input_name, min_size, opt_size, max_size)
    else:
        profile.set_shape(inputT.name, [nMinBS, *min_size], [nOptBS, *opt_size], [nMaxBS, *max_size])

    config.add_optimization_profile(profile)

    t0 = time()
    engine_string = builder.build_serialized_network(network, config)
    t1 = time()
    if engine_string is None:
        raise RuntimeError("Failed building %s" % (engine_file))
    else:
        logger.log(trt.Logger.INFO, "Succeeded building %s in %d s" % (engine_file, t1 - t0))
        with open(engine_file, 'wb') as f:
            f.write(engine_string)

    Builder.save_config(config_wrapper, config_file)


def build_neva_engine(
    model_type: str,
    model_dir: str,
    visual_checkpoint_path: str,
    vision_max_batch_size: int = 1,
):
    """Build neva visual engine"""
    device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

    if os.path.isdir(visual_checkpoint_path):
        # load untar checkpoint
        config_path = os.path.join(visual_checkpoint_path, 'model_config.yaml')
        with open(config_path, 'r') as f:
            nemo_config = yaml.safe_load(f)
        try:
            weights_path = os.path.join(visual_checkpoint_path, 'model_weights.ckpt')
            mp0_weights = torch.load(weights_path, map_location=device)
        except FileNotFoundError:
            weights_path = os.path.join(visual_checkpoint_path, 'mp_rank_00/model_weights.ckpt')
            mp0_weights = torch.load(weights_path, map_location=device)
    else:
        # extract NeMo checkpoint
        with tempfile.TemporaryDirectory() as temp:
            temp_path = Path(temp)
            mp0_weights, nemo_config, _ = load_nemo_model(visual_checkpoint_path, temp_path)

    vision_config = nemo_config["mm_cfg"]["vision_encoder"]

    class DownSampleBlock(torch.nn.Module):
        # pylint: disable=C0115,C0116
        def forward(self, x):
            vit_embeds = x
            h = w = int(vit_embeds.shape[1] ** 0.5)
            vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
            vit_embeds = self.flat_square(vit_embeds)
            vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
            return vit_embeds

        def flat_square(self, x):
            n, w, h, c = x.size()
            if w % 2 == 1:
                x = torch.cat([x, torch.zeros((n, 1, h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
                n, w, h, c = x.size()
            if h % 2 == 1:
                x = torch.cat([x, torch.zeros((n, w, 1, c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
                n, w, h, c = x.size()
            x = x.view(n, w, int(h / 2), int(c * 2))
            x = x.permute(0, 2, 1, 3).contiguous()
            x = x.view(n, int(h / 2), int(w / 2), int(c * 4))
            return x

    class VisionEncoderWrapper(torch.nn.Module):
        # pylint: disable=C0115,C0116
        def __init__(self, encoder, connector):
            super().__init__()
            self.encoder = encoder
            self.connector = connector

        def forward(self, images):
            vision_x = self.encoder(pixel_values=images, output_hidden_states=True)
            vision_x = vision_x.hidden_states[-2]
            vision_x = self.connector(vision_x)
            return vision_x

    encoder = AutoModel.from_pretrained(
        vision_config["from_pretrained"],
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        attn_implementation='eager',
    )
    vision_encoder = encoder.vision_model
    hf_config = encoder.config
    dtype = hf_config.torch_dtype

    # connector
    if nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "mlp2x_gelu":
        vision_connector = torch.nn.Sequential(
            torch.nn.Linear(vision_config["hidden_size"], nemo_config["hidden_size"], bias=True),
            torch.nn.GELU(),
            torch.nn.Linear(nemo_config["hidden_size"], nemo_config["hidden_size"], bias=True),
        ).to(dtype=dtype)

        key_prefix = "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector"
        for layer in range(0, 3, 2):
            vision_connector[layer].load_state_dict(
                {
                    'weight': mp0_weights[f"{key_prefix}.{layer}.weight"].to(dtype),
                    'bias': mp0_weights[f"{key_prefix}.{layer}.bias"].to(dtype),
                }
            )
    elif nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "linear":
        vision_connector = torch.nn.Linear(vision_config["hidden_size"], nemo_config["hidden_size"], bias=True)
        key_prefix = "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector"
        vision_connector.load_state_dict(
            {
                'weight': mp0_weights[f"{key_prefix}.weight"].to(dtype),
                'bias': mp0_weights[f"{key_prefix}.bias"].to(dtype),
            }
        )
    elif nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "mlp_downsample":
        vision_connector = torch.nn.Sequential(
            DownSampleBlock(),
            torch.nn.LayerNorm(vision_config["hidden_size"] * 4),
            torch.nn.Linear(vision_config["hidden_size"] * 4, nemo_config["hidden_size"], bias=True),
            torch.nn.GELU(),
            torch.nn.Linear(nemo_config["hidden_size"], nemo_config["hidden_size"], bias=True),
        ).to(dtype=dtype)
        key_prefix = "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector"
        for layer in [1, 2, 4]:
            vision_connector[layer].load_state_dict(
                {
                    'weight': mp0_weights[f"{key_prefix}.{layer}.weight"].to(dtype),
                    'bias': mp0_weights[f"{key_prefix}.{layer}.bias"].to(dtype),
                }
            )

    else:
        raise ValueError(f"Unknown projector type: {nemo_config['mm_cfg']['mm_mlp_adapter_type']}")

    # export the whole wrapper
    lita_num_frames = None
    wrapper = VisionEncoderWrapper(vision_encoder, vision_connector).to(device, dtype)
    if model_type == "lita" or model_type == "vila":
        image_size = hf_config.image_size
        if model_type == "lita":
            lita_num_frames = nemo_config['mm_cfg']['lita']['sample_frames']
    else:
        image_size = hf_config.vision_config.image_size
        if model_type == "vita":
            lita_num_frames = nemo_config['mm_cfg']['lita']['sample_frames']
    dummy_image = torch.empty(
        1, 3, image_size, image_size, dtype=dtype, device=device
    )  # dummy image shape [B, C, H, W]

    export_visual_wrapper_onnx(wrapper, dummy_image, model_dir)
    build_trt_engine(
        model_type,
        [3, image_size, image_size],
        model_dir,
        vision_max_batch_size,
        dtype,
        image_size=image_size,
        num_frames=lita_num_frames if model_type == "lita" or model_type == 'vita' else None,
        nemo_config=nemo_config,
    )


def build_video_neva_engine(
    model_dir: str,
    visual_checkpoint_path: str,
    vision_max_batch_size: int = 1,
):
    """Build video neva visual engine"""
    device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
    # extract NeMo checkpoint
    with tarfile.open(visual_checkpoint_path) as tar:
        nemo_config = yaml.safe_load(tar.extractfile("./model_config.yaml"))
        try:
            # trained without TP
            mp0_weights = torch.load(tar.extractfile("./model_weights.ckpt"), map_location=device)
        except KeyError:
            # trained with TP
            mp0_weights = torch.load(tar.extractfile("./mp_rank_00/model_weights.ckpt"), map_location=device)

    vision_config = nemo_config["mm_cfg"]["vision_encoder"]

    class VisionEncoderWrapper(torch.nn.Module):
        # pylint: disable=C0115,C0116
        def __init__(self, encoder, connector):
            super().__init__()
            self.encoder = encoder
            self.connector = connector

        def forward(self, images):
            b, num_frames, c, h, w = images.shape
            images = images.view(b * num_frames, c, h, w)
            vision_x = self.encoder(pixel_values=images, output_hidden_states=True)  # [(B num_frames), C, H, W]
            vision_x = vision_x.hidden_states[-2]
            vision_x = vision_x[:, 1:]

            # reshape back to [B, num_frames, img_size, hidden_size]
            vision_x = vision_x.view(b, num_frames, -1, vision_x.shape[-1])

            vision_x = self.connector(vision_x)
            return vision_x

    encoder = AutoModel.from_pretrained(
        vision_config["from_pretrained"],
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        attn_implementation='eager',
    )
    vision_encoder = encoder.vision_model
    hf_config = encoder.config
    dtype = hf_config.torch_dtype

    # connector
    assert nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "linear"
    vision_connector = torch.nn.Linear(vision_config["hidden_size"], nemo_config["hidden_size"], bias=True)

    key_prefix = "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector"
    vision_connector.load_state_dict(
        {
            'weight': mp0_weights[f"{key_prefix}.weight"].to(dtype),
            'bias': mp0_weights[f"{key_prefix}.bias"].to(dtype),
        }
    )

    # export the whole wrapper
    wrapper = VisionEncoderWrapper(vision_encoder, vision_connector).to(device, dtype)
    image_size = hf_config.vision_config.image_size
    num_frames = nemo_config['data']['num_frames']
    dummy_video = torch.empty(1, num_frames, 3, image_size, image_size, dtype=dtype, device=device)  # dummy image
    export_visual_wrapper_onnx(wrapper, dummy_video, model_dir)
    build_trt_engine(
        "video-neva",
        [num_frames, 3, image_size, image_size],  # [num_frames, 3, H, W]
        model_dir,
        vision_max_batch_size,
        dtype,
        image_size=image_size,
        num_frames=num_frames,
    )


def build_perception_engine(
    model_dir: str,
    perception_checkpoint_path: str,
    model_type: str = "salm",
    max_batch_size: int = 1,
):
    """Build perception engine"""
    assert model_type == "salm", f"Invalid model type {model_type}"

    def load_perception_model(perception_checkpoint_path):
        weights = "model_weights.ckpt"
        perception_state_dict = torch.load(os.path.join(perception_checkpoint_path, weights))
        config = "model_config.yaml"
        config = OmegaConf.load(os.path.join(perception_checkpoint_path, config))
        perception = AudioPerceptionModule(cfg=config)
        perception.load_state_dict(perception_state_dict)
        perception.eval()
        return perception

    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    # load perception model
    perception_model = load_perception_model(perception_checkpoint_path)
    feature_extractor = perception_model.preprocessor
    input_signal = torch.randn(1, 1000, dtype=torch.float32)
    input_signal_length = torch.tensor([1000], dtype=torch.int32)

    processed_signal, processed_signal_length = feature_extractor(
        input_signal=input_signal, length=input_signal_length
    )
    processed_signal_length = processed_signal_length.to(torch.int32)
    dump_path = model_dir + "/feature_extractor.ts"  # dump the feature extractor as torchscript
    feature_extractor.export(dump_path, (input_signal, input_signal_length))

    class PerceptionWrapper(torch.nn.Module):
        # pylint: disable=C0115,C0116
        def __init__(self, encoder, modality_adapter, proj):
            super().__init__()
            self.encoder = encoder
            self.modality_adapter = modality_adapter
            self.proj = proj

        @typecheck.disable_checks()
        def forward(self, processed_signal, processed_signal_length):
            encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_length)
            encoded, encoded_len = self.modality_adapter(audio_signal=encoded, length=encoded_len)
            # b, c, t -> b, t, c
            encoded = self.proj(encoded.transpose(1, 2))
            encoded_len = encoded_len.to(torch.int32)
            return encoded, encoded_len

    perception = PerceptionWrapper(perception_model.encoder, perception_model.modality_adapter, perception_model.proj)
    export_perception_wrapper_onnx(perception, (processed_signal, processed_signal_length), model_dir)
    # export the onnx perception model to tensorrt engine
    # 512 -> 5.12 sec, 3072 -> 30.72 sec
    opt_batch_size = max(1, max_batch_size // 2)
    shapes = [
        {"processed_signal": [1, 80, 64], "processed_signal_length": [1]},
        {"processed_signal": [opt_batch_size, 80, 512], "processed_signal_length": [opt_batch_size]},
        {"processed_signal": [max_batch_size, 80, 3072], "processed_signal_length": [max_batch_size]},
    ]
    build_trt_engine(
        model_type,
        shapes,
        model_dir,
        max_batch_size,
        dtype=torch.float16,
        nemo_config=None,
        part_name='perception_encoder',
    )


def build_mllama_visual_engine(
    model_dir: str,
    hf_model_path: str,
    processor_name: str = "meta-llama/Llama-3.2-11B-Vision-Instruct",
    vision_max_batch_size: int = 1,
):
    """Build mllama visual engine"""
    hf_model = MllamaForConditionalGeneration.from_pretrained(hf_model_path, torch_dtype="auto", device_map="auto")
    model_dtype = hf_model.dtype

    class MLLaMAVisionWrapper(torch.nn.Module):
        # pylint: disable=C0115,C0116
        def __init__(self, vision_model, output_proj):
            super().__init__()
            self.vision_model = vision_model
            self.output_proj = output_proj

        def forward(self, pixel_values, aspect_ratio_ids, aspect_ratio_mask):
            out = self.vision_model(pixel_values, aspect_ratio_ids, aspect_ratio_mask).last_hidden_state
            out = self.output_proj(out)
            return out

    wrapper = MLLaMAVisionWrapper(hf_model.vision_model, hf_model.multi_modal_projector)

    processor = AutoProcessor.from_pretrained(processor_name)
    image = Image.new('RGB', [2048, 2688])
    inputs = processor(images=image, return_tensors="pt").to(model_dtype)

    export_visual_wrapper_onnx(
        wrapper,
        tuple([value for _, value in inputs.items()]),
        model_dir,
        input_names=[key for key in inputs],
        dynamic_axes={key: {0: "batch"} for key in inputs},
    )
    shapes = [{k: list(v.shape) for k, v in inputs.items()}] * 3
    shapes[2] = shapes[0].copy()
    for k, v in shapes[2].items():
        shapes[2][k] = [vision_max_batch_size] + v[1:]
    build_trt_engine("mllama", shapes, model_dir, vision_max_batch_size, model_dtype)


def build_visual_engine(
    model_dir: str,
    visual_checkpoint_path: str,
    model_type: str = "neva",
    vision_max_batch_size: int = 1,
):
    """Build visual engine"""
    model_list = ['neva', 'lita', 'vila', 'vita']
    if model_type in model_list:
        build_neva_engine(model_type, model_dir, visual_checkpoint_path, vision_max_batch_size)
    elif model_type == "video-neva":
        build_video_neva_engine(model_dir, visual_checkpoint_path, vision_max_batch_size)
    else:
        raise RuntimeError(f"Invalid model type {model_type}")


def extract_lora_ckpt(
    lora_ckpt: str,
    output_dir: str,
):
    """Extrace lora from checkpoint"""
    if os.path.exists(os.path.join(lora_ckpt, "model_weights.ckpt")):
        model_weight = torch.load(os.path.join(lora_ckpt, "model_weights.ckpt"))
    elif os.path.exists(os.path.join(lora_ckpt, "mp_rank_00", "model_weights.ckpt")):
        model_weight = torch.load(os.path.join(lora_ckpt, "mp_rank_00", "model_weights.ckpt"))
    else:
        raise RuntimeError("Imcompatible lora checkpoint format")

    model_config = os.path.join(lora_ckpt, "model_config.yaml")

    if not os.path.exists(model_config):
        raise RuntimeError("Imcompatible lora checkpoint format")

    llm_lora_weight = {}

    for k, v in model_weight.items():
        if "mm_projector" not in k:
            llm_lora_weight[k] = v

    llm_lora_path = os.path.join(output_dir, "llm_lora.nemo")
    with tempfile.TemporaryDirectory() as tmp_dir:
        llm_weight_path = os.path.join(tmp_dir, "model_weights.ckpt")
        torch.save(llm_lora_weight, llm_weight_path)

        with tarfile.open(llm_lora_path, "w") as tar:
            tar.add(llm_weight_path, arcname="model_weights.ckpt")
            tar.add(model_config, arcname="model_config.yaml")

    return llm_lora_path


def build_mllama_engine(
    model_dir: str,
    checkpoint_path: str,
    processor_name: str = "meta-llama/Llama-3.2-11B-Vision-Instruct",
    vision_max_batch_size: int = 1,
    tensor_parallelism_size: int = 1,
    max_input_len: int = 256,
    max_output_len: int = 256,
    max_batch_size: int = 1,
    max_multimodal_len: int = 1024,
    dtype: str = "bfloat16",
    use_lora_plugin: str = None,
    lora_target_modules: List[str] = None,
    max_lora_rank: int = 64,
    lora_ckpt_list: List[str] = None,
):
    """Build mllama engine"""
    new_state_dict, config = convert_mllama_nemo_to_hf(checkpoint_path, processor_name)

    hf_model = MllamaForConditionalGeneration(config)
    hf_model = hf_model.to(torch.bfloat16)
    hf_model.load_state_dict(new_state_dict)

    with tempfile.TemporaryDirectory() as tmp_dir:
        hf_model_path = os.path.join(tmp_dir, "hf_checkpoint")
        hf_model.save_pretrained(hf_model_path)
        del hf_model, new_state_dict

        build_mllama_visual_engine(
            os.path.join(model_dir, "visual_engine"),
            hf_model_path,
            vision_max_batch_size=vision_max_batch_size,
        )
        build_mllama_trtllm_engine(
            os.path.join(model_dir, "llm_engine"),
            hf_model_path,
            tensor_parallelism_size,
            max_input_len,
            max_output_len,
            max_batch_size,
            max_multimodal_len,
            dtype,
        )