Surn commited on
Commit
20a0fad
·
1 Parent(s): feb9b54

Update to fix Collab launch

Browse files
app.py CHANGED
@@ -402,6 +402,27 @@ def ui(**kwargs):
402
  if __name__ == "__main__":
403
  parser = argparse.ArgumentParser()
404
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
  parser.add_argument(
406
  '--share', action='store_true', help='Share the gradio UI'
407
  )
@@ -418,6 +439,21 @@ if __name__ == "__main__":
418
  )
419
 
420
  args = parser.parse_args()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  UNLOAD_MODEL = args.unload_model
422
  MOVE_TO_CPU = args.unload_to_cpu
423
  if args.cache:
 
402
  if __name__ == "__main__":
403
  parser = argparse.ArgumentParser()
404
 
405
+ parser.add_argument(
406
+ '--listen',
407
+ type=str,
408
+ default='0.0.0.0' if 'SPACE_ID' in os.environ else '127.0.0.1',
409
+ help='IP to listen on for connections to Gradio',
410
+ )
411
+ parser.add_argument(
412
+ '--username', type=str, default='', help='Username for authentication'
413
+ )
414
+ parser.add_argument(
415
+ '--password', type=str, default='', help='Password for authentication'
416
+ )
417
+ parser.add_argument(
418
+ '--server_port',
419
+ type=int,
420
+ default=0,
421
+ help='Port to run the server listener on',
422
+ )
423
+ parser.add_argument(
424
+ '--inbrowser', action='store_true', help='Open in browser'
425
+ )
426
  parser.add_argument(
427
  '--share', action='store_true', help='Share the gradio UI'
428
  )
 
439
  )
440
 
441
  args = parser.parse_args()
442
+
443
+ launch_kwargs = {}
444
+ launch_kwargs['server_name'] = args.listen
445
+
446
+ if args.username and args.password:
447
+ launch_kwargs['auth'] = (args.username, args.password)
448
+ if args.server_port:
449
+ launch_kwargs['server_port'] = args.server_port
450
+ if args.inbrowser:
451
+ launch_kwargs['inbrowser'] = args.inbrowser
452
+ if args.share:
453
+ launch_kwargs['share'] = args.share
454
+ launch_kwargs['favicon_path']= "./assets/favicon.ico"
455
+
456
+
457
  UNLOAD_MODEL = args.unload_model
458
  MOVE_TO_CPU = args.unload_to_cpu
459
  if args.cache:
audiocraft/__init__.py CHANGED
@@ -7,4 +7,4 @@
7
  # flake8: noqa
8
  from . import data, modules, models
9
 
10
- __version__ = '0.0.2a1'
 
7
  # flake8: noqa
8
  from . import data, modules, models
9
 
10
+ __version__ = '0.0.2a2'
audiocraft/models/lm.py CHANGED
@@ -363,7 +363,8 @@ class LMModel(StreamingModule):
363
  logits = logits.permute(0, 1, 3, 2) # [B, K, card, T]
364
  logits = logits[..., -1] # [B x K x card]
365
 
366
- if use_sampling:
 
367
  probs = torch.softmax(logits / temp, dim=-1)
368
  if top_p > 0.0:
369
  next_token = utils.sample_top_p(probs, p=top_p)
 
363
  logits = logits.permute(0, 1, 3, 2) # [B, K, card, T]
364
  logits = logits[..., -1] # [B x K x card]
365
 
366
+ # Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
367
+ if use_sampling and temp > 0.0:
368
  probs = torch.softmax(logits / temp, dim=-1)
369
  if top_p > 0.0:
370
  next_token = utils.sample_top_p(probs, p=top_p)
audiocraft/models/musicgen.py CHANGED
@@ -36,13 +36,16 @@ class MusicGen:
36
  used to map audio to invertible discrete representations.
37
  lm (LMModel): Language model over discrete representations.
