subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
try:
import tensorrt as trt
from polygraphy.backend.trt import CreateConfig, Profile, engine_from_network, network_from_onnx_path, save_engine
HAVE_TRT = True
except (ImportError, ModuleNotFoundError):
HAVE_TRT = False
def build_engine(
onnx_path,
output_path,
fp16,
input_profile=None,
enable_refit=False,
enable_preview=False,
timing_cache=None,
workspace_size=0,
int8=False,
builder_optimization_level=None,
):
print(f"Building TensorRT engine for {onnx_path}: {output_path}")
p = Profile()
if input_profile:
for name, dims in input_profile.items():
assert len(dims) == 3
p.add(name, min=dims[0], opt=dims[1], max=dims[2])
preview_features = None
config_kwargs = {}
if workspace_size > 0:
config_kwargs["memory_pool_limits"] = {trt.MemoryPoolType.WORKSPACE: workspace_size}
engine = engine_from_network(
network_from_onnx_path(onnx_path),
config=CreateConfig(
fp16=fp16,
refittable=enable_refit,
profiles=[p],
preview_features=preview_features,
load_timing_cache=timing_cache,
int8=int8,
builder_optimization_level=builder_optimization_level,
**config_kwargs,
),
save_timing_cache=timing_cache,
)
save_engine(engine, path=output_path)