NLP Course documentation

โมเดล

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

โมเดล

Ask a Question Open In Colab Open In Studio Lab

ใน section นี้ เราจะมาดูวิธีการสร้างและการใช้งานโมเดล เราจะใช้คลาส AutoModel ซึ่งเป็นประโยชน์มากหากเราต้องการสร้างโมเดลใดๆ จาก checkpoint หนึ่งๆ

คลาส AutoModel และส่วนประกอบของมันทั้งหมดนั้น จริงๆแล้วก็เป็นเพียง wrapper ของโมเดลต่างๆที่มีอยู่ใน library มันเป็น wrapper ที่ฉลาดโดยที่มันสามารถเดาสถาปัตยกรรมของโมเดลที่เหมาะสมสำหรับ checkpoint ของคุณได้ และสร้างโมเดลด้วยสถาปัตยกรรมนั้น

แต่อย่างไรก็ตาม ถ้าคุณรู้ว่าคุณต้องการใช้โมเดลประเภทใด คุณสามารถใช้คลาสที่นิยามสถาปัตยกรรมนั้นได้โดยตรง เรามาดูกันว่ามันทำงานยังไงกับโมเดล BERT

สร้าง Transformer

สิ่งแรกที่เราจำเป็นต้องทำในการเริ่มสร้างโมเดล BERT นั้นก็คือการโหลดวัตถุกำหนดค่า(configuration object):

from transformers import BertConfig, BertModel

# Building the config
config = BertConfig()

# Building the model from the config
model = BertModel(config)

ใน configuration นั้นประกอบด้วยค่าของคุณสมบัติ(attributes) หลายอย่างๆ ที่ใช้สำหรับสร้างโมเดล:

print(config)
BertConfig {
  [...]
  "hidden_size": 768,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  [...]
}

ในขณะที่คุณยังไม่เห็นว่าคุณสมบัติต่างๆ เหล่าทำอะไรบ้าง คุณน่าจะพอจำบางส่วนได้: hidden_size ที่นิยามขนาดของเวคเตอร์ hidden_states, และ num_hidden_layers ที่นิยามจำนวนของเลเยอร์ที่โมเดล Transformer มี

วิธีการต่างๆในการโหลด

สร้างโมเดลจาก configuration พื้นฐาน และตั้งค่าเริ่มต้นด้วยค่าสุ่ม(random values):

from transformers import BertConfig, BertModel

config = BertConfig()
model = BertModel(config)

# โมเดลถูกกำหนดค่าเริ่มต้นด้วยการสุ่ม!

โมเดลสามารถอยู่ในสถานะนี้ได้ แต่มันจะให้ผลลัพธ์ที่แย่ออกมา; มันจำเป็นต้องผ่านการเรียนรู้ก่อน เราสามารถเทรนโมเดลจากโมเดลเปล่าๆ กับงานที่เรามีได้ แต่อย่างที่คุณเห็นใน Chapter 1, มันใช้เวลานานและข้อมูลจำนวนมาก โดยที่ไม่ได้มีประโยชน์อะไรเพิ่มขึ้นมาก เพื่อลดขึ้นตอนที่ไม่จำเป็นต่างๆ มันสำคัญอย่างยิ่งที่เราจะสามารถแชร์และนำโมเดลที่ผ่านการเทรนมาแล้วมาใช้ใหม่

การโหลดโมเดล Transformer ที่ผ่านการเทรนมาแล้วนั้นง่ายมาก เราสามารถเรียกใช้ from_pretrained():

from transformers import BertModel

model = BertModel.from_pretrained("bert-base-cased")

