catanddogs / app.py
itsdevice's picture
Upload 5 files
49bd98e verified
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
import tensorflow as tf
from tensorflow.keras.preprocessing.image import img_to_array
from PIL import Image
import numpy as np
import io
# Initialize FastAPI app
app = FastAPI(title="Cat vs Dog Classifier API")
# Load the pre-trained model
model = tf.keras.models.load_model('model.h5')
# Define class labels
class_names = ['Cat', 'Dog']
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
# Check if the uploaded file is an image
if not file.content_type.startswith('image/'):
raise HTTPException(status_code=400, detail="File must be an image")
# Read and preprocess the image
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert('RGB')
image = image.resize((224, 224)) # Resize to match model input
image_array = img_to_array(image) / 255.0 # Rescale to [0, 1]
image_array = np.expand_dims(image_array, axis=0) # Add batch dimension
# Make prediction
prediction = model.predict(image_array)
predicted_class = class_names[int(prediction[0][0] > 0.5)] # Sigmoid threshold
confidence = float(prediction[0][0]) if predicted_class == 'Dog' else float(1 - prediction[0][0])
return JSONResponse({
"predicted_class": predicted_class,
"confidence": confidence
})
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
@app.get("/")
async def root():
return {"message": "Welcome to the Cat vs Dog Classifier API. Use POST /predict/ to classify an image."}