Skip to content

Commit 446edfd

Browse files
authored
misc: add quant and attn backend -> step mask example (#454)
* misc: support quantize and attn backend for flux example * misc: support quantize and attn backend for flux example * misc: support quantize and attn backend for flux example * misc: support quantize and attn backend for flux example * misc: support quantize and attn backend for flux example * misc: support quantize and attn backend for flux example * misc: support quantize and attn backend for flux example * misc: support quantize and attn backend for flux example * misc: support quantize and attn backend for flux example * misc: support quantize and attn backend for flux example * misc: support quantize and attn backend for flux example * misc: support quantize and attn backend for flux example * misc: support quantize and attn backend for flux example * misc: support quantize and attn backend for flux example * misc: add quant and attn backend -> step mask example * misc: add quant and attn backend -> step mask example * misc: add quant and attn backend -> step mask example * misc: add quant and attn backend -> step mask example * misc: add quant and attn backend -> step mask example * misc: add quant and attn backend -> step mask example * misc: add quant and attn backend -> step mask example * misc: add quant and attn backend -> step mask example * misc: add quant and attn backend -> step mask example * misc: add quant and attn backend -> step mask example * misc: add quant and attn backend -> step mask example * misc: add quant and attn backend -> step mask example
1 parent 2a418d2 commit 446edfd

5 files changed

Lines changed: 48 additions & 23 deletions

File tree

README.md

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,13 @@
1111
</h2>
1212
</p>
1313

14-
15-
|Baseline|SCM S S*|SCM S D*|SCM F D*|SCM U D*|+TaylorSeer|+compile|
14+
|Baseline|SCM S S*|SCM F D*|SCM U D*|+TS|+compile|+FP8*|
1615
|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
17-
|24.85s|15.4s|17.1s|11.4s|8.2s|8.2s|7.1s|
18-
|<img src="https://github.com/vipshop/cache-dit/raw/main/assets/steps_mask/flux.NONE.png" width=90px>|<img src="assets/steps_mask/static.png" width=90px>|<img src="https://github.com/vipshop/cache-dit/raw/main/assets/steps_mask/flux.DBCache_F1B0_W8I1M0MC0_R0.15_SCM1111111101110011100110011000_dynamic_T0O0_S8.png" width=90px>|<img src="https://github.com/vipshop/cache-dit/raw/main/assets/steps_mask/flux.DBCache_F1B0_W8I1M0MC0_R0.2_SCM1111110100010000100000100000_dynamic_T0O0_S15.png" width=90px>|<img src="https://github.com/vipshop/cache-dit/raw/main/assets/steps_mask/flux.DBCache_F1B0_W8I1M0MC0_R0.3_SCM111101000010000010000001000000_dynamic_T0O0_S19.png" width=90px>|<img src="https://github.com/vipshop/cache-dit/raw/main/assets/steps_mask/flux.DBCache_F1B0_W8I1M0MC0_R0.35_SCM111101000010000010000001000000_dynamic_T1O1_S19.png" width=90px>|<img src="https://github.com/vipshop/cache-dit/raw/main/assets/steps_mask/flux.DBCache_F1B0_W8I1M0MC0_R0.35_SCM111101000010000010000001000000_dynamic_T1O1_S19.png" width=90px>|
16+
|24.85s|15.4s|11.4s|8.2s|8.2s|7.1s|4.5s|
17+
|<img src="https://github.com/vipshop/cache-dit/raw/main/assets/steps_mask/flux.NONE.png" width=90px>|<img src="assets/steps_mask/static.png" width=90px>|<img src="https://github.com/vipshop/cache-dit/raw/main/assets/steps_mask/flux.DBCache_F1B0_W8I1M0MC0_R0.2_SCM1111110100010000100000100000_dynamic_T0O0_S15.png" width=90px>|<img src="https://github.com/vipshop/cache-dit/raw/main/assets/steps_mask/flux.DBCache_F1B0_W8I1M0MC0_R0.3_SCM111101000010000010000001000000_dynamic_T0O0_S19.png" width=90px>|<img src="https://github.com/vipshop/cache-dit/raw/main/assets/steps_mask/flux.DBCache_F1B0_W8I1M0MC0_R0.35_SCM111101000010000010000001000000_dynamic_T1O1_S19.png" width=90px>|<img src="https://github.com/vipshop/cache-dit/raw/main/assets/steps_mask/flux.DBCache_F1B0_W8I1M0MC0_R0.35_SCM111101000010000010000001000000_dynamic_T1O1_S19.png" width=90px>|<img src="./assets/steps_mask/flux.C1_Q1_float8_DBCache_F1B0_W8I1M0MC0_R0.35_SCM111101000010000010000001000000_dynamic_T1O1_S19.png" width=90px>|
1918

2019
<p align="center">
21-
Scheme: <b>DBCache + SCM(steps_computation_mask) + TaylorSeer</b>, L20x1, S*: static cache, <b>D*: dynamic cache</b>, <br><b>S</b>: Slow, <b>F</b>: Fast, <b>U</b>: Ultra Fast, FLUX.1-Dev, Steps: 28, Prompt: "A cat holding a sign that says hello world"
20+
Scheme: <b>DBCache + SCM(steps_computation_mask) + TS(TaylorSeer) + FP8*</b>, L20x1, S*: static cache, <br><b>D*: dynamic cache</b>, <b>S</b>: Slow, <b>F</b>: Fast, <b>U</b>: Ultra Fast, <b>TS</b>: TaylorSeer, <b>FP8*</b>: FP8 DQ + Sage, <b>FLUX.1</b>-Dev
2221
</p>
2322

2423
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/speedup_v4.png>
1.01 MB
Loading

docs/User_Guide.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ cache_dit.enable_cache(
615615

616616
<div align="center">
617617

618-
|Baseline(L20x1)|F1B0 (0.12)|+TaylorSeer|F1B0 (0.15)|+TaylorSeer|+compile|
618+
|Baseline(L20x1)|F1B0 (0.12)|+TaylorSeer|F1B0 (0.15)|+TaylorSeer|+compile|
619619
|:---:|:---:|:---:|:---:|:---:|:---:|
620620
|24.85s|12.85s|12.86s|10.27s|10.28s|8.48s|
621621
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=140px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C0_DBCACHE_F1B0S1W0T0ET0_R0.12_S14_T12.85s.png width=140px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C0_DBCACHE_F1B0S1W0T1ET1_R0.12_S14_T12.86s.png width=140px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C0_DBCACHE_F1B0S1W0T0ET0_R0.15_S17_T10.27s.png width=140px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C0_DBCACHE_F1B0S1W0T1ET1_R0.15_S17_T10.28s.png width=140px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBCACHE_F1B0S1W0T1ET1_R0.15_S17_T8.48s.png width=140px>|
@@ -674,13 +674,13 @@ As we can observe, in the case of **static cache**, the image of `SCM Slow S*` (
674674

675675
<div align="center">
676676

677-
|Baseline(L20x1)|SCM Slow S*|SCM Slow D*|SCM Fast D*|SCM Ultra D*|+TaylorSeer|+compile|
678-
|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
679-
|24.85s|15.4s|17.1s|11.4s|8.2s|8.2s|7.1s|
680-
|<img src="https://github.com/vipshop/cache-dit/raw/main/assets/steps_mask/flux.NONE.png" width=110px>|<img src="../assets/steps_mask/static.png" width=110px>|<img src="https://github.com/vipshop/cache-dit/raw/main/assets/steps_mask/flux.DBCache_F1B0_W8I1M0MC0_R0.15_SCM1111111101110011100110011000_dynamic_T0O0_S8.png" width=110px>|<img src="https://github.com/vipshop/cache-dit/raw/main/assets/steps_mask/flux.DBCache_F1B0_W8I1M0MC0_R0.2_SCM1111110100010000100000100000_dynamic_T0O0_S15.png" width=110px>|<img src="https://github.com/vipshop/cache-dit/raw/main/assets/steps_mask/flux.DBCache_F1B0_W8I1M0MC0_R0.3_SCM111101000010000010000001000000_dynamic_T0O0_S19.png" width=110px>|<img src="https://github.com/vipshop/cache-dit/raw/main/assets/steps_mask/flux.DBCache_F1B0_W8I1M0MC0_R0.35_SCM111101000010000010000001000000_dynamic_T1O1_S19.png" width=110px>|<img src="https://github.com/vipshop/cache-dit/raw/main/assets/steps_mask/flux.DBCache_F1B0_W8I1M0MC0_R0.35_SCM111101000010000010000001000000_dynamic_T1O1_S19.png" width=110px>|
677+
|Baseline|SCM S S*|SCM S D*|SCM F D*|SCM U D*|+TS|+compile|+FP8 +Sage|
678+
|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
679+
|24.85s|15.4s|17.1s|11.4s|8.2s|8.2s|7.1s|4.5s|
680+
|<img src="https://github.com/vipshop/cache-dit/raw/main/assets/steps_mask/flux.NONE.png" width=95px>|<img src="../assets/steps_mask/static.png" width=95px>|<img src="https://github.com/vipshop/cache-dit/raw/main/assets/steps_mask/flux.DBCache_F1B0_W8I1M0MC0_R0.15_SCM1111111101110011100110011000_dynamic_T0O0_S8.png" width=95px>|<img src="https://github.com/vipshop/cache-dit/raw/main/assets/steps_mask/flux.DBCache_F1B0_W8I1M0MC0_R0.2_SCM1111110100010000100000100000_dynamic_T0O0_S15.png" width=95px>|<img src="https://github.com/vipshop/cache-dit/raw/main/assets/steps_mask/flux.DBCache_F1B0_W8I1M0MC0_R0.3_SCM111101000010000010000001000000_dynamic_T0O0_S19.png" width=95px>|<img src="https://github.com/vipshop/cache-dit/raw/main/assets/steps_mask/flux.DBCache_F1B0_W8I1M0MC0_R0.35_SCM111101000010000010000001000000_dynamic_T1O1_S19.png" width=95px>|<img src="https://github.com/vipshop/cache-dit/raw/main/assets/steps_mask/flux.DBCache_F1B0_W8I1M0MC0_R0.35_SCM111101000010000010000001000000_dynamic_T1O1_S19.png" width=95px>|<img src="../assets/steps_mask/flux.C1_Q1_float8_DBCache_F1B0_W8I1M0MC0_R0.35_SCM111101000010000010000001000000_dynamic_T1O1_S19.png" width=95px>|
681681

682682
<p align="center">
683-
DBCache + SCM(steps_computation_mask) + TaylorSeer, <b> L20x1 </b>, <br>S*: static cache, D*: dynamic cache, Steps: 28, "A cat holding a sign that says hello world"
683+
Scheme: <b>DBCache + SCM(steps_computation_mask) + TaylorSeer</b>, L20x1, S*: static cache, <b>D*: dynamic cache</b>, <br><b>S</b>: Slow, <b>F</b>: Fast, <b>U</b>: Ultra Fast, <b>TS</b>: TaylorSeer, FP8: FP8 DQ, Sage: SageAttention, <b>FLUX.1-Dev</b>, <br>Steps: 28, HxW=1024x1024, Prompt: "A cat holding a sign that says hello world"
684684
</p>
685685

686686
|DBCache + SCM Slow S*|DBCache + SCM Ultra D* + TaylorSeer + compile|

examples/api/run_steps_mask.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
import time
77
import torch
8-
from diffusers import FluxPipeline
9-
from utils import get_args
8+
from diffusers import FluxPipeline, FluxTransformer2DModel
9+
from utils import get_args, strify
1010
import cache_dit
1111

1212

@@ -62,7 +62,7 @@
6262
"black-forest-labs/FLUX.1-dev",
6363
),
6464
torch_dtype=torch.bfloat16,
65-
).to("cuda")
65+
)
6666

