ldhldh commited on
Commit
90e26fa
·
1 Parent(s): a519d6c

Upload 8 files

Browse files
Files changed (8) hide show
  1. app.py +34 -170
  2. config.py +31 -0
  3. data_structures.py +20 -0
  4. health.py +124 -0
  5. metrics.py +118 -0
  6. p2p_utils.py +67 -0
  7. pyproject.toml +10 -0
  8. state_updater.py +57 -0
app.py CHANGED
@@ -1,187 +1,51 @@
1
- from threading import Thread
2
- import gradio as gr
3
- import inspect
4
- from gradio import routes
5
- from typing import List, Type
6
 
7
- import requests, os, re, asyncio, queue, sys, git
8
- import math
9
- import time
10
- import datetime
11
- import requests, json
12
-
13
- from pprint import pprint
14
  import hivemind
15
- from petals.constants import PUBLIC_INITIAL_PEERS
16
- from health import fetch_health_state
17
-
18
- dht = hivemind.DHT(initial_peers=PUBLIC_INITIAL_PEERS, client_mode=True, start=True)
19
- model_name = "quantumaikr/llama-2-70b-fb16-korean"
20
-
21
- loop = asyncio.get_event_loop()
22
- # Monkey patch
23
- def get_types(cls_set: List[Type], component: str):
24
- docset = []
25
- types = []
26
- if component == "input":
27
- for cls in cls_set:
28
- doc = inspect.getdoc(cls)
29
- doc_lines = doc.split("\n")
30
- docset.append(doc_lines[1].split(":")[-1])
31
- types.append(doc_lines[1].split(")")[0].split("(")[-1])
32
- else:
33
- for cls in cls_set:
34
- doc = inspect.getdoc(cls)
35
- doc_lines = doc.split("\n")
36
- docset.append(doc_lines[-1].split(":")[-1])
37
- types.append(doc_lines[-1].split(")")[0].split("(")[-1])
38
- return docset, types
39
- routes.get_types = get_types
40
-
41
- # App code
42
-
43
- account_list = dict()
44
-
45
- account_list['id'] = "pass"
46
-
47
- name_list = dict()
48
- name_list['id'] = 'name'
49
-
50
- p2p_list = dict()
51
- p2p_list['id'] = '11111111'
52
-
53
- def chat(x):
54
-
55
- return "AI 응답입니다."
56
-
57
-
58
- def register(id, pw):
59
- if id in account_list:
60
- return "exist"
61
- else:
62
- account_list[id] = pw
63
- return "ok"
64
 
65
- def login(id, pw):
66
- if id in account_list:
67
- if account_list[id] == pw:
68
- return "ok"
69
- else:
70
- return "password error"
71
- else:
72
- return "no id"
73
 
74
- def add_name(id, name):
75
- name_list[id] = name
76
- return "ok"
77
 
78
- def get_name(id):
79
- if id in name_list:
80
- return name_list[id]
81
- else:
82
- return "no id"
83
 
84
- def get_id(name):
85
- reverse_dict= dict(map(reversed,name_list.items()))
86
- if name in reverse_dict:
87
- return reverse_dict[name]
88
- else:
89
- return "no name"
90
 
91
- def add_p(id, p_id):
92
- p2p_list[id] = p_id
93
- return "ok"
94
 
95
- def get_p(id):
96
- if id in p2p_list:
97
- return p2p_list[id]
98
- else:
99
- return "no id"
100
 
101
- def get_id_from_p2p(i):
102
- reverse_dict= dict(map(reversed,p2p_list.items()))
103
- if i in reverse_dict:
104
- return reverse_dict[i]
105
- else:
106
- return "no id"
107
 
108
- # Blockchain code
 
 
109
 
110
- def get_peers():
111
- data = fetch_health_state(dht)
112
- out = []
113
- for d in data['model_reports']:
114
- if d['name'] == model_name:
115
- for r in d['server_rows']:
116
- out.append(r['peer_id'])
117
 
118
- return out
 
 
119
 
120
- get_peers()
121
 
122
- with gr.Blocks() as demo:
123
- count = 0
124
- aa = gr.Interface(
125
- fn=chat,
126
- inputs=["text"],
127
- outputs="text",
128
- description="chat, ai 응답을 반환합니다.\n /run/predict",
 
129
  )
130
 
