Sidharthan commited on
Commit
9d2f477
·
1 Parent(s): 301f7c4

Added the application files, including the inference endpoints and configs

Browse files
Files changed (3) hide show
  1. Dockerfile +28 -0
  2. app.py +78 -0
  3. requirements.txt +8 -0
Dockerfile ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use CUDA-compatible base image if you need GPU support
2
+ # For CPU-only:
3
+ FROM python:3.9-slim
4
+
5
+ # Set working directory
6
+ WORKDIR /app
7
+
8
+ # Install system dependencies
9
+ RUN apt-get update && apt-get install -y \
10
+ build-essential \
11
+ git \
12
+ && rm -rf /var/lib/apt/lists/*
13
+
14
+ # Copy requirements first to leverage Docker cache
15
+ COPY requirements.txt .
16
+
17
+ # Install Python dependencies
18
+ RUN pip install --no-cache-dir --upgrade pip && \
19
+ pip install --no-cache-dir -r requirements.txt
20
+
21
+ # Copy the application code
22
+ COPY . .
23
+
24
+ # Expose the port the app runs on
25
+ EXPOSE 7860
26
+
27
+ # Command to run the application
28
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ from fastapi import FastAPI, HTTPException
3
+ from pydantic import BaseModel
4
+ from transformers import AutoTokenizer
5
+ from peft import AutoPeftModelForCausalLM
6
+ import torch
7
+ from typing import Optional
8
+
9
+ app = FastAPI(title="Gemma Script Generator API")
10
+
11
+ # Load model and tokenizer
12
+ MODEL_NAME = "Sidharthan/gemma2_scripter"
13
+
14
+ try:
15
+ tokenizer = AutoTokenizer.from_pretrained(
16
+ MODEL_NAME,
17
+ trust_remote_code=True
18
+ )
19
+ model = AutoPeftModelForCausalLM.from_pretrained(
20
+ MODEL_NAME,
21
+ device_map="auto", # Will use CPU if GPU not available
22
+ trust_remote_code=True,
23
+ #load_in_4bit=True
24
+ )
25
+ except Exception as e:
26
+ print(f"Error loading model: {str(e)}")
27
+ raise
28
+
29
+ class GenerationRequest(BaseModel):
30
+ message: str
31
+ max_length: Optional[int] = 512
32
+ temperature: Optional[float] = 0.7
33
+ top_p: Optional[float] = 0.95
34
+ top_k: Optional[int] = 50
35
+ repetition_penalty: Optional[float] = 1.2
36
+
37
+ class GenerationResponse(BaseModel):
38
+ generated_text: str
39
+
40
+ @app.post("/generate", response_model=GenerationResponse)
41
+ async def generate_script(request: GenerationRequest):
42
+ try:
43
+ # Format prompt
44
+ prompt = request.message
45
+ # Tokenize input
46
+ inputs = tokenizer(prompt, return_tensors="pt")
47
+ if torch.cuda.is_available():
48
+ inputs = {k: v.cuda() for k, v in inputs.items()}
49
+
50
+ # Generate
51
+ outputs = model.generate(
52
+ **inputs,
53
+ max_length=request.max_length,
54
+ do_sample=True,
55
+ temperature=request.temperature,
56
+ top_p=request.top_p,
57
+ top_k=request.top_k,
58
+ repetition_penalty=request.repetition_penalty,
59
+ num_return_sequences=1,
60
+ pad_token_id=tokenizer.pad_token_id,
61
+ eos_token_id=tokenizer.eos_token_id
62
+ )
63
+
64
+ # Decode output
65
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
66
+
67
+ return GenerationResponse(generated_text=generated_text)
68
+
69
+ except Exception as e:
70
+ raise HTTPException(status_code=500, detail=str(e))
71
+
72
+ @app.get("/health")
73
+ async def health_check():
74
+ return {"status": "healthy"}
75
+
76
+ if __name__ == "__main__":
77
+ import uvicorn
78
+ uvicorn.run(app, host="0.0.0.0", port=7860)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ torch
4
+ transformers
5
+ peft
6
+ pydantic
7
+ bitsandbytes
8
+ accelerate