yuean commited on
Commit
384c499
·
1 Parent(s): 3121ac6

Update README.md

Browse files

基于图像分类模型"microsoft/resnet-50"进行微调,数据使用"yuean/EuroSAT-2750",使用Trainer进行简单训练:
checkpoint = "microsoft/resnet-50"
model = AutoModelForImageClassification.from_pretrained(
checkpoint,
num_labels=len(labels),
id2label=id2label,
label2id=label2id,
ignore_mismatched_sizes=True,
)
training_args = TrainingArguments(
output_dir="my_resnet50_model",
remove_unused_columns=False,
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=16,
gradient_accumulation_steps=4,
per_device_eval_batch_size=16,
num_train_epochs=3,
warmup_ratio=0.1,
logging_steps=10,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
push_to_hub=True,
)

accuracy = evaluate.load("accuracy")


def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
return accuracy.compute(predictions=predictions, references=labels)


trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
tokenizer=image_processor,
compute_metrics=compute_metrics,
)

trainer.train()

Files changed (1) hide show
  1. README.md +2 -0
README.md CHANGED
@@ -2,4 +2,6 @@
2
  metrics:
3
  - accuracy
4
  pipeline_tag: image-classification
 
 
5
  ---
 
2
  metrics:
3
  - accuracy
4
  pipeline_tag: image-classification
5
+ datasets:
6
+ - yuean/EuroSAT-2750
7
  ---