Skip to content

Commit e330921

Browse files
Merge pull request #7 from google-ai-edge:fix-quantize-recipe-bug
PiperOrigin-RevId: 912611645
2 parents 71b16eb + 4d8394a commit e330921

6 files changed

Lines changed: 99 additions & 44 deletions

File tree

README.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,14 +128,16 @@ litert compile model.tflite --target sm8750 --target mt6989 --export-aipack my_n
128128
litert run model.tflite --desktop --cpu
129129
litert run my_model_ref --desktop --cpu
130130

131-
# Run with GPU acceleration
132-
litert run model.tflite --gpu
131+
# Run with GPU acceleration and CPU fallback (multi-accelerator)
132+
litert run model.tflite --gpu --cpu
133+
litert run model.tflite --accelerator gpu,cpu
133134

134135
# Run on connected Android device
135136
litert run model.tflite --android
136137

137-
# Run on connected Android device with NPU acceleration (JIT mode)
138-
litert run model.tflite --android --npu
138+
# Run on connected Android device with NPU acceleration and CPU fallback
139+
litert run model.tflite --android --npu --cpu
140+
litert run model.tflite --android --accelerator npu,cpu
139141

140142
# Run on connected Android device with NPU AOT-compiled model
141143
litert run model_sm8450.tflite --android --npu

litert_cli/commands/run/android.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,28 @@
2222
1. Run a model on an Android device:
2323
$ litert run /path/to/model.tflite --android
2424
25-
2. Run with custom inputs:
25+
2. Run with NPU acceleration and CPU fallback:
26+
$ litert run /path/to/model.tflite --android --npu --cpu
27+
OR
28+
$ litert run /path/to/model.tflite --android --accelerator npu,cpu
29+
30+
3. Run with custom inputs:
2631
$ litert run /path/to/model.tflite --android --input input_name=value
2732
28-
3. Run with multiple inputs:
33+
4. Run with multiple inputs:
2934
$ litert run /path/to/model.tflite --android --input input1=value1 --input
3035
input2=value2
3136
32-
4. Run with specific signature:
37+
5. Run with specific signature:
3338
$ litert run /path/to/model.tflite --android --signature_index 0
3439
35-
5. Run with multiple iterations:
40+
6. Run with multiple iterations:
3641
$ litert run /path/to/model.tflite --android --iterations 10
3742
38-
6. Print tensor details:
43+
7. Print tensor details:
3944
$ litert run /path/to/model.tflite --android --print-tensors
4045
41-
7. Run with sample size:
46+
8. Run with sample size:
4247
$ litert run /path/to/model.tflite --android --sample-size 100
4348
"""
4449

@@ -183,6 +188,7 @@ def run_android(
183188
Raises:
184189
click.ClickException: On device error setup or failed execution triggers.
185190
"""
191+
accel_list = [a.strip().lower() for a in accelerator.split(",") if a.strip()]
186192
click.echo("Preparing to run on Android device via adb...")
187193
android_utils.check_adb()
188194

