Skip to content

Commit 6d9f074

Browse files
authored
chore: remove un-needed cp imports (#532)
1 parent 6fc8489 commit 6d9f074

4 files changed

Lines changed: 5 additions & 46 deletions

File tree

src/cache_dit/__init__.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -72,43 +72,6 @@ def disable_compute_comm_overlap():
7272
pass
7373

7474

75-
try:
76-
from cache_dit.parallelism import disable_ulysses_anything
77-
from cache_dit.parallelism import enable_ulysses_anything
78-
from cache_dit.parallelism import disable_ulysses_anything_float8
79-
from cache_dit.parallelism import enable_ulysses_anything_float8
80-
from cache_dit.parallelism import disable_ulysses_float8
81-
from cache_dit.parallelism import enable_ulysses_float8
82-
83-
except ImportError as e: # noqa: F841
84-
err_msg = str(e)
85-
86-
def _raise_import_error(func_name: str): # noqa: F811
87-
raise ImportError(
88-
f"{func_name} requires additional dependencies. "
89-
"Please install cache-dit[parallelism] or cache-dit[all] "
90-
f"to use this feature. Error message: {err_msg}"
91-
)
92-
93-
def enable_ulysses_anything(*args, **kwargs):
94-
_raise_import_error("enable_ulysses_anything")
95-
96-
def disable_ulysses_anything(*args, **kwargs):
97-
_raise_import_error("disable_ulysses_anything")
98-
99-
def enable_ulysses_anything_float8(*args, **kwargs):
100-
_raise_import_error("enable_ulysses_anything_float8")
101-
102-
def disable_ulysses_anything_float8(*args, **kwargs):
103-
_raise_import_error("disable_ulysses_anything_float8")
104-
105-
def enable_ulysses_float8(*args, **kwargs):
106-
_raise_import_error("enable_ulysses_float8")
107-
108-
def disable_ulysses_float8(*args, **kwargs):
109-
_raise_import_error("disable_ulysses_float8")
110-
111-
11275
NONE = CacheType.NONE
11376
DBCache = CacheType.DBCache
11477
DBPrune = CacheType.DBPrune

src/cache_dit/caching/cache_blocks/pattern_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
except ImportError:
2424
ContextParallelSplitHook = None
2525
logger.warning(
26-
"Context parallelism requires the 'diffusers>=0.36.dev0'."
27-
"Please install latest version of diffusers from source: \n"
26+
"Context parallelism in cache-dit requires 'diffusers>=0.36.dev0.\n"
27+
"Please install latest version of diffusers from source via: \n"
2828
"pip3 install git+https://github.com/huggingface/diffusers.git"
2929
)
3030

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,4 @@
11
from cache_dit.parallelism.parallel_backend import ParallelismBackend
22
from cache_dit.parallelism.parallel_config import ParallelismConfig
3-
from cache_dit.parallelism.backends.native_diffusers import enable_ulysses_anything
4-
from cache_dit.parallelism.backends.native_diffusers import disable_ulysses_anything
5-
from cache_dit.parallelism.backends.native_diffusers import enable_ulysses_anything_float8
6-
from cache_dit.parallelism.backends.native_diffusers import disable_ulysses_anything_float8
7-
from cache_dit.parallelism.backends.native_diffusers import enable_ulysses_float8
8-
from cache_dit.parallelism.backends.native_diffusers import disable_ulysses_float8
93
from cache_dit.parallelism.parallel_interface import enable_parallelism
104
from cache_dit.parallelism.parallel_interface import maybe_pad_prompt

src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_templated_ulysses.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
"Please install latest version of diffusers from source: \n"
1616
"pip3 install git+https://github.com/huggingface/diffusers.git"
1717
)
18-
from cache_dit.logger import init_logger
1918
from ._distributed_primitives import (
2019
_get_rank_world_size,
2120
_gather_size_by_comm,
@@ -27,8 +26,11 @@
2726
_all_to_all_single,
2827
)
2928

29+
from cache_dit.logger import init_logger
30+
3031
logger = init_logger(__name__)
3132

33+
3234
__all__ = [
3335
"TemplatedUlyssesAnythingAttention",
3436
"TemplatedUlyssesAnythingAttentionFloat8",

0 commit comments

Comments
 (0)