Downloaded multitask_unity_large.pt, how can I use it

#14
by davidlandeo - opened

I downloaded the model multitask_unity_large.pt but I don't know how to use it, how to ask for

  • Speech-to-speech translation (S2ST)
  • Speech-to-text translation (S2TT)
  • Text-to-speech translation (T2ST)
  • Text-to-text translation (T2TT)
  • Automatic speech recognition (ASR) ?

You can pass the path to the model checkpoint to the Translator class instead of the checkpoint name. See instructions-to-run-inference-with-seamlessm4t-models for details

How it should look like?
I put all weights to local folder "seamless_weights" and pass
them to Translator class:
import torch
from seamless_communication.models.inference import Translator

translator = Translator("seamless_weights/multitask_unity_medium.pt",
vocoder_name_or_card="seamless_weights/vocoder_36langs.pt",
device=torch.device("cpu"))
it outputs me Value Error exeception:
ValueError: name must be a valid filename, but is 'seamless_weights/multitask_unity_medium.pt' instead.

What does it mean 'valid filename' in this context?
It is valid file name
os.path.exists('seamless_weights/multitask_unity_medium.pt')
True

Now I try to alter model cards seamless_communication/assets/cards/seamlessM4T_medium.yaml
to make them load the weights from the local folder
#checkpoint: "https://huggingface.co/facebook/seamless-m4t-medium/resolve/main/multitask_unity_medium.pt"
checkpoint: "/home/local/seamless_communication/seamless_weights/multitask_unity_medium.pt"
but got exeception that this path should be a valid uri...
It is quite unclear how to use local weights...

Update:

This seamless_communication models seems based on farseq2 library and use its classes for weights downloading
So to achieve the goal above we can implement subclasses for two fairseq2 classes:
from fairseq2.models.utils.model_loader import ModelLoader
from fairseq2.models.nllb.loader import NllbTokenizerLoader

Where get rid of check if path to weights are uri in fashion like this

try:
        # Load the checkpoint.
        uri = card.field("checkpoint").as_uri()
        pathname = self.download_manager.download_checkpoint(uri, card.name, force=force, progress=progress)
except AssetCardError:
        pathname = card.field("checkpoint").data

And use these subclasses in dedicated classes of seamless_communication.

But this doesn't fix a problem then weights are located on remote share...

Sign up or log in to comment