File size: 998 Bytes
bd7e8e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from tensorflow import keras

from maxim import maxim
from maxim.configs import MAXIM_CONFIGS


def Model(variant=None, input_resolution=(256, 256), **kw) -> keras.Model:
    """Factory function to easily create a Model variant like "S".

    Args:
      variant: UNet model variants. Options: 'S-1' | 'S-2' | 'S-3'
          | 'M-1' | 'M-2' | 'M-3'
      input_resolution: Size of the input images.
      **kw: Other UNet config dicts.

    Returns:
      The MAXIM model.
    """

    if variant is not None:
        config = MAXIM_CONFIGS[variant]
        for k, v in config.items():
            kw.setdefault(k, v)

    if "variant" in kw:
        _ = kw.pop("variant")
    if "input_resolution" in kw:
        _ = kw.pop("input_resolution")
    model_name = kw.pop("name")

    maxim_model = maxim.MAXIM(**kw)

    inputs = keras.Input((*input_resolution, 3))
    outputs = maxim_model(inputs)
    final_model = keras.Model(inputs, outputs, name=f"{model_name}_model")

    return final_model