Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,714 Bytes
3427608 0e99a0b 02f8ed6 3427608 02f8ed6 3427608 02f8ed6 745f608 3427608 745f608 3427608 02f8ed6 745f608 3427608 02f8ed6 3427608 02f8ed6 745f608 02f8ed6 3427608 b177a48 3427608 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import numpy as np
import json
from trueskill import TrueSkill
import paramiko
import io, os
import sys
sys.path.append('../')
from serve.constants import SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD, SSH_SKILL
trueskill_env = TrueSkill()
ssh_matchmaker_client = None
sftp_matchmaker_client = None
def create_ssh_matchmaker_client(server, port, user, password):
global ssh_matchmaker_client, sftp_matchmaker_client
ssh_matchmaker_client = paramiko.SSHClient()
ssh_matchmaker_client.load_system_host_keys()
ssh_matchmaker_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh_matchmaker_client.connect(server, port, user, password)
transport = ssh_matchmaker_client.get_transport()
transport.set_keepalive(60)
sftp_matchmaker_client = ssh_matchmaker_client.open_sftp()
def is_connected():
global ssh_matchmaker_client, sftp_matchmaker_client
# 检查SSH连接是否正常
if not ssh_matchmaker_client.get_transport().is_active():
return False
# 检查SFTP连接是否正常
try:
sftp_matchmaker_client.listdir('.') # 尝试列出根目录
except Exception as e:
print(f"Error checking SFTP connection: {e}")
return False
return True
def ucb_score(trueskill_diff, t, n):
exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
ucb = -trueskill_diff + 1.0 * exploration_term
return ucb
def update_trueskill(ratings, ranks):
new_ratings = trueskill_env.rate(ratings, ranks)
return new_ratings
def serialize_rating(rating):
return {'mu': rating.mu, 'sigma': rating.sigma}
def deserialize_rating(rating_dict):
return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
global sftp_matchmaker_client
if not is_connected():
create_ssh_matchmaker_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
data = {
'ratings': [serialize_rating(r) for r in ratings],
'comparison_counts': comparison_counts.tolist(),
'total_comparisons': total_comparisons
}
json_data = json.dumps(data)
with sftp_matchmaker_client.open(SSH_SKILL, 'w') as f:
f.write(json_data)
def load_json_via_sftp():
global sftp_matchmaker_client
if not is_connected():
create_ssh_matchmaker_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
with sftp_matchmaker_client.open(SSH_SKILL, 'r') as f:
data = json.load(f)
ratings = [deserialize_rating(r) for r in data['ratings']]
comparison_counts = np.array(data['comparison_counts'])
total_comparisons = data['total_comparisons']
return ratings, comparison_counts, total_comparisons
def matchmaker(num_players, k_group=4):
trueskill_env = TrueSkill()
ratings, comparison_counts, total_comparisons = load_json_via_sftp()
# Randomly select a player
# selected_player = np.random.randint(0, num_players)
selected_player = np.argmin(comparison_counts.sum(axis=1))
selected_trueskill_score = trueskill_env.expose(ratings[selected_player])
trueskill_scores = np.array([trueskill_env.expose(p) for p in ratings])
trueskill_diff = np.abs(trueskill_scores - selected_trueskill_score)
n = comparison_counts[selected_player]
ucb_scores = ucb_score(trueskill_diff, total_comparisons, n)
# Exclude self, select opponent with highest UCB score
ucb_scores[selected_player] = -float('inf') # minimize the score for the selected player to exclude it
opponents = np.argsort(ucb_scores)[-k_group + 1:].tolist()
# Group players
model_ids = [selected_player] + opponents
return model_ids
|