Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mismatch Issue Between Input Image and Model Prediction #37

Open
loseen01 opened this issue Nov 17, 2024 · 7 comments
Open

Mismatch Issue Between Input Image and Model Prediction #37

loseen01 opened this issue Nov 17, 2024 · 7 comments

Comments

@loseen01
Copy link

loseen01 commented Nov 17, 2024

Hello everyone,

I'm facing an issue with matrix handling when inputting an image, such as from the files (exsample.tif). The final prediction result shows a mismatch between the input image and what the model expects. This is surprising since the input image is supposed to be one of the images used during model training. Has anyone encountered this issue or have any suggestions on how to resolve it?
I'm a beginner so any help is appreciated.

@Dadatata-JZ
Copy link
Collaborator

Hi @loseen01

I'm sorry, but I don't fully understand the specific problem you're encountering without additional context. It sounds like a typical issue with a mismatch in input dimensions. Have you ensured that your image is resized to the required dimensions for the model? you need to confirm that your input tensor matches the model's expected size.

Feel free to post some snapshots or codes if you want us to take a quick look. Also, if you are a beginner of pytorch for CV applications, you may want to go over tutorials like this first.

https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html

@sara-hashemi
Copy link

Hi, I also encountered the same issue and would appreciate any input on it.
I used the exact code you provided for patch-level feature extraction:

import torch, torchvision
import torch.nn as nn
from torchvision import transforms
from PIL import Image
from models.ctran import ctranspath

mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
trnsfrms_val = transforms.Compose(
[
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean = mean, std = std)
]
)

model = ctranspath()
model.head = nn.Identity()
td = torch.load(r'./model_weight/CHIEF_CTransPath.pth')
model.load_state_dict(td['model'], strict=True)
model.eval()

image = Image.open("./exsample/exsample.tif")
image = trnsfrms_val(image).unsqueeze(dim=0)
with torch.no_grad():
patch_feature_emb = model(image) # Extracted features (torch.Tensor) with shape [1,768]
print(patch_feature_emb.size())

The exsample.tif image is provided in your code. Based on my understanding, the images should be in tif format, they should be preprocessed (background removal- mentioned in one of your other comments) and then the CTransPath would be applied to it. The transformation "transforms.Resize(224)" provides us with tiles with a 224x224 dimension. This would be the input to CTransPath. When I run this code the below error is presented to me (please note I inserted a couple of print lines to understand what torch input is being passed on to the model):

Patch shape after preprocessing/before model input: torch.Size([1, 3, 224, 224])
Traceback (most recent call last):
File "/home/sagemaker-user/CHIEF/featureExt.py", line 206, in
process_svs_directory(input_dir, output_dir, model)
File "/home/sagemaker-user/CHIEF/featureExt.py", line 194, in process_svs_directory
extract_features_from_patches(patches, model, output_dir, svs_name)
File "/home/sagemaker-user/CHIEF/featureExt.py", line 157, in extract_features_from_patches
feature = model(image) # Extract feature
File "/home/sagemaker-user/.conda/envs/chief_env_39/lib/python3.9/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/sagemaker-user/.conda/envs/chief_env_39/lib/python3.9/site-packages/timm/models/swin_transformer.py", line 541, in forward
x = self.forward_features(x)
File "/home/sagemaker-user/.conda/envs/chief_env_39/lib/python3.9/site-packages/timm/models/swin_transformer.py", line 530, in forward_features
x = self.patch_embed(x)
File "/home/sagemaker-user/.conda/envs/chief_env_39/lib/python3.9/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/sagemaker-user/.conda/envs/chief_env_39/lib/python3.9/site-packages/timm/models/layers/patch_embed.py", line 35, in forward
x = self.proj(x)
File "/home/sagemaker-user/.conda/envs/chief_env_39/lib/python3.9/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/sagemaker-user/.conda/envs/chief_env_39/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 399, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/home/sagemaker-user/.conda/envs/chief_env_39/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 395, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same

Can you please advise on this matter? Thank you

