Skip to content

Commit 7496c6a

Browse files
authored
feat: load config support hybrid parallel (#777)
* load config support hybrid parallel * load config support hybrid parallel * load config support hybrid parallel * load config support hybrid parallel * load config support hybrid parallel * load config support hybrid parallel * load config support hybrid parallel
1 parent 33fc439 commit 7496c6a

6 files changed

Lines changed: 106 additions & 25 deletions

File tree

README.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@
3434

3535
## 🔥Latest News
3636

37-
- [2026/01] **[🎉v1.2.0 Major Release](https://github.com/vipshop/cache-dit)** is ready: New Models Support(Z-Image, FLUX.2, LTX-2, etc), Request level Cache Context, HTTP Serving, [Ulysses Anything](https://cache-dit.readthedocs.io/en/latest/user_guide/CONTEXT_PARALLEL/#uaa-ulysses-anything-attention), TE-P, VAE-P, CN-P and [Ascend NPUs](https://cache-dit.readthedocs.io/en/latest/user_guide/ASCEND_NPU/) Support.
37+
- [2026/02] **[🎉v1.2.1](https://github.com/vipshop/cache-dit)** release is ready, the major updates including: [Ring](https://cache-dit.readthedocs.io/en/latest/user_guide/CONTEXT_PARALLEL) Attention w/ [batched P2P](https://cache-dit.readthedocs.io/en/latest/user_guide/CONTEXT_PARALLEL), [USP](https://cache-dit.readthedocs.io/en/latest/user_guide/CONTEXT_PARALLEL/) (Hybrid Ring and Ulysses), Hybrid 2D and 3D Parallelism (💥[USP + TP](https://cache-dit.readthedocs.io/en/latest/user_guide/HYBRID_PARALLEL/)), VAE-P Comm overhead reduce.
38+
- [2026/01] **[🎉v1.2.0](https://github.com/vipshop/cache-dit)** stable release is ready: New Models Support(Z-Image, FLUX.2, LTX-2, etc), Request level Cache Context, HTTP Serving, [Ulysses Anything](https://cache-dit.readthedocs.io/en/latest/user_guide/CONTEXT_PARALLEL/#uaa-ulysses-anything-attention), TE-P, VAE-P, CN-P and [Ascend NPUs](https://cache-dit.readthedocs.io/en/latest/user_guide/ASCEND_NPU/) support.
39+
3840

3941
## 🚀Quick Start
4042

@@ -55,14 +57,13 @@ Then accelerate your DiTs with just **♥️one line♥️** of code ~
5557
>>> cache_dit.enable_cache(
5658
... pipe, cache_config=DBCacheConfig(), # w/ default
5759
... parallelism_config=ParallelismConfig(ulysses_size=2))
60+
>>> # Or, Use Distributed Inference without Cache Acceleration.
61+
>>> cache_dit.enable_cache(
62+
... pipe, parallelism_config=ParallelismConfig(ulysses_size=2))
5863
>>> # Or, Hybrid Cache Acceleration + 2D Parallelism.
5964
>>> cache_dit.enable_cache(
6065
... pipe, cache_config=DBCacheConfig(), # w/ default
6166
... parallelism_config=ParallelismConfig(ulysses_size=2, tp_size=2))
62-
>>> # Or, Use Distributed Inference without Cache Acceleration.
63-
>>> cache_dit.enable_cache(
64-
... pipe, cache_config=None, # Set cache_config as None.
65-
... parallelism_config=ParallelismConfig(ulysses_size=2))
6667
>>> from cache_dit import load_configs
6768
>>> # Or, Load Acceleration config from a custom yaml file.
6869
>>> cache_dit.enable_cache(pipe, **load_configs("config.yaml"))

docs/README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@
4040

4141
## 🔥Latest News
4242

43-
- [2026/01] **[🎉v1.2.0 Major Release](https://github.com/vipshop/cache-dit)** is ready: New Models Support(Z-Image, FLUX.2, LTX-2, etc), Request level Cache Context, HTTP Serving, [Ulysses Anything](https://cache-dit.readthedocs.io/en/latest/user_guide/CONTEXT_PARALLEL/#uaa-ulysses-anything-attention), TE-P, VAE-P, CN-P and [Ascend NPUs](https://cache-dit.readthedocs.io/en/latest/user_guide/ASCEND_NPU/) Support.
43+
- [2026/02] **[🎉v1.2.1](https://github.com/vipshop/cache-dit)** release is ready, the major updates including: [Ring](https://cache-dit.readthedocs.io/en/latest/user_guide/CONTEXT_PARALLEL) Attention w/ [batched P2P](https://cache-dit.readthedocs.io/en/latest/user_guide/CONTEXT_PARALLEL), [USP](https://cache-dit.readthedocs.io/en/latest/user_guide/CONTEXT_PARALLEL/) (Hybrid Ring and Ulysses), Hybrid 2D and 3D Parallelism (💥[USP + TP](https://cache-dit.readthedocs.io/en/latest/user_guide/HYBRID_PARALLEL/)), VAE-P Comm overhead reduce.
44+
- [2026/01] **[🎉v1.2.0](https://github.com/vipshop/cache-dit)** stable release is ready: New Models Support(Z-Image, FLUX.2, LTX-2, etc), Request level Cache Context, HTTP Serving, [Ulysses Anything](https://cache-dit.readthedocs.io/en/latest/user_guide/CONTEXT_PARALLEL/#uaa-ulysses-anything-attention), TE-P, VAE-P, CN-P and [Ascend NPUs](https://cache-dit.readthedocs.io/en/latest/user_guide/ASCEND_NPU/) Support.
4445

4546
## 🚀Quick Start
4647

@@ -61,14 +62,13 @@ Then accelerate your DiTs with just **♥️one line♥️** of code ~
6162
>>> cache_dit.enable_cache(
6263
... pipe, cache_config=DBCacheConfig(), # w/ default
6364
... parallelism_config=ParallelismConfig(ulysses_size=2))
65+
>>> # Or, Use Distributed Inference without Cache Acceleration.
66+
>>> cache_dit.enable_cache(
67+
... pipe, parallelism_config=ParallelismConfig(ulysses_size=2))
6468
>>> # Or, Hybrid Cache Acceleration + 2D Parallelism.
6569
>>> cache_dit.enable_cache(
6670
... pipe, cache_config=DBCacheConfig(), # w/ default
6771
... parallelism_config=ParallelismConfig(ulysses_size=2, tp_size=2))
68-
>>> # Or, Use Distributed Inference without Cache Acceleration.
69-
>>> cache_dit.enable_cache(
70-
... pipe, cache_config=None, # Set cache_config as None.
71-
... parallelism_config=ParallelismConfig(ulysses_size=2))
7272
>>> from cache_dit import load_configs
7373
>>> # Or, Load Acceleration config from a custom yaml file.
7474
>>> cache_dit.enable_cache(pipe, **load_configs("config.yaml"))

docs/user_guide/LOAD_CONFIGS.md

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ Then, apply the acceleration config from yaml.
2727

2828
## Distributed inference
2929

30+
- 1D Parallelism
31+
3032
Define a parallelism only config yaml `parallel.yaml` file that contains:
3133

3234
```yaml
@@ -42,6 +44,43 @@ Then, apply the distributed inference acceleration config from yaml. `ulysses_si
4244
>>> cache_dit.enable_cache(pipe, **cache_dit.load_configs("parallel.yaml"))
4345
```
4446

47+
- 2D Parallelism
48+
49+
You can also define a 2D parallelism config yaml `parallel_2d.yaml` file that contains:
50+
51+
```yaml
52+
parallelism_config:
53+
ulysses_size: auto
54+
tp_size: 2
55+
parallel_kwargs:
56+
attention_backend: native
57+
extra_parallel_modules: ["text_encoder", "vae"]
58+
```
59+
Then, apply the 2D parallelism config from yaml. Here `tp_size: 2` means using tensor parallelism with size 2. The `ulysses_size: auto` means that cache-dit will auto detect the `world_size // tp_size` as the ulysses_size.
60+
```python
61+
>>> import cache_dit
62+
>>> cache_dit.enable_cache(pipe, **cache_dit.load_configs("parallel_2d.yaml"))
63+
```
64+
65+
- 3D Parallelism
66+
67+
You can also define a 3D parallelism config yaml `parallel_3d.yaml` file that contains:
68+
69+
```yaml
70+
parallelism_config:
71+
ulysses_size: 2
72+
ring_size: 2
73+
tp_size: 2
74+
parallel_kwargs:
75+
attention_backend: native
76+
extra_parallel_modules: ["text_encoder", "vae"]
77+
```
78+
Then, apply the 3D parallelism config from yaml. Here `ulysses_size: 2`, `ring_size: 2`, `tp_size: 2` means using ulysses parallelism with size 2, ring parallelism with size 2 and tensor parallelism with size 2.
79+
```python
80+
>>> import cache_dit
81+
>>> cache_dit.enable_cache(pipe, **cache_dit.load_configs("parallel_3d.yaml"))
82+
```
83+
4584
## Hybrid Cache and Parallelism
4685

4786
Define a hybrid cache and parallel acceleration config yaml `hybrid.yaml` file that contains:
@@ -81,6 +120,8 @@ pip3 install git+https://github.com/huggingface/diffusers.git # latest or >= 0.3
81120
pip3 install git+https://github.com/vipshop/cache-dit.git # latest
82121
83122
python3 -m cache_dit.generate flux --config cache.yaml
84-
torchrun --nproc_per_node=4 -m cache_dit.generate flux --config parallel.yaml
85123
torchrun --nproc_per_node=4 -m cache_dit.generate flux --config hybrid.yaml
124+
torchrun --nproc_per_node=4 -m cache_dit.generate flux --config parallel.yaml
125+
torchrun --nproc_per_node=4 -m cache_dit.generate flux --config parallel_2d.yaml
126+
torchrun --nproc_per_node=8 -m cache_dit.generate flux --config parallel.yaml
86127
```

examples/configs/parallel_2d.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
parallelism_config:
2+
ulysses_size: auto
3+
tp_size: 2
4+
parallel_kwargs:
5+
attention_backend: native
6+
extra_parallel_modules: ["text_encoder", "vae"]

examples/configs/parallel_3d.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
parallelism_config:
2+
ulysses_size: 2
3+
ring_size: 2
4+
tp_size: 2
5+
parallel_kwargs:
6+
attention_backend: native
7+
extra_parallel_modules: ["text_encoder", "vae"]

src/cache_dit/caching/utils.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,9 @@ def load_parallelism_config(
225225
backend_str = parallelism_config_kwargs["backend"]
226226
parallelism_config_kwargs["backend"] = ParallelismBackend.from_str(backend_str)
227227

228-
def _maybe_auto_parallel_size(size: str | int | None) -> Optional[int]:
228+
def _maybe_auto_parallel_size(
229+
size: str | int | None, partial_max_size: Optional[int] = None
230+
) -> Optional[int]:
229231
if size is None:
230232
return None
231233
if isinstance(size, int):
@@ -236,7 +238,11 @@ def _maybe_auto_parallel_size(size: str | int | None) -> Optional[int]:
236238
size = 1
237239
if dist.is_initialized():
238240
# Assume world size is the parallel size
239-
size = dist.get_world_size()
241+
world_size = dist.get_world_size()
242+
if partial_max_size is not None:
243+
size = world_size // partial_max_size
244+
else:
245+
size = world_size
240246
if size == 1:
241247
logger.warning(
242248
"Auto parallel size selected as 1. Make sure to run with torch.distributed "
@@ -247,20 +253,40 @@ def _maybe_auto_parallel_size(size: str | int | None) -> Optional[int]:
247253
return size
248254
raise ValueError(f"Invalid parallel size value: {size}. Must be int or 'auto'.")
249255

250-
if kwargs.get("auto_parallel_size", True):
251-
if "ulysses_size" in parallelism_config_kwargs:
252-
parallelism_config_kwargs["ulysses_size"] = _maybe_auto_parallel_size(
253-
parallelism_config_kwargs["ulysses_size"]
254-
)
255-
if "ring_size" in parallelism_config_kwargs:
256-
parallelism_config_kwargs["ring_size"] = _maybe_auto_parallel_size(
257-
parallelism_config_kwargs["ring_size"]
258-
)
259-
if "tp_size" in parallelism_config_kwargs:
260-
parallelism_config_kwargs["tp_size"] = _maybe_auto_parallel_size(
261-
parallelism_config_kwargs["tp_size"]
256+
def _maybe_auto_parallel_sizes(parallelism_config_kwargs: dict) -> dict:
257+
# Only allow one of the parallel size to be auto for simplicity
258+
auto_count = sum(
259+
1
260+
for key in ["ulysses_size", "ring_size", "tp_size"]
261+
if key in parallelism_config_kwargs and parallelism_config_kwargs[key] == "auto"
262+
)
263+
if auto_count > 1:
264+
raise ValueError(
265+
"Only one of 'ulysses_size', 'ring_size', or 'tp_size' can be set to 'auto'."
262266
)
263267

268+
ulysses_size = parallelism_config_kwargs.get("ulysses_size", 1)
269+
ring_size = parallelism_config_kwargs.get("ring_size", 1)
270+
tp_size = parallelism_config_kwargs.get("tp_size", 1)
271+
partial_max_size = None
272+
if isinstance(ulysses_size, str) and ulysses_size.lower() == "auto":
273+
partial_max_size = ring_size * tp_size
274+
elif isinstance(ring_size, str) and ring_size.lower() == "auto":
275+
partial_max_size = ulysses_size * tp_size
276+
elif isinstance(tp_size, str) and tp_size.lower() == "auto":
277+
partial_max_size = ulysses_size * ring_size
278+
279+
for key in ["ulysses_size", "ring_size", "tp_size"]:
280+
if key in parallelism_config_kwargs:
281+
parallelism_config_kwargs[key] = _maybe_auto_parallel_size(
282+
parallelism_config_kwargs[key], partial_max_size=partial_max_size
283+
)
284+
return parallelism_config_kwargs
285+
286+
if kwargs.get("auto_parallel_size", True):
287+
288+
parallelism_config_kwargs = _maybe_auto_parallel_sizes(parallelism_config_kwargs)
289+
264290
parallelism_config = ParallelismConfig(**parallelism_config_kwargs)
265291
return parallelism_config
266292

0 commit comments

Comments
 (0)