38
  """
39
- def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel):
40
  self.name = name
41
  self.compression_model = compression_model
42
  self.lm = lm
 
 
43
  self.device = next(iter(lm.parameters())).device
44
  self.generation_params: dict = {}
45
- self.set_generation_params(duration=15) # 15 seconds by default
 
46
  if self.device.type == 'cpu':
47
  self.autocast = TorchAutocast(enabled=False)
48
  else:
@@ -65,7 +68,7 @@ class MusicGen:
65
  return self.compression_model.channels
66
 
67
  @staticmethod
68
- def get_pretrained(name: str = 'melody', device='cuda'):
69
  """Return pretrained model, we provide four models:
70
  - small (300M), text to music, # see: https://huggingface.co/facebook/musicgen-small
71
  - medium (1.5B), text to music, # see: https://huggingface.co/facebook/musicgen-medium
@@ -73,6 +76,12 @@ class MusicGen:
73
  - large (3.3B), text to music, # see: https://huggingface.co/facebook/musicgen-large
74
  """
75
 
 
 
 
 
 
 
76
  if name == 'debug':
77
  # used only for unit tests
78
  compression_model = get_debug_compression_model(device)
@@ -97,7 +106,7 @@ class MusicGen:
97
  def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
98
  top_p: float = 0.0, temperature: float = 1.0,
99
  duration: float = 30.0, cfg_coef: float = 3.0,
100
- two_step_cfg: bool = False, rep_penalty: float = None):
101
  """Set the generation parameters for MusicGen.
102
 
103
  Args:
@@ -112,9 +121,11 @@ class MusicGen:
112
  are padded but seems to have little impact in practice.
113
  rep_penalty (float, optional): If set, use repetition penalty during generation. Not Implemented.
114
  """
115
- assert duration <= 30, "The MusicGen cannot generate more than 30 seconds"
 
 
116
  self.generation_params = {
117
- 'max_gen_len': int(duration * self.frame_rate),
118
  'use_sampling': use_sampling,
119
  'temp': temperature,
120
  'top_k': top_k,
@@ -123,6 +134,10 @@ class MusicGen:
123
  'two_step_cfg': two_step_cfg,
124
  }
125
 
 
 
 
 
126
  def generate_unconditional(self, num_samples: int, progress: bool = False) -> torch.Tensor:
127
  """Generate samples in an unconditional manner.
128
 
@@ -317,20 +332,79 @@ class MusicGen:
317
  Returns:
318
  torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
319
  """
 
 
 
 
320
  def _progress_callback(generated_tokens: int, tokens_to_generate: int):
321
- print(f'{generated_tokens: 6d} / {tokens_to_generate: 6d}', end='\r')
 
 
 
 
 
 
322
 
323
  if prompt_tokens is not None:
324
- assert self.generation_params['max_gen_len'] > prompt_tokens.shape[-1], \
325
  "Prompt is longer than audio to generate"
326
 
327
  callback = None
328
  if progress:
329
  callback = _progress_callback
330
 
331
- # generate by sampling from LM
332
- with self.autocast:
333
- gen_tokens = self.lm.generate(prompt_tokens, attributes, callback=callback, **self.generation_params)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
 
335
  # generate audio
336
  assert gen_tokens.dim() == 3
 
36
  used to map audio to invertible discrete representations.
37
  lm (LMModel): Language model over discrete representations.
38
  """
39
+ def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel, max_duration: float = 30):
40
  self.name = name
41
  self.compression_model = compression_model
42
  self.lm = lm
43
+ self.max_duration = max_duration
44
+ self.duration = 15.0 # default duration
45
  self.device = next(iter(lm.parameters())).device
46
  self.generation_params: dict = {}
47
+ self.set_generation_params(duration=self.duration) # 15 seconds by default
48
+ self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
49
  if self.device.type == 'cpu':
50
  self.autocast = TorchAutocast(enabled=False)
51
  else:
 
68
  return self.compression_model.channels
69
 
70
  @staticmethod
