Skip to content

Commit 6bc0fbc

Browse files
zucchini-nlpstevhliuArthurZucker
committed
[WIP] Emu3: add model (#33770)
* model can convert to HF and be loaded back * nit * works in single batch generation but hallucinates * use the image tokens * add image generation * now it works * add tests * update * add modulare but it doesn't work for porting docstring :( * skip some tests * add slow tests * modular removed the import? * guess this works * update * update * fix copies * fix test * fix copies * update * docs * fix tests * last fix tests? * pls * repo consistency * more style * style * remove file * address comments * tiny bits * update after the new modular * fix tests * add one more cond in check attributes * decompose down/up/mid blocks * allow static cache generation in VLMs * nit * fix copies * Update docs/source/en/model_doc/emu3.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/model_doc/emu3.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/model_doc/emu3.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/model_doc/emu3.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/model_doc/emu3.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/model_doc/emu3.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/model_doc/emu3.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/model_doc/emu3.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * fix VAE upsampling * Update src/transformers/models/emu3/modular_emu3.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * address comments * state overwritten stuff explicitly * fix copies * add the flag for flex attn --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
1 parent 59e28c3 commit 6bc0fbc

28 files changed

+5722
-5
lines changed

docs/source/en/_toctree.yml

+2
Original file line numberDiff line numberDiff line change
@@ -860,6 +860,8 @@
860860
title: DePlot
861861
- local: model_doc/donut
862862
title: Donut
863+
- local: model_doc/emu3
864+
title: Emu3
863865
- local: model_doc/flava
864866
title: FLAVA
865867
- local: model_doc/git

docs/source/en/index.md

+1
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ Flax), PyTorch, and/or TensorFlow.
137137
| [EfficientFormer](model_doc/efficientformer) ||||
138138
| [EfficientNet](model_doc/efficientnet) ||||
139139
| [ELECTRA](model_doc/electra) ||||
140+
| [Emu3](model_doc/emu3) ||||
140141
| [EnCodec](model_doc/encodec) ||||
141142
| [Encoder decoder](model_doc/encoder-decoder) ||||
142143
| [ERNIE](model_doc/ernie) ||||

docs/source/en/model_doc/emu3.md

