File size: 3,156 Bytes
3427608
 
 
 
 
 
0e99a0b
3427608
b177a48
 
3427608
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b177a48
3427608
 
b177a48
 
 
 
 
 
 
3427608
 
 
 
 
 
 
 
 
 
 
 
 
b177a48
 
3427608
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import numpy as np
import json
from trueskill import TrueSkill
import paramiko
import io, os
import sys
from serve.constants import SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD, SSH_SKILL
trueskill_env = TrueSkill()
sys.path.append('../')
from model.models import IMAGE_GENERATION_MODELS


def ucb_score(trueskill_diff, t, n):
    exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
    ucb = -trueskill_diff + 1.0 * exploration_term
    return ucb

def update_trueskill(ratings, ranks):
    new_ratings = trueskill_env.rate(ratings, ranks)
    return new_ratings

def serialize_rating(rating):
    return {'mu': rating.mu, 'sigma': rating.sigma}

def deserialize_rating(rating_dict):
    return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])

def create_ssh_client(server, port, user, password):
    ssh = paramiko.SSHClient()
    ssh.load_system_host_keys()
    ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
    ssh.connect(server, port, user, password)
    return ssh

def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
    ssh = create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
    data = {
        'ratings': [serialize_rating(r) for r in ratings],
        'comparison_counts': comparison_counts.tolist(),
        'total_comparisons': total_comparisons
    }  
    json_data = json.dumps(data)
    sftp = ssh.open_sftp()
    with sftp.open(SSH_SKILL, 'w') as f:
        f.write(json_data)

def load_json_via_sftp():
    ssh = create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
    sftp = ssh.open_sftp()
    with sftp.open(SSH_SKILL, 'r') as f:
        data = json.load(f)
    ratings = [deserialize_rating(r) for r in data['ratings']]
    comparison_counts = np.array(data['comparison_counts'])
    total_comparisons = data['total_comparisons']
    return ratings, comparison_counts, total_comparisons


def update_skill(rank, model_names, k_group=4):

    ratings, comparison_counts, total_comparisons = load_json_via_sftp()

    # group = Model_ID.group
    group = []
    for model_name in model_names:
        group.append(IMAGE_GENERATION_MODELS.index(model_name))
    print(group)

    pairwise_comparisons = [(i, j) for i in range(len(group)) for j in range(i+1, len(group))]
    for player1, player2 in pairwise_comparisons:
        if rank[player1] < rank[player2]:
            ranks = [0, 1]
            updated_ratings = update_trueskill([[ratings[group[player1]]], [ratings[group[player2]]]], ranks)
            ratings[group[player1]], ratings[group[player2]] = updated_ratings[0][0], updated_ratings[1][0]
        elif rank[player1] > rank[player2]:
            ranks = [1, 0]
            updated_ratings = update_trueskill([[ratings[group[player1]]], [ratings[group[player2]]]], ranks)
            ratings[group[player1]], ratings[group[player2]] = updated_ratings[0][0], updated_ratings[1][0]
            
        comparison_counts[group[player1], group[player2]] += 1
        comparison_counts[group[player2], group[player1]] += 1
        
    total_comparisons += 1

    save_json_via_sftp(ratings, comparison_counts, total_comparisons)