เหมือนที่คุณเห็นก่อนหน้านี้ เราสามารถที่จะแทนค่า BertModel ด้วยคลาส AutoModel` ที่คล้ายคลึงกัน จากนี้ไปเราจะใช้วิธีการนี้เพื่อเป็นการสร้างโค้ด checkpoint-agnostic; ถ้าโค้ดของคุณสามารถใช้งานได้กับหนึ่ง checkpoint มันก็ควรที่จะสามารถใช้กับอันอื่นได้ด้วย ซึ่งก็รวมถึง ไม่ว่าสถาปัตยกรรมจะแตกต่างกัน ตราบใดที่ checkpoint นั้นถูกเทรนมาสำหรับงานที่เหมือนกันก็ควรใช้ได้เหมือนกัน (ยกตัวอย่างเช่น งาน sentiment analysis)

ในตัวอย่างโค้ดด้านบน เราไม่ได้ใช้ BertConfig, แต่ใช้งานโมเดลที่ผ่านการเทรนมาแล้ว(pretrained) ผ่าน bert-base-cased identifier ซึ่งนี่เป็น checkpoint ของโมเดลที่โดนเทรนด้วยผู้ที่ประดิษฐ์ BERT เอง; คุณสามารถดูรายละเอียดเพิ่มเติมได้ที่ model card.

ถึงตอนนี้โมเดลนี้ได้ถูกสร้างและมีค่าตั้งต้นเท่ากับ weights ของ checkpoint มันสามารถถูกนำไปใช้สำหรับการอนุมาน(inference)ได้ทันทีกับงานที่มันถูกเทรนมา และมันสามารถถูกนำมาปรับจูนเพิ่มเติมให้เข้ากับงานใหม่ได้ การเทรนโมเดลที่ใช้ weights ของโมเดลที่ผ่านการเทรนมาแล้ว แทนที่การเทรนจากไม่มีอะไรเลยนั้น ทำให้เราได้ผลลัพธ์ที่ดีในเวลาอันรวดเร็ว

weights ได้ถูกดาวน์โหลด และ เก็บไว้ในโฟลเดอร์ cache(เมื่อเราทำการเรียกใช้งาน from_pretrained() อีกในอนาคต weights เหล่านี้จะไม่ถูกดาวน์โหลดซ้ำอีก) โดยโฟลเดอร์มีค่าเริ่มต้น(default) อยู่ที่ ~/.cache/huggingface/transformers คุณสามารถปรับเปลี่ยนโฟลเดอร์ cache ได้โดยตั้งค่า HF_HOME ใน environment variable

identifier ที่ใช้สำหรับโหลดโมเดลสามารถใช้ identifier ของโมเดลใดก็ได้บน Model Hub ตราบใดที่มันเข้ากันได้กับสถาปัตยกรรม BERT ลิสท์ของ BERT checkpoints ทั้งหมดที่มีอยู่สามารถดูได้จาก ที่นี่.

วิธีสำหรับการบันทึก

การบันทึกโมเดลนั้นเป็นอะไรง่ายพอๆกับการโหลด - เราใช้ save_pretrained() ซึ่งก็เปรียบเสมือนกับ from_pretrained():

model.save_pretrained("directory_on_my_computer")

นี่เป็นการบันทึกสองไฟล์ลงไปที่ฮาร์ดดิสของคุณ:

ls directory_on_my_computer

config.json pytorch_model.bin

ุุ้ถ้าคุณไปดูที่ไฟล์ config.json คุณจะพอนึกออกถึงคุณสมบัติ(attributes) ที่จำเป็นในการสร้างสถาปัตยกรรมของโมเดล ไฟล์นี้ประกอบด้วย metadata เช่น checkpoint เกิดมาจากที่ใด และ 🤗 Transformers เวอร์ชันใดที่คุณใช้ในการบันทึก checkpoint ล่าสุด

ไฟล์ pytorch_model.bin เป็นที่รู้จักในนาม state dictionary; มันประกอบด้วย weights ทัั้งหมดของโมเดลคุณ สองไฟล์ที่มีความเชื่อมโยงกัน ไฟล์ configuration จำเป็นที่จะต้องรู้สถาปัตยกรรมของโมเดลของคุณ ในขณะที่ weights ของโมเดลคุณ ก็คือ ตัวแปร(parameters) ของโมเดลคุณ

ใชโมเดล Transformer สำหรับการอนุมาน(inference)

ุถึงตรงนี้คุณรู้วิธีการโหลดและบันทึกโมเดลแล้ว งั้นมาลองใช้มันทำนายอะไรบางอย่างดูกัน โมดล Transformer นั้นสามารถประมวลผลตัวเลขได้อย่างเดียว ซึ่งตัวเลขเหล่านี้ก็ได้มาจากการสร้างขึ้นมาโดยใช้ tokenizer แต่ก่อนที่เราจะไปอธิบายกันถึง tokenizer เรามาลองค้นหากันดูว่าอินพุตแบบไหนที่สามารถใส่เข้าไปในโมเดลได้บ้าง

Tokenizers นั้นสามารถที่จะแปลงอินพุตไปเป็น tensors ที่เหมาะสมสำหรับ framework นั้นๆ แต่เพื่อช่วยให้คุณเข้าใจสิ่งที่เกิดขึ้น เราจะมาดูกันว่าอะไรที่จำเป็นต้องทำก่อนที่เราจะส่งอินพุตเข้าไปในโมเดล

สมมติว่าเรามีคำ สอง สาม คำ:

sequences = ["Hello!", "Cool.", "Nice!"]

tokenizer จะทำการแปลงคำเหล่านี้ไปเป็นดัชนีคำศัพท์(vocabulary indices) ซึ่งปกติจะเรียกว่า input IDs โดยตอนนี้แต่ละคำกลายเป็นลิสท์ของตัวเลข ผลลัพธ์ที่ได้ก็คือ:

encoded_sequences = [
    [101, 7592, 999, 102],
    [101, 4658, 1012, 102],
    [101, 3835, 999, 102],
]

นี่เป็นลิสท์ของคำที่ผ่านการเข้ารหัส(encoded): a list of lists, Tensors สามารถมีขนาดเป็นสี่เหลี่ยมจตุรัสเท่านั้น(ลองนึกถึงแมทริกซ์), “array” นี้มีขนาดเป็นสี่เหลี่ยมจตุรัสอยู่แล้ว ดังนั้นการแปลงมันไปเป็น tensor นั้นง่ายมาก:

import torch

model_inputs = torch.tensor(encoded_sequences)

ใช้ tensors เป็นอินพุตเข้าไปยังโมเดล

การใช้งาน tensor กับโมเดลนั้นง่ายมากๆ - เราก็แค่เรียกโมเดลพร้อมกับใส่อินพุต:

output = model(model_inputs)

ในขณะที่โมเดลสามารถรับตัวแปร(arguments) ต่างๆได้มากมาย แค่ input IDs เท่านั้นที่จำเป็น เดี๋ยวเราจะอธิบายกันอีกทีว่าตัวแปรตัวอื่นๆเอาไว้ทำอะไร และจำเป็นต้องใช้เมื่อไหร่, แต่ขั้นแรกเราต้องเข้าใจ Tokenizers ที่ใช้สร้างอินพุตที่โมเดล Transformer สามารถเข้าใจได้ก่อน

< > Update on GitHub