ksort commited on
Commit
745f608
·
1 Parent(s): b727799

Update ssh

Browse files
model/matchmaker.py CHANGED
@@ -18,8 +18,22 @@ def create_ssh_matchmaker_client(server, port, user, password):
18
  ssh_matchmaker_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
19
  ssh_matchmaker_client.connect(server, port, user, password)
20
 
21
- sftp_matchmaker_client = ssh_matchmaker_client.open_sftp()
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def ucb_score(trueskill_diff, t, n):
24
  exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
25
  ucb = -trueskill_diff + 1.0 * exploration_term
@@ -37,6 +51,8 @@ def deserialize_rating(rating_dict):
37
 
38
  def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
39
  global sftp_matchmaker_client
 
 
40
  data = {
41
  'ratings': [serialize_rating(r) for r in ratings],
42
  'comparison_counts': comparison_counts.tolist(),
@@ -48,6 +64,8 @@ def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
48
 
49
  def load_json_via_sftp():
50
  global sftp_matchmaker_client
 
 
51
  with sftp_matchmaker_client.open(SSH_SKILL, 'r') as f:
52
  data = json.load(f)
53
  ratings = [deserialize_rating(r) for r in data['ratings']]
 
18
  ssh_matchmaker_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
19
  ssh_matchmaker_client.connect(server, port, user, password)
20
 
21
+ transport = ssh_matchmaker_client.get_transport()
22
+ transport.set_keepalive(60)
23
 
24
+ sftp_matchmaker_client = ssh_matchmaker_client.open_sftp()
25
+ def is_connected():
26
+ global ssh_matchmaker_client, sftp_matchmaker_client
27
+ # 检查SSH连接是否正常
28
+ if not ssh_matchmaker_client.get_transport().is_active():
29
+ return False
30
+ # 检查SFTP连接是否正常
31
+ try:
32
+ sftp_matchmaker_client.listdir('.') # 尝试列出根目录
33
+ except Exception as e:
34
+ print(f"Error checking SFTP connection: {e}")
35
+ return False
36
+ return True
37
  def ucb_score(trueskill_diff, t, n):
38
  exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
39
  ucb = -trueskill_diff + 1.0 * exploration_term
 
51
 
52
  def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
53
  global sftp_matchmaker_client
54
+ if not is_connected():
55
+ create_ssh_matchmaker_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
56
  data = {
57
  'ratings': [serialize_rating(r) for r in ratings],
58
  'comparison_counts': comparison_counts.tolist(),
 
64
 
65
  def load_json_via_sftp():
66
  global sftp_matchmaker_client
67
+ if not is_connected():
68
+ create_ssh_matchmaker_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
69
  with sftp_matchmaker_client.open(SSH_SKILL, 'r') as f:
70
  data = json.load(f)
71
  ratings = [deserialize_rating(r) for r in data['ratings']]
serve/gradio_web.py CHANGED
@@ -375,7 +375,8 @@ def build_side_by_side_ui_anony(models):
375
  vote_textbox.submit(
376
  disable_vote,
377
  inputs=None,
378
- outputs=[vote_submit_btn, vote_mode_btn]
 
379
  ).then(
380
  text_response_rank_igm,
381
  inputs=[generate_ig0, generate_ig1, generate_ig2, generate_ig3, Top1_text, Top2_text, Top3_text, Top4_text, vote_textbox],
@@ -391,7 +392,8 @@ def build_side_by_side_ui_anony(models):
391
  vote_submit_btn.click(
392
  disable_vote,
393
  inputs=None,
394
- outputs=[vote_submit_btn, vote_mode_btn]
 
395
  ).then(
396
  text_response_rank_igm,
397
  inputs=[generate_ig0, generate_ig1, generate_ig2, generate_ig3, Top1_text, Top2_text, Top3_text, Top4_text, vote_textbox],
 
375
  vote_textbox.submit(
376
  disable_vote,
377
  inputs=None,
378
+ outputs=[vote_submit_btn, vote_mode_btn, \
379
+ A1_btn, A2_btn, A3_btn, A4_btn, B1_btn, B2_btn, B3_btn, B4_btn, C1_btn, C2_btn, C3_btn, C4_btn, D1_btn, D2_btn, D3_btn, D4_btn]
380
  ).then(
381
  text_response_rank_igm,
382
  inputs=[generate_ig0, generate_ig1, generate_ig2, generate_ig3, Top1_text, Top2_text, Top3_text, Top4_text, vote_textbox],
 
392
  vote_submit_btn.click(
393
  disable_vote,
394
  inputs=None,
395
+ outputs=[vote_submit_btn, vote_mode_btn, \
396
+ A1_btn, A2_btn, A3_btn, A4_btn, B1_btn, B2_btn, B3_btn, B4_btn, C1_btn, C2_btn, C3_btn, C4_btn, D1_btn, D2_btn, D3_btn, D4_btn]
397
  ).then(
398
  text_response_rank_igm,
399
  inputs=[generate_ig0, generate_ig1, generate_ig2, generate_ig3, Top1_text, Top2_text, Top3_text, Top4_text, vote_textbox],
serve/update_skill.py CHANGED
@@ -19,8 +19,22 @@ def create_ssh_skill_client(server, port, user, password):
19
  ssh_skill_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
20
  ssh_skill_client.connect(server, port, user, password)
21
 
22
- sftp_skill_client = ssh_skill_client.open_sftp()
 
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def ucb_score(trueskill_diff, t, n):
25
  exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
26
  ucb = -trueskill_diff + 1.0 * exploration_term
@@ -39,6 +53,8 @@ def deserialize_rating(rating_dict):
39
 
40
  def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
41
  global sftp_skill_client
 
 
42
  data = {
43
  'ratings': [serialize_rating(r) for r in ratings],
44
  'comparison_counts': comparison_counts.tolist(),
@@ -50,6 +66,8 @@ def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
50
 
51
  def load_json_via_sftp():
52
  global sftp_skill_client
 
 
53
  with sftp_skill_client.open(SSH_SKILL, 'r') as f:
54
  data = json.load(f)
55
  ratings = [deserialize_rating(r) for r in data['ratings']]
 
19
  ssh_skill_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
20
  ssh_skill_client.connect(server, port, user, password)
21
 
22
+ transport = ssh_skill_client.get_transport()
23
+ transport.set_keepalive(60)
24
 
25
+ sftp_skill_client = ssh_skill_client.open_sftp()
26
+ def is_connected():
27
+ global ssh_skill_client, sftp_skill_client
28
+ # 检查SSH连接是否正常
29
+ if not ssh_skill_client.get_transport().is_active():
30
+ return False
31
+ # 检查SFTP连接是否正常
32
+ try:
33
+ sftp_skill_client.listdir('.') # 尝试列出根目录
34
+ except Exception as e:
35
+ print(f"Error checking SFTP connection: {e}")
36
+ return False
37
+ return True
38
  def ucb_score(trueskill_diff, t, n):
39
  exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
40
  ucb = -trueskill_diff + 1.0 * exploration_term
 
53
 
54
  def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
55
  global sftp_skill_client
56
+ if not is_connected():
57
+ create_ssh_skill_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
58
  data = {
59
  'ratings': [serialize_rating(r) for r in ratings],
60
  'comparison_counts': comparison_counts.tolist(),
 
66
 
67
  def load_json_via_sftp():
68
  global sftp_skill_client
69
+ if not is_connected():
70
+ create_ssh_skill_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
71
  with sftp_skill_client.open(SSH_SKILL, 'r') as f:
72
  data = json.load(f)
73
  ratings = [deserialize_rating(r) for r in data['ratings']]
serve/upload.py CHANGED
@@ -17,7 +17,22 @@ def create_ssh_client(server, port, user, password):
17
  ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
18
  ssh_client.connect(server, port, user, password)
19
 
 
 
 
20
  sftp_client = ssh_client.open_sftp()
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  def get_image_from_url(image_url):
23
  response = requests.get(image_url)
@@ -26,7 +41,8 @@ def get_image_from_url(image_url):
26
 
27
  def get_random_mscoco_prompt():
28
  global sftp_client
29
- create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
 
30
  num = random.randint(0, 2999)
31
  file = "{}.txt".format(num)
32
 
@@ -40,6 +56,8 @@ def get_random_mscoco_prompt():
40
 
41
  def create_remote_directory(remote_directory):
42
  global ssh_client
 
 
43
  stdin, stdout, stderr = ssh_client.exec_command(f'mkdir -p {SSH_LOG}/{remote_directory}')
44
  error = stderr.read().decode('utf-8')
45
  if error:
@@ -50,6 +68,8 @@ def create_remote_directory(remote_directory):
50
 
51
  def upload_ssh_all(states, output_dir, data, data_path):
52
  global sftp_client
 
 
53
  output_file_list = []
54
  image_list = []
55
  for i in range(len(states)):
@@ -71,4 +91,4 @@ def upload_ssh_all(states, output_dir, data, data_path):
71
  with io.BytesIO(json_data.encode('utf-8')) as json_byte_stream:
72
  sftp_client.putfo(json_byte_stream, data_path)
73
  print(f"Successfully uploaded JSON data to {data_path}")
74
- create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
 
17
  ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
18
  ssh_client.connect(server, port, user, password)
19
 
20
+ transport = ssh_client.get_transport()
21
+ transport.set_keepalive(60)
22
+
23
  sftp_client = ssh_client.open_sftp()
24
+ def is_connected():
25
+ global ssh_client, sftp_client
26
+ # 检查SSH连接是否正常
27
+ if not ssh_client.get_transport().is_active():
28
+ return False
29
+ # 检查SFTP连接是否正常
30
+ try:
31
+ sftp_client.listdir('.') # 尝试列出根目录
32
+ except Exception as e:
33
+ print(f"Error checking SFTP connection: {e}")
34
+ return False
35
+ return True
36
 
37
  def get_image_from_url(image_url):
38
  response = requests.get(image_url)
 
41
 
42
  def get_random_mscoco_prompt():
43
  global sftp_client
44
+ if not is_connected():
45
+ create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
46
  num = random.randint(0, 2999)
47
  file = "{}.txt".format(num)
48
 
 
56
 
57
  def create_remote_directory(remote_directory):
58
  global ssh_client
59
+ if not is_connected():
60
+ create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
61
  stdin, stdout, stderr = ssh_client.exec_command(f'mkdir -p {SSH_LOG}/{remote_directory}')
62
  error = stderr.read().decode('utf-8')
63
  if error:
 
68
 
69
  def upload_ssh_all(states, output_dir, data, data_path):
70
  global sftp_client
71
+ if not is_connected():
72
+ create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
73
  output_file_list = []
74
  image_list = []
75
  for i in range(len(states)):
 
91
  with io.BytesIO(json_data.encode('utf-8')) as json_byte_stream:
92
  sftp_client.putfo(json_byte_stream, data_path)
93
  print(f"Successfully uploaded JSON data to {data_path}")
94
+ # create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
serve/utils.py CHANGED
@@ -176,7 +176,7 @@ def enable_vote_buttons():
176
  def disable_vote_buttons():
177
  return tuple(gr.update(visible=False, interactive=False) for i in range(6))
178
  def disable_vote():
179
- return (gr.update(interactive=False), gr.update(interactive=False))
180
  def enable_vote_mode_buttons(mode):
181
  print(mode)
182
  if mode == "Best":
 
176
  def disable_vote_buttons():
177
  return tuple(gr.update(visible=False, interactive=False) for i in range(6))
178
  def disable_vote():
179
+ return (gr.update(interactive=False) for i in range(14))
180
  def enable_vote_mode_buttons(mode):
181
  print(mode)
182
  if mode == "Best":