|
import torchvision.models as models |
|
from model.densenetccnl import * |
|
from model.unetnc import * |
|
from model.gienet import * |
|
|
|
|
|
def get_model(name, n_classes=1, filters=64,version=None,in_channels=3, is_batchnorm=True, norm='batch', model_path=None, use_sigmoid=True, layers=3,img_size=512): |
|
model = _get_model_instance(name) |
|
|
|
|
|
if name == 'dnetccnl': |
|
model = model(img_size=128, in_channels=in_channels, out_channels=n_classes, filters=32) |
|
elif name == 'dnetccnl512': |
|
model = model(img_size=img_size, in_channels=in_channels, out_channels=n_classes, filters=32) |
|
elif name == 'unetnc': |
|
model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7) |
|
elif name == 'gie': |
|
model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7) |
|
elif name == 'giecbam': |
|
model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7) |
|
elif name == 'gie2head': |
|
model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7) |
|
elif name == 'giemask': |
|
model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7) |
|
elif name == 'giemask2': |
|
model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7) |
|
elif name == 'giedilated': |
|
model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7) |
|
elif name == 'bmp': |
|
model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7) |
|
elif name == 'displacement': |
|
model = model(n_classes=2, num_filter=32, BatchNorm='GN', in_channels=5) |
|
return model |
|
|
|
def _get_model_instance(name): |
|
try: |
|
return { |
|
'dnetccnl': dnetccnl, |
|
'dnetccnl512': dnetccnl512, |
|
'unetnc': UnetGenerator, |
|
'gie':GieGenerator, |
|
'giecbam':GiecbamGenerator, |
|
'giedilated':DilatedSingleUnet, |
|
'gie2head':Gie2headGenerator, |
|
'giemask':GiemaskGenerator, |
|
'giemask2':Giemask2Generator, |
|
'bmp':BmpGenerator, |
|
}[name] |
|
except: |
|
print('Model {} not available'.format(name)) |
|
|