Skip to content

Commit 81d0e4d

Browse files
authored
[Feature] Add lmdeploy tis python backend model (#1014)
* add lmdeploy tis python backend model * fix pr check * update
1 parent 8fe7b27 commit 81d0e4d

3 files changed

Lines changed: 242 additions & 0 deletions

File tree

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from mmengine.config import read_base
2+
from opencompass.models.lmdeploy_tis import LmdeployTisModel
3+
4+
with read_base():
5+
# choose a list of datasets
6+
from .datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets
7+
from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets
8+
from .datasets.SuperGLUE_WiC.SuperGLUE_WiC_gen_d06864 import WiC_datasets
9+
from .datasets.SuperGLUE_WSC.SuperGLUE_WSC_gen_7902a7 import WSC_datasets
10+
from .datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
11+
from .datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets
12+
from .datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets
13+
from .datasets.race.race_gen_69ee4f import race_datasets
14+
from .datasets.crowspairs.crowspairs_gen_381af0 import crowspairs_datasets
15+
# and output the results in a choosen format
16+
from .summarizers.medium import summarizer
17+
18+
datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
19+
20+
meta_template = dict(
21+
round=[
22+
dict(role='HUMAN', begin='<|im_start|>user\n', end='<|im_end|>\n'),
23+
dict(role='BOT', begin='<|im_start|>assistant\n', end='<|im_end|>\n', generate=True),
24+
],
25+
eos_token_id=92542
26+
)
27+
28+
models = [
29+
dict(
30+
type=LmdeployTisModel,
31+
abbr='internlm-chat-20b-lmdeploy-tis',
32+
path="internlm/internlm-chat-20b",
33+
tis_addr='0.0.0.0:33337',
34+
max_out_len=100,
35+
max_seq_len=2048,
36+
batch_size=8,
37+
meta_template=meta_template,
38+
run_cfg=dict(num_gpus=1, num_procs=1),
39+
end_str='<|im_end|>',
40+
)
41+
]

opencompass/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .lightllm_api import LightllmAPI # noqa: F401
1919
from .llama2 import Llama2, Llama2Chat # noqa: F401, F403
2020
from .lmdeploy_pytorch import LmdeployPytorchModel # noqa: F401
21+
from .lmdeploy_tis import LmdeployTisModel # noqa: F401
2122
from .minimax_api import MiniMax # noqa: F401
2223
from .mistral_api import Mistral # noqa: F401
2324
from .mixtral import Mixtral # noqa: F401

opencompass/models/lmdeploy_tis.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
import threading
2+
from concurrent.futures import ThreadPoolExecutor
3+
from functools import partial
4+
from queue import Queue
5+
from typing import Dict, List, Optional, Union
6+
7+
import numpy as np
8+
9+
from opencompass.models.base import BaseModel, LMTemplateParser
10+
from opencompass.utils.logging import get_logger
11+
from opencompass.utils.prompt import PromptList
12+
13+
PromptType = Union[PromptList, str]
14+
15+
16+
def valid_str(string, coding='utf-8'):
17+
"""decode text according to its encoding type."""
18+
invalid_chars = [b'\xef\xbf\xbd']
19+
bstr = bytes(string, coding)
20+
for invalid_char in invalid_chars:
21+
bstr = bstr.replace(invalid_char, b'')
22+
ret = bstr.decode(encoding=coding, errors='ignore')
23+
return ret
24+
25+
26+
def prepare_tensor(name, input_tensor):
27+
"""Create grpcclient's InferInput instance according to a given tensor."""
28+
import tritonclient.grpc as grpcclient
29+
from tritonclient.utils import np_to_triton_dtype
30+
t = grpcclient.InferInput(name, list(input_tensor.shape),
31+
np_to_triton_dtype(input_tensor.dtype))
32+
t.set_data_from_numpy(input_tensor)
33+
return t
34+
35+
36+
def stream_callback(que, result, error):
37+
"""callback function invoked by triton client."""
38+
que.put((result, error))
39+
40+
41+
class LmdeployTisModel(BaseModel):
42+
"""Model wrapper for LMDeploy Python Backend Triton Inference Server gRPC
43+
API.
44+
45+
Args:
46+
path (str): The name of OpenAI's model.
47+
tis_addr (str): The address (ip:port format) of turbomind's
48+
triton inference server
49+
max_seq_len (int): The maximum allowed sequence length of a model.
50+
Note that the length of prompt + generated tokens shall not exceed
51+
this value. Defaults to 2048.
52+
meta_template (Dict, optional): The model's meta prompt
53+
template if needed, in case the requirement of injecting or
54+
wrapping of any meta instructions.
55+
"""
56+
57+
is_api: bool = True
58+
59+
def __init__(self,
60+
path: str,
61+
tis_addr: str = '0.0.0.0:33337',
62+
max_seq_len: int = 2048,
63+
meta_template: Optional[Dict] = None,
64+
end_str: Optional[str] = None):
65+
super().__init__(path=path,
66+
max_seq_len=max_seq_len,
67+
meta_template=meta_template)
68+
from lmdeploy.tokenizer import Tokenizer
69+
70+
self.logger = get_logger()
71+
self.template_parser = LMTemplateParser(meta_template)
72+
self.eos_token_id = None
73+
if meta_template and 'eos_token_id' in meta_template:
74+
self.eos_token_id = meta_template['eos_token_id']
75+
self.tis_addr = tis_addr
76+
self.tokenizer = Tokenizer(path)
77+
self.end_str = end_str
78+
79+
def generate(
80+
self,
81+
inputs: List[str or PromptList],
82+
max_out_len: int = 512,
83+
temperature: float = 1.0,
84+
) -> List[str]:
85+
"""Generate results given a list of inputs.
86+
87+
Args:
88+
inputs (List[str or PromptList]): A list of strings or PromptDicts.
89+
The PromptDict should be organized in OpenCompass'
90+
API format.
91+
max_out_len (int): The maximum length of the output.
92+
temperature (float): What sampling temperature to use,
93+
between 0 and 2. Higher values like 0.8 will make the output
94+
more random, while lower values like 0.2 will make it more
95+
focused and deterministic. Defaults to 0.7.
96+
97+
Returns:
98+
List[str]: A list of generated strings.
99+
"""
100+
101+
with ThreadPoolExecutor() as executor:
102+
results = list(
103+
executor.map(self._generate, inputs,
104+
[max_out_len] * len(inputs),
105+
[temperature] * len(inputs),
106+
[self.end_str] * len(inputs)))
107+
return results
108+
109+
def wait(self):
110+
"""Wait till the next query can be sent.
111+
112+
Applicable in both single-thread and multi-thread environments.
113+
"""
114+
return self.token_bucket.get_token()
115+
116+
def get_token_len(self, prompt: str) -> int:
117+
input_ids = self.tokenizer.encode(prompt)
118+
return len(input_ids)
119+
120+
def _call_triton_server(self, prompt, tis_addr, session_id,
121+
request_output_len, temperature, res_que):
122+
import tritonclient.grpc as grpcclient
123+
124+
with grpcclient.InferenceServerClient(tis_addr) as client:
125+
inputs = [
126+
prepare_tensor('prompt',
127+
np.array([prompt.encode()], dtype=np.object_)),
128+
prepare_tensor('max_tokens',
129+
np.array([request_output_len], dtype=np.int32)),
130+
prepare_tensor('temperature',
131+
np.array([temperature], dtype=np.float_)),
132+
prepare_tensor('top_p', np.array([1.0], dtype=np.float_)),
133+
prepare_tensor('top_k', np.array([1], dtype=np.int32)),
134+
prepare_tensor('ignore_eos', np.array([False],
135+
dtype=np.bool_)),
136+
prepare_tensor('stream', np.array([True], dtype=np.bool_)),
137+
]
138+
139+
# async_stream
140+
client.start_stream(partial(stream_callback, res_que))
141+
client.async_stream_infer('lmdeploy_model',
142+
inputs,
143+
sequence_id=session_id,
144+
sequence_start=True,
145+
sequence_end=True)
146+
147+
res_que.put(None)
148+
return
149+
150+
def _process_result(self, que):
151+
text = ''
152+
while True:
153+
res = que.get()
154+
if res is not None:
155+
result, err = res
156+
if err is not None:
157+
print(err)
158+
else:
159+
res = result.as_numpy('response').item().decode()
160+
text += res
161+
else:
162+
return text
163+
164+
def _generate(self,
165+
prompt: str or PromptList,
166+
max_out_len: int,
167+
temperature: float,
168+
end_str: Optional[str] = None) -> str:
169+
"""Generate results given a list of inputs.
170+
171+
Args:
172+
prompt (str or PromptList): A string or PromptDict.
173+
The PromptDict should be organized in OpenCompass'
174+
API format.
175+
max_out_len (int): The maximum length of the output.
176+
temperature (float): What sampling temperature to use,
177+
between 0 and 2. Higher values like 0.8 will make the output
178+
more random, while lower values like 0.2 will make it more
179+
focused and deterministic.
180+
181+
Returns:
182+
str: The generated string.
183+
"""
184+
assert type(
185+
prompt
186+
) is str, 'We only support string for LMDeploy Python Backend TIS API'
187+
188+
res_que = Queue()
189+
190+
self._call_triton_server(prompt=prompt,
191+
tis_addr=self.tis_addr,
192+
session_id=threading.currentThread().ident,
193+
request_output_len=max_out_len,
194+
temperature=temperature,
195+
res_que=res_que)
196+
text = self._process_result(res_que)
197+
response = valid_str(text)
198+
if end_str:
199+
response = response.split(end_str)[0]
200+
return response

0 commit comments

Comments
 (0)