sachin commited on
Commit
8243283
·
1 Parent(s): 9598ec0

add nim inference

Browse files
{recipes → inference}/mistral_inference.py RENAMED
File without changes
inference/nim_inference.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import requests, base64
3
+ import os
4
+ import json
5
+
6
+ def vision_inference(image_name):
7
+ try:
8
+ invoke_url = "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-11b-vision-instruct/chat/completions"
9
+ stream = False
10
+
11
+ with open(image_name, "rb") as f:
12
+ image_b64 = base64.b64encode(f.read()).decode()
13
+
14
+ #assert len(image_b64) < 180_000, \
15
+ # "To upload larger images, use the assets API (see docs)"
16
+
17
+ api_key = os.environ["NIM_API_KEY"]
18
+
19
+ headers = {
20
+ "Authorization": f"Bearer {api_key}",
21
+ "Accept": "text/event-stream" if stream else "application/json"
22
+ }
23
+ payload = {
24
+ "model": 'meta/llama-3.2-11b-vision-instruct',
25
+ "messages": [
26
+ {
27
+ "role": "user",
28
+ "content": f'What is in this image? <img src="data:image/png;base64,{image_b64}" />'
29
+ }
30
+ ],
31
+ "max_tokens": 512,
32
+ "temperature": 1.00,
33
+ "top_p": 1.00,
34
+ "stream": stream
35
+ }
36
+
37
+ response = requests.post(invoke_url, headers=headers, json=payload)
38
+
39
+ if stream:
40
+ for line in response.iter_lines():
41
+ if line:
42
+ #print(line.decode("utf-8"))
43
+ data = line.decode("utf-8")
44
+ #content = json.loads(data)['choices'][0]['delta'].get('content', '')
45
+ else:
46
+ #print(response.json())
47
+ data = response.json()
48
+ content = data['choices'][0]['message']['content']
49
+
50
+ #print(content)
51
+ return content
52
+
53
+ except Exception as e: # Added general exception handling
54
+ print(f"Error: {e}")
55
+ return None
56
+
57
+ #image_name = "/home/gaganyatri/Pictures/hackathon/eat-health/fruit-stall-1.jpg"
58
+ #content = vision_inference(image_name)
59
+ #print(content)
recipes/engine.py CHANGED
@@ -2,7 +2,7 @@ import pandas as pd
2
  import numpy as np
3
  import requests
4
  import json
5
- from .mistral_inference import text_llm
6
  from django.core.files.storage import default_storage
7
 
8
  def execute_prompt(prompt, local=True):
 
2
  import numpy as np
3
  import requests
4
  import json
5
+ from inference.mistral_inference import text_llm
6
  from django.core.files.storage import default_storage
7
 
8
  def execute_prompt(prompt, local=True):
recipes/urls.py CHANGED
@@ -1,9 +1,10 @@
1
  from django.urls import path
2
  from .views import recipe_generate_route, execute_prompt_route_get
3
- from .views import VisionLLMView
4
 
5
  urlpatterns = [
6
  path('execute_prompt_get/', execute_prompt_route_get, name='execute_prompt_get'),
7
  path('recipe_generate/', recipe_generate_route, name='recipe_generate'),
8
  path('vision_llm_url/', VisionLLMView.as_view()),
 
9
  ]
 
1
  from django.urls import path
2
  from .views import recipe_generate_route, execute_prompt_route_get
3
+ from .views import VisionLLMView, NIMVisionLLMView
4
 
5
  urlpatterns = [
6
  path('execute_prompt_get/', execute_prompt_route_get, name='execute_prompt_get'),
7
  path('recipe_generate/', recipe_generate_route, name='recipe_generate'),
8
  path('vision_llm_url/', VisionLLMView.as_view()),
9
+ path('nim_vision_llm_url/', NIMVisionLLMView.as_view()),
10
  ]
recipes/views.py CHANGED
@@ -8,6 +8,7 @@ from mistralai import Mistral
8
  import os
9
  import base64
10
  import json
 
11
 
12
  class PromptSerializer(serializers.Serializer):
13
  prompt = serializers.CharField()
@@ -60,6 +61,7 @@ class VisionLLMView(APIView):
60
  #image_data = base64.b64decode(data['image'])
61
  #image_data = base64.b64decode(data['messages'][0]['image'][0])
62
  image_data = (data['messages'][0]['image'][0])
 
63
 
64
  # Define the messages for the chat
65
  messages = [
@@ -68,7 +70,7 @@ class VisionLLMView(APIView):
68
  "content": [
69
  {
70
  "type": "text",
71
- "text": data['messages'][0]['prompt']
72
  },
73
  {
74
  "type": "image_url",
@@ -83,7 +85,56 @@ class VisionLLMView(APIView):
83
  model=model,
84
  messages=messages
85
  )
 
 
86
  #print(chat_response.choices[0].message.content)
87
  # Return the content of the response
88
- return Response({"response": chat_response.choices[0].message.content})
 
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  import os
9
  import base64
10
  import json
11
+ import requests
12
 
13
  class PromptSerializer(serializers.Serializer):
14
  prompt = serializers.CharField()
 
61
  #image_data = base64.b64decode(data['image'])
62
  #image_data = base64.b64decode(data['messages'][0]['image'][0])
63
  image_data = (data['messages'][0]['image'][0])
64
+ prompt = data['messages'][0]['prompt']
65
 
66
  # Define the messages for the chat
67
  messages = [
 
70
  "content": [
71
  {
72
  "type": "text",
73
+ "text": prompt
74
  },
75
  {
76
  "type": "image_url",
 
85
  model=model,
86
  messages=messages
87
  )
88
+
89
+ content = chat_response.choices[0].message.content
90
  #print(chat_response.choices[0].message.content)
91
  # Return the content of the response
92
+ return Response({"response": content})
93
+
94
 
95
+ class NIMVisionLLMView(APIView):
96
+ def post(self, request, format=None):
97
+ try:
98
+ invoke_url = "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-11b-vision-instruct/chat/completions"
99
+ stream = False
100
+ api_key = os.environ["NIM_API_KEY"]
101
+ data = request.data
102
+ image_data = (data['messages'][0]['image'][0])
103
+ prompt = data['messages'][0]['prompt']
104
+ headers = {
105
+ "Authorization": f"Bearer {api_key}",
106
+ "Accept": "text/event-stream" if stream else "application/json"
107
+ }
108
+ payload = {
109
+ "model": 'meta/llama-3.2-11b-vision-instruct',
110
+ "messages": [
111
+ {
112
+ "role": "user",
113
+ "content": f'{prompt} <img src="data:image/png;base64,{image_data}" />'
114
+ }
115
+ ],
116
+ "max_tokens": 512,
117
+ "temperature": 1.00,
118
+ "top_p": 1.00,
119
+ "stream": stream
120
+ }
121
+ response = requests.post(invoke_url, headers=headers, json=payload)
122
+
123
+ if stream:
124
+ for line in response.iter_lines():
125
+ if line:
126
+ #print(line.decode("utf-8"))
127
+ data = line.decode("utf-8")
128
+ #content = json.loads(data)['choices'][0]['delta'].get('content', '')
129
+ else:
130
+ #print(response.json())
131
+ data = response.json()
132
+ content = data['choices'][0]['message']['content']
133
+
134
+ #print(content)
135
+ return Response({"response": content})
136
+
137
+
138
+ except Exception as e: # Added general exception handling
139
+ print(f"Error: {e}")
140
+ return None