131
- rr = gr.Interface(
132
- fn=register,
133
- inputs=["text", "text"],
134
- outputs="text",
135
- description="register, 회원가입(성공시:ok, 중복시:exist 반환)\n /run/predict_1",
136
- )
137
-
138
- ll = gr.Interface(
139
- fn=login,
140
- inputs=["text", "text"],
141
- outputs="text",
142
- description="login, 로그인(성공시: ok, 실패시: password error, 아이디가 없으면: no id) \n /run/predict_2",
143
- )
144
-
145
- ad = gr.Interface(
146
- fn=add_name,
147
- inputs=["text", "text"],
148
- outputs="text",
149
- description="add_name, id로 닉네임 추가. ok 반환.\n /run/predict_3",
150
- )
151
-
152
- nn = gr.Interface(
153
- fn=get_name,
154
- inputs=["text"],
155
- outputs="text",
156
- description="get_name, id로 닉네임 반환(없으면 no id)\n /run/predict_4",
157
- )
158
 
159
- nnn = gr.Interface(
160
- fn=get_id,
161
- inputs=["text"],
162
- outputs="text",
163
- description="get_name, 닉네임으로 id 반환(없으면 no name)\n /run/predict_5",
164
- )
165
-
166
- adp = gr.Interface(
167
- fn=add_p,
168
- inputs=["text", "text"],
169
- outputs="text",
170
- description="add_p, id로 p2p id 추가. ok 반환. \n /run/predict_6",
171
- )
172
-
173
- nnp = gr.Interface(
174
- fn=get_p,
175
- inputs=["text"],
176
- outputs="text",
177
- description="get_p, id로 p2p id 반환. 없으면 no id. \n /run/predict_7",
178
- )
179
-
180
- nnp = gr.Interface(
181
- fn=get_id_from_p2p,
182
- inputs=["text"],
183
- outputs="text",
184
- description="get_p, p2p id로 일반 id 반환. 없으면 no id. \n /run/predict_8",
185
- )
186
-
187
- demo.queue(max_size=32).launch(enable_queue=True)
 
1
+ from functools import partial
 
 
 
 
2
 
 
 
 
 
 
 
 
3
  import hivemind
4
+ from flask import Flask, jsonify, request
5
+ from flask_cors import CORS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ import config
8
+ from p2p_utils import check_reachability
9
+ from state_updater import StateUpdaterThread
 
 
 
 
 
10
 
11
+ logger = hivemind.get_logger(__name__)
 
 
12
 
 
 
 
 
 
13
 
14
+ logger.info("Connecting to DHT")
15
+ dht = hivemind.DHT(initial_peers=config.INITIAL_PEERS, client_mode=True, num_workers=32, start=True)
 
 
 
 
16
 
17
+ logger.info("Starting Flask app")
18
+ app = Flask(__name__)
19
+ CORS(app)
20
 
21
+ logger.info("Starting updater")
22
+ updater = StateUpdaterThread(dht, app, daemon=True)
23
+ updater.start()
24
+ updater.ready.wait()
 
25
 
 
 
 
 
 
 
26
 
27
+ @app.route("/")
28
+ def main_page():
29
+ return updater.state_html
30
 
 
 
 
 
 
 
 
31
 
32
+ @app.route("/api/v1/state")
33
+ def api_v1_state():
34
+ return app.response_class(response=updater.state_json, status=200, mimetype="application/json")
35
 
 
36
 
