img1 / main.py
Jamiiwej2903's picture
Update main.py
4db1b30 verified
raw
history blame
2.5 kB
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
import uvicorn
from fastapi.responses import StreamingResponse
import io
import requests
from PIL import Image
import ffmpeg
import tempfile
import os
from huggingface_hub import InferenceClient
app = FastAPI()
client = InferenceClient("stabilityai/stable-video-diffusion-img2vid-xt-1-1-tensorrt")
@app.post("/generate_video/")
async def generate_video_api(
file: UploadFile = File(...),
num_frames: int = Form(14),
fps: int = Form(7)
):
try:
# Read the uploaded image file
image_content = await file.read()
image = Image.open(io.BytesIO(image_content))
# Generate video frames using the stable-video-diffusion model
video_frames = client.post(
json={
"inputs": image,
"parameters": {
"num_inference_steps": 25,
"num_frames": num_frames,
}
}
)
# Create a temporary directory
with tempfile.TemporaryDirectory() as tmpdir:
# Save frames as temporary files
frame_files = []
for i, frame in enumerate(video_frames):
frame_file = os.path.join(tmpdir, f"frame_{i:03d}.png")
frame.save(frame_file)
frame_files.append(frame_file)
# Create a temporary file for the video
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as temp_video:
temp_video_path = temp_video.name
# Use ffmpeg-python to combine images into a video
input_stream = ffmpeg.input(os.path.join(tmpdir, 'frame_%03d.png'), framerate=fps)
output_stream = ffmpeg.output(input_stream, temp_video_path, vcodec='libx264', pix_fmt='yuv420p')
ffmpeg.run(output_stream, overwrite_output=True)
# Read the temporary video file
with open(temp_video_path, 'rb') as video_file:
video_content = video_file.read()
# Delete the temporary video file
os.unlink(temp_video_path)
# Return the video as a streaming response
return StreamingResponse(io.BytesIO(video_content), media_type="video/mp4")
except Exception as err:
# Handle any errors
raise HTTPException(status_code=500, detail=f"An error occurred: {err}")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)