codewithdark commited on
Commit
ab483ac
·
verified ·
1 Parent(s): bfa790c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +41 -14
README.md CHANGED
@@ -54,25 +54,52 @@ The model was fine-tuned using the following settings:
54
  To use the fine-tuned model for inference, simply load the model from Hugging Face's Model Hub and input a chest X-ray image:
55
 
56
  ```python
57
- from transformers import ViTForImageClassification, ViTFeatureExtractor
58
- import torch
59
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- # Load model and feature extractor
62
- model = ViTForImageClassification.from_pretrained('codewithdark/vit-chest-xray')
63
- feature_extractor = ViTFeatureExtractor.from_pretrained('codewithdark/vit-chest-xray')
64
 
65
- # Prepare an image for prediction
66
- image = Image.open('path_to_chest_xray_image.jpg')
 
67
 
68
- # Preprocess the image and make predictions
69
- inputs = feature_extractor(images=image, return_tensors="pt")
70
- outputs = model(**inputs)
71
- logits = outputs.logits
72
- predictions = torch.sigmoid(logits).squeeze()
73
 
74
- # Display predictions
75
- print(predictions)
 
 
 
76
  ```
77
 
78
  ### Fine-Tuning
 
54
  To use the fine-tuned model for inference, simply load the model from Hugging Face's Model Hub and input a chest X-ray image:
55
 
56
  ```python
 
 
57
  from PIL import Image
58
+ import torch
59
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
60
+
61
+ # Load model and processor
62
+ processor = AutoImageProcessor.from_pretrained("codewithdark/vit-chest-xray")
63
+ model = AutoModelForImageClassification.from_pretrained("codewithdark/vit-chest-xray")
64
+
65
+ # Define label columns (class names)
66
+ label_columns = ['Cardiomegaly', 'Edema', 'Consolidation', 'Pneumonia', 'No Finding']
67
+
68
+ # Step 1: Load and preprocess the image
69
+ image_path = "/content/images.jpeg" # Replace with your image path
70
+
71
+ # Open the image
72
+ image = Image.open(image_path)
73
+
74
+ # Ensure the image is in RGB mode (required by most image classification models)
75
+ if image.mode != 'RGB':
76
+ image = image.convert('RGB')
77
+ print("Image converted to RGB.")
78
+
79
+ # Step 2: Preprocess the image using the processor
80
+ inputs = processor(images=image, return_tensors="pt")
81
+
82
+ # Step 3: Make a prediction (using the model)
83
+ with torch.no_grad(): # Disable gradient computation during inference
84
+ outputs = model(**inputs)
85
 
86
+ # Step 4: Extract logits and get the predicted class index
87
+ logits = outputs.logits # Raw logits from the model
88
+ predicted_class_idx = torch.argmax(logits, dim=-1).item() # Get the class index
89
 
90
+ # Step 5: Map the predicted index to a class label
91
+ # You can also use `model.config.id2label`, but we'll use `label_columns` for this task
92
+ predicted_class_label = label_columns[predicted_class_idx]
93
 
94
+ # Output the results
95
+ print(f"Predicted Class Index: {predicted_class_idx}")
96
+ print(f"Predicted Class Label: {predicted_class_label}")
 
 
97
 
98
+ '''
99
+ Output :
100
+ Predicted Class Index: 4
101
+ Predicted Class Label: No Finding
102
+ '''
103
  ```
104
 
105
  ### Fine-Tuning