6767
if args.cache:
6868
from cache_dit import DBCacheConfig, TaylorSeerCalibratorConfig
@@ -94,10 +94,36 @@
9494
),
9595
)
9696

97+
assert isinstance(pipe.transformer, FluxTransformer2DModel)
98+
if args.quantize:
99+
pipe.transformer = cache_dit.quantize(
100+
pipe.transformer,
101+
quant_type=args.quantize_type,
102+
exclude_layers=[
103+
"embedder",
104+
"embed",
105+
],
106+
)
107+
pipe.text_encoder_2 = cache_dit.quantize(
108+
pipe.text_encoder_2,
109+
quant_type=args.quantize_type,
110+
)
111+
print(f"Applied quantization: {args.quantize_type} to Transformer and Text Encoder 2.")
112+
113+
pipe.to("cuda")
114+
115+
if args.attn is not None:
116+
if hasattr(pipe.transformer, "set_attention_backend"):
117+
pipe.transformer.set_attention_backend(args.attn)
118+
print(f"Set attention backend to {args.attn}")
119+
97120

98121
if args.compile:
99122
cache_dit.set_compile_configs()
100123
pipe.transformer = torch.compile(pipe.transformer)
124+
pipe.text_encoder = torch.compile(pipe.text_encoder)
125+
pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2)
126+
pipe.vae = torch.compile(pipe.vae)
101127