+179
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
# Emu3
18+
19+
## Overview
20+
21+
The Emu3 model was proposed in [Emu3: Next-Token Prediction is All You Need](https://arxiv.org/abs/2409.18869) by Xinlong Wang, Xiaosong Zhang, Zhengxiong Luo, Quan Sun, Yufeng Cui, Jinsheng Wang, Fan Zhang, Yueze Wang, Zhen Li, Qiying Yu, Yingli Zhao, Yulong Ao, Xuebin Min, Tao Li, Boya Wu, Bo Zhao, Bowen Zhang, Liangdong Wang, Guang Liu, Zheqi He, Xi Yang, Jingjing Liu, Yonghua Lin, Tiejun Huang, Zhongyuan Wang.
22+
23+
Emu3 is a multimodal LLM that uses vector quantization to tokenize images into discrete tokens. Discretized image tokens are later fused with text token ids for image and text generation. The model can additionally generate images by predicting image token ids.
24+
25+
26+
The abstract from the paper is the following:
27+
28+
*While next-token prediction is considered a promising path towards artificial general intelligence, it has struggled to excel in multimodal tasks, which are still dominated by diffusion models (e.g., Stable Diffusion) and compositional approaches (e.g., CLIP combined with LLMs). In this paper, we introduce Emu3, a new suite of state-of-the-art multimodal models trained solely with next-token prediction. By tokenizing images, text, and videos into a discrete space, we train a single transformer from scratch on a mixture of multimodal sequences. Emu3 outperforms several well-established task-specific models in both generation and perception tasks, surpassing flagship models such as SDXL and LLaVA-1.6, while eliminating the need for diffusion or compositional architectures. Emu3 is also capable of generating high-fidelity video via predicting the next token in a video sequence. We simplify complex multimodal model designs by converging on a singular focus: tokens, unlocking great potential for scaling both during training and inference. Our results demonstrate that next-token prediction is a promising path towards building general multimodal intelligence beyond language. We open-source key techniques and models to support further research in this direction.*
29+
30+
Tips:
31+
32+
- We advise users to set `processor.tokenizer.padding_side = "left"` before batched generation as it leads to more accurate results.
33+
34+
- Note that the model has been trained with a specific prompt format for chatting. Use `processor.apply_chat_template(my_conversation_dict)` to correctly format your prompts.
35+
36+
- Emu3 has two different checkpoints for image-generation and text-generation, make sure to use the correct checkpoint when loading the model. To generate an image, it is advised to use `prefix_constraints` so that the generated tokens are sampled only from possible image tokens. See more below for usage examples.
37+
38+
> [!TIP]
39+
> Emu3 implementation in Transformers uses a special image token to indicate where to merge image embeddings. The special image token isn't new and uses one of the reserved tokens: `<|extra_0|>`. You have to add `<image>` to your prompt in the place where the image should be embedded for correct generation.
40+
41+
42+
This model was contributed by [RaushanTurganbay](https://huggingface.co/RaushanTurganbay).
43+
The original code can be found [here](https://github.com/baaivision/Emu3).
44+
45+
46+
## Usage example
47+
48+
### Text generation inference
49+
50+
Here's how to load the model and perform inference in half-precision (`torch.bfloat16`) to generate textual output from text or text and image inputs:
51+
52+
```python
53+
from transformers import Emu3Processor, Emu3ForConditionalGeneration
54+
import torch
55+
from PIL import Image
56+
import requests
57+
58+
processor = Emu3Processor.from_pretrained("Emu3-community/Emu3-Chat-hf")
59+
model = Emu3ForConditionalGeneration.from_pretrained("Emu3-community/Emu3-Chat-hf", torch_dtype=torch.bfloat16, device_map="cuda")
60+
61+
# prepare image and text prompt
62+
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
63+
image = Image.open(requests.get(url, stream=True).raw)
64+
prompt = "What do you see in this image?<image>"
65+
66+
inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
67+
68+
# autoregressively complete prompt
69+
output = model.generate(**inputs, max_new_tokens=50)
70+
print(processor.decode(output[0], skip_special_tokens=True))
71+
```
72+
73+
### Image generation inference
74+
75+
Emu3 can also generate images from textual input. Here is how you can do it:
76+
77+
```python
78+
processor = Emu3Processor.from_pretrained("Emu3-community/Emu3-Gen-hf")
79+
model = Emu3ForConditionalGeneration.from_pretrained("Emu3-community/Emu3-Gen-hf", torch_dtype="bfloat16", device_map="auto", attn_implementation="flash_attention_2")
80+
81+
82+
inputs = processor(
83+
text=["a portrait of young girl. masterpiece, film grained, best quality.", "a dog running under the rain"],
84+
padding=True,
85+
return_tensors="pt",
86+
return_for_image_generation=True,
87+
)
88+
inputs = inputs.to(device="cuda:0", dtype=torch.bfloat16)
89+
90+
neg_prompt = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry."
91+
neg_inputs = processor(text=[neg_prompt] * 2, return_tensors="pt").to(device="cuda:0")
92+
93+
image_sizes = inputs.pop("image_sizes")
94+
HEIGHT, WIDTH = image_sizes[0]
95+
VISUAL_TOKENS = model.vocabulary_mapping.image_tokens
96+
97+
def prefix_allowed_tokens_fn(batch_id, input_ids):
98+
height, width = HEIGHT, WIDTH
99+
visual_tokens = VISUAL_TOKENS
100+
image_wrapper_token_id = torch.tensor([processor.tokenizer.image_wrapper_token_id], device=model.device)
101+
eoi_token_id = torch.tensor([processor.tokenizer.eoi_token_id], device=model.device)
102+
eos_token_id = torch.tensor([processor.tokenizer.eos_token_id], device=model.device)
103+
pad_token_id = torch.tensor([processor.tokenizer.pad_token_id], device=model.device)
104+
eof_token_id = torch.tensor([processor.tokenizer.eof_token_id], device=model.device)
105+
eol_token_id = processor.tokenizer.encode("<|extra_200|>", return_tensors="pt")[0]
106+
107+
position = torch.nonzero(input_ids == image_wrapper_token_id, as_tuple=True)[0][0]
108+
offset = input_ids.shape[0] - position
109+
if offset % (width + 1) == 0:
110+
return (eol_token_id, )
111+
elif offset == (width + 1) * height + 1:
112+
return (eof_token_id, )
113+
elif offset == (width + 1) * height + 2:
114+
return (eoi_token_id, )
115+
elif offset == (width + 1) * height + 3:
116+
return (eos_token_id, )
117+
elif offset > (width + 1) * height + 3:
118+
return (pad_token_id, )
119+
else:
120+
return visual_tokens
121+
122+
123+
out = model.generate(
124+
**inputs,
125+
max_new_tokens=50_000, # make sure to have enough tokens for one image
126+
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
127+
return_dict_in_generate=True,
128+
negative_prompt_ids=neg_inputs.input_ids, # indicate for Classifier-Free Guidance
129+
negative_prompt_attention_mask=neg_inputs.attention_mask,
130+
)
131+
132+
image = model.decode_image_tokens(out.sequences[:, inputs.input_ids.shape[1]: ], height=HEIGHT, width=WIDTH)
133+
images = processor.postprocess(list(image.float()), return_tensors="PIL.Image.Image") # internally we convert to np but it's not supported in bf16 precision
134+
for i, image in enumerate(images['pixel_values']):
135+
image.save(f"result{i}.png")
136+
137+
```
138+
139+
140+
## Emu3Config
141+
142+
[[autodoc]] Emu3Config
143+
144+
## Emu3VQVAEConfig
145+
146+
[[autodoc]] Emu3VQVAEConfig
147+
148+
## Emu3TextConfig
149+
150+
[[autodoc]] Emu3TextConfig
151+
152+
## Emu3Processor
153+
154+
[[autodoc]] Emu3Processor
155+
156+
## Emu3ImageProcessor
157+
158+
[[autodoc]] Emu3ImageProcessor
159+
- preprocess
160+
161+
## Emu3VQVAE
162+
163+
[[autodoc]] Emu3VQVAE
164+
- forward
165+
166+
## Emu3TextModel
167+
168+
[[autodoc]] Emu3TextModel
169+
- forward
170+
171+
## Emu3ForCausalLM
172+
173+
[[autodoc]] Emu3ForCausalLM
174+
- forward
175+
176+
## Emu3ForConditionalGeneration
177+
178+
[[autodoc]] Emu3ForConditionalGeneration
179+
- forward

docs/source/en/perf_infer_gpu_one.md

+2
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ FlashAttention-2 is currently supported for the following architectures:
4949
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
5050
* [DiffLlama](https://huggingface.co/docs/transformers/model_doc/diffllama#transformers.DiffLlamaModel)
5151
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
52+
* [Emu3](https://huggingface.co/docs/transformers/model_doc/emu3)
5253
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
5354
* [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model)
5455
* [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2)
@@ -245,6 +246,7 @@ For now, Transformers supports SDPA inference and training for the following arc
245246
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
246247
* [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader)
247248
* [EncoderDecoder](https://huggingface.co/docs/transformers/model_doc/encoder_decoder#transformers.EncoderDecoderModel)
249+
* [Emu3](https://huggingface.co/docs/transformers/model_doc/emu3)
248250
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
249251
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
250252
* [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model)

src/transformers/__init__.py

+30
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,12 @@
428428
"ElectraConfig",
429429
"ElectraTokenizer",
430430
],
431+
"models.emu3": [
432+
"Emu3Config",
433+
"Emu3Processor",
434+
"Emu3TextConfig",
435+
"Emu3VQVAEConfig",
436+
],
431437
"models.encodec": [
432438
"EncodecConfig",
433439
"EncodecFeatureExtractor",
@@ -1222,6 +1228,7 @@
12221228
_import_structure["models.donut"].extend(["DonutFeatureExtractor", "DonutImageProcessor"])
12231229
_import_structure["models.dpt"].extend(["DPTFeatureExtractor", "DPTImageProcessor"])
12241230
_import_structure["models.efficientnet"].append("EfficientNetImageProcessor")
1231+
_import_structure["models.emu3"].append("Emu3ImageProcessor")
12251232
_import_structure["models.flava"].extend(["FlavaFeatureExtractor", "FlavaImageProcessor", "FlavaProcessor"])
12261233
_import_structure["models.fuyu"].extend(["FuyuImageProcessor", "FuyuProcessor"])
12271234
_import_structure["models.glpn"].extend(["GLPNFeatureExtractor", "GLPNImageProcessor"])
@@ -2243,6 +2250,15 @@
22432250
"load_tf_weights_in_electra",
22442251
]
22452252
)
2253+
_import_structure["models.emu3"].extend(
2254+
[
2255+
"Emu3ForCausalLM",
2256+
"Emu3ForConditionalGeneration",
2257+
"Emu3PreTrainedModel",
2258+
"Emu3TextModel",
2259+
"Emu3VQVAE",
2260+
]
2261+
)
22462262
_import_structure["models.encodec"].extend(
22472263
[
22482264
"EncodecModel",
@@ -5440,6 +5456,12 @@
54405456
ElectraConfig,
54415457
ElectraTokenizer,
54425458
)
5459+
from .models.emu3 import (
5460+
Emu3Config,
5461+
Emu3Processor,
5462+
Emu3TextConfig,
5463+
Emu3VQVAEConfig,
5464+
)
54435465
from .models.encodec import (
54445466
EncodecConfig,
54455467
EncodecFeatureExtractor,
@@ -6270,6 +6292,7 @@
62706292
from .models.donut import DonutFeatureExtractor, DonutImageProcessor
62716293
from .models.dpt import DPTFeatureExtractor, DPTImageProcessor
62726294
from .models.efficientnet import EfficientNetImageProcessor
6295+
from .models.emu3 import Emu3ImageProcessor
62736296
from .models.flava import (
62746297
FlavaFeatureExtractor,
62756298
FlavaImageProcessor,
@@ -7139,6 +7162,13 @@
71397162
ElectraPreTrainedModel,
71407163
load_tf_weights_in_electra,
71417164
)
7165+
from .models.emu3 import (
7166+
Emu3ForCausalLM,
7167+
Emu3ForConditionalGeneration,
7168+
Emu3PreTrainedModel,
7169+
Emu3TextModel,
7170+
Emu3VQVAE,
7171+
)
71427172
from .models.encodec import (
71437173
EncodecModel,
71447174
EncodecPreTrainedModel,

src/transformers/generation/utils.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1634,17 +1634,18 @@ def _get_cache(
16341634
cache_dtype = self.get_output_embeddings().weight.dtype
16351635

16361636
def get_layer_device_map(execution_device_map: Optional[dict] = None):
1637+
num_hidden_layers = self.config.get_text_config().num_hidden_layers
16371638
if execution_device_map is None:
16381639
return None
16391640
elif len(execution_device_map) == 1 and "" in execution_device_map:
1640-
return {idx: execution_device_map[""] for idx in range(self.config.num_hidden_layers)}
1641+
return {idx: execution_device_map[""] for idx in range(num_hidden_layers)}
16411642
layer_device_map = {}
16421643
for layer in execution_device_map:
1643-
for idx in range(self.config.num_hidden_layers):
1644+
for idx in range(num_hidden_layers):
16441645
if f".{idx}." in f"{layer}.":
16451646
layer_device_map[idx] = execution_device_map[layer]
16461647
break
1647-
for idx in range(self.config.num_hidden_layers):
1648+
for idx in range(num_hidden_layers):
16481649
if idx not in layer_device_map:
16491650
raise RuntimeError(f"layer {idx} has not been mapped to a device.")
16501651
return layer_device_map

src/transformers/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
dpt,
8787
efficientnet,
8888
electra,
89+
emu3,
8990
encodec,
9091
encoder_decoder,
9192
ernie,

src/transformers/models/auto/configuration_auto.py

+2
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@
103103
("efficientformer", "EfficientFormerConfig"),
104104
("efficientnet", "EfficientNetConfig"),
105105
("electra", "ElectraConfig"),
106+
("emu3", "Emu3Config"),
106107
("encodec", "EncodecConfig"),
107108
("encoder-decoder", "EncoderDecoderConfig"),
108109
("ernie", "ErnieConfig"),
@@ -420,6 +421,7 @@
420421
("efficientformer", "EfficientFormer"),
421422
("efficientnet", "EfficientNet"),
422423
("electra", "ELECTRA"),
424+
("emu3", "Emu3"),
423425
("encodec", "EnCodec"),
424426
("encoder-decoder", "Encoder decoder"),
425427
("ernie", "ERNIE"),

src/transformers/models/auto/modeling_auto.py

+3
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,7 @@
499499
("dbrx", "DbrxForCausalLM"),
500500
("diffllama", "DiffLlamaForCausalLM"),
501501
("electra", "ElectraForCausalLM"),
502+
("emu3", "Emu3ForCausalLM"),
502503
("ernie", "ErnieForCausalLM"),
503504
("falcon", "FalconForCausalLM"),
504505
("falcon_mamba", "FalconMambaForCausalLM"),
@@ -800,6 +801,7 @@
800801
("blip", "BlipForConditionalGeneration"),
801802
("blip-2", "Blip2ForConditionalGeneration"),
802803
("chameleon", "ChameleonForConditionalGeneration"),
804+
("emu3", "Emu3ForConditionalGeneration"),
803805
("fuyu", "FuyuForCausalLM"),
804806
("git", "GitForCausalLM"),
805807
("idefics", "IdeficsForVisionText2Text"),
@@ -1428,6 +1430,7 @@
14281430
("deberta-v2", "DebertaV2Model"),
14291431
("distilbert", "DistilBertModel"),
14301432
("electra", "ElectraModel"),
1433+
("emu3", "Emu3TextModel"),
14311434
("flaubert", "FlaubertModel"),
14321435
("ibert", "IBertModel"),
14331436
("longformer", "LongformerModel"),

src/transformers/models/auto/processing_auto.py

+1
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
("clipseg", "CLIPSegProcessor"),
6060
("clvp", "ClvpProcessor"),
6161
("colpali", "ColPaliProcessor"),
62+
("emu3", "Emu3Processor"),
6263
("flava", "FlavaProcessor"),
6364
("fuyu", "FuyuProcessor"),
6465
("git", "GitProcessor"),

src/transformers/models/auto/tokenization_auto.py

+1
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@
186186
),
187187
),
188188
("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)),
189+
("emu3", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
189190
("ernie", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
190191
("ernie_m", ("ErnieMTokenizer" if is_sentencepiece_available() else None, None)),
191192
("esm", ("EsmTokenizer", None)),

src/transformers/models/chameleon/processing_chameleon.py

+1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class ChameleonProcessor(ProcessorMixin):
6262

6363
attributes = ["image_processor", "tokenizer"]
6464
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
65+
valid_kwargs = ["image_seq_length", "image_token"]
6566
image_processor_class = "ChameleonImageProcessor"
6667

6768
def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, image_token: str = "<image>"):

0 commit comments

Comments
 (0)