Skip to content

Commit 1fe3a5a

Browse files
feat: add fp16 inference support (torch/onnx) (#871)
* feat: add fp16 inference in clip_torch * Revert "feat: add fp16 inference in clip_torch" This reverts commit 326e265. * feat: add fp16 inference in clip_torch * fix: device * fix: str to torch.dtype * fix: layernorm * feat: add fp16 inference in clip_trt * feat: add fp16 inference in clip_onnx * fix: housekeeping * fix: ci * fix: ci * fix: ci * fix: ci and get test path * fix: dtype amp and gpu test dependency * fix: layernorm * fix: cast dtype in visiontransformer * fix: clip_onnx * fix: clip_onnx * fix: convert onnx to fp16 * fix: dtype in preproc images * fix: dtype in preproc images * fix: typo * fix: dtype in clip_torch and fp16 in trt * fix: remove plain text in trt_test * fix: test * fix: typo * fix: stash * Revert "fix: stash" This reverts commit f72fd99. * fix: for test * fix: onnx * fix: for test * fix: for test * fix: trt * fix: convert onnx to fp16 before convert trt * fix: discard changes in trt * fix: optimize fp16 test * fix: move __cast_dtype__ * Revert "fix: move __cast_dtype__" This reverts commit edf4629. * fix: ci
1 parent fd16e5a commit 1fe3a5a

9 files changed

Lines changed: 82 additions & 21 deletions

File tree

.github/workflows/ci.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,11 @@ jobs:
113113
pip install --no-cache-dir "server/[onnx]"
114114
pip install --no-cache-dir "server/[transformers]"
115115
pip install --no-cache-dir "server/[search]"
116-
pip install --no-cache-dir "server/[transformers]"
117116
- name: Test
118117
id: test
119118
run: |
120119
pytest --suppress-no-test-exit-code --cov=clip_client --cov=clip_server --cov-report=xml \
121-
-v -s -m "not gpu" ${{ matrix.test-path }}
120+
-v -s ${{ matrix.test-path }}
122121
echo "::set-output name=codecov_flag::cas"
123122
timeout-minutes: 30
124123
- name: Check codecov file
@@ -158,6 +157,7 @@ jobs:
158157
python -m pip install wheel pytest pytest-cov nvidia-pyindex
159158
pip install -e "client/[test]"
160159
pip install -e "server/[tensorrt]"
160+
pip install -e "server/[onnx]"
161161
{
162162
pip install -e "server/[flash-attn]"
163163
} || {
@@ -168,6 +168,8 @@ jobs:
168168
run: |
169169
pytest --suppress-no-test-exit-code --cov=clip_client --cov=clip_server --cov-report=xml \
170170
-v -s -m "gpu" ./tests/test_tensorrt.py
171+
pytest --suppress-no-test-exit-code --cov=clip_client --cov=clip_server --cov-report=xml \
172+
-v -s -m "gpu" ./tests/test_simple.py
171173
echo "::set-output name=codecov_flag::cas"
172174
timeout-minutes: 30
173175
env:

server/clip_server/executors/clip_onnx.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def __init__(
2727
minibatch_size: int = 32,
2828
access_paths: str = '@r',
2929
model_path: Optional[str] = None,
30+
dtype: Optional[str] = None,
3031
**kwargs,
3132
):
3233
"""
@@ -41,8 +42,17 @@ def __init__(
4142
:param model_path: The path to the model to be used. If not specified, the model will be downloaded or loaded
4243
from the local cache. Visit https://clip-as-service.jina.ai/user-guides/server/#use-custom-model-for-onnx
4344
to learn how to finetune custom models.
45+
:param dtype: inference data type, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
4446
"""
4547
super().__init__(**kwargs)
48+
import torch
49+
50+
if not device:
51+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
52+
self._device = device
53+
if not dtype:
54+
dtype = 'fp32' if self._device in ('cpu', torch.device('cpu')) else 'fp16'
55+
self._dtype = dtype
4656

4757
self._minibatch_size = minibatch_size
4858
self._access_paths = access_paths
@@ -55,18 +65,11 @@ def __init__(
5565
self._num_worker_preprocess = num_worker_preprocess
5666
self._pool = ThreadPool(processes=num_worker_preprocess)
5767

58-
self._model = CLIPOnnxModel(name, model_path)
68+
self._model = CLIPOnnxModel(name, model_path, dtype)
5969
self._tokenizer = Tokenizer(name)
6070

6171
self._image_transform = clip._transform_ndarray(self._model.image_size)
6272

63-
import torch
64-
65-
if not device:
66-
self._device = 'cuda' if torch.cuda.is_available() else 'cpu'
67-
else:
68-
self._device = device
69-
7073
# define the priority order for the execution providers
7174
providers = ['CPUExecutionProvider']
7275

@@ -116,6 +119,7 @@ def _preproc_images(self, docs: 'DocumentArray', drop_image_content: bool):
116119
preprocess_fn=self._image_transform,
117120
return_np=True,
118121
drop_image_content=drop_image_content,
122+
dtype=self._dtype,
119123
)
120124

121125
def _preproc_texts(self, docs: 'DocumentArray'):

server/clip_server/executors/clip_torch.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import warnings
33
from functools import partial
44
from multiprocessing.pool import ThreadPool
5-
from typing import Dict, Optional
5+
from typing import Dict, Union, Optional
66

77
import numpy as np
88
import torch
@@ -12,6 +12,7 @@
1212
set_rank,
1313
split_img_txt_da,
1414
)
15+
from clip_server.helper import __cast_dtype__
1516
from clip_server.model import clip
1617
from clip_server.model.clip_model import CLIPModel
1718
from clip_server.model.tokenization import Tokenizer
@@ -28,6 +29,7 @@ def __init__(
2829
num_worker_preprocess: int = 4,
2930
minibatch_size: int = 32,
3031
access_paths: str = '@r',
32+
dtype: Optional[Union[str, torch.dtype]] = None,
3133
**kwargs,
3234
):
3335
"""
@@ -40,6 +42,7 @@ def __init__(
4042
number if you encounter OOM errors.
4143
:param access_paths: The access paths to traverse on the input documents to get the images and texts to be
4244
processed. Visit https://docarray.jina.ai/fundamentals/documentarray/access-elements for more details.
45+
:param dtype: inference data type, if None defaults to torch.float32 if device == 'cpu' else torch.float16.
4346
"""
4447
super().__init__(**kwargs)
4548

@@ -52,9 +55,17 @@ def __init__(
5255
self._access_paths = kwargs['traversal_paths']
5356

5457
if not device:
55-
self._device = 'cuda' if torch.cuda.is_available() else 'cpu'
56-
else:
57-
self._device = device
58+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
59+
self._device = device
60+
if isinstance(dtype, str):
61+
dtype = __cast_dtype__.get(dtype)
62+
elif not dtype:
63+
dtype = (
64+
torch.float32
65+
if self._device in ('cpu', torch.device('cpu'))
66+
else torch.float16
67+
)
68+
self._dtype = dtype
5869

5970
if not self._device.startswith('cuda') and (
6071
'OMP_NUM_THREADS' not in os.environ
@@ -77,7 +88,9 @@ def __init__(
7788
self._num_worker_preprocess = num_worker_preprocess
7889
self._pool = ThreadPool(processes=num_worker_preprocess)
7990

80-
self._model = CLIPModel(name, device=self._device, jit=jit, **kwargs)
91+
self._model = CLIPModel(
92+
name, device=self._device, jit=jit, dtype=dtype, **kwargs
93+
)
8194
self._tokenizer = Tokenizer(name)
8295
self._image_transform = clip._transform_ndarray(self._model.image_size)
8396

@@ -96,6 +109,7 @@ def _preproc_images(self, docs: 'DocumentArray', drop_image_content: bool):
96109
device=self._device,
97110
return_np=False,
98111
drop_image_content=drop_image_content,
112+
dtype=self._dtype,
99113
)
100114

101115
def _preproc_texts(self, docs: 'DocumentArray'):

server/clip_server/executors/helper.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from typing import Tuple, List, Callable, Any, Dict
1+
from typing import Tuple, List, Callable, Any, Dict, Union
22
import torch
33
import numpy as np
44
from docarray import Document, DocumentArray
55
from docarray.math.distance.numpy import cosine
6+
from clip_server.helper import __cast_dtype__
67

78

89
from clip_server.model.tokenization import Tokenizer
@@ -22,8 +23,12 @@ def preproc_image(
2223
device: str = 'cpu',
2324
return_np: bool = False,
2425
drop_image_content: bool = False,
26+
dtype: Union[str, torch.dtype] = torch.float32,
2527
) -> Tuple['DocumentArray', Dict]:
2628

29+
if isinstance(dtype, str):
30+
dtype = __cast_dtype__.get(dtype)
31+
2732
tensors_batch = []
2833

2934
for d in da:
@@ -42,7 +47,7 @@ def preproc_image(
4247
if drop_image_content:
4348
d.pop('blob', 'tensor')
4449

45-
tensors_batch = torch.stack(tensors_batch).type(torch.float32)
50+
tensors_batch = torch.stack(tensors_batch).type(dtype)
4651

4752
if return_np:
4853
tensors_batch = tensors_batch.cpu().numpy()

server/clip_server/helper.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import sys
44
import threading
5+
import torch
56
from packaging.version import Version
67
from urllib.request import Request, urlopen
78

@@ -19,6 +20,9 @@
1920
)
2021

2122

23+
__cast_dtype__ = {'fp16': torch.float16, 'fp32': torch.float32, 'bf16': torch.bfloat16}
24+
25+
2226
def _version_check(package: str = None, github_repo: str = None):
2327
try:
2428

server/clip_server/model/clip_onnx.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from typing import Dict
2+
from typing import Dict, Optional
33

44
from clip_server.model.pretrained_models import (
55
download_model,
@@ -201,8 +201,11 @@
201201

202202

203203
class CLIPOnnxModel(BaseCLIPModel):
204-
def __init__(self, name: str, model_path: str = None):
204+
def __init__(
205+
self, name: str, model_path: str = None, dtype: Optional[str] = 'fp32'
206+
):
205207
super().__init__(name)
208+
self._dtype = dtype
206209
if name in _MODELS:
207210
if not model_path:
208211
cache_dir = os.path.expanduser(
@@ -237,6 +240,22 @@ def __init__(self, name: str, model_path: str = None):
237240
f'The given model path {model_path} should be a folder containing both '
238241
f'`textual.onnx` and `visual.onnx`.'
239242
)
243+
if dtype == 'fp16':
244+
import onnx
245+
from onnxmltools.utils import float16_converter
246+
247+
_textual_model_fp16 = (
248+
float16_converter.convert_float_to_float16_model_path(
249+
self._textual_path
250+
)
251+
)
252+
_visual_model_fp16 = (
253+
float16_converter.convert_float_to_float16_model_path(
254+
self._visual_path
255+
)
256+
)
257+
onnx.save_model(_textual_model_fp16, self._textual_path)
258+
onnx.save_model(_visual_model_fp16, self._visual_path)
240259
else:
241260
raise RuntimeError(
242261
'CLIP model {} not found or not supports ONNX backend; below is a list of all available models:\n{}'.format(

server/clip_server/model/model.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from dataclasses import dataclass
1616
from typing import Tuple, Union, Optional
1717
from copy import deepcopy
18+
from clip_server.helper import __cast_dtype__
1819
from open_clip.transformer import QuickGELU, LayerNorm, LayerNormFp32, Attention
1920
from open_clip.timm_model import TimmModel
2021
from open_clip.factory import _MODEL_CONFIGS
@@ -81,6 +82,11 @@ def __init__(
8182
super().__init__(image_size, patch_size, output_dim=output_dim, **kwargs)
8283
self.transformer = Transformer(dtype=dtype, **kwargs)
8384

85+
def forward(self, x: torch.Tensor):
86+
dtype = self.transformer.get_cast_dtype()
87+
x = x.to(dtype)
88+
return super().forward(x)
89+
8490

8591
class TextTransformer(_TextTransformer):
8692
def __init__(
@@ -435,7 +441,9 @@ def load_openai_model(
435441
preprocess : Callable[[PIL.Image], torch.Tensor]
436442
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
437443
"""
438-
if dtype is None:
444+
if isinstance(dtype, str):
445+
dtype = __cast_dtype__.get(dtype, 'amp')
446+
elif dtype is None:
439447
dtype = (
440448
torch.float32 if device in ('cpu', torch.device('cpu')) else torch.float16
441449
)
@@ -550,7 +558,9 @@ def load_openclip_model(
550558
pretrained_image: bool = False,
551559
dtype: Optional[Union[str, torch.dtype]] = None,
552560
):
553-
if dtype is None:
561+
if isinstance(dtype, str):
562+
dtype = __cast_dtype__.get(dtype)
563+
elif dtype is None:
554564
dtype = (
555565
torch.float32 if device in ('cpu', torch.device('cpu')) else torch.float16
556566
)

server/setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
'onnx': [
5454
'onnxruntime',
5555
'onnx',
56+
'onnxmltools',
5657
]
5758
+ (['onnxruntime-gpu>=1.8.0'] if sys.platform != 'darwin' else []),
5859
'tensorrt': ['nvidia-tensorrt'],

tests/test_simple.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def test_protocols(port_generator, protocol, jit, pytestconfig):
2727
c.profile(content=f'{pytestconfig.rootdir}/tests/img/00000.jpg')
2828

2929

30+
@pytest.mark.gpu
3031
@pytest.mark.parametrize(
3132
'inputs',
3233
[
@@ -48,6 +49,7 @@ def test_plain_inputs(make_flow, inputs):
4849
)
4950

5051

52+
@pytest.mark.gpu
5153
@pytest.mark.parametrize(
5254
'inputs',
5355
[

0 commit comments

Comments
 (0)