Spaces:
Running
Running
# Copyright 2024 the LlamaFactory team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import json | |
import math | |
import os | |
from typing import Any, Dict, List | |
from transformers.trainer import TRAINER_STATE_NAME | |
from .logging import get_logger | |
from .packages import is_matplotlib_available | |
if is_matplotlib_available(): | |
import matplotlib.figure | |
import matplotlib.pyplot as plt | |
logger = get_logger(__name__) | |
def smooth(scalars: List[float]) -> List[float]: | |
r""" | |
EMA implementation according to TensorBoard. | |
""" | |
if len(scalars) == 0: | |
return [] | |
last = scalars[0] | |
smoothed = [] | |
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function | |
for next_val in scalars: | |
smoothed_val = last * weight + (1 - weight) * next_val | |
smoothed.append(smoothed_val) | |
last = smoothed_val | |
return smoothed | |
def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figure": | |
r""" | |
Plots loss curves in LlamaBoard. | |
""" | |
plt.close("all") | |
plt.switch_backend("agg") | |
fig = plt.figure() | |
ax = fig.add_subplot(111) | |
steps, losses = [], [] | |
for log in trainer_log: | |
if log.get("loss", None): | |
steps.append(log["current_steps"]) | |
losses.append(log["loss"]) | |
ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original") | |
ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed") | |
ax.legend() | |
ax.set_xlabel("step") | |
ax.set_ylabel("loss") | |
return fig | |
def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None: | |
r""" | |
Plots loss curves and saves the image. | |
""" | |
plt.switch_backend("agg") | |
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f: | |
data = json.load(f) | |
for key in keys: | |
steps, metrics = [], [] | |
for i in range(len(data["log_history"])): | |
if key in data["log_history"][i]: | |
steps.append(data["log_history"][i]["step"]) | |
metrics.append(data["log_history"][i][key]) | |
if len(metrics) == 0: | |
logger.warning(f"No metric {key} to plot.") | |
continue | |
plt.figure() | |
plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original") | |
plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed") | |
plt.title("training {} of {}".format(key, save_dictionary)) | |
plt.xlabel("step") | |
plt.ylabel(key) | |
plt.legend() | |
figure_path = os.path.join(save_dictionary, "training_{}.png".format(key.replace("/", "_"))) | |
plt.savefig(figure_path, format="png", dpi=100) | |
print("Figure saved at:", figure_path) | |