71
+ def get_pretrained(name: str = 'melody', device=None):
72
  """Return pretrained model, we provide four models:
73
  - small (300M), text to music, # see: https://huggingface.co/facebook/musicgen-small
74
  - medium (1.5B), text to music, # see: https://huggingface.co/facebook/musicgen-medium
 
76
  - large (3.3B), text to music, # see: https://huggingface.co/facebook/musicgen-large
77
  """
78
 
79
+ if device is None:
80
+ if torch.cuda.device_count():
81
+ device = 'cuda'
82
+ else:
83
+ device = 'cpu'
84
+
85
  if name == 'debug':
86
  # used only for unit tests
87
  compression_model = get_debug_compression_model(device)
 
106
  def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
107
  top_p: float = 0.0, temperature: float = 1.0,
108
  duration: float = 30.0, cfg_coef: float = 3.0,
109
+ two_step_cfg: bool = False, extend_stride: float = 18, rep_penalty: float = None):
110
  """Set the generation parameters for MusicGen.
111
 
112
  Args:
 
121
  are padded but seems to have little impact in practice.
122
  rep_penalty (float, optional): If set, use repetition penalty during generation. Not Implemented.
123
  """
124
+ assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
125
+ self.extend_stride = extend_stride
126
+ self.duration = duration
127
  self.generation_params = {
128
+ #'max_gen_len': int(duration * self.frame_rate),
129
  'use_sampling': use_sampling,
130
  'temp': temperature,
131
  'top_k': top_k,
 
134
  'two_step_cfg': two_step_cfg,
135
  }
136
 
137
+ def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
138
+ """Override the default progress callback."""
139
+ self._progress_callback = progress_callback
140
+
141
  def generate_unconditional(self, num_samples: int, progress: bool = False) -> torch.Tensor:
142
  """Generate samples in an unconditional manner.
143
 
 
332
  Returns:
333
  torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
334
  """
335
+ total_gen_len = int(self.duration * self.frame_rate)
336
+ max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
337
+ current_gen_offset: int = 0
338
+
339
  def _progress_callback(generated_tokens: int, tokens_to_generate: int):
