From 9e363aca9d1d3538e1c679265a27e49569b8fa9f Mon Sep 17 00:00:00 2001 From: Nino Risteski <95188570+NinoRisteski@users.noreply.github.com> Date: Fri, 21 Feb 2025 14:04:15 +0100 Subject: [PATCH] Update vision_utils.py my try for VLMs Data Collator --- unsloth_zoo/vision_utils.py | 50 +++++++++++++++++++++++-------------- 1 file changed, 31 insertions(+), 19 deletions(-) diff --git a/unsloth_zoo/vision_utils.py b/unsloth_zoo/vision_utils.py index a655ec1..f6f7b6b 100644 --- a/unsloth_zoo/vision_utils.py +++ b/unsloth_zoo/vision_utils.py @@ -255,41 +255,53 @@ def __init__(self, model, processor, formatting_func = None, ignore_index = -100 self.processor = processor self.formatting_func = formatting_func return - pass def __call__(self, examples): - # [TODO] Support non image inputs as well - # The issue is batch = self.processor( forces tensors to be returned and not None. - texts = [] + # Support mixed text & image examples + texts = [] images = [] - + if self.formatting_func is not None: examples = [self.formatting_func(example) for example in examples] - - for example in examples: + + # Determine whether any example has image data. + has_any_image = any( + "images" in example and example["images"] is not None and len(example["images"]) > 0 + for example in examples + ) + + # If at least one example contains images, use a dummy image for text-only examples. + dummy_image = None + if has_any_image: + from PIL import Image + dummy_image = Image.new("RGB", (1, 1)) + + for example in examples: messages = example["messages"] message = self.processor.apply_chat_template( messages, tokenize = False, add_generation_prompt = False, ) - # Dataset with 2 columns messages / images - if "images" in example: + texts.append(message) + + # Use the image provided or set to a dummy if we are in a mixed batch. + if "images" in example and example["images"] and example["images"][0] is not None: image = example["images"][0] else: - image, video = process_vision_info(messages) - texts .append(message) + image = dummy_image if has_any_image else None images.append(image) - pass - # Tokenize the texts and process the images + # If the batch is entirely text-only then set images to None. + if not has_any_image: + images = None + + # Tokenize texts and process images (if any) batch = self.processor( - text = texts, - images = images, - padding = True, - # [TODO] Truncating to max_seq_length does NOT work for VLMs - # truncation = True, - return_tensors = "pt", + text=texts, + images=images, + padding=True, + return_tensors="pt", ) batch.pop("token_type_ids", None)