gabcares commited on
Commit
07eac76
·
verified ·
1 Parent(s): 2832c92

Source code sepsis FastAPI

Browse files

- RESTFul API
- GraphQL

Files changed (8) hide show
  1. Dockerfile +25 -0
  2. assets/favicon.ico +0 -0
  3. config.py +75 -0
  4. graph_ql.py +151 -0
  5. main.py +9 -0
  6. requirements.txt +198 -0
  7. rest.py +201 -0
  8. utils/pipeline_helper.py +23 -0
Dockerfile ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11.9-slim
2
+
3
+ # Copy requirements file
4
+ COPY requirements.txt .
5
+
6
+ # Update pip
7
+ RUN pip --timeout=3000 install --no-cache-dir --upgrade pip
8
+
9
+ # Install dependecies
10
+ RUN pip --timeout=3000 install --no-cache-dir -r requirements.txt
11
+
12
+ # Make project directory
13
+ RUN mkdir -p /src/api/
14
+
15
+ # Set working directory
16
+ WORKDIR /src/api
17
+
18
+ # Copy API
19
+ COPY . .
20
+
21
+ # Expose app port
22
+ EXPOSE 7860
23
+
24
+ # Start application
25
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
assets/favicon.ico ADDED
config.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ # ENV when using standalone uvicorn server running FastAPI in api directory
4
+ ENV_PATH = Path('../../env/online.env')
5
+
6
+ ONE_DAY_SEC = 24*60*60
7
+
8
+ ONE_WEEK_SEC = ONE_DAY_SEC*7
9
+
10
+ PIPELINE_FUNCTION_URL = "https://raw.githubusercontent.com/D0nG4667/sepsis_prediction_full_stack/main/dev/models/pipeline_func/pipeline_functions.joblib"
11
+
12
+ RANDOM_FOREST_URL = "https://raw.githubusercontent.com/D0nG4667/sepsis_prediction_full_stack/main/dev/models/RandomForestClassifier.joblib"
13
+
14
+ XGBOOST_URL = "https://raw.githubusercontent.com/D0nG4667/sepsis_prediction_full_stack/main/dev/models/XGBClassifier.joblib"
15
+
16
+ ADABOOST_URL = "https://raw.githubusercontent.com/D0nG4667/sepsis_prediction_full_stack/main/dev/models/AdaBoostClassifier.joblib"
17
+
18
+ CATBOOST_URL = "https://raw.githubusercontent.com/D0nG4667/sepsis_prediction_full_stack/main/dev/models/CatBoostClassifier.joblib"
19
+
20
+ DECISION_TREE_URL = "https://raw.githubusercontent.com/D0nG4667/sepsis_prediction_full_stack/main/dev/models/DecisionTreeClassifier.joblib"
21
+
22
+ KNN_URL = "https://raw.githubusercontent.com/D0nG4667/sepsis_prediction_full_stack/main/dev/models/KNeighborsClassifier.joblib"
23
+
24
+ LGBM_URL = "https://raw.githubusercontent.com/D0nG4667/sepsis_prediction_full_stack/main/dev/models/LGBMClassifier.joblib"
25
+
26
+ LOG_REG_URL = "https://raw.githubusercontent.com/D0nG4667/sepsis_prediction_full_stack/main/dev/models/LogisticRegression.joblib"
27
+
28
+ SVC_URL = "https://raw.githubusercontent.com/D0nG4667/sepsis_prediction_full_stack/main/dev/models/SVC.joblib"
29
+
30
+ ENCODER_URL = "https://raw.githubusercontent.com/D0nG4667/sepsis_prediction_full_stack/main/dev/models/enc/encoder.joblib"
31
+
32
+ ALL_MODELS = {
33
+ "AdaBoostClassifier": ADABOOST_URL,
34
+ "CatBoostClassifier": CATBOOST_URL,
35
+ "DecisionTreeClassifier": DECISION_TREE_URL,
36
+ "KNeighborsClassifier": KNN_URL,
37
+ "LGBMClassifier": LGBM_URL,
38
+ "LogisticRegression": LOG_REG_URL,
39
+ "RandomForestClassifier": RANDOM_FOREST_URL,
40
+ "SupportVectorClassifier": SVC_URL,
41
+ "XGBoostClassifier": XGBOOST_URL
42
+ }
43
+
44
+ DESCRIPTION = """
45
+ This API identifies ICU patients at risk of developing sepsis using `9 models` of which `Random Forest Classifier` and `XGBoost Classifier` are the best.\n
46
+
47
+ The models were trained on [The John Hopkins University datasets at Kaggle](https://www.kaggle.com/datasets/chaunguynnghunh/sepsis?select=README.md).\n
48
+
49
+ ### Features
50
+ `PRG:` Plasma glucose\n
51
+ `PL:` Blood Work Result-1 (mu U/ml)\n
52
+ `PR:` Blood Pressure (mm Hg)\n
53
+ `SK:` Blood Work Result-2 (mm)\n
54
+ `TS:` Blood Work Result-3 (mu U/ml)\n
55
+ `M11:` Body mass index (weight in kg/(height in m)^2\n
56
+ `BD2:` Blood Work Result-4 (mu U/ml)\n
57
+ `Age:` patients age (years)\n
58
+ `Insurance:` If a patient holds a valid insurance card\n
59
+
60
+ ### Results
61
+ **Sepsis prediction:** *Positive* if a patient in ICU will develop a sepsis, and *Negative* otherwise\n
62
+
63
+ **Sepsis probability:** In percentage\n
64
+
65
+ ### GraphQL API
66
+ To explore the GraphQL sub-application (built-with strawberry) to this RESTFul API click the link below.\n
67
+ 🍓[GraphQL](/graphql)
68
+
69
+ ### Let's Connect
70
+ 👨‍⚕️ `Gabriel Okundaye`\n
71
+ [<img src="https://upload.wikimedia.org/wikipedia/commons/c/ca/LinkedIn_logo_initials.png" alt="LinkedIn" width="20" height="20"> LinkendIn](https://www.linkedin.com/in/dr-gabriel-okundaye)
72
+
73
+ [<img src="https://github.githubassets.com/images/modules/logos_page/GitHub-Mark.png" alt="GitHub" width="20" height="20"> GitHub](https://github.com/D0nG4667/sepsis_prediction_full_stack)
74
+
75
+ """
graph_ql.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import strawberry
2
+ from strawberry.asgi import GraphQL
3
+
4
+ import pandas as pd
5
+ import joblib
6
+ from sklearn.pipeline import Pipeline
7
+ from sklearn.preprocessing._label import LabelEncoder
8
+
9
+ import httpx
10
+ from io import BytesIO
11
+
12
+ from typing import Tuple, List, Optional, Union
13
+ from enum import Enum
14
+
15
+
16
+ from config import RANDOM_FOREST_URL, XGBOOST_URL, ENCODER_URL
17
+
18
+ import logging
19
+
20
+
21
+ # API input features
22
+
23
+ @strawberry.enum
24
+ class ModelChoice(Enum):
25
+ RandomForestClassifier = RANDOM_FOREST_URL
26
+ XGBoostClassifier = XGBOOST_URL
27
+
28
+
29
+ @strawberry.input
30
+ class SepsisFeatures:
31
+ prg: List[int]
32
+ pl: List[int]
33
+ pr: List[int]
34
+ sk: List[int]
35
+ ts: List[int]
36
+ m11: List[float]
37
+ bd2: List[float]
38
+ age: List[int]
39
+ insurance: List[int]
40
+
41
+
42
+ @strawberry.type
43
+ class Url:
44
+ url: str
45
+ pipeline_url: str
46
+ encoder_url: str
47
+
48
+
49
+ @strawberry.type
50
+ class ResultData:
51
+ prediction: List[str]
52
+ probability: List[float]
53
+
54
+
55
+ @strawberry.type
56
+ class PredictionResponse:
57
+ execution_msg: str
58
+ execution_code: int
59
+ result: ResultData
60
+
61
+
62
+ @strawberry.type
63
+ class ErrorResponse:
64
+ execution_msg: str
65
+ execution_code: int
66
+ error: Optional[str]
67
+
68
+
69
+ logging.basicConfig(level=logging.ERROR,
70
+ format='%(asctime)s - %(levelname)s - %(message)s')
71
+
72
+
73
+ async def url_to_data(url: Url) -> BytesIO:
74
+ async with httpx.AsyncClient() as client:
75
+ response = await client.get(url)
76
+ response.raise_for_status() # Ensure we catch any HTTP errors
77
+ # Convert response content to BytesIO object
78
+ data = BytesIO(response.content)
79
+ return data
80
+
81
+
82
+ # Load the model pipelines and encoder
83
+ async def load_pipeline(pipeline_url: Url, encoder_url: Url) -> Tuple[Pipeline, LabelEncoder]:
84
+ pipeline, encoder = None, None
85
+ try:
86
+ pipeline: Pipeline = joblib.load(await url_to_data(pipeline_url))
87
+ encoder: LabelEncoder = joblib.load(await url_to_data(encoder_url))
88
+ except Exception as e:
89
+ logging.error(
90
+ "Omg, an error occurred in loading the pipeline resources: %s", e)
91
+ finally:
92
+ return pipeline, encoder
93
+
94
+
95
+ async def pipeline_classifier(pipeline: Pipeline, encoder: LabelEncoder, data: SepsisFeatures) -> Union[ErrorResponse, PredictionResponse]:
96
+ msg = 'Execution failed'
97
+ code = 0
98
+ output = ErrorResponse(**{'execution_msg': msg,
99
+ 'execution_code': code, 'error': None})
100
+ try:
101
+ # Create dataframe
102
+ df = pd.DataFrame.from_dict(data.__dict__)
103
+
104
+ # Make prediction
105
+ preds = pipeline.predict(df)
106
+ preds_int = [int(pred) for pred in preds]
107
+
108
+ predictions = encoder.inverse_transform(preds_int)
109
+ probabilities_np = pipeline.predict_proba(df)
110
+
111
+ probabilities = [round(float(max(prob)*100), 2)
112
+ for prob in probabilities_np]
113
+
114
+ result = ResultData(**{"prediction": predictions,
115
+ "probability": probabilities}
116
+ )
117
+
118
+ msg = 'Execution was successful'
119
+ code = 1
120
+ output = PredictionResponse(
121
+ **{'execution_msg': msg,
122
+ 'execution_code': code, 'result': result}
123
+ )
124
+
125
+ except Exception as e:
126
+ error = f"Omg, pipeline classifier and/or encoder failure. {e}"
127
+
128
+ output = ErrorResponse(**{'execution_msg': msg,
129
+ 'execution_code': code, 'error': error})
130
+
131
+ finally:
132
+ return output
133
+
134
+
135
+ @strawberry.type
136
+ class Query:
137
+ @strawberry.field
138
+ async def predict_sepsis(self, model: ModelChoice, data: SepsisFeatures) -> Union[ErrorResponse, PredictionResponse]:
139
+ pipeline_url: Url = model.value
140
+ pipeline, encoder = await load_pipeline(pipeline_url, ENCODER_URL)
141
+
142
+ output = await pipeline_classifier(pipeline, encoder, data)
143
+
144
+ return output
145
+
146
+
147
+ # Create the GraphQL Schema
148
+ schema = strawberry.Schema(query=Query)
149
+
150
+ # Create the GraphQL application
151
+ graphql_app = GraphQL(schema)
main.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi.responses import RedirectResponse
2
+
3
+ from graph_ql import graphql_app
4
+ from rest import app
5
+
6
+
7
+ # Add Graph QL Application to the FastAPI RESTFul Application
8
+ app.add_route("/graphql", graphql_app)
9
+ app.add_websocket_route("/graphql", graphql_app)
requirements.txt ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # aiocache==0.12.2
2
+ aiohttp==3.9.5
3
+ aiosignal==1.3.1
4
+ altair==5.3.0
5
+ annotated-types==0.7.0
6
+ anyio==4.2.0
7
+ argon2-cffi==21.3.0
8
+ argon2-cffi-bindings==21.2.0
9
+ asttokens==2.4.1
10
+ async-lru==2.0.4
11
+ attrs==23.1.0
12
+ Babel==2.11.0
13
+ beautifulsoup4==4.12.3
14
+ bleach==4.1.0
15
+ blinker==1.8.2
16
+ Brotli==1.0.9
17
+ cachetools==5.4.0
18
+ catboost==1.2.3
19
+ certifi==2024.7.4
20
+ cffi==1.16.0
21
+ charset-normalizer==3.3.2
22
+ click==8.1.7
23
+ colorama==0.4.6
24
+ comm==0.2.2
25
+ contourpy==1.2.1
26
+ cycler==0.12.1
27
+ debugpy==1.8.2
28
+ decorator==5.1.1
29
+ defusedxml==0.7.1
30
+ dnspython==2.6.1
31
+ email_validator==2.2.0
32
+ entrypoints==0.4
33
+ exceptiongroup==1.2.2
34
+ executing==2.0.1
35
+ extra-streamlit-components==0.1.71
36
+ Faker==26.0.0
37
+ fastapi==0.111.0
38
+ fastapi-cache2==0.2.1
39
+ fastapi-cli==0.0.4
40
+ fastjsonschema==2.16.2
41
+ favicon==0.7.0
42
+ filelock==3.15.4
43
+ fonttools==4.53.1
44
+ frozenlist==1.4.1
45
+ fsspec==2024.6.1
46
+ gitdb==4.0.11
47
+ GitPython==3.1.43
48
+ graphql-core==3.2.3
49
+ graphviz==0.20.3
50
+ h11==0.14.0
51
+ htbuilder==0.6.2
52
+ httpcore==1.0.5
53
+ httptools==0.6.1
54
+ httpx==0.27.0
55
+ huggingface-hub==0.24.1
56
+ idna==3.7
57
+ imbalanced-learn==0.12.3
58
+ importlib_metadata==8.0.0
59
+ inquirerpy==0.3.4
60
+ # ipykernel==6.29.5
61
+ # ipython==8.26.0
62
+ # ipywidgets==8.1.3
63
+ jedi==0.19.1
64
+ Jinja2==3.1.4
65
+ joblib==1.4.2
66
+ json5==0.9.6
67
+ jsonschema==4.19.2
68
+ jsonschema-specifications==2023.12.1
69
+ # jupyter_client==8.6.2
70
+ # jupyter_core==5.7.2
71
+ # jupyter-events==0.10.0
72
+ # jupyter-lsp==2.2.0
73
+ # jupyter_server==2.14.1
74
+ # jupyter_server_terminals==0.4.4
75
+ # jupyterlab==4.0.11
76
+ # jupyterlab-pygments==0.1.2
77
+ # jupyterlab_server==2.25.1
78
+ # jupyterlab_widgets==3.0.11
79
+ # kaleido==0.1.0.post1
80
+ kiwisolver==1.4.5
81
+ libcst==1.4.0
82
+ lightgbm==4.4.0
83
+ lxml==5.2.2
84
+ Markdown==3.6
85
+ markdown-it-py==3.0.0
86
+ markdownlit==0.0.7
87
+ MarkupSafe==2.1.3
88
+ # matplotlib==3.9.1
89
+ # matplotlib-inline==0.1.7
90
+ mdurl==0.1.2
91
+ mistune==2.0.4
92
+ more-itertools==10.3.0
93
+ multidict==6.0.5
94
+ # nbclient==0.8.0
95
+ # nbconvert==7.10.0
96
+ # nbformat==5.9.2
97
+ nest_asyncio==1.6.0
98
+ notebook_shim==0.2.3
99
+ numpy==1.26.4
100
+ orjson==3.10.6
101
+ overrides==7.4.0
102
+ packaging==24.1
103
+ pandas==2.2.2
104
+ pandocfilters==1.5.0
105
+ parso==0.8.4
106
+ pendulum==3.0.0
107
+ pfzy==0.3.4
108
+ pickleshare==0.7.5
109
+ pillow==10.4.0
110
+ pip==24.0
111
+ platformdirs==4.2.2
112
+ # plotly==5.22.0
113
+ prometheus-client==0.14.1
114
+ prompt_toolkit==3.0.47
115
+ protobuf==5.27.2
116
+ psutil==6.0.0
117
+ pure_eval==0.2.3
118
+ pyarrow==17.0.0
119
+ pycparser==2.21
120
+ pydantic==2.8.2
121
+ pydantic_core==2.20.1
122
+ pydeck==0.9.1
123
+ Pygments==2.18.0
124
+ pymdown-extensions==10.8.1
125
+ pyparsing==3.1.2
126
+ PySocks==1.7.1
127
+ python-dateutil==2.9.0
128
+ python-dotenv==1.0.1
129
+ python-json-logger==2.0.7
130
+ python-multipart==0.0.9
131
+ pytz==2024.1
132
+ # pywin32==306
133
+ # pywinpty==2.0.10
134
+ PyYAML==6.0.1
135
+ pyzmq==26.0.3
136
+ redis==5.0.7
137
+ referencing==0.35.1
138
+ requests==2.32.3
139
+ rfc3339-validator==0.1.4
140
+ rfc3986-validator==0.1.1
141
+ rich==13.7.1
142
+ rpds-py==0.10.6
143
+ scikit-learn==1.5.0
144
+ scipy==1.14.0
145
+ Send2Trash==1.8.2
146
+ setuptools==69.5.1
147
+ shellingham==1.5.4
148
+ six==1.16.0
149
+ # skops==0.10.0
150
+ smmap==5.0.1
151
+ sniffio==1.3.0
152
+ soupsieve==2.5
153
+ st-annotated-text==4.0.1
154
+ st-theme==1.2.3
155
+ stack-data==0.6.2
156
+ starlette==0.37.2
157
+ strawberry-graphql==0.236.2
158
+ # streamlit==1.36.0
159
+ # streamlit-camera-input-live==0.2.0
160
+ # streamlit-card==1.0.2
161
+ # streamlit-embedcode==0.1.2
162
+ # streamlit-extras==0.4.3
163
+ # streamlit-faker==0.0.3
164
+ # streamlit-image-coordinates==0.1.9
165
+ # streamlit-keyup==0.2.4
166
+ # streamlit-toggle-switch==1.0.2
167
+ # streamlit-vertical-slider==2.5.5
168
+ tabulate==0.9.0
169
+ tenacity==8.5.0
170
+ terminado==0.17.1
171
+ threadpoolctl==3.5.0
172
+ time-machine==2.14.2
173
+ tinycss2==1.2.1
174
+ toml==0.10.2
175
+ toolz==0.12.1
176
+ tornado==6.4.1
177
+ tqdm==4.66.4
178
+ traitlets==5.14.3
179
+ typer==0.12.3
180
+ typing_extensions==4.12.2
181
+ tzdata==2024.1
182
+ ujson==5.10.0
183
+ urllib3==2.2.2
184
+ uvicorn==0.30.1
185
+ validators==0.33.0
186
+ watchdog==4.0.1
187
+ watchfiles==0.22.0
188
+ wcwidth==0.2.13
189
+ webencodings==0.5.1
190
+ webp==0.4.0
191
+ websocket-client==1.8.0
192
+ websockets==12.0
193
+ wheel==0.43.0
194
+ widgetsnbextension==4.0.11
195
+ # win-inet-pton==1.1.0
196
+ xgboost==2.0.3
197
+ yarl==1.9.4
198
+ zipp==3.19.2
rest.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+
4
+ from collections.abc import AsyncIterator
5
+ from contextlib import asynccontextmanager
6
+
7
+ from fastapi import FastAPI, Query
8
+ from fastapi.responses import FileResponse
9
+ from fastapi.staticfiles import StaticFiles
10
+ from fastapi_cache import FastAPICache
11
+ from fastapi_cache.backends.redis import RedisBackend
12
+ from fastapi_cache.coder import PickleCoder
13
+ from fastapi_cache.decorator import cache
14
+ import logging
15
+
16
+ from redis import asyncio as aioredis
17
+
18
+ from pydantic import BaseModel, Field
19
+ from typing import Tuple, List, Union, Optional
20
+
21
+ from sklearn.pipeline import Pipeline
22
+ from sklearn.preprocessing._label import LabelEncoder
23
+ import joblib
24
+
25
+ import pandas as pd
26
+
27
+ import httpx
28
+ from io import BytesIO
29
+
30
+
31
+ from config import ONE_DAY_SEC, ONE_WEEK_SEC, XGBOOST_URL, RANDOM_FOREST_URL, ENCODER_URL, ENV_PATH, DESCRIPTION, ALL_MODELS
32
+
33
+ load_dotenv(ENV_PATH)
34
+
35
+
36
+ @asynccontextmanager
37
+ async def lifespan(_: FastAPI) -> AsyncIterator[None]:
38
+ url = os.getenv("REDIS_URL")
39
+ username = os.getenv("REDIS_USERNAME")
40
+ password = os.getenv("REDIS_PASSWORD")
41
+ redis = aioredis.from_url(url=url, username=username,
42
+ password=password, encoding="utf8", decode_responses=True)
43
+ FastAPICache.init(RedisBackend(redis), prefix="fastapi-cache")
44
+ yield
45
+
46
+
47
+ # FastAPI Object
48
+ app = FastAPI(
49
+ title='Sepsis classification',
50
+ version='1.0.0',
51
+ description=DESCRIPTION,
52
+ lifespan=lifespan,
53
+ )
54
+
55
+ app.mount("/assets", StaticFiles(directory="assets"), name="assets")
56
+
57
+
58
+ @app.get('/favicon.ico', include_in_schema=False)
59
+ async def favicon():
60
+ file_name = "favicon.ico"
61
+ file_path = os.path.join(app.root_path, "assets", file_name)
62
+ return FileResponse(path=file_path, headers={"Content-Disposition": "attachment; filename=" + file_name})
63
+
64
+
65
+ # API input features
66
+
67
+ class SepsisFeatures(BaseModel):
68
+ prg: List[int] = Field(description="PRG: Plasma glucose")
69
+ pl: List[int] = Field(description="PL: Blood Work Result-1 (mu U/ml)")
70
+ pr: List[int] = Field(description="PR: Blood Pressure (mm Hg)")
71
+ sk: List[int] = Field(description="SK: Blood Work Result-2 (mm)")
72
+ ts: List[int] = Field(description="TS: Blood Work Result-3 (mu U/ml)")
73
+ m11: List[float] = Field(
74
+ description="M11: Body mass index (weight in kg/(height in m)^2")
75
+ bd2: List[float] = Field(description="BD2: Blood Work Result-4 (mu U/ml)")
76
+ age: List[int] = Field(description="Age: patients age (years)")
77
+ insurance: List[int] = Field(
78
+ description="Insurance: If a patient holds a valid insurance card")
79
+
80
+
81
+ class Url(BaseModel):
82
+ url: str
83
+ pipeline_url: str
84
+ encoder_url: str
85
+
86
+
87
+ class ResultData(BaseModel):
88
+ prediction: List[str]
89
+ probability: List[float]
90
+
91
+
92
+ class PredictionResponse(BaseModel):
93
+ execution_msg: str
94
+ execution_code: int
95
+ result: ResultData
96
+
97
+
98
+ class ErrorResponse(BaseModel):
99
+ execution_msg: str
100
+ execution_code: int
101
+ error: Optional[str]
102
+
103
+
104
+ logging.basicConfig(level=logging.ERROR,
105
+ format='%(asctime)s - %(levelname)s - %(message)s')
106
+
107
+
108
+ # Load the model pipelines and encoder
109
+ # Cache for 1 day
110
+ @cache(expire=ONE_DAY_SEC, namespace='pipeline_resource', coder=PickleCoder)
111
+ async def load_pipeline(pipeline_url: Url, encoder_url: Url) -> Tuple[Pipeline, LabelEncoder]:
112
+ async def url_to_data(url: Url):
113
+ async with httpx.AsyncClient() as client:
114
+ response = await client.get(url)
115
+ response.raise_for_status() # Ensure we catch any HTTP errors
116
+ # Convert response content to BytesIO object
117
+ data = BytesIO(response.content)
118
+ return data
119
+
120
+ pipeline, encoder = None, None
121
+ try:
122
+ pipeline: Pipeline = joblib.load(await url_to_data(pipeline_url))
123
+ encoder: LabelEncoder = joblib.load(await url_to_data(encoder_url))
124
+ except Exception as e:
125
+ logging.error(
126
+ "Omg, an error occurred in loading the pipeline resources: %s", e)
127
+ finally:
128
+ return pipeline, encoder
129
+
130
+
131
+ # Endpoints
132
+
133
+ # Status endpoint: check if api is online
134
+ @app.get('/')
135
+ @cache(expire=ONE_WEEK_SEC, namespace='status_check') # Cache for 1 week
136
+ async def status_check():
137
+ return {"Status": "API is online..."}
138
+
139
+
140
+ @cache(expire=ONE_DAY_SEC, namespace='pipeline_classifier') # Cache for 1 day
141
+ async def pipeline_classifier(pipeline: Pipeline, encoder: LabelEncoder, data: SepsisFeatures) -> Union[ErrorResponse, PredictionResponse]:
142
+ msg = 'Execution failed'
143
+ code = 0
144
+ output = ErrorResponse(**{'execution_msg': msg,
145
+ 'execution_code': code, 'error': None})
146
+
147
+ try:
148
+ # Create dataframe
149
+ df = pd.DataFrame.from_dict(data.__dict__)
150
+
151
+ # Make prediction
152
+ preds = pipeline.predict(df)
153
+ preds_int = [int(pred) for pred in preds]
154
+
155
+ predictions = encoder.inverse_transform(preds_int)
156
+ probabilities_np = pipeline.predict_proba(df)
157
+
158
+ probabilities = [round(float(max(prob)*100), 2)
159
+ for prob in probabilities_np]
160
+
161
+ result = ResultData(**{"prediction": predictions,
162
+ "probability": probabilities})
163
+
164
+ msg = 'Execution was successful'
165
+ code = 1
166
+ output = PredictionResponse(
167
+ **{'execution_msg': msg,
168
+ 'execution_code': code, 'result': result}
169
+ )
170
+
171
+ except Exception as e:
172
+ error = f"Omg, pipeline classifier and/or encoder failure. {e}"
173
+ output = ErrorResponse(**{'execution_msg': msg,
174
+ 'execution_code': code, 'error': error})
175
+
176
+ finally:
177
+ return output
178
+
179
+
180
+ # Random forest endpoint: classify sepsis with random forest
181
+ @app.post('/api/v1/random_forest/prediction', tags=['Random Forest'])
182
+ async def random_forest_classifier(data: SepsisFeatures) -> Union[ErrorResponse, PredictionResponse]:
183
+ random_forest_pipeline, encoder = await load_pipeline(RANDOM_FOREST_URL, ENCODER_URL)
184
+ output = await pipeline_classifier(random_forest_pipeline, encoder, data)
185
+ return output
186
+
187
+
188
+ # Xgboost endpoint: classify sepsis with xgboost
189
+ @app.post('/api/v1/xgboost/prediction', tags=['XGBoost'])
190
+ async def xgboost_classifier(data: SepsisFeatures) -> Union[ErrorResponse, PredictionResponse]:
191
+ xgboost_pipeline, encoder = await load_pipeline(XGBOOST_URL, ENCODER_URL)
192
+ output = await pipeline_classifier(xgboost_pipeline, encoder, data)
193
+ return output
194
+
195
+
196
+ @app.post('/api/v1/prediction', tags=['All Models'])
197
+ async def query_sepsis_prediction(data: SepsisFeatures, model: str = Query('RandomForestClassifier', enum=list(ALL_MODELS.keys()))) -> Union[ErrorResponse, PredictionResponse]:
198
+ pipeline_url: Url = ALL_MODELS[model]
199
+ pipeline, encoder = await load_pipeline(pipeline_url, ENCODER_URL)
200
+ output = await pipeline_classifier(pipeline, encoder, data)
201
+ return output
utils/pipeline_helper.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from typing import Union
3
+
4
+ numerical_features = ['prg', 'pl', 'pr', 'sk', 'ts', 'm11', 'bd2', 'age']
5
+
6
+ categorical_features = ['insurance']
7
+
8
+ new_features = ['age_group']
9
+
10
+
11
+ def as_category(data: Union[pd.DataFrame | pd.Series]) -> Union[pd.DataFrame | pd.Series]:
12
+ return data.astype('category')
13
+
14
+
15
+ def feature_creation(df: pd.DataFrame) -> pd.DataFrame:
16
+ df_copy = df.copy()
17
+ if 'age_group' not in df_copy.columns and 'age' in df_copy.columns:
18
+ df_copy['age_group'] = df_copy['age'].apply(
19
+ lambda x: '60 and above' if x >= 60 else 'below 60')
20
+ df_copy['age_group'] = as_category(df_copy['age_group'])
21
+ df_copy.drop(columns='age', inplace=True)
22
+
23
+ return df_copy