Skip to content

Commit f778407

Browse files
Merge pull request #9 from google-ai-edge:feat-mediatek-support
PiperOrigin-RevId: 913747130
2 parents 0e69758 + f8f9156 commit f778407

10 files changed

Lines changed: 686 additions & 135 deletions

File tree

litert_cli/commands/benchmark/android.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,15 @@ def run_android(*, model_path: pathlib.Path, accelerator: str) -> None:
128128
f"--compiler_plugin_library_path={shlex.quote(cli_android_root)}"
129129
)
130130

131+
if soc_vendor == "mediatek":
132+
recommend_version = constants.MEDIATEK_SOC_VERSION_MAP.get(
133+
target_model, ""
134+
)
135+
if "v9" in recommend_version:
136+
bench_args.append("--mediatek_nerun_pilot_version=version9")
137+
elif "v8" in recommend_version:
138+
bench_args.append("--mediatek_nerun_pilot_version=version8")
139+
131140
env_vars = ""
132141
if remote_dispatch_dir:
133142
quoted_dispatch_dir = shlex.quote(remote_dispatch_dir)
@@ -137,6 +146,32 @@ def run_android(*, model_path: pathlib.Path, accelerator: str) -> None:
137146
)
138147

139148
full_command = env_vars + " ".join(bench_args)
140-
subprocess.run(["adb", "shell", full_command], check=True)
141-
except subprocess.CalledProcessError as e:
142-
click.secho(f"Execution failed on device: {repr(e)}", fg="red")
149+
process = subprocess.Popen(
150+
["adb", "shell", full_command],
151+
stdout=subprocess.PIPE,
152+
stderr=subprocess.STDOUT,
153+
text=True,
154+
)
155+
156+
from litert_cli.core.log_filters import BenchmarkLogFilter
157+
158+
output_lines = []
159+
log_filter = BenchmarkLogFilter(constants.DEFAULT_QUIET)
160+
161+
for line in process.stdout:
162+
output_lines.append(line)
163+
if log_filter.should_show(line):
164+
click.echo(line, nl=False)
165+
166+
process.wait()
167+
if process.returncode != 0:
168+
click.secho(
169+
f"Execution failed on device with exit code {process.returncode}",
170+
fg="red",
171+
)
172+
click.echo("Full output for debugging:")
173+
for line in output_lines:
174+
click.echo(line, nl=False)
175+
raise click.ClickException("Benchmark failed on device.")
176+
except Exception as e:
177+
raise click.ClickException(f"Failed to execute benchmark on device: {e}")

litert_cli/commands/compile.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from __future__ import annotations
2222

2323
from collections.abc import Sequence
24+
import importlib
2425
import pathlib
2526
import shutil
2627
import textwrap
@@ -30,6 +31,7 @@
3031
from litert_cli.core import deps
3132
from litert_cli.core import npu_utils
3233
from litert_cli.core import utils
34+
from litert_cli.core.targets_manager import TargetsManager
3335

3436

