File size: 505 Bytes
513aed0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import argparse
import torch
from common import flops_calculation_function
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
"--model-path",
type=str,
help="Path to models checkpoint (.pth file).",
)
args = parser.parse_args()
checkpoint = torch.load(args.model_path, map_location="cpu")
model = checkpoint["model"]
flops = flops_calculation_function(model, torch.ones(1, 3, 480, 480))
print(f"MMACs = {flops}") |