Spaces:
Sleeping
Sleeping
from typing import Literal, Optional | |
import fire | |
from packaging.version import Version | |
from ..pip_utils import is_installed, run_pip, version | |
import platform | |
def get_cuda_version_from_torch() -> Optional[Literal["11", "12"]]: | |
try: | |
import torch | |
except ImportError: | |
return None | |
return torch.version.cuda.split(".")[0] | |
def install(cu: Optional[Literal["11", "12"]] = get_cuda_version_from_torch()): | |
if cu is None or cu not in ["11", "12"]: | |
print("Could not detect CUDA version. Please specify manually.") | |
return | |
print("Installing TensorRT requirements...") | |
if is_installed("tensorrt"): | |
if version("tensorrt") < Version("9.0.0"): | |
run_pip("uninstall -y tensorrt") | |
cudnn_name = f"nvidia-cudnn-cu{cu}==8.9.4.25" | |
if not is_installed("tensorrt"): | |
run_pip(f"install {cudnn_name} --no-cache-dir") | |
run_pip( | |
"install --pre --extra-index-url https://pypi.nvidia.com tensorrt==9.0.1.post11.dev4 --no-cache-dir" | |
) | |
if not is_installed("polygraphy"): | |
run_pip( | |
"install polygraphy==0.47.1 --extra-index-url https://pypi.ngc.nvidia.com" | |
) | |
if not is_installed("onnx_graphsurgeon"): | |
run_pip( | |
"install onnx-graphsurgeon==0.3.26 --extra-index-url https://pypi.ngc.nvidia.com" | |
) | |
if platform.system() == 'Windows' and not is_installed("pywin32"): | |
run_pip( | |
"install pywin32" | |
) | |
pass | |
if __name__ == "__main__": | |
fire.Fire(install) | |