From 97f36cd4b954ae3c0fda29c2f400d8b87712c3ce Mon Sep 17 00:00:00 2001 From: coco <1228759711@qq.com> Date: Wed, 22 May 2024 09:09:34 +0000 Subject: [PATCH] enhence mixtoken for qwen --- paddlemix/datasets/collator.py | 3 +++ paddlemix/processors/qwen_vl_processing.py | 3 +-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/paddlemix/datasets/collator.py b/paddlemix/datasets/collator.py index 42daa9594..bcd8f56dd 100644 --- a/paddlemix/datasets/collator.py +++ b/paddlemix/datasets/collator.py @@ -138,6 +138,9 @@ def __call__(self, data_list): if "images" in raw_data: if isinstance(raw_data["images"], list): + if not isinstance(raw_data["images"][0], list): + raw_data["images"] = [raw_data["images"]] + raw_data["images"] = [self.processor.image_processor(path) for path in raw_data["images"]] raw_data["images"] = paddle.stack(x=raw_data["images"], axis=0) images.append(raw_data["images"]) diff --git a/paddlemix/processors/qwen_vl_processing.py b/paddlemix/processors/qwen_vl_processing.py index ce587353b..6ce584a7c 100644 --- a/paddlemix/processors/qwen_vl_processing.py +++ b/paddlemix/processors/qwen_vl_processing.py @@ -128,8 +128,7 @@ def train_preprocess(self, sources, system_message: str = "You are a helpful ass labels=target[: self.max_len], ) if len(image_path) > 0: - inputs["images"] = self.image_processor(image_path) - + inputs["images"] = image_path return inputs def batch_decode(self, *args, **kwargs):