vumichien commited on
Commit
e70f81d
·
1 Parent(s): 57067dc

Update ldm/modules/encoders/modules.py

Browse files
Files changed (1) hide show
  1. ldm/modules/encoders/modules.py +4 -4
ldm/modules/encoders/modules.py CHANGED
@@ -61,7 +61,7 @@ class FrozenT5Embedder(AbstractEncoder):
61
  super().__init__()
62
  self.tokenizer = T5Tokenizer.from_pretrained(version)
63
  self.transformer = T5EncoderModel.from_pretrained(version)
64
- self.device = device
65
  self.max_length = max_length # TODO: typical value?
66
  if freeze:
67
  self.freeze()
@@ -98,7 +98,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
98
  assert layer in self.LAYERS
99
  self.tokenizer = CLIPTokenizer.from_pretrained(version)
100
  self.transformer = CLIPTextModel.from_pretrained(version)
101
- self.device = device
102
  self.max_length = max_length
103
  if freeze:
104
  self.freeze()
@@ -148,7 +148,7 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
148
  del model.visual
149
  self.model = model
150
 
151
- self.device = device
152
  self.max_length = max_length
153
  if freeze:
154
  self.freeze()
@@ -194,7 +194,7 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
194
 
195
 
196
  class FrozenCLIPT5Encoder(AbstractEncoder):
197
- def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
198
  clip_max_length=77, t5_max_length=77):
199
  super().__init__()
200
  self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
 
61
  super().__init__()
62
  self.tokenizer = T5Tokenizer.from_pretrained(version)
63
  self.transformer = T5EncoderModel.from_pretrained(version)
64
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
65
  self.max_length = max_length # TODO: typical value?
66
  if freeze:
67
  self.freeze()
 
98
  assert layer in self.LAYERS
99
  self.tokenizer = CLIPTokenizer.from_pretrained(version)
100
  self.transformer = CLIPTextModel.from_pretrained(version)
101
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
102
  self.max_length = max_length
103
  if freeze:
104
  self.freeze()
 
148
  del model.visual
149
  self.model = model
150
 
151
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
152
  self.max_length = max_length
153
  if freeze:
154
  self.freeze()
 
194
 
195
 
196
  class FrozenCLIPT5Encoder(AbstractEncoder):
197
+ def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
198
  clip_max_length=77, t5_max_length=77):
199
  super().__init__()
200
  self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)