37
+ @app.route("/api/v1/is_reachable/<peer_id>")
38
+ def api_v1_is_reachable(peer_id):
39
+ peer_id = hivemind.PeerID.from_base58(peer_id)
40
+ rpc_info = dht.run_coroutine(partial(check_reachability, peer_id, use_cache=False))
41
+ return jsonify(
42
+ success=rpc_info["ok"],
43
+ message=rpc_info.get("error"),
44
+ your_ip=request.remote_addr,
45
  )
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ @app.route("/metrics")
49
+ @app.route("/api/prometheus")
50
+ def metrics():
51
+ return app.response_class(response=updater.prometheus_metrics, status=200, mimetype="text/plain")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from petals.constants import PUBLIC_INITIAL_PEERS
2
+
3
+ from data_structures import ModelInfo
4
+
5
+ INITIAL_PEERS = PUBLIC_INITIAL_PEERS
6
+
7
+ MODELS = [
8
+ ModelInfo(
9
+ dht_prefix="StableBeluga2-hf",
10
+ repository="https://huggingface.co/petals-team/StableBeluga2",
11
+ num_blocks=80,
12
+ ),
13
+ ModelInfo(
14
+ dht_prefix="falcon-180B-chat",
15
+ repository="https://huggingface.co/tiiuae/falcon-180B-chat",
16
+ num_blocks=80,
17
+ limited=True,
18
+ ),
19
+ ModelInfo(
20
+ dht_prefix="Llama-2-70b-chat-hf",
21
+ repository="https://huggingface.co/meta-llama/Llama-2-70b-chat-hf",
22
+ num_blocks=80,
23
+ ),
24
+ ModelInfo(
25
+ dht_prefix="Llama-2-70b-hf",
26
+ repository="https://huggingface.co/meta-llama/Llama-2-70b-hf",
27
+ num_blocks=80,
28
+ ),
29
+ ]
30
+
31
+ UPDATE_PERIOD = 60
data_structures.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from urllib.parse import urlparse
3
+
4
+ import petals
5
+ import pydantic
6
+
7
+
8
+ @pydantic.dataclasses.dataclass
9
+ class ModelInfo(petals.data_structures.ModelInfo):
10
+ dht_prefix: Optional[str] = None
11
+ official: bool = True
12
+ limited: bool = False
13
+
14
+ @property
15
+ def name(self) -> str:
16
+ return urlparse(self.repository).path.lstrip("/")
17
+
18
+ @property
19
+ def short_name(self) -> str:
20
+ return self.name.split("/")[-1]
health.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import time
3
+ from collections import Counter
4
+ from contextlib import suppress
5
+ from dataclasses import asdict
6
+ from functools import partial
7
+
8
+ import hivemind
9
+ import numpy as np
10
+ from multiaddr import Multiaddr
11
+ from petals.data_structures import UID_DELIMITER, ServerState
12
+ from petals.utils.dht import compute_spans, get_remote_module_infos
13
+
14
+ import config
15
+ from data_structures import ModelInfo
16
+ from p2p_utils import check_reachability_parallel, get_peers_ips, extract_peer_ip_info
17
+
18
+ logger = hivemind.get_logger(__name__)
19
+
20
+
21
+ def fetch_health_state(dht: hivemind.DHT) -> dict:
22
+ start_time = time.perf_counter()
23
+ bootstrap_peer_ids = []
24
+ for addr in config.INITIAL_PEERS:
25
+ peer_id = hivemind.PeerID.from_base58(Multiaddr(addr)["p2p"])
26
+ if peer_id not in bootstrap_peer_ids:
27
+ bootstrap_peer_ids.append(peer_id)
28
+
29
+ reach_infos = dht.run_coroutine(partial(check_reachability_parallel, bootstrap_peer_ids))
30
+ bootstrap_states = ["online" if reach_infos[peer_id]["ok"] else "unreachable" for peer_id in bootstrap_peer_ids]
31
+
32
+ models = config.MODELS[:]
33
+ model_index = dht.get("_petals.models", latest=True)
34
+ if model_index is not None and isinstance(model_index.value, dict):
35
+ official_dht_prefixes = {model.dht_prefix for model in models}
36
+ custom_models = []
37
+ for dht_prefix, model in model_index.value.items():
38
+ if dht_prefix in official_dht_prefixes:
39
+ continue
40
+ with suppress(TypeError, ValueError):
41
+ model_info = ModelInfo.from_dict(model.value)
42
+ if model_info.repository is None or not model_info.repository.startswith("https://huggingface.co/"):
43
+ continue
44
+ model_info.dht_prefix = dht_prefix
45
+ model_info.official = False
46
+ custom_models.append(model_info)
47
+ models.extend(sorted(custom_models, key=lambda info: (-info.num_blocks, info.dht_prefix)))
48
+ logger.info(f"Fetching info for models {[info.name for info in models]}")
49
+
50
+ block_uids = [f"{model.dht_prefix}{UID_DELIMITER}{i}" for model in models for i in range(model.num_blocks)]
51
+ module_infos = get_remote_module_infos(dht, block_uids, latest=True)
52
+
53
+ model_servers = {}
54
+ all_servers = {}
55
+ offset = 0
56
+ for model in models:
57
+ model_servers[model.dht_prefix] = compute_spans(
58
+ module_infos[offset : offset + model.num_blocks], min_state=ServerState.OFFLINE
59
+ )
60
+ all_servers.update(model_servers[model.dht_prefix])
61
+ offset += model.num_blocks
62
+
63
+ online_servers = [peer_id for peer_id, span in all_servers.items() if span.state == ServerState.ONLINE]
64
+
65
+ reach_infos.update(dht.run_coroutine(partial(check_reachability_parallel, online_servers, fetch_info=True)))
66
+ peers_info = {str(peer.peer_id): {"location": extract_peer_ip_info(str(peer.addrs[0])), "multiaddrs": [str(multiaddr) for multiaddr in peer.addrs]} for peer in dht.run_coroutine(get_peers_ips)}
67
+
68
+ top_contributors = Counter()
69
+ model_reports = []
70
+ for model in models:
71
+ block_healthy = np.zeros(model.num_blocks, dtype=bool)
72
+ server_rows = []
73
+ for peer_id, span in sorted(model_servers[model.dht_prefix].items()):
74
+ reachable = reach_infos[peer_id]["ok"] if peer_id in reach_infos else True
75
+ state = span.state.name.lower() if reachable else "unreachable"
76
+ if state == "online":
77
+ block_healthy[span.start : span.end] = True
78
+
79
+ show_public_name = state == "online" and span.length >= 10
80
+ if model.official and span.server_info.public_name and show_public_name:
81
+ top_contributors[span.server_info.public_name] += span.length
82
+
83
+ row = {
84
+ "short_peer_id": "..." + str(peer_id)[-6:],
85
+ "peer_id": peer_id,
86
+ "peer_ip_info": peers_info.get(str(peer_id), "unknown"),
87
+ "show_public_name": show_public_name,
88
+ "state": state,
89
+ "span": span,
90
+ "adapters": [dict(name=name, short_name=name.split("/")[-1]) for name in span.server_info.adapters],
91
+ "pings_to_me": {
92
+ str(origin_id): origin.server_info.next_pings[str(peer_id)]
93
+ for origin_id, origin in model_servers[model.dht_prefix].items()
94
+ if origin.server_info.next_pings is not None and str(peer_id) in origin.server_info.next_pings
95
+ },
96
+ }
97
+ if span.server_info.cache_tokens_left is not None:
98
+ # We use num_blocks * 2 to account for both keys and values
99
+ row["cache_tokens_left_per_block"] = span.server_info.cache_tokens_left // (span.length * 2)
100
+ server_rows.append(row)
101
+
102
+ model_reports.append(
103
+ dict(
104
+ name=model.name,
105
+ short_name=model.short_name,
106
+ state="healthy" if block_healthy.all() else "broken",
107
+ server_rows=server_rows,
108
+ **asdict(model),
109
+ )
110
+ )
111
+
112
+ reachability_issues = [
113
+ dict(peer_id=peer_id, err=info["error"]) for peer_id, info in sorted(reach_infos.items()) if not info["ok"]
114
+ ]
115
+
116
+ return dict(
117
+ bootstrap_states=bootstrap_states,
118
+ top_contributors=top_contributors,
119
+ model_reports=model_reports,
120
+ reachability_issues=reachability_issues,
121
+ last_updated=datetime.datetime.now(datetime.timezone.utc),
122
+ update_period=config.UPDATE_PERIOD,
123
+ update_duration=time.perf_counter() - start_time
124
+ )
metrics.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import Counter, defaultdict
2
+ from typing import List
3
+
4
+ import numpy as np
5
+
6
+
7
+ def get_servers_metrics(model_reports) -> List[str]:
8
+ servers_num_total = 0
9
+ servers_num_relay = 0
10
+ num_peers = 0
11
+ pings = []
12
+ num_ping_infs = 0
13
+ version_counts = Counter()
14
+ result = ["# SERVER LEVEL METRICS"]
15
+
16
+ for model_reports in model_reports:
17
+ for server in model_reports["server_rows"]:
18
+ if server["span"].server_info is not None:
19
+ next_pings = server["span"].server_info.next_pings
20
+ if next_pings is not None:
21
+ servers_num_total += 1
22
+ num_peers += len(next_pings)
23
+ pings_not_inf = [v for k, v in next_pings.items() if v != float("inf")]
24
+ pings.extend(pings_not_inf)
25
+ num_ping_infs += len([v for v in next_pings.values() if v == float("inf")])
26
+
27
+ if server["span"].server_info.using_relay:
28
+ servers_num_relay += 1
29
+
30
+ version = server["span"].server_info.version
31
+ if version:
32
+ version_counts[version] += 1
33
+
34
+ if servers_num_total > 0 and pings:
35
+ peers_per_srv = (len(pings) + num_ping_infs) / servers_num_total
36
+ pings_inf_share = num_ping_infs / (num_ping_infs + len(pings))
37
+
38
+ result.extend(
39
+ [
40
+ f"peers_per_srv {peers_per_srv:.1f}",
41
+ f"pings_inf_share {pings_inf_share:.3f}",
42
+ ]
43
+ )
44
+
45
+ result.append(f"servers_num_total {servers_num_total}")
46
+ result.append(f"servers_num_relay {servers_num_relay}")
47
+
48
+ if pings:
49
+ result.append("# PINGS")
50
+ pings = np.sort(pings).tolist()
51
+ for pct in (25, 50, 75, 90, 95):
52
+ result.append(f'ping_pct{{pct="{pct}"}} {np.percentile(pings, pct):.4f}')
53
+
54
+ result.append("# VERSIONS")
55
+ for version_number, version_count in version_counts.items():
56
+ result.append(f'server_version{{version_number="{version_number}"}} {version_count}')
57
+
58
+ return result
59
+
60
+
61
+ def get_models_metrics(model_reports) -> List[str]:
62
+ result = [
63
+ "# MODEL LEVEL METRICS",
64
+ ]
65
+
66
+ for model_reports in model_reports:
67
+ model_name = model_reports["dht_prefix"]
68
+
69
+ result.append(f"# MODEL: {model_name} {'-' * 50}")
70
+
71
+ blocks = defaultdict(lambda: np.zeros(model_reports["num_blocks"]))
72
+
73
+ for server in model_reports["server_rows"]:
74
+ for block_idx in range(server["span"].start, server["span"].end):
75
+ blocks["total"][block_idx] += 1
76
+ blocks[server["state"]][block_idx] += 1
77
+
78
+ if server["span"].server_info is not None:
79
+ for rps in ("network_rps", "inference_rps", "forward_rps"):
80
+ rps_value = getattr(server["span"].server_info, rps, 0)
81
+ if rps_value is not None:
82
+ blocks[rps][block_idx] += rps_value
83
+
84
+ result.extend(
85
+ [
86
+ f'n_blocks{{model="{model_name}"}} {model_reports["num_blocks"]}',
87
+ f'servers_num{{model="{model_name}"}} {len(model_reports["server_rows"])}',
88
+ f'blocks_total{{model="{model_name}"}} {blocks["total"].sum()}',
89
+ f'blocks_online_min{{model="{model_name}"}} {blocks["online"].min()}',
90
+ ]
91
+ )
92
+
93
+ for block_state in ("online", "joining", "offline", "unreachable"):
94
+ result.append(f'blocks{{model="{model_name}",state="{block_state}"}} {blocks[block_state].sum():.0f}')
95
+
96
+ for rps in ("network_rps", "inference_rps", "forward_rps"):
97
+ rps_type = rps.split("_")[0]
98
+ result.append(f'rps_avg{{model="{model_name}",rps="{rps_type}"}} {blocks[rps].mean():.1f}')
99
+ result.append(f'rps_min{{model="{model_name}",rps="{rps_type}"}} {blocks[rps].min():.1f}')
100
+
101
+ return result
102
+
103
+
104
+ def get_prometheus_metrics(state_dict) -> str:
105
+ """prepares metrics in Prometeus format
106
+ description: https://prometheus.io/docs/instrumenting/exposition_formats/
107
+ returns multline string with single metric per line
108
+ """
109
+ result = []
110
+
111
+ result.append("# GENERAL METRICS")
112
+ result.append(f"update_duration {state_dict.get('update_duration', None):.1f}")
113
+
114
+ result.extend(get_servers_metrics(state_dict["model_reports"]))
115
+
116
+ result.extend(get_models_metrics(state_dict["model_reports"]))
117
+
118
+ return "\n".join(result)
p2p_utils.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import asyncio
3
+ import requests
4
+ import hivemind
5
+ import functools
6
+ from async_timeout import timeout
7
+ from petals.server.handler import TransformerConnectionHandler
8
+
9
+ info_cache = hivemind.TimedStorage()
10
+
11
+
12
+ async def check_reachability(peer_id, _, node, *, fetch_info=False, connect_timeout=5, expiration=300, use_cache=True):
13
+ if use_cache:
14
+ entry = info_cache.get(peer_id)
15
+ if entry is not None:
16
+ return entry.value
17
+
18
+ try:
19
+ with timeout(connect_timeout):
20
+ if fetch_info: # For Petals servers
21
+ stub = TransformerConnectionHandler.get_stub(node.p2p, peer_id)
22
+ response = await stub.rpc_info(hivemind.proto.runtime_pb2.ExpertUID())
23
+ rpc_info = hivemind.MSGPackSerializer.loads(response.serialized_info)
24
+ rpc_info["ok"] = True
25
+ else: # For DHT-only bootstrap peers
26
+ await node.p2p._client.connect(peer_id, [])
27
+ await node.p2p._client.disconnect(peer_id)
28
+ rpc_info = {"ok": True}
29
+ except Exception as e:
30
+ # Actual connection error
31
+ if not isinstance(e, asyncio.TimeoutError):
32
+ message = str(e) if str(e) else repr(e)
33
+ if message == "protocol not supported":
34
+ # This may be returned when a server is joining, see https://github.com/petals-infra/health.petals.dev/issues/1
35
+ return {"ok": True}
36
+ else:
37
+ message = f"Failed to connect in {connect_timeout:.0f} sec. Firewall may be blocking connections"
38
+ rpc_info = {"ok": False, "error": message}
39
+
40
+ info_cache.store(peer_id, rpc_info, hivemind.get_dht_time() + expiration)
41
+ return rpc_info
42
+
43
+
44
+ async def check_reachability_parallel(peer_ids, dht, node, *, fetch_info=False):
45
+ rpc_infos = await asyncio.gather(
46
+ *[check_reachability(peer_id, dht, node, fetch_info=fetch_info) for peer_id in peer_ids]
47
+ )
48
+ return dict(zip(peer_ids, rpc_infos))
49
+
50
+
51
+ async def get_peers_ips(dht, dht_node):
52
+ return await dht_node.p2p.list_peers()
53
+
54
+ @functools.cache
55
+ def get_location(ip_address):
56
+ try:
57
+ response = requests.get(f"http://ip-api.com/json/{ip_address}")
58
+ if response.status_code == 200:
59
+ return response.json()
60
+ except Exception:
61
+ pass
62
+ return {}
63
+
64
+ def extract_peer_ip_info(multiaddr_str):
65
+ if ip_match := re.search(r"/ip4/(\d+\.\d+\.\d+\.\d+)", multiaddr_str):
66
+ return get_location(ip_match[1])
67
+ return {}
pyproject.toml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.black]
2
+ line-length = 120
3
+ required-version = "22.3.0"
4
+
5
+ [tool.isort]
6
+ profile = "black"
7
+ line_length = 120
8
+ combine_as_imports = true
9
+ combine_star = true
10
+ known_local_folder = ["tests", "cli"]
state_updater.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import threading
3
+ import time
4
+ from dataclasses import asdict, is_dataclass
5
+ from enum import Enum
6
+
7
+ import hivemind
8
+ import simplejson
9
+ from flask import Flask, render_template
10
+
11
+ import config
12
+ from health import fetch_health_state
13
+ from metrics import get_prometheus_metrics
14
+
15
+ logger = hivemind.get_logger(__name__)
16
+
17
+
18
+ class StateUpdaterThread(threading.Thread):
19
+ def __init__(self, dht: hivemind.DHT, app: Flask, **kwargs):
20
+ super().__init__(**kwargs)
21
+ self.dht = dht
22
+ self.app = app
23
+
24
+ self.state_json = self.state_html = None
25
+ self.ready = threading.Event()
26
+
27
+ def run(self):
28
+ while True:
29
+ start_time = time.perf_counter()
30
+ try:
31
+ state_dict = fetch_health_state(self.dht)
32
+ with self.app.app_context():
33
+ self.state_html = render_template("index.html", **state_dict)
34
+ self.prometheus_metrics = get_prometheus_metrics(state_dict)
35
+ self.state_json = simplejson.dumps(state_dict, indent=2, ignore_nan=True, default=json_default)
36
+
37
+ self.ready.set()
38
+ logger.info(f"Fetched new state in {time.perf_counter() - start_time:.1f} sec")
39
+ except Exception:
40
+ logger.error("Failed to update state:", exc_info=True)
41
+
42
+ delay = config.UPDATE_PERIOD - (time.perf_counter() - start_time)
43
+ if delay < 0:
44
+ logger.warning("Update took more than update_period, consider increasing it")
45
+ time.sleep(max(delay, 0))
46
+
47
+
48
+ def json_default(value):
49
+ if is_dataclass(value):
50
+ return asdict(value)
51
+ if isinstance(value, Enum):
52
+ return value.name.lower()
53
+ if isinstance(value, hivemind.PeerID):
54
+ return value.to_base58()
55
+ if isinstance(value, datetime.datetime):
56
+ return value.timestamp()
57
+ raise TypeError(f"Can't serialize {repr(value)}")