Jonathan Malott commited on
Commit
1557730
·
1 Parent(s): eef0144

Updated sampling.py

Browse files
Files changed (3) hide show
  1. .gitignore +2 -2
  2. dalle/utils/sampling.py +11 -3
  3. minDALL-E +0 -1
.gitignore CHANGED
@@ -9,7 +9,7 @@ _exampleImages/
9
  _trash/
10
 
11
  1.3B.tar.gz
12
-
13
  stage1_last.ckpt
14
 
15
- stage2_last.ckpt
 
9
  _trash/
10
 
11
  1.3B.tar.gz
12
+
13
  stage1_last.ckpt
14
 
15
+ stage2_last.ckpt
dalle/utils/sampling.py CHANGED
@@ -68,6 +68,8 @@ def sampling(model: torch.nn.Module,
68
  pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len)
69
  pos_enc_tokens = get_positional_encoding(tokens, mode='1d')
70
 
 
 
71
  for cnt, h in enumerate(pbar):
72
  if code is None:
73
  code_ = None
@@ -93,8 +95,6 @@ def sampling(model: torch.nn.Module,
93
  else:
94
  past.append(present)
95
 
96
- st.session_state.bar = cnt/max_seq_len
97
-
98
  logits = cutoff_topk_logits(logits, top_k)
99
  probs = F.softmax(logits, dim=-1)
100
  probs = cutoff_topp_probs(probs, top_p)
@@ -102,6 +102,14 @@ def sampling(model: torch.nn.Module,
102
  idx = torch.multinomial(probs, num_samples=1).clone().detach()
103
  code = idx if code is None else torch.cat([code, idx], axis=1)
104
 
 
 
 
 
 
 
 
 
105
  del past
106
  return code
107
 
@@ -151,4 +159,4 @@ def sampling_igpt(model: torch.nn.Module,
151
  code = idx if code is None else torch.cat([code, idx], axis=1)
152
 
153
  del past
154
- return code
 
68
  pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len)
69
  pos_enc_tokens = get_positional_encoding(tokens, mode='1d')
70
 
71
+ #my_bar = st.progress(0)
72
+
73
  for cnt, h in enumerate(pbar):
74
  if code is None:
75
  code_ = None
 
95
  else:
96
  past.append(present)
97
 
 
 
98
  logits = cutoff_topk_logits(logits, top_k)
99
  probs = F.softmax(logits, dim=-1)
100
  probs = cutoff_topp_probs(probs, top_p)
 
102
  idx = torch.multinomial(probs, num_samples=1).clone().detach()
103
  code = idx if code is None else torch.cat([code, idx], axis=1)
104
 
105
+ #print(cnt/max_seq_len)
106
+ if(st.session_state.page != 0):
107
+ break
108
+
109
+ st.session_state.bar.progress(cnt/max_seq_len)
110
+
111
+ #my_bar.progress(cnt/max_seq_len)
112
+
113
  del past
114
  return code
115
 
 
159
  code = idx if code is None else torch.cat([code, idx], axis=1)
160
 
161
  del past
162
+ return code
minDALL-E DELETED
@@ -1 +0,0 @@
1
- Subproject commit e5480076b9634e9dc097e1892157ed2cf15a2f86