340
+ generated_tokens += current_gen_offset
341
+ if self._progress_callback is not None:
342
+ # Note that total_gen_len might be quite wrong depending on the
343
+ # codebook pattern used, but with delay it is almost accurate.
344
+ self._progress_callback(generated_tokens, total_gen_len)
345
+ else:
346
+ print(f'{generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
347
 
348
  if prompt_tokens is not None:
349
+ assert max_prompt_len >= prompt_tokens.shape[-1], \
350
  "Prompt is longer than audio to generate"
351
 
352
  callback = None
353
  if progress:
354
  callback = _progress_callback
355
 
356
+ if self.duration <= self.max_duration:
357
+ # generate by sampling from LM, simple case.
358
+ with self.autocast:
359
+ gen_tokens = self.lm.generate(
360
+ prompt_tokens, attributes,
361
+ callback=callback, max_gen_len=total_gen_len, **self.generation_params)
362
+
363
+ else:
364
+ # now this gets a bit messier, we need to handle prompts,
365
+ # melody conditioning etc.
366
+ ref_wavs = [attr.wav['self_wav'] for attr in attributes]
367
+ all_tokens = []
368
+ if prompt_tokens is None:
369
+ prompt_length = 0
370
+ else:
371
+ all_tokens.append(prompt_tokens)
372
+ prompt_length = prompt_tokens.shape[-1]
373
+
374
+ stride_tokens = int(self.frame_rate * self.extend_stride)
375
+
376
+ while current_gen_offset + prompt_length < total_gen_len:
377
+ time_offset = current_gen_offset / self.frame_rate
378
+ chunk_duration = min(self.duration - time_offset, self.max_duration)
379
+ max_gen_len = int(chunk_duration * self.frame_rate)
380
+ for attr, ref_wav in zip(attributes, ref_wavs):
381
+ wav_length = ref_wav.length.item()
382
+ if wav_length == 0:
383
+ continue
384
+ # We will extend the wav periodically if it not long enough.
385
+ # we have to do it here rather than in conditioners.py as otherwise
386
+ # we wouldn't have the full wav.
387
+ initial_position = int(time_offset * self.sample_rate)
388
+ wav_target_length = int(self.max_duration * self.sample_rate)
389
+ print(initial_position / self.sample_rate, wav_target_length / self.sample_rate)
390
+ positions = torch.arange(initial_position,
391
+ initial_position + wav_target_length, device=self.device)
392
+ attr.wav['self_wav'] = WavCondition(
393
+ ref_wav[0][:, positions % wav_length],
394
+ torch.full_like(ref_wav[1], wav_target_length))
395
+ with self.autocast:
396
+ gen_tokens = self.lm.generate(
397
+ prompt_tokens, attributes,
398
+ callback=callback, max_gen_len=max_gen_len, **self.generation_params)
399
+ if prompt_tokens is None:
400
+ all_tokens.append(gen_tokens)
401
+ else:
402
+ all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
403
+ prompt_tokens = gen_tokens[:, :, stride_tokens:]
404
+ prompt_length = prompt_tokens.shape[-1]
405
+ current_gen_offset += stride_tokens
406
+
407
+ gen_tokens = torch.cat(all_tokens, dim=-1)
408
 
409
  # generate audio
410
  assert gen_tokens.dim() == 3
audiocraft/modules/transformer.py CHANGED
@@ -25,6 +25,22 @@ from xformers import ops
25
  from .rope import RotaryEmbedding
26
  from .streaming import StreamingModule
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  def _is_profiled() -> bool:
30
  # Return true if we are currently running with a xformers profiler activated.
@@ -75,14 +91,22 @@ def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float =
75
 
76
  def expand_repeated_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
77
  """torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers"""
78
- bs, slen, n_kv_heads, head_dim = x.shape
79
  if n_rep == 1:
80
  return x
81
- return (
82
- x[:, :, :, None, :]
83
- .expand(bs, slen, n_kv_heads, n_rep, head_dim)
84
- .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
85
- )
 
 
 
 
 
 
 
 
 
86
 
87
 
88
  class LayerScale(nn.Module):
@@ -210,6 +234,7 @@ class StreamingMultiheadAttention(StreamingModule):
210
  # Return a causal mask, accounting for potentially stored past keys/values
211
  # We actually return a bias for the attention score, as this has the same
212
  # convention both in the builtin MHA in Pytorch, and Xformers functions.
 
213
  if self.memory_efficient:
214
  from xformers.ops import LowerTriangularMask
215
  if current_steps == 1:
@@ -222,7 +247,7 @@ class StreamingMultiheadAttention(StreamingModule):
222
  return LowerTriangularMask()
223
  if self._streaming_state:
224
  past_keys = self._streaming_state['past_keys']
225
- past_steps = past_keys.shape[1]
226
  else:
227
  past_steps = 0
228
 
@@ -239,6 +264,7 @@ class StreamingMultiheadAttention(StreamingModule):
239
  torch.full([], float('-inf'), device=device, dtype=dtype))
240
 
241
  def _complete_kv(self, k, v):
 
242
  if self.cross_attention:
243
  # With cross attention we assume all keys and values
244
  # are already available, and streaming is with respect
@@ -247,20 +273,20 @@ class StreamingMultiheadAttention(StreamingModule):
247
  # Complete the key/value pair using the streaming state.
248
  if self._streaming_state:
249
  pk = self._streaming_state['past_keys']
250
- nk = torch.cat([pk, k], dim=1)
251
  if v is k:
252
  nv = nk
253
  else:
254
  pv = self._streaming_state['past_values']
255
- nv = torch.cat([pv, v], dim=1)
256
  else:
257
  nk = k
258
  nv = v
259
 
260
- assert nk.shape[1] == nv.shape[1]
261
  offset = 0