102128

103129
def run_pipe():
@@ -121,7 +147,7 @@ def run_pipe():
121147
cache_dit.summary(pipe)
122148

123149
time_cost = end - start
124-
save_path = f"flux.{cache_dit.strify(pipe)}.png"
150+
save_path = f"flux.{strify(args, pipe)}.png"
125151
print(f"Time cost: {time_cost:.2f}s")
126152
print(f"Saving image to {save_path}")
127153
image.save(save_path)
@@ -134,3 +160,4 @@ def run_pipe():
134160
# python3 run_steps_mask.py --cache --Fn 1 --step-mask u --step-policy dynamic --rdt 0.30
135161
# python3 run_steps_mask.py --cache --Fn 1 --step-mask u --step-policy dynamic --rdt 0.30 --taylorseer --taylorseer-order 1
136162
# python3 run_steps_mask.py --cache --Fn 1 --step-mask u --step-policy dynamic --rdt 0.30 --compile --taylorseer --taylorseer-order 1
163+
# python3 run_steps_mask.py --cache --Fn 1 --step-mask u --step-policy dynamic --rdt 0.35 --compile --taylorseer --taylorseer-order 1 --quantize --quantize-type float8 --attn sage

examples/pipeline/run_flux.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,13 @@
4848
pipe.transformer.set_attention_backend(args.attn)
4949
print(f"Set attention backend to {args.attn}")
5050

51+
if args.compile:
52+
cache_dit.set_compile_configs()
53+
pipe.transformer = torch.compile(pipe.transformer)
54+
pipe.text_encoder = torch.compile(pipe.text_encoder)
55+
pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2)
56+
pipe.vae = torch.compile(pipe.vae)
57+
5158

5259
def run_pipe():
5360
image = pipe(
@@ -59,14 +66,6 @@ def run_pipe():
5966
).images[0]
6067
return image
6168

62-
63-
if args.compile:
64-
cache_dit.set_compile_configs()
65-
pipe.transformer = torch.compile(pipe.transformer)
66-
pipe.text_encoder = torch.compile(pipe.text_encoder)
67-
pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2)
68-
pipe.vae = torch.compile(pipe.vae)
69-
7069
# warmup
7170
_ = run_pipe()
7271

0 commit comments

Comments
 (0)