Ivan commited on
Commit
1375deb
·
1 Parent(s): c876b97

Create demo

Browse files
Files changed (5) hide show
  1. app.py +43 -0
  2. artifacts/ball.png +0 -0
  3. artifacts/panda.jpg +0 -0
  4. requirements.txt +2 -0
  5. utils.py +47 -0
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from utils import model_initialization, prediction
3
+ from PIL import Image
4
+ from typing import Dict, Any
5
+
6
+
7
+ def gradio_interface(image: Image.Image) -> Dict[str, Any]:
8
+ """
9
+ Perform image classification using a pre-trained model.
10
+
11
+ Args:
12
+ image (Image.Image): The input image uploaded by the user.
13
+
14
+ Returns:
15
+ Dict[str, Any]: A dictionary containing the classification result with the
16
+ most promising label and confidence score.
17
+ """
18
+ # Initialize the pre-trained pipeline
19
+ pipe = model_initialization()
20
+
21
+ # Perform prediction on the uploaded image
22
+ result = prediction(pipe, image)
23
+
24
+ return result
25
+
26
+
27
+ # Define the Gradio interface
28
+ demo = gr.Interface(
29
+ fn=gradio_interface,
30
+ inputs=gr.Image(type="pil", label="Upload Image"), # Accepts PIL Image input
31
+ outputs=gr.JSON(label="Prediction Details"), # Outputs as JSON
32
+ title="RESNET WILL NEVER DIE. Image Classification with ResNet-18",
33
+ description=(
34
+ "Welcome to the Image Classification Demo! Upload an image to classify it using"
35
+ "ResNet-18 model. The model will predict the most likely label along with its confidence score."
36
+ ),
37
+ theme="soft",
38
+ examples=[["artifacts/ball.png"], ["artifacts/panda.jpg"]],
39
+ )
40
+
41
+ # Launch the Gradio app
42
+ if __name__ == "__main__":
43
+ demo.launch()
artifacts/ball.png ADDED
artifacts/panda.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ transformers==4.46.3
2
+ torch==2.5.1
utils.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline, Pipeline
2
+ from functools import lru_cache
3
+ from typing import Optional, Dict, Any
4
+ import numpy as np
5
+
6
+
7
+ @lru_cache
8
+ def model_initialization(task: str = "image-classification", model_name: str = "microsoft/resnet-18") -> Pipeline:
9
+ """
10
+ Initialize the Hugging Face pipeline for a specified task and model.
11
+
12
+ Args:
13
+ task (str): The task type, e.g., "image-classification".
14
+ model_name (str): The name or path of the model to use.
15
+
16
+ Returns:
17
+ Pipeline: A Hugging Face pipeline object ready for inference.
18
+ """
19
+ pipe = pipeline(task, model=model_name)
20
+
21
+ return pipe
22
+
23
+
24
+ def prediction(pipe: Pipeline, img: np.ndarray) -> Optional[Dict[str, Any]]:
25
+ """
26
+ Perform image classification on the given image using the specified pipeline.
27
+
28
+ Args:
29
+ pipe (Pipeline): The initialized hf pipeline object.
30
+ img (np.ndarray): The image to classify.
31
+
32
+ Returns:
33
+ Optional[Dict[str, Any]]: A dictionary containing the most promising label and its confidence score,
34
+ or None if no results are returned.
35
+ """
36
+ results = pipe(img)
37
+ results.sort(key=lambda x: x["score"], reverse=True)
38
+
39
+ if not results:
40
+ return None
41
+
42
+ response = {
43
+ "most_promising_label": results[0]["label"],
44
+ "confidence": round(results[0]["score"], 2)
45
+ }
46
+
47
+ return response