Spaces:
Build error
Build error
import cv2 | |
import numpy as np | |
import urllib.request | |
import matplotlib.pyplot as plt | |
from tensorflow.keras.optimizers import Adam | |
from huggingface_hub import from_pretrained_keras | |
reloaded_model = from_pretrained_keras('ShaharAdar/best-model-try') | |
reloaded_model.compile(optimizer=Adam(0.00001), | |
loss='categorical_crossentropy', | |
metrics=['accuracy'] | |
) | |
class_names = ['Clams', 'Corals', 'Crabs', 'Dolphin', 'Eel', 'Fish', | |
'Jelly Fish', 'Lobster', 'Nudibranchs', 'Octopus', 'Otter', | |
'Penguin', 'Puffers', 'Sea Rays', 'Sea Urchins', 'Seahorse', | |
'Seal', 'Sharks', 'Shrimp', 'Squid', 'Starfish', | |
'Turtle_Tortoise', 'Whale'] | |
def fetch_image(filepath): | |
try: | |
# Directly read the image from the provided file path | |
image = cv2.imread(filepath, cv2.IMREAD_UNCHANGED) | |
# Convert the image from BGR to RGB (if necessary) | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
return image | |
except Exception as e: | |
print("Error reading image:", e) | |
return None # Return None to indicate an error | |
def fetch_image_2(filepath): | |
resp = urllib.request.urlopen(url) | |
image = np.asarray(bytearray(resp.read()), dtype="uint8") | |
image = cv2.imdecode(image, cv2.IMREAD_UNCHANGED) | |
# Convert the image from BGR to RGB | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
return image | |
def disp_img(image): | |
# Display the image | |
plt.imshow(image) | |
plt.xticks([]) | |
plt.yticks([]) | |
plt.grid(False) | |
plt.show() | |
def resize_2(img): | |
img = cv2.resize(img,(224,224)) # resize image to match model's expected sizing | |
img = img.reshape(1,224,224,3) # return the image with shaping that TF wants. | |
return img | |
def make_prediction(image): | |
prediction = reloaded_model.predict(image) | |
predicted_class = prediction.argmax() | |
print('Predicted class: ', class_names[predicted_class]) | |
def predict_class(url): | |
image = fetch_image_2(url) | |
disp_img(image) | |
image = resize_2(image) | |
make_prediction(image) | |
print("\n") | |
img = input("Enter url of images you want to predict it's class:") | |
print(predict_class(img)) |