multimodalart HF staff commited on
Commit
31024de
·
1 Parent(s): a47c17f

Resize accordingly

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -128,6 +128,8 @@ def train(*inputs):
128
  if os.path.exists("model.ckpt"): os.remove("model.ckpt")
129
  if os.path.exists("hastrained.success"): os.remove("hastrained.success")
130
  file_counter = 0
 
 
131
  for i, input in enumerate(inputs):
132
  if(i < maximum_concepts-1):
133
  if(input):
@@ -139,7 +141,7 @@ def train(*inputs):
139
  for j, file_temp in enumerate(files):
140
  file = Image.open(file_temp.name)
141
  image = pad_image(file)
142
- image = image.resize((512, 512))
143
  extension = file_temp.name.split(".")[1]
144
  image = image.convert('RGB')
145
  image.save(f'instance_images/{prompt}_({j+1}).jpg', format="JPEG", quality = 100)
@@ -150,7 +152,7 @@ def train(*inputs):
150
  type_of_thing = inputs[-4]
151
  remove_attribution_after = inputs[-6]
152
  experimental_face_improvement = inputs[-9]
153
- which_model = inputs[-10]
154
  if(uses_custom):
155
  Training_Steps = int(inputs[-3])
156
  Train_text_encoder_for = int(inputs[-2])
@@ -172,7 +174,6 @@ def train(*inputs):
172
 
173
  stptxt = int((Training_Steps*Train_text_encoder_for)/100)
174
  gradient_checkpointing = False if which_model == "v1-5" else True
175
- resolution = 512 if which_model != "v2-768" else 768
176
  cache_latents = True if which_model != "v1-5" else False
177
  if (type_of_thing == "object" or type_of_thing == "style" or (type_of_thing == "person" and not experimental_face_improvement)):
178
  args_general = argparse.Namespace(
 
128
  if os.path.exists("model.ckpt"): os.remove("model.ckpt")
129
  if os.path.exists("hastrained.success"): os.remove("hastrained.success")
130
  file_counter = 0
131
+ which_model = inputs[-10]
132
+ resolution = 512 if which_model != "v2-768" else 768
133
  for i, input in enumerate(inputs):
134
  if(i < maximum_concepts-1):
135
  if(input):
 
141
  for j, file_temp in enumerate(files):
142
  file = Image.open(file_temp.name)
143
  image = pad_image(file)
144
+ image = image.resize((resolution, resolution))
145
  extension = file_temp.name.split(".")[1]
146
  image = image.convert('RGB')
147
  image.save(f'instance_images/{prompt}_({j+1}).jpg', format="JPEG", quality = 100)
 
152
  type_of_thing = inputs[-4]
153
  remove_attribution_after = inputs[-6]
154
  experimental_face_improvement = inputs[-9]
155
+
156
  if(uses_custom):
157
  Training_Steps = int(inputs[-3])
158
  Train_text_encoder_for = int(inputs[-2])
 
174
 
175
  stptxt = int((Training_Steps*Train_text_encoder_for)/100)
176
  gradient_checkpointing = False if which_model == "v1-5" else True
 
177
  cache_latents = True if which_model != "v1-5" else False
178
  if (type_of_thing == "object" or type_of_thing == "style" or (type_of_thing == "person" and not experimental_face_improvement)):
179
  args_general = argparse.Namespace(