From 3863362906240c91aeffa99581341d1a744edcea Mon Sep 17 00:00:00 2001 From: lightseek-bot <243258330+lightseek-bot@users.noreply.github.com> Date: Tue, 16 Jun 2026 08:16:42 +0000 Subject: [PATCH] revert: remove instanttensor loader Signed-off-by: lightseek-bot <243258330+lightseek-bot@users.noreply.github.com> --- docs/.vitepress/config.mts | 3 +- docs/configuration/server.md | 2 +- docs/guides/instanttensor.md | 70 ------------ python/pyproject.toml | 1 - .../tokenspeed/runtime/configs/load_config.py | 5 - .../tokenspeed/runtime/model_loader/loader.py | 8 +- .../runtime/model_loader/weight_utils.py | 57 --------- python/tokenspeed/runtime/models/gpt_oss.py | 70 +++++------- python/tokenspeed/runtime/models/kimi_k25.py | 76 +++++------- .../tokenspeed/runtime/utils/server_args.py | 4 - .../deepseek-v4-flash-evalscope-gsm8k.yaml | 1 - ...deepseek-v4-flash-mtp-evalscope-gsm8k.yaml | 1 - ...oss-120b-mxfp4-evalscope-gpqa-diamond.yaml | 4 - .../kimi-k2.5-nvfp4-evalscope-aime25.yaml | 7 -- ...imi-k2.5-nvfp4-evalscope-gpqa-diamond.yaml | 1 - .../eval/kimi-k2.5-nvfp4-evalscope-gsm8k.yaml | 1 - .../eval/kimi-k2.5-nvfp4-evalscope-mmlu.yaml | 1 - .../kimi-k2.5-nvfp4-evalscope-ocr-bench.yaml | 1 - .../minimax-m2.7-nvfp4-evalscope-aime25.yaml | 1 - ...n3.5-397b-a17b-nvfp4-evalscope-aime25.yaml | 7 -- .../kimi-k2.5-nvfp4-evalscope-agentic.yaml | 2 - ...17b-nvfp4-evalscope-agentic-b200-8gpu.yaml | 1 - ...3.5-397b-a17b-nvfp4-evalscope-longctx.yaml | 1 - test/runtime/test_gpt_oss_mxfp4_streaming.py | 108 ------------------ test/runtime/test_instanttensor_loader.py | 95 --------------- 25 files changed, 60 insertions(+), 468 deletions(-) delete mode 100644 docs/guides/instanttensor.md delete mode 100644 test/runtime/test_gpt_oss_mxfp4_streaming.py delete mode 100644 test/runtime/test_instanttensor_loader.py diff --git a/docs/.vitepress/config.mts b/docs/.vitepress/config.mts index 05912edaa..1289eb32d 100644 --- a/docs/.vitepress/config.mts +++ b/docs/.vitepress/config.mts @@ -37,8 +37,7 @@ export default defineConfig({ text: "Guides", items: [ { text: "Getting Started", link: "/guides/getting-started" }, - { text: "Launching a Server", link: "/guides/launching" }, - { text: "InstantTensor Loading", link: "/guides/instanttensor" } + { text: "Launching a Server", link: "/guides/launching" } ] }, { diff --git a/docs/configuration/server.md b/docs/configuration/server.md index b9f1d043e..905fbbd9a 100644 --- a/docs/configuration/server.md +++ b/docs/configuration/server.md @@ -16,7 +16,7 @@ For a compact compatibility table, see | `--tokenizer` | Tokenizer path when it differs from the model path. | | `--tokenizer-mode` | Select tokenizer behavior. `auto` uses fast tokenizers and model-specific hooks when available. | | `--skip-tokenizer-init` | Skip tokenizer initialization for input-ID-only serving paths. | -| `--load-format` | Weight loading format: `auto`, `pt`, `safetensors`, `instanttensor`, `npcache`, `dummy`, or `extensible`. See [InstantTensor](/guides/instanttensor) for the accelerated NVIDIA loader. | +| `--load-format` | Weight loading format: `auto`, `pt`, `safetensors`, `npcache`, `dummy`, or `extensible`. | | `--trust-remote-code` | Allow custom model code from the model repository. | | `--revision` | Model branch, tag, or commit. | | `--download-dir` | Hugging Face download/cache directory. | diff --git a/docs/guides/instanttensor.md b/docs/guides/instanttensor.md deleted file mode 100644 index 4ee5a6e45..000000000 --- a/docs/guides/instanttensor.md +++ /dev/null @@ -1,70 +0,0 @@ -# Loading Weights with InstantTensor - -[InstantTensor](https://github.com/scitix/InstantTensor) accelerates loading -safetensors weights on NVIDIA GPUs through distributed loading, pipelined -prefetching, and direct I/O. It also supports GPUDirect Storage (GDS) when -available, which lets it fully utilize the bandwidth of high-speed networked -storage (e.g. 400 Gbps). - -InstantTensor only changes *how* the safetensors shards are read off disk and -moved onto the GPU — the resulting weights are bit-for-bit identical to the -default safetensors loader, so model accuracy is unaffected. - -## Installation - -InstantTensor ships as a dependency of TokenSpeed, so a normal install already -includes it — no extra step is needed. It is a CUDA-only package and is -imported lazily, so it is only loaded when you actually select -`--load-format instanttensor` on an NVIDIA GPU. - -## Usage - -Pass `--load-format instanttensor`. It works with any parallelism -configuration; when the job spans multiple ranks, the world process group is -handed to InstantTensor so reads are sharded across ranks. - -```bash -tokenspeed serve Qwen/Qwen3-30B-A3B --load-format instanttensor -``` - -```bash -tokenspeed serve deepseek-ai/DeepSeek-R1 \ - --load-format instanttensor \ - --tensor-parallel-size 8 \ - --enable-expert-parallel -``` - -## Memory considerations - -InstantTensor reads each checkpoint tensor **directly onto the GPU**, whereas -the default safetensors loader stages the full tensor in host (CPU) memory and -copies only the current rank's shard to the GPU. InstantTensor's own overhead is -small: it uses a GPU staging buffer (dynamically sized, configurable) that is -released before the KV cache is sized, plus a little fixed runtime overhead, so -its post-load GPU footprint is close to the default loader. - -Because tensors land on the GPU, a model's `load_weights` must **consume the -weight iterator lazily** — copying each tensor into its (pre-allocated) -parameter and then releasing it. A `load_weights` that instead collects the -whole iterator into a list keeps every loaded tensor resident on the GPU at once -and will OOM during loading on large models. This stays hidden with the -CPU-staging loaders, where the buffered tensors live in plentiful host RAM. -TokenSpeed's model loaders stream the iterator for this reason. - -Tuning: - -- `INSTANTTENSOR_BUFFER_SIZE` / `INSTANTTENSOR_MAX_FREE_MEM_USAGE` bound - InstantTensor's GPU I/O staging buffer, trading a little throughput for lower - peak memory. -- `--gpu-memory-utilization` only sizes the KV cache *after* weights are loaded; - it does not change peak memory during loading. - -## Notes - -- InstantTensor requires NVIDIA GPUs. Requesting it on a non-NVIDIA platform - raises an error. -- Only `*.safetensors` checkpoints are supported (same shard selection as - `--load-format safetensors`). - -For benchmarks and implementation details, see the -[InstantTensor repository](https://github.com/scitix/InstantTensor). diff --git a/python/pyproject.toml b/python/pyproject.toml index 2a756f363..76b830246 100755 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -41,7 +41,6 @@ dependencies = [ "fastapi", "hf_transfer", "huggingface_hub", - "instanttensor>=0.1.9", "modelscope", "msgspec", "ninja", diff --git a/python/tokenspeed/runtime/configs/load_config.py b/python/tokenspeed/runtime/configs/load_config.py index 79bf5bd56..ab6447634 100755 --- a/python/tokenspeed/runtime/configs/load_config.py +++ b/python/tokenspeed/runtime/configs/load_config.py @@ -33,7 +33,6 @@ class LoadFormat(str, enum.Enum): AUTO = "auto" PT = "pt" SAFETENSORS = "safetensors" - INSTANTTENSOR = "instanttensor" NPCACHE = "npcache" DUMMY = "dummy" SHARDED_STATE = "sharded_state" @@ -52,10 +51,6 @@ class LoadConfig: not available. "pt" will load the weights in the pytorch bin format. "safetensors" will load the weights in the safetensors format. - "instanttensor" will load the safetensors weights on NVIDIA GPUs - using InstantTensor, which accelerates loading via distributed - loading, pipelined prefetching, and direct I/O (with optional - GPUDirect Storage support). "npcache" will load the weights in pytorch format and store a numpy cache to speed up the loading. "dummy" will initialize the weights with random values, which is diff --git a/python/tokenspeed/runtime/model_loader/loader.py b/python/tokenspeed/runtime/model_loader/loader.py index 363b8e206..dda83d8f8 100755 --- a/python/tokenspeed/runtime/model_loader/loader.py +++ b/python/tokenspeed/runtime/model_loader/loader.py @@ -57,7 +57,6 @@ filter_files_not_needed_for_inference, get_quant_config, initialize_dummy_weights, - instanttensor_weights_iterator, np_cache_weights_iterator, pt_weights_iterator, safetensors_weights_iterator, @@ -254,10 +253,7 @@ def _prepare_weights( # Some quantized models use .pt files for storing the weights. if load_format == LoadFormat.AUTO: allow_patterns = ["*.safetensors", "*.bin"] - elif ( - load_format == LoadFormat.SAFETENSORS - or load_format == LoadFormat.INSTANTTENSOR - ): + elif load_format == LoadFormat.SAFETENSORS: use_safetensors = True allow_patterns = ["*.safetensors"] elif load_format == LoadFormat.MISTRAL: @@ -335,8 +331,6 @@ def _get_weights_iterator( hf_folder, hf_weights_files, ) - elif self.load_config.load_format == LoadFormat.INSTANTTENSOR: - weights_iterator = instanttensor_weights_iterator(hf_weights_files) elif use_safetensors: weights_iterator = safetensors_weights_iterator( hf_weights_files, diff --git a/python/tokenspeed/runtime/model_loader/weight_utils.py b/python/tokenspeed/runtime/model_loader/weight_utils.py index c0b3620e0..eff6bcecc 100755 --- a/python/tokenspeed/runtime/model_loader/weight_utils.py +++ b/python/tokenspeed/runtime/model_loader/weight_utils.py @@ -42,7 +42,6 @@ import torch from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator -from tokenspeed_kernel.platform import current_platform from tqdm.auto import tqdm from tokenspeed.runtime.configs.load_config import LoadConfig @@ -481,62 +480,6 @@ def safetensors_weights_iterator( yield from result.items() -def instanttensor_weights_iterator( - hf_weights_files: list[str], -) -> Generator[tuple[str, torch.Tensor], None, None]: - """Iterate over the weights in the model safetensor files using the - InstantTensor library. - - InstantTensor accelerates loading safetensors weights on NVIDIA GPUs - through distributed loading, pipelined prefetching, and direct I/O. When - the job spans multiple ranks, the world process group is passed to - InstantTensor so reads are sharded across ranks. - - Args: - hf_weights_files: Local paths to the ``*.safetensors`` shards to load. - - Yields: - ``(name, tensor)`` pairs for every tensor in the checkpoint, with the - tensors materialized on the current CUDA device. - """ - try: - import instanttensor - except ImportError as e: - raise ImportError( - "Please install instanttensor via `pip install instanttensor`" - ) from e - - if not current_platform().is_nvidia: - raise ValueError("InstantTensor requires NVIDIA GPUs") - - process_group = None - if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1: - # The default (world) group spans every rank in the job, matching the - # semantics InstantTensor expects for distributed loading. - process_group = torch.distributed.group.WORLD - - device = torch.cuda.current_device() - - enable_tqdm = ( - not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 - ) - - with instanttensor.safe_open( - hf_weights_files, framework="pt", device=device, process_group=process_group - ) as f: - # Since InstantTensor 0.1.9, tensors are cloned internally by default, - # so no extra clone is needed here. - yield from tqdm( - f.tensors(), - desc="Loading safetensors using InstantTensor loader", - disable=not enable_tqdm, - bar_format=_BAR_FORMAT, - position=tqdm._get_free_pos(), - total=len(f.keys()), - mininterval=1.0, - ) - - def pt_weights_iterator( hf_weights_files: list[str], ) -> Generator[tuple[str, torch.Tensor], None, None]: diff --git a/python/tokenspeed/runtime/models/gpt_oss.py b/python/tokenspeed/runtime/models/gpt_oss.py index 5e0810b44..280ce6f14 100644 --- a/python/tokenspeed/runtime/models/gpt_oss.py +++ b/python/tokenspeed/runtime/models/gpt_oss.py @@ -716,27 +716,17 @@ def _load_normal_weights( } def _load_mxfp4_weights(self, weights, weight_name_mapping: dict): - # The MoE expert tensors dominate the checkpoint. Stream them straight - # into their (pre-allocated) parameter slots as they arrive instead of - # buffering the whole iterator into a list. Buffering is invisible to - # CPU-staging loaders, but a GPU-direct loader (e.g. - # ``--load-format instanttensor``) yields tensors already on the GPU, - # so collecting every expert tensor would keep the entire checkpoint - # resident on the device at once and OOM mid-load. The remaining - # non-expert weights (attention, embeddings, norms, router) are small - # and are collected for the generic ``_load_normal_weights`` pass. + + mxfp4_weights = [] normal_weights = [] - def expert_weights(): - for name, weight in weights: - if ".experts" in name: - yield name, weight - else: - normal_weights.append((name, weight)) + for name, weight in weights: + if ".experts" in name: + mxfp4_weights.append((name, weight)) + else: + normal_weights.append((name, weight)) - # ``_load_mxfp4_experts_weights`` drains this generator fully, so by the - # time it returns ``normal_weights`` holds every non-expert tensor. - mxfp4_loaded_params = self._load_mxfp4_experts_weights(expert_weights()) + mxfp4_loaded_params = self._load_mxfp4_experts_weights(mxfp4_weights) self._load_normal_weights( normal_weights, weight_name_mapping=weight_name_mapping, @@ -782,33 +772,27 @@ def _copy_into_param(param, narrow_weight): ) param.data[slices].copy_(narrow_weight[slices]) - # The two MXFP4 expert checkpoint layouts are mutually exclusive and - # are detected from the first expert tensor (a checkpoint is uniformly - # one layout), reproducing the original whole-iterator ``any(...)`` - # probe for AMD-Quark per-expert checkpoints (e.g. - # ``amd/gpt-oss-120b-w-mxfp4-a-fp8``: one tensor set per expert plus a - # scalar ``input_scale`` for static FP8 activation quantization) - # without buffering the iterator. Each expert tensor is streamed - # straight into its slot as it arrives. - per_expert_re = re.compile(r"\.experts\.\d+\.(gate_up_proj|down_proj)\.") - per_expert_format = None + # Detect AMD-Quark per-expert checkpoints (e.g. + # ``amd/gpt-oss-120b-w-mxfp4-a-fp8``). These store one set of tensors + # per expert (``...experts.{e}.gate_up_proj.{weight,...}``) plus a + # scalar ``input_scale`` for static FP8 activation quantization. + if any( + re.search(r"\.experts\.\d+\.(gate_up_proj|down_proj)\.", n) + for n, _ in weights + ): + return self._load_mxfp4_per_expert_weights( + weights, + params_dict=params_dict, + moe_tp_rank_start=moe_tp_rank_start, + moe_tp_rank_end=moe_tp_rank_end, + moe_ep_rank_start=moe_ep_rank_start, + moe_ep_rank_end=moe_ep_rank_end, + moe_tp_rank=moe_tp_rank, + copy_into_param=_copy_into_param, + mxfp4_block=mxfp4_block, + ) for name, weight in weights: - if per_expert_format is None: - per_expert_format = per_expert_re.search(name) is not None - if per_expert_format: - loaded_params |= self._load_mxfp4_per_expert_weights( - [(name, weight)], - params_dict=params_dict, - moe_tp_rank_start=moe_tp_rank_start, - moe_tp_rank_end=moe_tp_rank_end, - moe_ep_rank_start=moe_ep_rank_start, - moe_ep_rank_end=moe_ep_rank_end, - moe_tp_rank=moe_tp_rank, - copy_into_param=_copy_into_param, - mxfp4_block=mxfp4_block, - ) - continue weight = _WeightCreator.maybe_materialize(weight) if "gate_up_proj_blocks" in name: diff --git a/python/tokenspeed/runtime/models/kimi_k25.py b/python/tokenspeed/runtime/models/kimi_k25.py index 18bf29126..330d07194 100644 --- a/python/tokenspeed/runtime/models/kimi_k25.py +++ b/python/tokenspeed/runtime/models/kimi_k25.py @@ -900,60 +900,44 @@ def forward( ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - """Load weights, streaming language weights to the language model. - - The language weights are forwarded to ``language_model.load_weights`` - lazily (as a generator) instead of being collected into a list first. - Materializing the whole iterator would keep every loaded tensor alive - at once; that is harmless for CPU-staged loaders but OOMs GPU-direct - loaders (e.g. ``--load-format instanttensor``), which would then hold - the entire model on the device during loading. Vision weights are - small and are still collected, then loaded after the language model. - """ - vision_weights: list[Tuple[str, torch.Tensor]] = [] - encoder_only = getattr(self.config, "encoder_only", False) - load_vision = self.is_multimodal_active and not getattr( - self.config, "language_only", False - ) + """Load weights for the model, separating vision and language weights""" + vision_weights = [] + language_weights = [] + + for name, loaded_weight in weights: + # nvidia/Kimi-K2.5-NVFP4 stores decoder layers under + # language_model.layers.*, while TokenSpeed's DeepSeek module + # expects model.layers.* after stripping language_model. + if name.startswith("language_model.layers."): + name = name.replace( + "language_model.layers.", "language_model.model.layers.", 1 + ) - def language_weights() -> Iterable[Tuple[str, torch.Tensor]]: - for name, loaded_weight in weights: - # nvidia/Kimi-K2.5-NVFP4 stores decoder layers under - # language_model.layers.*, while TokenSpeed's DeepSeek module - # expects model.layers.* after stripping language_model. - if name.startswith("language_model.layers."): - name = name.replace( - "language_model.layers.", "language_model.model.layers.", 1 - ) - - if "vision_tower" in name or "mm_projector" in name: - name = name.replace(r"wqkv.", r"attn.qkv_proj.") - name = name.replace(r"wo.", r"attn.proj.") - name = name.replace("mm_projector.proj.0", "mm_projector.linear_1") - name = name.replace("mm_projector.proj.2", "mm_projector.linear_2") - if load_vision: - vision_weights.append((name, loaded_weight)) - else: - yield name.replace("language_model.", ""), loaded_weight - - if not encoder_only: - # Consumes the iterator lazily; fills vision_weights as a side - # effect for the multimodal branch below. - self.language_model.load_weights(language_weights()) - elif load_vision: - # Encoder-only: still drain the iterator to collect vision weights. - for _ in language_weights(): - pass - - if load_vision: + if "vision_tower" in name or "mm_projector" in name: + name = name.replace(r"wqkv.", r"attn.qkv_proj.") + name = name.replace(r"wo.", r"attn.proj.") + name = name.replace("mm_projector.proj.0", "mm_projector.linear_1") + name = name.replace("mm_projector.proj.2", "mm_projector.linear_2") + vision_weights.append((name, loaded_weight)) + else: + name = name.replace("language_model.", "") + language_weights.append((name, loaded_weight)) + + if self.is_multimodal_active and not getattr( + self.config, "language_only", False + ): + vision_state_dict = dict(vision_weights) params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in vision_weights: + for name, loaded_weight in vision_state_dict.items(): if name not in params_dict: raise ValueError(f"Weight {name} not found in params_dict") param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + if not getattr(self.config, "encoder_only", False) and language_weights: + self.language_model.load_weights(language_weights) + @classmethod def get_model_config_for_expert_location(cls, config: KimiK25Config): text_config = config.text_config diff --git a/python/tokenspeed/runtime/utils/server_args.py b/python/tokenspeed/runtime/utils/server_args.py index f15c13e0c..5ac05fcab 100755 --- a/python/tokenspeed/runtime/utils/server_args.py +++ b/python/tokenspeed/runtime/utils/server_args.py @@ -717,7 +717,6 @@ def add_cli_args(parser: argparse.ArgumentParser): "auto", "pt", "safetensors", - "instanttensor", "npcache", "dummy", "extensible", @@ -728,9 +727,6 @@ def add_cli_args(parser: argparse.ArgumentParser): "is not available. " '"pt" will load the weights in the pytorch bin format. ' '"safetensors" will load the weights in the safetensors format. ' - '"instanttensor" accelerates safetensors loading on NVIDIA GPUs ' - "via distributed loading, pipelined prefetching, and direct I/O " - "(with optional GPUDirect Storage support). " '"npcache" will load the weights in pytorch format and store ' "a numpy cache to speed up the loading. " '"dummy" will initialize the weights with random values.', diff --git a/test/ci/eval/deepseek-v4-flash-evalscope-gsm8k.yaml b/test/ci/eval/deepseek-v4-flash-evalscope-gsm8k.yaml index b01cca44a..6d116405d 100644 --- a/test/ci/eval/deepseek-v4-flash-evalscope-gsm8k.yaml +++ b/test/ci/eval/deepseek-v4-flash-evalscope-gsm8k.yaml @@ -14,7 +14,6 @@ install: server: command: >- ts serve - --load-format instanttensor --model deepseek-ai/DeepSeek-V4-Flash --tensor-parallel-size 4 --enable-expert-parallel diff --git a/test/ci/eval/deepseek-v4-flash-mtp-evalscope-gsm8k.yaml b/test/ci/eval/deepseek-v4-flash-mtp-evalscope-gsm8k.yaml index 6e5ec9ca7..f5e420353 100644 --- a/test/ci/eval/deepseek-v4-flash-mtp-evalscope-gsm8k.yaml +++ b/test/ci/eval/deepseek-v4-flash-mtp-evalscope-gsm8k.yaml @@ -14,7 +14,6 @@ install: server: command: >- ts serve - --load-format instanttensor --model deepseek-ai/DeepSeek-V4-Flash --tensor-parallel-size 4 --enable-expert-parallel diff --git a/test/ci/eval/gpt-oss-120b-mxfp4-evalscope-gpqa-diamond.yaml b/test/ci/eval/gpt-oss-120b-mxfp4-evalscope-gpqa-diamond.yaml index a6892391b..280a63a09 100644 --- a/test/ci/eval/gpt-oss-120b-mxfp4-evalscope-gpqa-diamond.yaml +++ b/test/ci/eval/gpt-oss-120b-mxfp4-evalscope-gpqa-diamond.yaml @@ -13,9 +13,6 @@ runner: GPT_OSS_EVAL_ATTENTION_BACKEND: trtllm GPT_OSS_EVAL_MOE_BACKEND: flashinfer_trtllm GPT_OSS_EVAL_MODEL: openai/gpt-oss-120b - # InstantTensor is NVIDIA-only and the b200 pool has the memlock raised; - # AMD leaves this unset and uses the default loader. - GPT_OSS_EVAL_LOAD_FORMAT: instanttensor amd-mi35x-2gpu-test: GPT_OSS_EVAL_DISABLE_KVSTORE: "1" GPT_OSS_EVAL_DISABLE_PREFIX_CACHING: "1" @@ -28,7 +25,6 @@ server: command: >- ts serve --model ${GPT_OSS_EVAL_MODEL} - ${GPT_OSS_EVAL_LOAD_FORMAT:+--load-format ${GPT_OSS_EVAL_LOAD_FORMAT}} --attn-tp-size 2 --moe-tp-size 2 --max-model-len 80000 diff --git a/test/ci/eval/kimi-k2.5-nvfp4-evalscope-aime25.yaml b/test/ci/eval/kimi-k2.5-nvfp4-evalscope-aime25.yaml index 792d3ccaf..0375764b4 100644 --- a/test/ci/eval/kimi-k2.5-nvfp4-evalscope-aime25.yaml +++ b/test/ci/eval/kimi-k2.5-nvfp4-evalscope-aime25.yaml @@ -8,12 +8,6 @@ runner: labels: - b200-4gpu - b300-4gpu - # InstantTensor's io_uring loader needs a raised RLIMIT_MEMLOCK on the node. - # The b200 pool is configured for it; the b300 pool is not yet, so b300 falls - # back to the default loader (no KIMI_LOAD_FORMAT -> flag omitted below). - env: - b200-4gpu: - KIMI_LOAD_FORMAT: instanttensor env: CI: "true" install: @@ -21,7 +15,6 @@ install: server: command: >- ts serve - ${KIMI_LOAD_FORMAT:+--load-format ${KIMI_LOAD_FORMAT}} --model nvidia/Kimi-K2.5-NVFP4 --tp 4 --max-model-len 80000 diff --git a/test/ci/eval/kimi-k2.5-nvfp4-evalscope-gpqa-diamond.yaml b/test/ci/eval/kimi-k2.5-nvfp4-evalscope-gpqa-diamond.yaml index caf1fba1e..a1b3cd5fe 100644 --- a/test/ci/eval/kimi-k2.5-nvfp4-evalscope-gpqa-diamond.yaml +++ b/test/ci/eval/kimi-k2.5-nvfp4-evalscope-gpqa-diamond.yaml @@ -14,7 +14,6 @@ install: server: command: >- ts serve - --load-format instanttensor --model nvidia/Kimi-K2.5-NVFP4 --tp 4 --max-model-len 80000 diff --git a/test/ci/eval/kimi-k2.5-nvfp4-evalscope-gsm8k.yaml b/test/ci/eval/kimi-k2.5-nvfp4-evalscope-gsm8k.yaml index fdeab8114..64693b3fb 100644 --- a/test/ci/eval/kimi-k2.5-nvfp4-evalscope-gsm8k.yaml +++ b/test/ci/eval/kimi-k2.5-nvfp4-evalscope-gsm8k.yaml @@ -14,7 +14,6 @@ install: server: command: >- ts serve - --load-format instanttensor --model nvidia/Kimi-K2.5-NVFP4 --tp 4 --max-model-len 80000 diff --git a/test/ci/eval/kimi-k2.5-nvfp4-evalscope-mmlu.yaml b/test/ci/eval/kimi-k2.5-nvfp4-evalscope-mmlu.yaml index e24d27644..3227a470b 100644 --- a/test/ci/eval/kimi-k2.5-nvfp4-evalscope-mmlu.yaml +++ b/test/ci/eval/kimi-k2.5-nvfp4-evalscope-mmlu.yaml @@ -14,7 +14,6 @@ install: server: command: >- ts serve - --load-format instanttensor --model nvidia/Kimi-K2.5-NVFP4 --tp 4 --max-model-len 80000 diff --git a/test/ci/eval/kimi-k2.5-nvfp4-evalscope-ocr-bench.yaml b/test/ci/eval/kimi-k2.5-nvfp4-evalscope-ocr-bench.yaml index 78e5eb893..8f5007398 100644 --- a/test/ci/eval/kimi-k2.5-nvfp4-evalscope-ocr-bench.yaml +++ b/test/ci/eval/kimi-k2.5-nvfp4-evalscope-ocr-bench.yaml @@ -14,7 +14,6 @@ install: server: command: >- ts serve - --load-format instanttensor --model nvidia/Kimi-K2.5-NVFP4 --tp 4 --max-model-len 80000 diff --git a/test/ci/eval/minimax-m2.7-nvfp4-evalscope-aime25.yaml b/test/ci/eval/minimax-m2.7-nvfp4-evalscope-aime25.yaml index 33e00dedb..e93662b59 100644 --- a/test/ci/eval/minimax-m2.7-nvfp4-evalscope-aime25.yaml +++ b/test/ci/eval/minimax-m2.7-nvfp4-evalscope-aime25.yaml @@ -14,7 +14,6 @@ install: server: command: >- ts serve - --load-format instanttensor --model nvidia/MiniMax-M2.7-NVFP4 --attn-tp-size 2 --moe-tp-size 2 diff --git a/test/ci/eval/qwen3.5-397b-a17b-nvfp4-evalscope-aime25.yaml b/test/ci/eval/qwen3.5-397b-a17b-nvfp4-evalscope-aime25.yaml index 5154cffcf..02b3524e9 100644 --- a/test/ci/eval/qwen3.5-397b-a17b-nvfp4-evalscope-aime25.yaml +++ b/test/ci/eval/qwen3.5-397b-a17b-nvfp4-evalscope-aime25.yaml @@ -8,12 +8,6 @@ runner: labels: - b200-4gpu - b300-4gpu - # InstantTensor's io_uring loader needs a raised RLIMIT_MEMLOCK on the node. - # The b200 pool is configured for it; the b300 pool is not yet, so b300 falls - # back to the default loader (no QWEN_LOAD_FORMAT -> flag omitted below). - env: - b200-4gpu: - QWEN_LOAD_FORMAT: instanttensor env: CI: "true" install: @@ -21,7 +15,6 @@ install: server: command: >- ts serve - ${QWEN_LOAD_FORMAT:+--load-format ${QWEN_LOAD_FORMAT}} --model nvidia/Qwen3.5-397B-A17B-NVFP4 --tp 4 --max-model-len 80000 diff --git a/test/ci/perf/kimi-k2.5-nvfp4-evalscope-agentic.yaml b/test/ci/perf/kimi-k2.5-nvfp4-evalscope-agentic.yaml index 8276c266d..9931923fe 100644 --- a/test/ci/perf/kimi-k2.5-nvfp4-evalscope-agentic.yaml +++ b/test/ci/perf/kimi-k2.5-nvfp4-evalscope-agentic.yaml @@ -15,8 +15,6 @@ env: install: - bash test/ci_system/install_deps.sh server: - # b300-only job: the b300 pool does not yet have the raised RLIMIT_MEMLOCK - # that InstantTensor's io_uring loader requires, so use the default loader. command: >- ts serve --model nvidia/Kimi-K2.5-NVFP4 diff --git a/test/ci/perf/qwen3.5-397b-a17b-nvfp4-evalscope-agentic-b200-8gpu.yaml b/test/ci/perf/qwen3.5-397b-a17b-nvfp4-evalscope-agentic-b200-8gpu.yaml index f4e52f5e5..0300d66d7 100644 --- a/test/ci/perf/qwen3.5-397b-a17b-nvfp4-evalscope-agentic-b200-8gpu.yaml +++ b/test/ci/perf/qwen3.5-397b-a17b-nvfp4-evalscope-agentic-b200-8gpu.yaml @@ -14,7 +14,6 @@ install: server: command: >- ts serve - --load-format instanttensor --model nvidia/Qwen3.5-397B-A17B-NVFP4 --attn-tp-size 8 --moe-tp-size 8 diff --git a/test/ci/perf/qwen3.5-397b-a17b-nvfp4-evalscope-longctx.yaml b/test/ci/perf/qwen3.5-397b-a17b-nvfp4-evalscope-longctx.yaml index de9b3d7d5..bf66d8cb1 100644 --- a/test/ci/perf/qwen3.5-397b-a17b-nvfp4-evalscope-longctx.yaml +++ b/test/ci/perf/qwen3.5-397b-a17b-nvfp4-evalscope-longctx.yaml @@ -13,7 +13,6 @@ install: server: command: >- ts serve - --load-format instanttensor --model nvidia/Qwen3.5-397B-A17B-NVFP4 --attn-tp-size 8 --moe-tp-size 8 diff --git a/test/runtime/test_gpt_oss_mxfp4_streaming.py b/test/runtime/test_gpt_oss_mxfp4_streaming.py deleted file mode 100644 index 5349d9844..000000000 --- a/test/runtime/test_gpt_oss_mxfp4_streaming.py +++ /dev/null @@ -1,108 +0,0 @@ -"""Regression test: GPT-OSS MXFP4 weight loading streams the iterator. - -A GPU-direct loader (``--load-format instanttensor``) yields each checkpoint -tensor already resident on the GPU. ``_load_mxfp4_weights`` must therefore -consume the weight iterator lazily and copy each (large) MoE expert tensor -straight into its slot, rather than buffering every expert tensor into a list -first -- the latter keeps the whole checkpoint on the device at once and OOMs -mid-load. This test pins that behavior without needing a GPU or real weights. -""" - -import os -import sys -import types -import unittest - -# CI Registration (parsed via AST, runtime no-op) -sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from ci_system.ci_register import register_cuda_ci - -register_cuda_ci(est_time=10, suite="runtime-1gpu") - -from tokenspeed.runtime.models.gpt_oss import GptOssForCausalLM - - -class TestGptOssMxfp4Streaming(unittest.TestCase): - def test_load_mxfp4_weights_streams_experts(self): - # Expert tensors interleaved with the small non-expert weights, in - # checkpoint order. The "weight" stand-in is just the name string, - # because the stubbed loaders below never touch tensor data. - items = [ - "model.embed_tokens.weight", - "model.layers.0.mlp.experts.gate_up_proj_blocks", - "model.layers.0.input_layernorm.weight", - "model.layers.0.mlp.experts.down_proj_blocks", - "lm_head.weight", - ] - - pulled = [] - - def source(): - for name in items: - pulled.append(name) - yield name, name - - seen_experts = [] - received = {} - - def fake_load_experts(weights): - # The dispatcher must hand us a lazy generator, not a materialized - # list of every expert tensor. - received["is_generator"] = isinstance(weights, types.GeneratorType) - iterator = iter(weights) - first_expert = next(iterator) - seen_experts.append(first_expert[0]) - # Reaching the first expert (item #2) must not have drained the - # whole source iterator (5 items) -- proof that loading is - # interleaved with iteration, i.e. streamed. - received["pulled_after_first_expert"] = len(pulled) - for name, _ in iterator: - seen_experts.append(name) - return {"loaded_expert_param"} - - normal_seen = {} - - def fake_load_normal( - normal_weights, *, weight_name_mapping, other_loaded_param_names - ): - normal_seen["names"] = [name for name, _ in normal_weights] - normal_seen["other"] = other_loaded_param_names - - fake_self = types.SimpleNamespace( - _load_mxfp4_experts_weights=fake_load_experts, - _load_normal_weights=fake_load_normal, - ) - - GptOssForCausalLM._load_mxfp4_weights( - fake_self, source(), weight_name_mapping={} - ) - - # Streamed, not buffered. - self.assertTrue(received["is_generator"]) - self.assertEqual(received["pulled_after_first_expert"], 2) - - # Expert tensors (matched by the ".experts" marker) are routed to the - # expert loader, in order. - self.assertEqual( - seen_experts, - [ - "model.layers.0.mlp.experts.gate_up_proj_blocks", - "model.layers.0.mlp.experts.down_proj_blocks", - ], - ) - - # Everything else is collected for the generic loader, and the set of - # already-loaded expert params is threaded through to it. - self.assertEqual( - normal_seen["names"], - [ - "model.embed_tokens.weight", - "model.layers.0.input_layernorm.weight", - "lm_head.weight", - ], - ) - self.assertEqual(normal_seen["other"], {"loaded_expert_param"}) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/runtime/test_instanttensor_loader.py b/test/runtime/test_instanttensor_loader.py deleted file mode 100644 index a29c8b863..000000000 --- a/test/runtime/test_instanttensor_loader.py +++ /dev/null @@ -1,95 +0,0 @@ -import argparse -import glob -import os -import sys -import tempfile -import unittest -from importlib.util import find_spec - -import torch - -# CI Registration (parsed via AST, runtime no-op) -sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from ci_system.ci_register import register_cuda_ci - -register_cuda_ci(est_time=30, suite="runtime-1gpu") - -from tokenspeed_kernel.platform import current_platform - -from tokenspeed.runtime.configs.load_config import LoadConfig, LoadFormat -from tokenspeed.runtime.model_loader.loader import DefaultModelLoader -from tokenspeed.runtime.model_loader.weight_utils import ( - download_weights_from_hf, - instanttensor_weights_iterator, - safetensors_weights_iterator, -) -from tokenspeed.runtime.utils.server_args import ServerArgs - -INSTANTTENSOR_AVAILABLE = find_spec("instanttensor") is not None -# InstantTensor is NVIDIA-only. torch.cuda.is_available() is also True on ROCm, -# so guard on the platform vendor to keep the parity test off AMD runners. -IS_NVIDIA = current_platform().is_nvidia - - -class TestInstantTensorConfig(unittest.TestCase): - """Config/CLI wiring that needs neither a GPU nor instanttensor.""" - - def test_cli_flag_maps_to_load_format(self): - parser = argparse.ArgumentParser() - ServerArgs.add_cli_args(parser) - args = parser.parse_args( - ["--model", "test/model", "--load-format", "instanttensor"] - ) - self.assertEqual(args.load_format, "instanttensor") - - def test_load_config_normalizes_to_enum(self): - load_config = LoadConfig(load_format="instanttensor") - self.assertEqual(load_config.load_format, LoadFormat.INSTANTTENSOR) - - def test_prepare_weights_treats_instanttensor_as_safetensors(self): - with tempfile.TemporaryDirectory() as tmpdir: - # _prepare_weights only globs/filters file paths; it never reads the - # tensor data, so an empty placeholder shard is sufficient here. - open(os.path.join(tmpdir, "model.safetensors"), "wb").close() - - loader = DefaultModelLoader(LoadConfig(load_format="instanttensor")) - _, hf_weights_files, use_safetensors = loader._prepare_weights( - tmpdir, revision=None, fall_back_to_pt=False - ) - - self.assertTrue(use_safetensors) - self.assertEqual(len(hf_weights_files), 1) - self.assertTrue(hf_weights_files[0].endswith("model.safetensors")) - - -@unittest.skipIf(not IS_NVIDIA, "InstantTensor requires NVIDIA GPUs") -@unittest.skipIf(not INSTANTTENSOR_AVAILABLE, "instanttensor is not installed") -class TestInstantTensorWeights(unittest.TestCase): - """Iterator parity test (requires an NVIDIA GPU and instanttensor).""" - - def test_instanttensor_matches_safetensors(self): - model = "openai-community/gpt2" - with tempfile.TemporaryDirectory() as tmpdir: - download_weights_from_hf( - model, cache_dir=tmpdir, allow_patterns=["*.safetensors"] - ) - safetensors_files = glob.glob(f"{tmpdir}/**/*.safetensors", recursive=True) - self.assertGreater(len(safetensors_files), 0) - - instanttensor_tensors = {} - for name, tensor in instanttensor_weights_iterator(safetensors_files): - # Copy immediately in case InstantTensor exposes internal buffers. - instanttensor_tensors[name] = tensor.to("cpu") - - reference_tensors = dict(safetensors_weights_iterator(safetensors_files)) - - self.assertEqual(len(instanttensor_tensors), len(reference_tensors)) - for name, got in instanttensor_tensors.items(): - ref = reference_tensors[name] - self.assertEqual(got.dtype, ref.dtype) - self.assertEqual(got.shape, ref.shape) - self.assertTrue(torch.equal(got, ref)) - - -if __name__ == "__main__": - unittest.main()