Spaces:
Running
Running
import numpy as np | |
import json | |
from trueskill import TrueSkill | |
import paramiko | |
import io, os | |
import sys | |
import random | |
sys.path.append('../') | |
from serve.constants import SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD, SSH_VIDEO_SKILL | |
trueskill_env = TrueSkill() | |
ssh_matchmaker_client = None | |
sftp_matchmaker_client = None | |
def create_ssh_matchmaker_client(server, port, user, password): | |
global ssh_matchmaker_client, sftp_matchmaker_client | |
ssh_matchmaker_client = paramiko.SSHClient() | |
ssh_matchmaker_client.load_system_host_keys() | |
ssh_matchmaker_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) | |
ssh_matchmaker_client.connect(server, port, user, password) | |
transport = ssh_matchmaker_client.get_transport() | |
transport.set_keepalive(60) | |
sftp_matchmaker_client = ssh_matchmaker_client.open_sftp() | |
def is_connected(): | |
global ssh_matchmaker_client, sftp_matchmaker_client | |
if ssh_matchmaker_client is None or sftp_matchmaker_client is None: | |
return False | |
if not ssh_matchmaker_client.get_transport().is_active(): | |
return False | |
try: | |
sftp_matchmaker_client.listdir('.') | |
except Exception as e: | |
print(f"Error checking SFTP connection: {e}") | |
return False | |
return True | |
def ucb_score(trueskill_diff, t, n): | |
exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5)) | |
ucb = -trueskill_diff + 1.0 * exploration_term | |
return ucb | |
def update_trueskill(ratings, ranks): | |
new_ratings = trueskill_env.rate(ratings, ranks) | |
return new_ratings | |
def serialize_rating(rating): | |
return {'mu': rating.mu, 'sigma': rating.sigma} | |
def deserialize_rating(rating_dict): | |
return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma']) | |
def save_json_via_sftp(ratings, comparison_counts, total_comparisons): | |
global sftp_matchmaker_client | |
if not is_connected(): | |
create_ssh_matchmaker_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) | |
data = { | |
'ratings': [serialize_rating(r) for r in ratings], | |
'comparison_counts': comparison_counts.tolist(), | |
'total_comparisons': total_comparisons | |
} | |
json_data = json.dumps(data) | |
with sftp_matchmaker_client.open(SSH_VIDEO_SKILL, 'w') as f: | |
f.write(json_data) | |
def load_json_via_sftp(): | |
global sftp_matchmaker_client | |
if not is_connected(): | |
create_ssh_matchmaker_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) | |
with sftp_matchmaker_client.open(SSH_VIDEO_SKILL, 'r') as f: | |
data = json.load(f) | |
ratings = [deserialize_rating(r) for r in data['ratings']] | |
comparison_counts = np.array(data['comparison_counts']) | |
total_comparisons = data['total_comparisons'] | |
return ratings, comparison_counts, total_comparisons | |
def matchmaker_video(num_players, k_group=4): | |
trueskill_env = TrueSkill() | |
ratings, comparison_counts, total_comparisons = load_json_via_sftp() | |
ratings = ratings[:num_players] | |
comparison_counts = comparison_counts[:num_players, :num_players] | |
selected_player = np.argmin(comparison_counts.sum(axis=1)) | |
selected_trueskill_score = trueskill_env.expose(ratings[selected_player]) | |
trueskill_scores = np.array([trueskill_env.expose(p) for p in ratings]) | |
trueskill_diff = np.abs(trueskill_scores - selected_trueskill_score) | |
n = comparison_counts[selected_player] | |
ucb_scores = ucb_score(trueskill_diff, total_comparisons, n) | |
# Exclude self, select opponent with highest UCB score | |
ucb_scores[selected_player] = -float('inf') | |
excluded_players_1 = [7, 10] | |
excluded_players_2 = [6, 8, 9] | |
excluded_players = excluded_players_1 + excluded_players_2 | |
if selected_player in excluded_players_1: | |
for player in excluded_players: | |
ucb_scores[player] = -float('inf') | |
if selected_player in excluded_players_2: | |
for player in excluded_players_1: | |
ucb_scores[player] = -float('inf') | |
else: | |
excluded_ucb_scores = {player: ucb_scores[player] for player in excluded_players} | |
max_player = max(excluded_ucb_scores, key=excluded_ucb_scores.get) | |
if max_player in excluded_players_1: | |
for player in excluded_players: | |
if player != max_player: | |
ucb_scores[player] = -float('inf') | |
else: | |
for player in excluded_players_1: | |
ucb_scores[player] = -float('inf') | |
opponents = np.argsort(ucb_scores)[-k_group + 1:].tolist() | |
# Group players | |
model_ids = [selected_player] + opponents | |
random.shuffle(model_ids) | |
return model_ids | |