@@ -270,11 +276,11 @@ def run_android(
270276
# Pass None as device_id to use the default connected device.
271277
remote_dispatch_dir = (
272278
npu.push_npu_runtime_libraries(None, android_root)
273-
if accelerator == "npu"
279+
if "npu" in accel_list
274280
else ""
275281
)
276282

277-
if accelerator == "npu":
283+
if "npu" in accel_list:
278284
# Download and push SOC-specific LiteRT dispatch and compiler plugin libraries
279285
target_model = npu.get_soc_target_model(None)
280286
soc_vendor = "mediatek" if "mt" in target_model else "qualcomm"

litert_cli/commands/run/cli.py

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,14 @@
7878
7. Print detailed tensor outputs:
7979
8080
$ litert run model.tflite --print-tensors --sample-size 10
81+
82+
8. Run with multiple accelerators (npu -> gpu -> cpu fallback):
83+
84+
$ litert run model.tflite --npu --gpu --cpu
85+
86+
OR explicitly:
87+
88+
$ litert run model.tflite --accelerator npu,gpu,cpu
8189
"""),
8290
)
8391
@deps.require_extra("run")
@@ -118,23 +126,24 @@
118126
flag_value="android",
119127
help="Target Android platform to run.",
120128
)
129+
@click.option(
130+
"--accelerator",
131+
type=str,
132+
help="Comma-separated list of hardware accelerators (e.g. npu,gpu,cpu).",
133+
)
121134
@click.option(
122135
"--cpu",
123-
"accelerator",
124-
flag_value="cpu",
125-
default=True,
126-
help="Use CPU accelerator (Default).",
136+
is_flag=True,
137+
help="Use CPU accelerator.",
127138
)
128139
@click.option(
129140
"--gpu",
130-
"accelerator",
131-
flag_value="gpu",
141+
is_flag=True,
132142
help="Use GPU accelerator.",
133143
)
134144
@click.option(
135145
"--npu",
136-
"accelerator",
137-
flag_value="npu",
146+
is_flag=True,
138147
help="Use NPU accelerator.",
139148
)
140149
@click.option(
@@ -169,7 +178,10 @@ def run_cmd(
169178
model_params: Sequence[str],
170179
model_help: bool,
171180
target: str,
172-
accelerator: str,
181+
accelerator: str | None,
182+
cpu: bool,
183+
gpu: bool,
184+
npu: bool,
173185
signature_index: int,
174186
iterations: int,
175187
print_tensors: bool,
@@ -185,11 +197,33 @@ def run_cmd(
185197
model_help: Show help specific to the matched model plugin.
186198
target: Execution target ('desktop' or 'android').
187199
accelerator: Hardware accelerator ('cpu', 'gpu', or 'npu').
200+
cpu: Use CPU accelerator.
201+
gpu: Use GPU accelerator.
202+
npu: Use NPU accelerator.
188203
signature_index: Index of model signature to run.
189204
iterations: Number of times to execute the model for benchmarking.
190205
print_tensors: Whether to print output tensor elements.
191206
sample_size: Number of sample elements to print from tensors.
192207
"""
208+
# Resolve the order of accelerators
209+
accelerator_list = []
210+
if accelerator:
211+
accelerator_list = [
212+
a.strip().lower() for a in accelerator.split(",") if a.strip()
213+
]
214+
else:
215+
if npu:
216+
accelerator_list.append("npu")
217+
if gpu:
218+
accelerator_list.append("gpu")
219+
if cpu:
220+
accelerator_list.append("cpu")
221+
222+
if not accelerator_list:
223+
accelerator_list = ["cpu"]
224+
225+
accelerator = ",".join(accelerator_list)
226+
193227
# Quiet if default is true
194228
if constants.DEFAULT_QUIET:
195229

litert_cli/commands/run/desktop.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,18 @@
2626
OR
2727
$ litert run /path/to/model.tflite --desktop --accelerator gpu
2828
29-
3. Run with custom inputs:
29+
3. Run with multiple accelerators (gpu -> cpu native fallback):
30+
$ litert run /path/to/model.tflite --desktop --gpu --cpu
31+
OR
32+
$ litert run /path/to/model.tflite --desktop --accelerator gpu,cpu
33+
34+
4. Run with custom inputs:
3035
$ litert run /path/to/model.tflite --desktop --input input_name=value
3136
32-
4. Run with multiple iterations (benchmark):
37+
5. Run with multiple iterations (benchmark):
3338
$ litert run /path/to/model.tflite --desktop --iterations 10
3439
35-
5. Print tensor details:
40+
6. Print tensor details:
3641
$ litert run /path/to/model.tflite --desktop --print-tensors
3742
"""
3843

@@ -240,22 +245,31 @@ def run_desktop(
240245
click.ClickException: On loading failure or inference execution errors.
241246
"""
242247

243-
click.echo(
244-
f"Loading model on desktop: {model_path} with {accelerator.upper()}"
245-
)
248+
accel_list = [a.strip().lower() for a in accelerator.split(",") if a.strip()]
246249

247250
# pylint: disable=g-import-not-at-top,reimported
248251
from ai_edge_litert.compiled_model import CompiledModel
249252
from ai_edge_litert.compiled_model import Environment
250253
from ai_edge_litert.hardware_accelerator import HardwareAccelerator
251254

252-
hw_accel = HardwareAccelerator.CPU
253-
if accelerator == "gpu":
254-
hw_accel = HardwareAccelerator.GPU
255-
elif accelerator == "npu":
256-
raise click.ClickException(
257-
"NPU accelerator is not yet formally supported via desktop API."
258-
)
255+
hw_accel = HardwareAccelerator(0)
256+
for accel in accel_list:
257+
if accel == "cpu":
258+
hw_accel |= HardwareAccelerator.CPU
259+
elif accel == "gpu":
260+
hw_accel |= HardwareAccelerator.GPU
261+
elif accel == "npu":
262+
hw_accel |= HardwareAccelerator.NPU
263+
else:
264+
raise click.ClickException(f"Unsupported hardware accelerator: {accel!r}")
265+
266+
if hw_accel == HardwareAccelerator(0):
267+
hw_accel = HardwareAccelerator.CPU
268+
269+
click.echo(
270+
f"Loading model on desktop: {model_path} with native hardware"
271+
f" accelerators: {hw_accel}"
272+
)
259273

260274
ctx = utils.silence_stderr() if quiet else contextlib.nullcontext()
261275
with ctx:

test_scripts/models/efficientnet.sh

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,6 @@ export TEST_DATA_DIR="$REPO_ROOT/litert_cli/test_data"
5353
echo -e "${YELLOW}Installing litert-cli from source...${NC}"
5454
pip install -e "$REPO_ROOT"
5555

56-
57-
58-
5956
# --- 1. Download EfficientNet-B1 model ---
6057
run_case "Download: EfficientNet-B1 from HuggingFace" \
6158
litert download litert-community/efficientnet_b1 --file "*.tflite" --output "$MODEL_DIR/efficientnet"
@@ -113,14 +110,14 @@ if has_android_device; then
113110
run_case "Benchmark: EfficientNet Dynamic INT8 on Android" \
114111
litert benchmark "$MODEL_DIR/efficientnet/efficientnet_b1_int8_dynamic.tflite" --android
115112
else
116-
echo -e "\n${YELLOW}No Android device detected. Skipping benchmarks (litert benchmark only supports Android/GCP).${NC}"
113+
echo -e "\n${YELLOW}No Android device detected. Skipping benchmarks on Android.${NC}"
117114
fi
118115

119116

120117
# --- 5. Compile (AOT Compilation) ---
121118
# TODO: Add this back when we fix the NPU compile issue.
122119
# run_case "Compile: EfficientNet FP32 for Qualcomm sm8750 NPU" \
123-
# litert compile "$EFFICIENTNET_TFLITE" --target sm8750 --output-dir "$MODEL_DIR/efficientnet"
120+
# litert compile "$EFFICIENTNET_TFLITE" --target sm8750 --output-dir "$MODEL_DIR/efficientnet"
124121

125122
# --- 6. Visualize (Model Explorer) ---
126123
run_case "Visualize: Launch Model Explorer in the background" \

test_scripts/models/yamnet.sh

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,10 @@ if has_android_device; then
9393
echo -e "\n${GREEN}Android device detected. Running Android inference...${NC}"
9494
run_case "Run: YamNet FP32 on Android (CPU)" \
9595
litert run "$YAMNET_TFLITE" --android --cpu --iterations 1
96-
97-
run_case "Run: YamNet FP32 on Android (GPU)" \
98-
litert run "$YAMNET_TFLITE" --android --gpu --iterations 1
96+
97+
# Works on Qualcomm NPU SM8750, but not GPU.
98+
# run_case "Run: YamNet FP32 on Android (GPU)" \
99+
# litert run "$YAMNET_TFLITE" --android --gpu --iterations 1
99100

100101
run_case "Run: YamNet Dynamic INT8 on Android (CPU)" \
101102
litert run "$MODEL_DIR/yamnet/yamnet_int8_dynamic.tflite" --android --cpu --iterations 1
@@ -107,8 +108,9 @@ if has_android_device; then
107108
run_case "Benchmark: YamNet FP32 on Android (CPU)" \
108109
litert benchmark "$YAMNET_TFLITE" --android
109110

110-
run_case "Benchmark: YamNet FP32 on Android (GPU)" \
111-
litert benchmark "$YAMNET_TFLITE" --android --gpu
111+
# Works on Qualcomm NPU SM8750, but not GPU.
112+
# run_case "Benchmark: YamNet FP32 on Android (GPU)" \
113+
# litert benchmark "$YAMNET_TFLITE" --android --gpu
112114

113115
run_case "Benchmark: YamNet Dynamic INT8 on Android" \
114116
litert benchmark "$MODEL_DIR/yamnet/yamnet_int8_dynamic.tflite" --android

0 commit comments

Comments
 (0)