262
  if self.past_context is not None:
263
- offset = max(0, nk.shape[1] - self.past_context)
264
  if self._is_streaming:
265
  self._streaming_state['past_keys'] = nk[:, offset:]
266
  if v is not k:
@@ -272,6 +298,8 @@ class StreamingMultiheadAttention(StreamingModule):
272
  return nk, nv
273
 
274
  def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
 
 
275
  # Apply rope embeddings to query and key tensors.
276
  assert self.rope is not None
277
  if 'past_keys' in self._streaming_state:
@@ -292,6 +320,11 @@ class StreamingMultiheadAttention(StreamingModule):
292
  assert not is_causal, ("new param added in torch 2.0.1 not supported, "
293
  "use the causal args in the constructor.")
294
 
 
 
 
 
 
295
  dtype = query.dtype
296
  if self._is_streaming:
297
  assert self.causal or self.cross_attention, \
@@ -324,8 +357,7 @@ class StreamingMultiheadAttention(StreamingModule):
324
  if self.qk_layer_norm is True:
325
  q = self.q_layer_norm(q)
326
  k = self.k_layer_norm(k)
327
- # q, k, v = [rearrange(x, "b t (h d) -> (b h) t d", h=self.num_heads) for x in [q, k, v]]
328
- q, k, v = [rearrange(x, "b t (h d) -> b t h d", h=self.num_heads) for x in [q, k, v]]
329
  else:
330
  if not _is_profiled():
331
  # profiling breaks that propertysomehow.
@@ -333,7 +365,11 @@ class StreamingMultiheadAttention(StreamingModule):
333
  assert value is key, "specialized implementation"
334
  projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias)
335
  if self.kv_repeat == 1:
336
- packed = rearrange(projected, "b t (p h d) -> b t p h d", p=3, h=self.num_heads)
 
 
 
 
337
  q, k, v = ops.unbind(packed, dim=2)
338
  else:
339
  embed_dim = self.embed_dim
@@ -344,16 +380,16 @@ class StreamingMultiheadAttention(StreamingModule):
344
  end = start + per_head_dim * kv_heads
345
  k = projected[:, :, start: end]
346
  v = projected[:, :, end:]
347
- q = rearrange(q, "b t (h d) -> b t h d", h=self.num_heads)
348
- k = rearrange(k, "b t (h d) -> b t h d", h=kv_heads)
349
- v = rearrange(v, "b t (h d) -> b t h d", h=kv_heads)
350
 
351
  if self.qk_layer_norm is True:
352
  assert self.kv_repeat == 1
353
- q, k = [rearrange(x, "b t h d -> b t (h d)") for x in [q, k]]
354
  q = self.q_layer_norm(q)
355
  k = self.k_layer_norm(k)
356
- q, k = [rearrange(x, "b t (h d) -> b t h d", h=self.num_heads) for x in [q, k]]
357
  if self.rope:
358
  q, k = self._apply_rope(q, k)
359
  k, v = self._complete_kv(k, v)
@@ -364,7 +400,11 @@ class StreamingMultiheadAttention(StreamingModule):
364
  q, k, v = [x.float() for x in [q, k, v]]
365
  if self.memory_efficient:
366
  p = self.dropout if self.training else 0
367
- x = ops.memory_efficient_attention(q, k, v, attn_mask, p=p)
 
 
 
 
368
  else:
369
  # We include the dot product as float32, for consistency
370
  # with the other implementations that include that step
@@ -374,18 +414,21 @@ class StreamingMultiheadAttention(StreamingModule):
374
  # extend a bit the range of operations done in float32,
375
  # although this should make no difference.
376
  q = q / q.shape[-1] ** 0.5
 
 
377
  if self._is_streaming and self.safe_streaming and q.device.type == 'cuda':
378
  with torch.autocast(device_type=q.device.type, dtype=torch.float32):
379
- pre_w = torch.einsum("bqhc,bkhc->bhqk", q, k)
380
  else:
381
- pre_w = torch.einsum("bqhc,bkhc->bhqk", q, k)
382
  if attn_mask is not None:
383
  pre_w = pre_w + attn_mask
384
  w = torch.softmax(pre_w, dim=-1)
385
  w = F.dropout(w, self.dropout, training=self.training).to(v)
386
- x = torch.einsum("bhqk,bkhc->bqhc", w, v)
 
387
  x = x.to(dtype)
388
- x = rearrange(x, "b t h d -> b t (h d)", h=self.num_heads)
389
  x = self.out_proj(x)
390
  else:
391
  key, value = self._complete_kv(key, value)
 
25
  from .rope import RotaryEmbedding
26
  from .streaming import StreamingModule
27
 
28
+ _efficient_attention_backend: str = 'torch'
29
+
30
+
31
+ def set_efficient_attention_backend(backend: str = 'torch'):
32
+ # Using torch by default, it seems a bit faster on older P100 GPUs (~20% faster).
33
+ global _efficient_attention_backend
34
+ assert _efficient_attention_backend in ['xformers', 'torch']
35
+ _efficient_attention_backend = backend
36
+
37
+
38
+ def _get_attention_time_dimension() -> int:
39
+ if _efficient_attention_backend == 'torch':
40
+ return 2
41
+ else:
42
+ return 1
43
+
44
 
45
  def _is_profiled() -> bool:
46
  # Return true if we are currently running with a xformers profiler activated.
 
91
 
92
  def expand_repeated_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
93
  """torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers"""
 
94
  if n_rep == 1:
95
  return x
96
+ if _efficient_attention_backend == 'torch':
97
+ bs, n_kv_heads, slen, head_dim = x.shape
98
+ return (
99
+ x[:, :, None, :, :]
100
+ .expand(bs, n_kv_heads, n_rep, slen, head_dim)
101
+ .reshape(bs, n_kv_heads * n_rep, slen, head_dim)
102
+ )
103
+ else:
104
+ bs, slen, n_kv_heads, head_dim = x.shape
105
+ return (
106
+ x[:, :, :, None, :]
107
+ .expand(bs, slen, n_kv_heads, n_rep, head_dim)
108
+ .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
109
+ )
110
 
111
 
112
  class LayerScale(nn.Module):
 
234
  # Return a causal mask, accounting for potentially stored past keys/values
235
  # We actually return a bias for the attention score, as this has the same
236
  # convention both in the builtin MHA in Pytorch, and Xformers functions.
237
+ time_dim = _get_attention_time_dimension()
238
  if self.memory_efficient:
239
  from xformers.ops import LowerTriangularMask
240
  if current_steps == 1:
 
247
  return LowerTriangularMask()
248
  if self._streaming_state:
249
  past_keys = self._streaming_state['past_keys']
250
+ past_steps = past_keys.shape[time_dim]
251
  else:
252
  past_steps = 0
253
 
 
264
  torch.full([], float('-inf'), device=device, dtype=dtype))
265
 
266
  def _complete_kv(self, k, v):
267
+ time_dim = _get_attention_time_dimension()
268
  if self.cross_attention:
269
  # With cross attention we assume all keys and values
270
  # are already available, and streaming is with respect
 
273
  # Complete the key/value pair using the streaming state.
274
  if self._streaming_state:
275
  pk = self._streaming_state['past_keys']
276
+ nk = torch.cat([pk, k], dim=time_dim)
277
  if v is k:
278
  nv = nk
279
  else:
280
  pv = self._streaming_state['past_values']
281
+ nv = torch.cat([pv, v], dim=time_dim)
282
  else:
283
  nk = k
284
  nv = v
285
 
286
+ assert nk.shape[time_dim] == nv.shape[time_dim]
287
  offset = 0
288
  if self.past_context is not None:
289
+ offset = max(0, nk.shape[time_dim] - self.past_context)
290
  if self._is_streaming:
291
  self._streaming_state['past_keys'] = nk[:, offset:]
292
  if v is not k:
 
298
  return nk, nv
299
 
300
  def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
301
+ # TODO: fix and verify layout.
302
+ assert _efficient_attention_backend == 'xformers', 'Rope not supported with torch attn.'
303
  # Apply rope embeddings to query and key tensors.
304
  assert self.rope is not None
305
  if 'past_keys' in self._streaming_state:
 
320
  assert not is_causal, ("new param added in torch 2.0.1 not supported, "
321
  "use the causal args in the constructor.")
322
 
323
+ time_dim = _get_attention_time_dimension()
324
+ if time_dim == 2:
325
+ layout = "b h t d"
326
+ else:
327
+ layout = "b t h d"
328
  dtype = query.dtype
329
  if self._is_streaming:
330
  assert self.causal or self.cross_attention, \
 
357
  if self.qk_layer_norm is True:
358
  q = self.q_layer_norm(q)
359
  k = self.k_layer_norm(k)
360
+ q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
 
361
  else:
362
  if not _is_profiled():
363
  # profiling breaks that propertysomehow.
 
365
  assert value is key, "specialized implementation"
366
  projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias)
367
  if self.kv_repeat == 1:
368
+ if time_dim == 2:
369
+ bound_layout = "b h p t d"
370
+ else:
371
+ bound_layout = "b t p h d"
372
+ packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
373
  q, k, v = ops.unbind(packed, dim=2)
374
  else:
375
  embed_dim = self.embed_dim
 
380
  end = start + per_head_dim * kv_heads
381
  k = projected[:, :, start: end]
382
  v = projected[:, :, end:]
383
+ q = rearrange(q, f"b t (h d) -> {layout}", h=self.num_heads)
384
+ k = rearrange(k, f"b t (h d) -> {layout}", h=kv_heads)
385
+ v = rearrange(v, f"b t (h d) -> {layout}", h=kv_heads)
386
 
387
  if self.qk_layer_norm is True:
388
  assert self.kv_repeat == 1
389
+ q, k = [rearrange(x, f"{layout} -> b t (h d)") for x in [q, k]]
390
  q = self.q_layer_norm(q)
391
  k = self.k_layer_norm(k)
392
+ q, k = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k]]
393
  if self.rope:
394
  q, k = self._apply_rope(q, k)
395
  k, v = self._complete_kv(k, v)
 
400
  q, k, v = [x.float() for x in [q, k, v]]
401
  if self.memory_efficient:
402
  p = self.dropout if self.training else 0
403
+ if _efficient_attention_backend == 'torch':
404
+ x = torch.nn.functional.scaled_dot_product_attention(
405
+ q, k, v, is_causal=attn_mask is not None, dropout_p=p)
406
+ else:
407
+ x = ops.memory_efficient_attention(q, k, v, attn_mask, p=p)
408
  else:
409
  # We include the dot product as float32, for consistency
410
  # with the other implementations that include that step
 
414
  # extend a bit the range of operations done in float32,
415
  # although this should make no difference.
416
  q = q / q.shape[-1] ** 0.5
417
+ key_layout = layout.replace('t', 'k')
418
+ query_layout = layout
419
  if self._is_streaming and self.safe_streaming and q.device.type == 'cuda':
420
  with torch.autocast(device_type=q.device.type, dtype=torch.float32):
421
+ pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k)
422
  else:
423
+ pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k)
424
  if attn_mask is not None:
425
  pre_w = pre_w + attn_mask
426
  w = torch.softmax(pre_w, dim=-1)
427
  w = F.dropout(w, self.dropout, training=self.training).to(v)
428
+ # Key and value have the same format.
429
+ x = torch.einsum(f"b h t k, {key_layout} -> {layout}", w, v)
430
  x = x.to(dtype)
431
+ x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
432
  x = self.out_proj(x)
433
  else:
434
  key, value = self._complete_kv(key, value)