@Dadatata-JZ
Copy link
Collaborator

Dadatata-JZ commented Dec 6, 2024

@sara-hashemi
hi, check whether you have installed ctranspath correctly (pls carefully read the installation note coz it requires an old timm version). Once you make sure your env's passed the test for extracting features using ctranspath, you can then switch back and load CHIEF weights (using the same env).

@sara-hashemi
Copy link

sara-hashemi commented Dec 6, 2024

@Dadatata-JZ
I used the version 0.5.4 for timm upon installation. I have installed all the steps mentioned on the page as below:
1- Installing openslide
2- installing requirements.txt
3- Cloning chief through git clone https://github.com/hms-dbmi/CHIEF.git
4- Downloading all the pre-trained models in Google Drive to the specified location in the CHIEF folder
It wasn't mentioned that CTransPath should be installed. Is that a prerequisite and could you please further elaborate?

@sara-hashemi
Copy link

sara-hashemi commented Dec 9, 2024

@sara-hashemi hi, check whether you have installed ctranspath correctly (pls carefully read the installation note coz it requires an old timm version). Once you make sure your env's passed the test for extracting features using ctranspath, you can then switch back and load CHIEF weights (using the same env).

I redid all the installation and prerequisite's process this morning, including installation of timm-0.5.4 and requirements.txt. I am running your code on an AWS GPU server and it keeps getting stuck in the initial phase. I have downloaded all the pre-trained files/weights available on your Google Drive onto a folder (model_weights) and am running the below sample code from your GitHub page:

=========================================

import torch, torchvision
import torch.nn as nn
from torchvision import transforms
from PIL import Image
from models.ctran import ctranspath

mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
trnsfrms_val = transforms.Compose(
[
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean = mean, std = std)
]
)

model = ctranspath()
model.head = nn.Identity()
td = torch.load(r'./model_weight/CHIEF_CTransPath.pth')
model.load_state_dict(td['model'], strict=True)
model.eval()

image = Image.open("./exsample/exsample.tif")
image = trnsfrms_val(image).unsqueeze(dim=0)
with torch.no_grad():
patch_feature_emb = model(image) # Extracted features (torch.Tensor) with shape [1,768]
print(patch_feature_emb.size())

=========================================

This should be able to extract features from the one provided image (exsample.tif) in the form of a tensor with a shape of 1x768. However, The same error is being shown: RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same

  • Could you please let me know what other requirements should be met other than the ones mentioned on the GitHub page?

  • Have you had any experience to resolve such errors?

  • Should all images be in a tif format (the TCGA samples are in .svs format and we would need to know whether certain formats should be converted)?
    Thanks!

@Dadatata-JZ
Copy link
Collaborator

RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor)

@sara-hashemi it is a pytorch question. you model is on GPU and your data is on CPU.

@loseen01
Copy link
Author

Hi @Dadatata-JZ
Thank you very much for your response. I resized my photos, but I am still facing problems and I thought it was possible because of the resize process.
The image may have changed or become distorted. Is it possible for this to happen?
Also regarding the origin of the tumor
When I run the following code, it gives me a different result on the same image with each run . And when I used seed to install randomness
The result has become fixed on all images. For example, that thyroid is showing on all origins, and if you change the seed number, it will also be fixed, but on another origin, for example breast, and so on.
This is the code I used
To predict the origin of the tumor (without seed)
Do you have any idea what the possible cause is?

import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import os

from models.ctran import ctranspath
from Downstream.Tumor_origin.src.network import CHIEF_Tumor_origin

Increase PIL's MAX_IMAGE_PIXELS limit to handle large images

⚠️ Caution: Disabling the pixel limit can expose your application to potential DoS attacks.

Ensure that all input images are from trusted sources.

Image.MAX_IMAGE_PIXELS = None # Disables the limit

Constants

MEAN = (0.485, 0.456, 0.406)
STD = (0.229, 0.224, 0.225)

