Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 153 additions & 0 deletions backend/mnemorai/constants/languages.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,156 @@
with open(config.get("G2P").get("LANGUAGE_JSON")) as f:
G2P_LANGCODES = json.load(f)
G2P_LANGUAGES: dict = dict(map(reversed, G2P_LANGCODES.items()))

EPITRAN_LANGCODES = {
"aar-Latn": "Afar",
"afr-Latn": "Afrikaans",
"aii-Syrc": "Assyrian Neo-Aramaic",
"amh-Ethi": "Amharic",
"amh-Ethi-pp": "Amharic (more phonetic)",
"amh-Ethi-red": "Amharic (reduced)",
"ara-Arab": "Literary Arabic",
"ava-Cyrl": "Avaric",
"aze-Cyrl": "Azerbaijani (Cyrillic)",
"aze-Latn": "Azerbaijani",
"ben-Beng": "Bengali",
"ben-Beng-red": "Bengali (reduced)",
"ben-Beng-east": "Eastern Bengali",
"bho-Deva": "Bhojpuri",
"bxk-Latn": "Bukusu",
"cat-Latn": "Catalan",
"ceb-Latn": "Cebuano",
"ces-Latn": "Czech",
"cjy-Latn": "Jin (Wiktionary)",
"ckb-Arab": "Sorani",
"cmn-Hans": "Mandarin (Simplified)*",
"cmn-Hant": "Mandarin (Traditional)*",
"cmn-Latn": "Mandarin (Pinyin)*",
"csb-Latn": "Kashubian",
"deu-Latn": "German",
"deu-Latn-np": "German†",
"deu-Latn-nar": "German (more phonetic)",
"eng-Latn": "English",
"epo-Latn": "Esperanto",
"est-Latn": "Estonian",
"fas-Arab": "Farsi (Perso-Arabic)",
"fin-Latn": "Finnish",
"fra-Latn": "French",
"fra-Latn-np": "French†",
"fra-Latn-p": "French (more phonetic)",
"ful-Latn": "Fulah",
"gan-Latn": "Gan (Wiktionary)",
"glg-Latn": "Galician",
"got-Goth": "Gothic",
"got-Latn": "Gothic (Latin)",
"hak-Latn": "Hakka (pha̍k-fa-sṳ)",
"hat-Latn-bab": "Haitian (Latin-Babel)",
"hau-Latn": "Hausa",
"hin-Deva": "Hindi",
"hmn-Latn": "Hmong",
"hrv-Latn": "Croatian",
"hsn-Latn": "Xiang (Wiktionary)",
"hun-Latn": "Hungarian",
"ilo-Latn": "Ilocano",
"ind-Latn": "Indonesian",
"ita-Latn": "Italian",
"jam-Latn": "Jamaican",
"jav-Latn": "Javanese",
"jpn-Hrgn": "Japanese (Hiragana)",
"jpn-Hrgn-red": "Japanese (Hiragana, reduced)",
"jpn-Ktkn": "Japanese (Katakana)",
"jpn-Ktkn-red": "Japanese (Katakana, reduced)",
"jpn-Jpan": "Japanese (Hiragana, Katakana, Kanji)",
"jpn-Hira": "Japanese (Hiragana)",
"jpn-Hira-red": "Japanese (Hiragana, reduced)",
"jpn-Kana": "Japanese (Katakana)",
"jpn-Kana-red": "Japanese (Katakana, reduced)",
"kat-Geor": "Georgian",
"kaz-Cyrl": "Kazakh (Cyrillic)",
"kaz-Cyrl-bab": "Kazakh (Cyrillicβ€”Babel)",
"kaz-Latn": "Kazakh (Latin)",
"kbd-Cyrl": "Kabardian",
"khm-Khmr": "Khmer",
"kin-Latn": "Kinyarwanda",
"kir-Arab": "Kyrgyz (Perso-Arabic)",
"kir-Cyrl": "Kyrgyz (Cyrillic)",
"kir-Latn": "Kyrgyz (Latin)",
"kmr-Latn": "Kurmanji",
"kmr-Latn-red": "Kurmanji (reduced)",
"kor-Hang": "Korean",
"lao-Laoo": "Lao",
"lao-Laoo-prereform": "Lao (Before spelling reform)",
"lav-Latn": "Latvian",
"lez-Cyrl": "Lezgian",
"lij-Latn": "Ligurian",
"lit-Latn": "Lithuanian",
"lsm-Latn": "Saamia",
"ltc-Latn-bax": "Middle Chinese (Baxter and Sagart 2014)",
"lug-Latn": "Ganda / Luganda",
"mal-Mlym": "Malayalam",
"mar-Deva": "Marathi",
"mlt-Latn": "Maltese",
"mon-Cyrl-bab": "Mongolian (Cyrillic)",
"mri-Latn": "Maori",
"msa-Latn": "Malay",
"mya-Mymr": "Burmese",
"nan-Latn": "Hokkien (pe̍h-oΔ“-jΔ«)",
"nan-Latn-tl": "Hokkien (TΓ’i-lΓ΄)",
"nld-Latn": "Dutch",
"nya-Latn": "Chichewa",
"ood-Latn-alv": "Tohono O'odham (Alvarez-Hale)",
"ood-Latn-sax": "Tohono O'odham (Saxton)",
"ori-Orya": "Odia",
"orm-Latn": "Oromo",
"pan-Guru": "Punjabi (Eastern)",
"pol-Latn": "Polish",
"por-Latn": "Portuguese",
"quy-Latn": "Ayacucho Quechua / Quechua Chanka",
"ron-Latn": "Romanian",
"run-Latn": "Rundi",
"rus-Cyrl": "Russian",
"sag-Latn": "Sango",
"sin-Sinh": "Sinhala",
"slv-Latn": "Slovene / Slovenian",
"sna-Latn": "Shona",
"som-Latn": "Somali",
"spa-Latn": "Spanish",
"spa-Latn-eu": "Spanish (Iberian)",
"sqi-Latn": "Albanian",
"sro-Latn": "Sardinian (Campidanese)",
"srp-Latn": "Serbian (Latin)",
"srp-Cyrl": "Serbian (Cyrillic)",
"swa-Latn": "Swahili",
"swa-Latn-red": "Swahili (reduced)",
"swe-Latn": "Swedish",
"tam-Taml": "Tamil",
"tam-Taml-red": "Tamil (reduced)",
"tel-Telu": "Telugu",
"tgk-Cyrl": "Tajik",
"tgl-Latn": "Tagalog",
"tgl-Latn-red": "Tagalog (reduced)",
"tha-Thai": "Thai",
"tir-Ethi": "Tigrinya",
"tir-Ethi-pp": "Tigrinya (more phonemic)",
"tir-Ethi-red": "Tigrinya (reduced)",
"tok-Latn": "Toki Pona",
"tpi-Latn": "Tok Pisin",
"tuk-Cyrl": "Turkmen (Cyrillic)",
"tuk-Latn": "Turkmen (Latin)",
"tur-Latn": "Turkish (Latin)",
"tur-Latn-bab": "Turkish (Latinβ€”Babel)",
"tur-Latn-red": "Turkish (reduced)",
"ukr-Cyrl": "Ukrainian",
"urd-Arab": "Urdu",
"uig-Arab": "Uyghur (Perso-Arabic)",
"uzb-Cyrl": "Uzbek (Cyrillic)",
"uzb-Latn": "Uzbek (Latin)",
"vie-Latn": "Vietnamese",
"wuu-Latn": "Shanghainese Wu (Wiktionary)",
"xho-Latn": "Xhosa",
"yor-Latn": "Yoruba",
"yue-Latn": "Cantonese (Jyutping)",
"yue-Hant": "Cantonese (Character)",
"zha-Latn": "Zhuang",
"zul-Latn": "Zulu",
}
2 changes: 1 addition & 1 deletion backend/mnemorai/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,4 @@ async def generate_mnemonic_img(

if __name__ == "__main__":
pipeline = MnemonicPipeline()
print(asyncio.run(pipeline.generate_mnemonic_img("ratatouille", "eng-us")))
print(asyncio.run(pipeline.generate_mnemonic_img("tikus", "ind")))
155 changes: 101 additions & 54 deletions backend/mnemorai/services/imagine/image_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@

import torch
from diffusers import (
AutoencoderKL,
AutoPipelineForText2Image,
SanaPipeline,
SanaTransformer2DModel,
FlowMatchEulerDiscreteScheduler,
FluxPipeline,
)
from diffusers import (
BitsAndBytesConfig as DiffusersBitsAndBytesConfig,
)
from transformers import AutoModel
from huggingface_hub import hf_hub_download
from nunchaku import NunchakuT5EncoderModel
from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
from nunchaku.utils import get_precision
from transformers import BitsAndBytesConfig as BitsAndBytesConfig
from transformers import CLIPTextModel, CLIPTokenizer, T5TokenizerFast

from mnemorai.constants.config import config
from mnemorai.logger import logger
Expand All @@ -37,69 +39,117 @@ def __init__(self, model: str = None):
os.makedirs(self.output_dir, exist_ok=True)
self.image_gen_params = self.config.get("PARAMS", {})

# if seed is provided, set it
if "seed" in self.image_gen_params:
if not isinstance(self.image_gen_params["seed"], int):
logger.warning("Seed must be an integer. Using no seed.")
else:
self.image_gen_params["generator"] = torch.Generator(
device="cuda"
).manual_seed(self.image_gen_params["seed"])
# remove seed from params to avoid passing it to the pipeline
del self.image_gen_params["seed"]

# Initialize pipe to None; will be loaded on first use
self.pipe = None

def _get_pipe_func(self):
if "sana" in self.model_name.lower():
return SanaPipeline
if "flux" in self.model_name.lower():
return FluxPipeline
else:
return AutoPipelineForText2Image

def _initialize_pipe(self):
"""Initialize the pipeline."""
pipe_func = self._get_pipe_func()
logger.debug(f"Initializing pipeline for model: {self.model}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16

quantization = self.config.get("QUANTIZATION")
if "flux" in self.model_name.lower():
bfl_repo = "black-forest-labs/FLUX.1-dev"
device = "cuda"

if quantization != "4bit" and quantization != "8bit":
logger.debug("Using default model loading without quantization")
self.pipe = pipe_func.from_pretrained(
self.model,
torch_dtype=torch.float16,
variant="fp16",
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
bfl_repo,
subfolder="scheduler",
torch_dtype=dtype,
cache_dir="models",
)
else:
if "sana" in self.model_name.lower():
if quantization == "8bit":
quant_config = BitsAndBytesConfig(load_in_8bit=True)
logger.debug("Using 8-bit quantization for Sana model")
elif quantization == "4bit":
quant_config = BitsAndBytesConfig(load_in_4bit=True)
logger.debug("Using 4-bit quantization for Sana model")
else:
raise ValueError(
f"Invalid quantization type. Use '8bit' or '4bit'. Your quantization is: {quantization}"
text_encoder = CLIPTextModel.from_pretrained(
bfl_repo,
subfolder="text_encoder",
torch_dtype=dtype,
cache_dir="models",
)
# T5 encoder in int4
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors",
cache_dir="models",
)
tokenizer = CLIPTokenizer.from_pretrained(
bfl_repo,
subfolder="tokenizer",
torch_dtype=dtype,
clean_up_tokenization_spaces=True,
cache_dir="models",
)
tokenizer_2 = T5TokenizerFast.from_pretrained(
bfl_repo,
subfolder="tokenizer_2",
torch_dtype=dtype,
clean_up_tokenization_spaces=True,
cache_dir="models",
)
vae = AutoencoderKL.from_pretrained(
bfl_repo,
subfolder="vae",
torch_dtype=dtype,
cache_dir="models",
)
precision = (
get_precision()
) # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors",
offload=self.config.get("OFFLOAD_T5", True),
)
# Set attention implementation to fp16
transformer.set_attention_impl("nunchaku-fp16")

params = {
"scheduler": scheduler,
"vae": vae,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
"text_encoder": text_encoder,
"text_encoder_2": text_encoder_2,
"transformer": transformer,
}
self.pipe = FluxPipeline(**params) # .to(device, dtype=dtype)

lora_config = self.config.get("FLUX_LORA", {})
if lora_config.get("USE_LORA", False):
logger.info("Loading LoRA weights for FLUX model.")
transformer.update_lora_params(
hf_hub_download(
lora_config.get("LORA_REPO"),
lora_config.get("LORA_FILE"),
)

text_encoder_8bit = AutoModel.from_pretrained(
self.model,
subfolder="text_encoder",
quantization_config=quant_config,
torch_dtype=torch.float16,
cache_dir="models",
)

quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
transformer_8bit = SanaTransformer2DModel.from_pretrained(
self.model,
subfolder="transformer",
quantization_config=quant_config,
torch_dtype=torch.float16,
cache_dir="models",
)
transformer.set_lora_strength(lora_config.get("LORA_SCALE", 1.0))

self.pipe = SanaPipeline.from_pretrained(
self.model,
text_encoder=text_encoder_8bit,
transformer=transformer_8bit,
torch_dtype=torch.float16,
device_map="balanced",
)
else:
raise NotImplementedError("Quantization not supported for this model.")
# offload, does not decrease performance
if self.config.get("SEQUENTIAL_OFFLOAD", True):
logger.info("Enabling sequential CPU offload for FLUX model.")
self.pipe.enable_sequential_cpu_offload(device=device)
else:
self.pipe = pipe_func.from_pretrained(
self.model,
torch_dtype=dtype,
variant="fp16" if dtype == torch.float16 else None,
cache_dir="models",
)

@manage_memory(
targets=["pipe"],
Expand All @@ -108,7 +158,7 @@ def _initialize_pipe(self):
)
def generate_img(
self,
prompt: str = "Imagine a flashy bottle that stands out from the other bottles.",
prompt: str = "A flashy bottle that stands out from the other bottles.",
word1: str = "flashy",
word2: str = "bottle",
):
Expand All @@ -126,9 +176,6 @@ def generate_img(
"""
file_path = self.output_dir / f"{word1}_{word2}_{self.model_name}.png"

# Clean prompt by dropping "imagine " prefix
prompt = prompt.lower().lstrip("imagine").strip()

logger.info(f"Generating image for prompt: {prompt}")
image = self.pipe(prompt=prompt, **self.image_gen_params).images[0]
logger.info(f"Saving image to: {file_path}")
Expand Down
Loading
Loading