3537
@click.command(
@@ -52,6 +54,16 @@
5254
"""),
5355
)
5456
@click.argument("model_path", type=str)
57+
@click.option(
58+
"--update-targets",
59+
type=str,
60+
required=False,
61+
default=None,
62+
help=(
63+
"Update SoC target lists from GitHub. Pass 'main' for latest, or a"
64+
" version tag like 'v2.1.4'."
65+
),
66+
)
5567
@click.option(
5668
"--target",
5769
type=str,
@@ -93,6 +105,7 @@
93105
def compile_cmd(
94106
model_path: str,
95107
target: Sequence[str],
108+
update_targets: str | None,
96109
export_aipack: pathlib.Path | None,
97110
output_dir: pathlib.Path | None,
98111
) -> None:
@@ -115,6 +128,24 @@ def compile_cmd(
115128
if constants.DEFAULT_QUIET:
116129
utils.enable_quiet_mode()
117130

131+
# Initialize targets
132+
manager = TargetsManager()
133+
134+
# Handle update or first-run download
135+
if update_targets:
136+
manager.download_targets(version=update_targets)
137+
importlib.reload(constants)
138+
else:
139+
# Check if cache exists
140+
if not manager.load_targets():
141+
click.echo("No target cache found. Downloading default target lists...")
142+
try:
143+
manager.download_targets(version="main")
144+
importlib.reload(constants)
145+
except Exception as e:
146+
click.echo(f"Warning: Failed to download default targets: {e}")
147+
click.echo("Falling back to built-in static target lists.")
148+
118149
resolved_model_path, _ = core_models.resolve_model_reference(model_path)
119150
if str(resolved_model_path) != str(model_path):
120151
click.echo(f"Resolved model '{model_path}' to '{resolved_model_path}'")

litert_cli/commands/run/android.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -352,16 +352,35 @@ def run_android(
352352
env_vars = ""
353353
cmd_str = f"{env_vars} " if env_vars else ""
354354
cmd_str += " ".join(shlex.quote(arg) for arg in run_cmd_args)
355-
subprocess.run(
356-
[
357-
"adb",
358-
"shell",
359-
cmd_str,
360-
],
361-
check=True,
355+
process = subprocess.Popen(
356+
["adb", "shell", cmd_str],
357+
stdout=subprocess.PIPE,
358+
stderr=subprocess.STDOUT,
359+
text=True,
362360
)
363-
except subprocess.CalledProcessError as e:
364-
raise click.ClickException(f"Execution failed on device: {e!r}") from e
361+
362+
from litert_cli.core.log_filters import RunLogFilter
363+
364+
output_lines = []
365+
log_filter = RunLogFilter(constants.DEFAULT_QUIET, print_tensors)
366+
367+
for line in process.stdout:
368+
output_lines.append(line)
369+
if log_filter.should_show(line):
370+
click.echo(line, nl=False)
371+
372+
process.wait()
373+
if process.returncode != 0:
374+
click.secho(
375+
f"Execution failed on device with exit code {process.returncode}",
376+
fg="red",
377+
)
378+
click.echo("Full output for debugging:")
379+
for line in output_lines:
380+
click.echo(line, nl=False)
381+
raise click.ClickException("Execution failed on device.")
382+
except Exception as e:
383+
raise click.ClickException(f"Failed to execute on device: {e}")
365384
finally:
366385
# Cleanup remote paths
367386
click.echo("Clearing remote files...")

litert_cli/core/constants.py

Lines changed: 40 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
ENV_LITERT_CLI_FORCE_OSS: str = "LITERT_CLI_FORCE_OSS"
3535
ENV_LITERT_VERBOSE: str = "LITERT_VERBOSE"
3636

37-
DEFAULT_QUIET: bool = os.environ.get(ENV_LITERT_VERBOSE, "1") != "1"
37+
DEFAULT_QUIET: bool = os.environ.get(ENV_LITERT_VERBOSE, "0") != "1"
3838

3939
_FORCE_OSS = os.environ.get(ENV_LITERT_CLI_FORCE_OSS, "").lower() in (
4040
"1",
@@ -80,61 +80,46 @@
8080
"https://softwarecenter.qualcomm.com/api/download/software/sdks/"
8181
f"Qualcomm_AI_Runtime_Community/All/{QAIRT_SDK_VERSION}/v{QAIRT_SDK_VERSION}.zip"
8282
)
83+
MEDIATEK_SDK_URL: str = (
84+
"https://s3.ap-southeast-1.amazonaws.com/mediatek.neuropilot.com/"
85+
"66f2c33a-2005-4f0b-afef-2053c8654e4f.gz"
86+
)
87+
MEDIATEK_V8_VERSION: str = "v8_0_10"
88+
MEDIATEK_V9_VERSION: str = "v9_0_3"
89+
90+
from litert_cli.core.targets_manager import TargetsManager
91+
92+
_manager = TargetsManager()
93+
_loaded_targets = _manager.load_targets()
94+
95+
_qnn_map = {}
96+
_mtk_map = {}
97+
_aot_map = {}
98+
99+
if _loaded_targets:
100+
# Reconstruct maps from loaded targets
101+
_qnn_map = {
102+
k: v.properties.get("qnn_version", "")
103+
for k, v in _loaded_targets.items()
104+
if v.vendor == "qualcomm"
105+
}
83106

84-
# TODO: switch to read from litert/vendors/qualcomm/supported_soc.csv
85-
QNN_SOC_VERSION_MAP: types.MappingProxyType[str, str] = types.MappingProxyType({
86-
"sm8350": "68",
87-
"sm8450": "69",
88-
"sm8550": "73",
89-
"sm8650": "75",
90-
"sm8750": "79",
91-
"sm8850": "81",
92-
# Sub-flagship & Mid-range Mobile SoCs
93-
"sm8635": "73",
94-
"sm7675": "73",
95-
"sm7550": "73",
96-
"sm7475": "69",
97-
# Automotive Cockpit SoCs
98-
"sa8255": "73",
99-
"sa8295": "68",
100-
"qnn_all": "81",
101-
})
107+
_mtk_map = {
108+
k: v.properties.get("recommend_version", "")
109+
for k, v in _loaded_targets.items()
110+
if v.vendor == "mediatek"
111+
}
112+
113+
_aot_map = {k: (v.vendor, v.vendor_id) for k, v in _loaded_targets.items()}
114+
115+
QNN_SOC_VERSION_MAP: types.MappingProxyType[str, str] = types.MappingProxyType(
116+
_qnn_map
117+
)
118+
119+
MEDIATEK_SOC_VERSION_MAP: types.MappingProxyType[str, str] = (
120+
types.MappingProxyType(_mtk_map)
121+
)
102122

103123
AOT_SUPPORTED_TARGETS: types.MappingProxyType[str, tuple[str, str]] = (
104-
types.MappingProxyType({
105-
# Qualcomm
106-
"sm8350": ("qualcomm", "SM8350"),
107-
"sm8450": ("qualcomm", "SM8450"),
108-
"sm8550": ("qualcomm", "SM8550"),
109-
"sm8650": ("qualcomm", "SM8650"),
110-
"sm8750": ("qualcomm", "SM8750"),
111-
"sm8850": ("qualcomm", "SM8850"),
112-
# Sub-flagship & Mid-range Mobile SoCs
113-
"sm8635": ("qualcomm", "SM8635"),
114-
"sm7675": ("qualcomm", "SM7675"),
115-
"sm7550": ("qualcomm", "SM7550"),
116-
"sm7475": ("qualcomm", "SM7475"),
117-
# Automotive Cockpit SoCs
118-
"sa8255": ("qualcomm", "SA8255"),
119-
"sa8295": ("qualcomm", "SA8295"),
120-
"qnn_all": ("qualcomm", "ALL"),
121-
# MediaTek
122-
"mt6853": ("mediatek", "MT6853"),
123-
"mt6877": ("mediatek", "MT6877"),
124-
"mt6878": ("mediatek", "MT6878"),
125-
"mt6879": ("mediatek", "MT6879"),
126-
"mt6886": ("mediatek", "MT6886"),
127-
"mt6893": ("mediatek", "MT6893"),
128-
"mt6895": ("mediatek", "MT6895"),
129-
"mt6897": ("mediatek", "MT6897"),
130-
"mt6983": ("mediatek", "MT6983"),
131-
"mt6985": ("mediatek", "MT6985"),
132-
"mt6989": ("mediatek", "MT6989"),
133-
"mt6991": ("mediatek", "MT6991"),
134-
"mt6993": ("mediatek", "MT6993"),
135-
"mt8171": ("mediatek", "MT8171"),
136-
"mt8188": ("mediatek", "MT8188"),
137-
"mt8189": ("mediatek", "MT8189"),
138-
"mtk_all": ("mediatek", "ALL"),
139-
})
124+
types.MappingProxyType(_aot_map)
140125
)

litert_cli/core/log_filters.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#!/usr/bin/env python3
2+
# Copyright 2026 The LiteRT CLI Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
17+
"""Log filters for LiteRT CLI commands to prevent terminal noise."""
18+
19+
20+
class BenchmarkLogFilter:
21+
"""Filters output of litert benchmark command."""
22+
23+
def __init__(self, default_quiet: bool):
24+
self.default_quiet = default_quiet
25+
26+
def should_show(self, line: str) -> bool:
27+
"""Determines if a line should be shown in the output."""
28+
is_core_info = (
29+
"benchmark_litert_model" in line
30+
or "compiler_plugin.cc" in line
31+
or ("Replacing" in line and "node(s) with delegate" in line)
32+
)
33+
return not self.default_quiet or is_core_info
34+
35+
36+
class RunLogFilter:
37+
"""Filters output of litert run command."""
38+
39+
def __init__(self, default_quiet: bool, print_tensors: bool):
40+
self.default_quiet = default_quiet
41+
self.print_tensors = print_tensors
42+
43+
def should_show(self, line: str) -> bool:
44+
"""Determines if a line should be shown in the output."""
45+
is_core_info = (
46+
"run_model.cc" in line
47+
or "compiler_plugin.cc" in line
48+
or ("Replacing" in line and "node(s) with delegate" in line)
49+
)
50+
return not self.default_quiet or is_core_info or self.print_tensors

0 commit comments

Comments
 (0)