File size: 839 Bytes
76b4794 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
from __future__ import annotations
from transformers import PretrainedConfig
from transformers import PreTrainedModel
from torch import nn
import torch
class FastTextJpConfig(PretrainedConfig):
model_type = "fast_text_jp"
def __init__(self, **kwargs):
super().__init__(**kwargs)
class FastTextJpModel(PreTrainedModel):
"""FastTextのEmbeddingを行います。
"""
config_class = FastTextJpConfig
def __init__(self, config: FastTextJpConfig):
super().__init__(config)
self.word_embeddings = nn.Embedding(config.vocab_size,
config.hidden_size)
def forward(self, input_ids, **kwargs):
return self.word_embeddings(torch.tensor([0]))
FastTextJpConfig.register_for_auto_class()
FastTextJpModel.register_for_auto_class("AutoModel")
|