import datetime from concurrent.futures import as_completed from urllib import parse import pandas as pd import streamlit as st import wandb from requests_futures.sessions import FuturesSession from dashboard_utils.time_tracker import _log, simple_time_tracker URL_QUICKSEARCH = "https://huggingface.co/api/quicksearch?" WANDB_REPO = st.secrets["WANDB_REPO_INDIVIDUAL_METRICS"] CACHE_TTL = 100 MAX_DELTA_ACTIVE_RUN_SEC = 60 * 5 @st.cache(ttl=CACHE_TTL, show_spinner=False) @simple_time_tracker(_log) def get_new_bubble_data(): serialized_data_points, latest_timestamp = get_serialized_data_points() serialized_data = get_serialized_data(serialized_data_points, latest_timestamp) usernames = [] for item in serialized_data["points"][0]: usernames.append(item["profileId"]) profiles = get_profiles(usernames) return serialized_data, profiles @st.cache(ttl=CACHE_TTL, show_spinner=False) @simple_time_tracker(_log) def get_profiles(usernames): profiles = [] with FuturesSession(max_workers=32) as session: futures = [] for username in usernames: future = session.get(URL_QUICKSEARCH + parse.urlencode({"type": "user", "q": username})) future.username = username futures.append(future) for future in as_completed(futures): resp = future.result() username = future.username response = resp.json() avatarUrl = None if response["users"]: for user_candidate in response["users"]: if user_candidate["user"] == username: avatarUrl = response["users"][0]["avatarUrl"] break if not avatarUrl: avatarUrl = "/avatars/57584cb934354663ac65baa04e6829bf.svg" if avatarUrl.startswith("/avatars/"): avatarUrl = f"https://huggingface.co{avatarUrl}" profiles.append( {"id": username, "name": username, "src": avatarUrl, "url": f"https://huggingface.co/{username}"} ) return profiles @st.cache(ttl=CACHE_TTL, show_spinner=False) @simple_time_tracker(_log) def get_serialized_data_points(): api = wandb.Api() runs = api.runs(WANDB_REPO) serialized_data_points = {} latest_timestamp = None for run in runs: run_summary = run.summary._json_dict run_name = run.name state = run.state if run_name in serialized_data_points: if "_timestamp" in run_summary and "_step" in run_summary: timestamp = run_summary["_timestamp"] serialized_data_points[run_name]["Runs"].append( { "batches": run_summary["_step"], "runtime": run_summary["_runtime"], "loss": run_summary["train/loss"], "state": state, "velocity": run_summary["_step"] / run_summary["_runtime"], "date": datetime.datetime.utcfromtimestamp(timestamp), } ) if not latest_timestamp or timestamp > latest_timestamp: latest_timestamp = timestamp else: if "_timestamp" in run_summary and "_step" in run_summary: timestamp = run_summary["_timestamp"] serialized_data_points[run_name] = { "profileId": run_name, "Runs": [ { "batches": run_summary["_step"], "runtime": run_summary["_runtime"], "loss": run_summary["train/loss"], "state": state, "velocity": run_summary["_step"] / run_summary["_runtime"], "date": datetime.datetime.utcfromtimestamp(timestamp), } ], } if not latest_timestamp or timestamp > latest_timestamp: latest_timestamp = timestamp latest_timestamp = datetime.datetime.utcfromtimestamp(latest_timestamp) return serialized_data_points, latest_timestamp @st.cache(ttl=CACHE_TTL, show_spinner=False) @simple_time_tracker(_log) def get_serialized_data(serialized_data_points, latest_timestamp): serialized_data_points_v2 = [] max_velocity = 1 for run_name, serialized_data_point in serialized_data_points.items(): activeRuns = [] loss = 0 runtime = 0 batches = 0 velocity = 0 for run in serialized_data_point["Runs"]: if run["state"] == "running": run["date"] = run["date"].isoformat() activeRuns.append(run) loss += run["loss"] velocity += run["velocity"] loss = loss / len(activeRuns) if activeRuns else 0 runtime += run["runtime"] batches += run["batches"] new_item = { "date": latest_timestamp.isoformat(), "profileId": run_name, "batches": runtime, # "batches": batches quick and dirty fix "runtime": runtime, "activeRuns": activeRuns, } serialized_data_points_v2.append(new_item) serialized_data = {"points": [serialized_data_points_v2], "maxVelocity": max_velocity} return serialized_data def get_leaderboard(serialized_data): data_leaderboard = {"user": [], "runtime": []} for user_item in serialized_data["points"][0]: data_leaderboard["user"].append(user_item["profileId"]) data_leaderboard["runtime"].append(user_item["runtime"]) df = pd.DataFrame(data_leaderboard) df = df.sort_values("runtime", ascending=False) df["runtime"] = df["runtime"].apply(lambda x: datetime.timedelta(seconds=x)) df["runtime"] = df["runtime"].apply(lambda x: str(x)) df.reset_index(drop=True, inplace=True) df.rename(columns={"user": "User", "runtime": "Total time contributed"}, inplace=True) df["Rank"] = df.index + 1 df = df.set_index("Rank") return df def get_global_metrics(serialized_data): current_time = datetime.datetime.utcnow() num_contributing_users = len(serialized_data["points"][0]) num_active_users = 0 total_runtime = 0 for user_item in serialized_data["points"][0]: for run in user_item["activeRuns"]: date_run = datetime.datetime.fromisoformat(run["date"]) delta_time_sec = (current_time - date_run).total_seconds() if delta_time_sec < MAX_DELTA_ACTIVE_RUN_SEC: num_active_users += 1 break total_runtime += user_item["runtime"] total_runtime = datetime.timedelta(seconds=total_runtime) return { "num_contributing_users": num_contributing_users, "num_active_users": num_active_users, "total_runtime": total_runtime, }