|
import sagemaker |
|
import boto3 |
|
from sagemaker.huggingface import HuggingFace |
|
|
|
|
|
iam_client = boto3.client('iam') |
|
role = iam_client.get_role(RoleName='{IAM_ROLE_WITH_SAGEMAKER_PERMISSIONS}')['Role']['Arn'] |
|
hyperparameters = { |
|
'model_name_or_path':'dalle-mini/vqgan_imagenet_f16_16384', |
|
'output_dir':'/opt/ml/model' |
|
|
|
|
|
} |
|
|
|
|
|
git_config = {'repo': 'https://github.com/huggingface/transformers.git','branch': 'v4.17.0'} |
|
|
|
|
|
huggingface_estimator = HuggingFace( |
|
entry_point='run_mlm.py', |
|
source_dir='./examples/pytorch/language-modeling', |
|
instance_type='ml.p3.2xlarge', |
|
instance_count=1, |
|
role=role, |
|
git_config=git_config, |
|
transformers_version='4.17.0', |
|
pytorch_version='1.10.2', |
|
py_version='py38', |
|
hyperparameters = hyperparameters |
|
) |
|
|
|
|
|
huggingface_estimator.fit() |