File size: 505 Bytes
513aed0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import argparse
import torch

from common import flops_calculation_function

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model-path",
        type=str,
        help="Path to models checkpoint (.pth file).",
    )
    args = parser.parse_args()

    checkpoint = torch.load(args.model_path, map_location="cpu")
    model = checkpoint["model"]

    flops = flops_calculation_function(model, torch.ones(1, 3, 480, 480))

    print(f"MMACs = {flops}")