class TumorOriginPredictor:
def init(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Initialize feature extractor
    self.feature_extractor = self._init_feature_extractor()
    
    # Initialize classifier
    self.classifier = self._init_classifier()
    
    # Setup transforms
    self.transforms = transforms.Compose([
        transforms.Resize((224, 224)),  # Updated to resize to 224x224
        transforms.ToTensor(),
        transforms.Normalize(mean=MEAN, std=STD)
    ])
    
    # Label mapping
    self.label_dict = {
        0: 'Prostate', 1: 'Lung', 2: 'Endometrial', 3: 'Breast', 
        4: 'Head Neck', 5: 'Colorectal', 6: 'Thyroid', 7: 'Skin',
        8: 'Esophagogastric', 9: 'Ovarian', 10: 'Glioma', 11: 'Bladder',
        12: 'Adrenal', 13: 'Colon', 14: 'Germ Cell', 15: 'Pancreatobiliary',
        16: 'Liver', 17: 'Cervix'
    }

def _init_feature_extractor(self):
    model = ctranspath()
    model.head = nn.Identity()
    
    # Load feature extractor weights
    weights = torch.load('./model_weight/CHIEF_CTransPath.pth', map_location=torch.device('cpu'))
    model.load_state_dict(weights['model'], strict=True)
    
    model = model.to(self.device)
    model.eval()
    return model

def _init_classifier(self):
    model = CHIEF_Tumor_origin(n_classes=18)  # 18 tumor types
    
    # Load classifier weights
    weights = torch.load('./model_weight/CHIEF_finetune.pth', map_location=torch.device('cpu'))
    model.load_state_dict(weights, strict=False)
    
    model = model.to(self.device)
    model.eval()
    return model

def predict(self, image_path):
    """
    Predict tumor origin from an image path
    
    Args:
        image_path (str): Path to the image file
        
    Returns:
        dict: Dictionary containing predicted class and probabilities
    """
    try:
        # Load and preprocess image
        image = Image.open(image_path)
        image = self.transforms(image).unsqueeze(0)
        image = image.to(self.device)
    except Image.DecompressionBombError as e:
        print(f"Error: The image is too large and may be a decompression bomb. Details: {e}")
        return {}
    except FileNotFoundError:
        print(f"Error: The image file '{image_path}' does not exist.")
        return {}
    except Exception as e:
        print(f"An unexpected error occurred while opening the image: {e}")
        return {}
    
    try:
        # Extract features
        with torch.no_grad():
            # Get patch features
            patch_features = self.feature_extractor(image)  # Shape: [1, 768]
            
            # Reshape features for classifier
            # The classifier expects features of shape [N, 768] where N is number of patches
            patch_features = patch_features.unsqueeze(0)  # Shape: [1, 1, 768]
            
            # Get predictions
            logits, probabilities = self.classifier(patch_features)
            
            # Get predicted class
            pred_class = torch.argmax(probabilities, dim=1).item()
            
            # Get class probabilities
            probs = probabilities.squeeze().cpu().numpy()
            
            # Create results dictionary
            results = {
                'predicted_class': self.label_dict.get(pred_class, "Unknown"),
                'confidence': float(probabilities[0][pred_class]),
                'probabilities': {self.label_dict[i]: float(prob) 
                                for i, prob in enumerate(probs)}
            }
            
        return results
    except Exception as e:
        print(f"Error during prediction: {str(e)}")
        return {}

def main():
# Example usage
predictor = TumorOriginPredictor()

# Example image path
image_path = "./exsample/a.tiff"

if not os.path.exists(image_path):
    print(f"Error: Image file not found at {image_path}")
    return
    
results = predictor.predict(image_path)

if results:
    print("\nTumor Origin Prediction Results:")
    print(f"Predicted Class: {results['predicted_class']}")
    print(f"Confidence: {results['confidence']:.4f}")
    
    print("\nClass Probabilities:")
    for class_name, prob in results['probabilities'].items():
        print(f"{class_name}: {prob:.4f}")
else:
    print("Prediction could not be completed due to errors.")

if name == "main":
main()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants