|
| 1 | +import os |
| 2 | +import PIL.Image |
| 3 | +import torch |
| 4 | +import numpy as np |
| 5 | +from transformers import AutoModelForCausalLM |
| 6 | +from janus.models import MultiModalityCausalLM, VLChatProcessor |
| 7 | +import time |
| 8 | +import re |
| 9 | + |
| 10 | +# Specify the path to the model |
| 11 | +model_path = "deepseek-ai/Janus-1.3B" |
| 12 | +vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path) |
| 13 | +tokenizer = vl_chat_processor.tokenizer |
| 14 | + |
| 15 | +vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained( |
| 16 | + model_path, trust_remote_code=True |
| 17 | +) |
| 18 | +vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval() |
| 19 | + |
| 20 | + |
| 21 | +def create_prompt(user_input: str) -> str: |
| 22 | + conversation = [ |
| 23 | + { |
| 24 | + "role": "User", |
| 25 | + "content": user_input, |
| 26 | + }, |
| 27 | + {"role": "Assistant", "content": ""}, |
| 28 | + ] |
| 29 | + |
| 30 | + sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts( |
| 31 | + conversations=conversation, |
| 32 | + sft_format=vl_chat_processor.sft_format, |
| 33 | + system_prompt="", |
| 34 | + ) |
| 35 | + prompt = sft_format + vl_chat_processor.image_start_tag |
| 36 | + return prompt |
| 37 | + |
| 38 | + |
| 39 | +@torch.inference_mode() |
| 40 | +def generate( |
| 41 | + mmgpt: MultiModalityCausalLM, |
| 42 | + vl_chat_processor: VLChatProcessor, |
| 43 | + prompt: str, |
| 44 | + short_prompt: str, |
| 45 | + parallel_size: int = 16, |
| 46 | + temperature: float = 1, |
| 47 | + cfg_weight: float = 5, |
| 48 | + image_token_num_per_image: int = 576, |
| 49 | + img_size: int = 384, |
| 50 | + patch_size: int = 16, |
| 51 | +): |
| 52 | + input_ids = vl_chat_processor.tokenizer.encode(prompt) |
| 53 | + input_ids = torch.LongTensor(input_ids) |
| 54 | + |
| 55 | + tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).cuda() |
| 56 | + for i in range(parallel_size * 2): |
| 57 | + tokens[i, :] = input_ids |
| 58 | + if i % 2 != 0: |
| 59 | + tokens[i, 1:-1] = vl_chat_processor.pad_id |
| 60 | + |
| 61 | + inputs_embeds = mmgpt.language_model.get_input_embeddings()(tokens) |
| 62 | + |
| 63 | + generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda() |
| 64 | + outputs = None # Initialize outputs for use in the loop |
| 65 | + |
| 66 | + for i in range(image_token_num_per_image): |
| 67 | + outputs = mmgpt.language_model.model( |
| 68 | + inputs_embeds=inputs_embeds, |
| 69 | + use_cache=True, |
| 70 | + past_key_values=outputs.past_key_values if i != 0 else None |
| 71 | + ) |
| 72 | + hidden_states = outputs.last_hidden_state |
| 73 | + |
| 74 | + logits = mmgpt.gen_head(hidden_states[:, -1, :]) |
| 75 | + logit_cond = logits[0::2, :] |
| 76 | + logit_uncond = logits[1::2, :] |
| 77 | + |
| 78 | + logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond) |
| 79 | + probs = torch.softmax(logits / temperature, dim=-1) |
| 80 | + |
| 81 | + next_token = torch.multinomial(probs, num_samples=1) |
| 82 | + generated_tokens[:, i] = next_token.squeeze(dim=-1) |
| 83 | + |
| 84 | + next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1) |
| 85 | + img_embeds = mmgpt.prepare_gen_img_embeds(next_token) |
| 86 | + inputs_embeds = img_embeds.unsqueeze(dim=1) |
| 87 | + |
| 88 | + dec = mmgpt.gen_vision_model.decode_code( |
| 89 | + generated_tokens.to(dtype=torch.int), |
| 90 | + shape=[parallel_size, 8, img_size // patch_size, img_size // patch_size] |
| 91 | + ) |
| 92 | + dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) |
| 93 | + |
| 94 | + dec = np.clip((dec + 1) / 2 * 255, 0, 255) |
| 95 | + |
| 96 | + visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8) |
| 97 | + visual_img[:, :, :] = dec |
| 98 | + |
| 99 | + os.makedirs('generated_samples', exist_ok=True) |
| 100 | + |
| 101 | + # Create a timestamp |
| 102 | + timestamp = time.strftime("%Y%m%d-%H%M%S") |
| 103 | + |
| 104 | + # Sanitize the short_prompt to ensure it's safe for filenames |
| 105 | + short_prompt = re.sub(r'\W+', '_', short_prompt)[:50] |
| 106 | + |
| 107 | + # Save images with timestamp and part of the user prompt in the filename |
| 108 | + for i in range(parallel_size): |
| 109 | + save_path = os.path.join('generated_samples', f"img_{timestamp}_{short_prompt}_{i}.jpg") |
| 110 | + PIL.Image.fromarray(visual_img[i]).save(save_path) |
| 111 | + |
| 112 | + |
| 113 | +def interactive_image_generator(): |
| 114 | + print("Welcome to the interactive image generator!") |
| 115 | + |
| 116 | + # Ask for the number of images at the start of the session |
| 117 | + while True: |
| 118 | + num_images_input = input("How many images would you like to generate per prompt? (Enter a positive integer): ") |
| 119 | + if num_images_input.isdigit() and int(num_images_input) > 0: |
| 120 | + parallel_size = int(num_images_input) |
| 121 | + break |
| 122 | + else: |
| 123 | + print("Invalid input. Please enter a positive integer.") |
| 124 | + |
| 125 | + while True: |
| 126 | + user_input = input("Please describe the image you'd like to generate (or type 'exit' to quit): ") |
| 127 | + |
| 128 | + if user_input.lower() == 'exit': |
| 129 | + print("Exiting the image generator. Goodbye!") |
| 130 | + break |
| 131 | + |
| 132 | + prompt = create_prompt(user_input) |
| 133 | + |
| 134 | + # Create a sanitized version of user_input for the filename |
| 135 | + short_prompt = re.sub(r'\W+', '_', user_input)[:50] |
| 136 | + |
| 137 | + print(f"Generating {parallel_size} image(s) for: '{user_input}'") |
| 138 | + generate( |
| 139 | + mmgpt=vl_gpt, |
| 140 | + vl_chat_processor=vl_chat_processor, |
| 141 | + prompt=prompt, |
| 142 | + short_prompt=short_prompt, |
| 143 | + parallel_size=parallel_size # Pass the user-specified number of images |
| 144 | + ) |
| 145 | + |
| 146 | + print("Image generation complete! Check the 'generated_samples' folder for the output.\n") |
| 147 | + |
| 148 | + |
| 149 | +if __name__ == "__main__": |
| 150 | + interactive_image_generator() |
0 commit comments