jeremyLE-Ekimetrics commited on
Commit
9fcd62f
1 Parent(s): 7a7548d
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
  license: cc-by-4.0
3
- sdk: gradio
 
4
  colorFrom: blue
5
  pinned: false
6
  title: Biomap
7
  emoji: 🐢
8
  colorTo: green
9
- app_file: biomap/app.py
10
  ---
11
 
12
  # Welcome to the project inno-satellite-images-segmentation-gan
 
1
  ---
2
  license: cc-by-4.0
3
+ sdk: streamlit
4
+ sdk_version: 1.25.0
5
  colorFrom: blue
6
  pinned: false
7
  title: Biomap
8
  emoji: 🐢
9
  colorTo: green
10
+ app_file: biomap/streamlit_app.py
11
  ---
12
 
13
  # Welcome to the project inno-satellite-images-segmentation-gan
biomap/Untitled.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
biomap/checkpoint/model/model.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
  oid sha256:106fe1ea7f4f0819e360823374bce7840a1a150b39a2e45090612c159a25dfca
3
- size 95521785
 
1
  version https://git-lfs.github.com/spec/v1
2
  oid sha256:106fe1ea7f4f0819e360823374bce7840a1a150b39a2e45090612c159a25dfca
3
+ size 95521785
biomap/helper.py CHANGED
@@ -1,16 +1,21 @@
1
  import torch.multiprocessing
2
  import torchvision.transforms as T
3
  import numpy as np
4
- from utils import transform_to_pil, compute_biodiv_score, plot_imgs_labels
5
  from utils_gee import get_image
6
  from dateutil.relativedelta import relativedelta
 
 
7
  import datetime
8
  import matplotlib as mpl
9
  from joblib import Parallel, cpu_count, delayed
10
  import logging
11
  from inference import inference
 
 
12
 
13
- def inference_on_location(model, latitude=2.98, longitude=48.81, start_date=2020, end_date=2022, how="year"):
 
14
  """Performe an inference on the latitude and longitude between the start date and the end date
15
 
16
  Args:
@@ -47,6 +52,7 @@ def inference_on_location(model, latitude=2.98, longitude=48.81, start_date=2020
47
  dates = [d.strftime("%Y-%m-%d") for d in dates]
48
 
49
  all_image = Parallel(n_jobs=cpu_count(), prefer="threads")(delayed(get_image)(location, d1,d2) for d1, d2 in zip(dates[:-1],dates[1:]))
 
50
  outputs = inference(np.array(all_image), model)
51
 
52
  logging.info("Calculating Biodiversity Scores...")
@@ -61,8 +67,8 @@ def inference_on_location(model, latitude=2.98, longitude=48.81, start_date=2020
61
  # fig.save("test.png")
62
  return fig
63
 
64
-
65
- def inference_on_location_and_month(model, latitude = 2.98, longitude = 48.81, start_date = '2020-03-20'):
66
  """Performe an inference on the latitude and longitude between the start date and the end date
67
 
68
  Args:
@@ -83,7 +89,6 @@ def inference_on_location_and_month(model, latitude = 2.98, longitude = 48.81, s
83
  end_date = datetime.datetime.strptime(start_date, "%Y-%m-%d") + relativedelta(months=1)
84
  end_date = datetime.datetime.strftime(end_date, "%Y-%m-%d")
85
 
86
- logging.info("Getting Image...")
87
  img_test = get_image(location, start_date, end_date)
88
  outputs = inference(np.array([img_test]), model)
89
 
@@ -91,8 +96,8 @@ def inference_on_location_and_month(model, latitude = 2.98, longitude = 48.81, s
91
  score, score_details = compute_biodiv_score(outputs[0]["linear_preds"].detach().numpy())
92
  logging.info(f"Calculated Biodiversity Score : {score}")
93
  img, label, labeled_img = transform_to_pil(outputs[0])
94
-
95
- return img, labeled_img, score
96
 
97
 
98
  if __name__ == "__main__":
 
1
  import torch.multiprocessing
2
  import torchvision.transforms as T
3
  import numpy as np
4
+ from utils import transform_to_pil, compute_biodiv_score, plot_imgs_labels, plot_image
5
  from utils_gee import get_image
6
  from dateutil.relativedelta import relativedelta
7
+
8
+ from model import LitUnsupervisedSegmenter
9
  import datetime
10
  import matplotlib as mpl
11
  from joblib import Parallel, cpu_count, delayed
12
  import logging
13
  from inference import inference
14
+ import streamlit as st
15
+ import cv2
16
 
17
+ @st.cache_data(hash_funcs={LitUnsupervisedSegmenter: lambda dt: dt.name})
18
+ def inference_on_location(model, longitude=2.98, latitude=48.81, start_date=2020, end_date=2022, how="year"):
19
  """Performe an inference on the latitude and longitude between the start date and the end date
20
 
21
  Args:
 
52
  dates = [d.strftime("%Y-%m-%d") for d in dates]
53
 
54
  all_image = Parallel(n_jobs=cpu_count(), prefer="threads")(delayed(get_image)(location, d1,d2) for d1, d2 in zip(dates[:-1],dates[1:]))
55
+ # all_image = [cv2.imread("output/img.png") for i in range(len(dates))]
56
  outputs = inference(np.array(all_image), model)
57
 
58
  logging.info("Calculating Biodiversity Scores...")
 
67
  # fig.save("test.png")
68
  return fig
69
 
70
+ @st.cache_data(hash_funcs={LitUnsupervisedSegmenter: lambda dt: dt.name})
71
+ def inference_on_location_and_month(model, longitude = 2.98, latitude = 48.81, start_date = '2020-03-20'):
72
  """Performe an inference on the latitude and longitude between the start date and the end date
73
 
74
  Args:
 
89
  end_date = datetime.datetime.strptime(start_date, "%Y-%m-%d") + relativedelta(months=1)
90
  end_date = datetime.datetime.strftime(end_date, "%Y-%m-%d")
91
 
 
92
  img_test = get_image(location, start_date, end_date)
93
  outputs = inference(np.array([img_test]), model)
94
 
 
96
  score, score_details = compute_biodiv_score(outputs[0]["linear_preds"].detach().numpy())
97
  logging.info(f"Calculated Biodiversity Score : {score}")
98
  img, label, labeled_img = transform_to_pil(outputs[0])
99
+ fig = plot_image([start_date], [np.asarray(img)], [np.asarray(labeled_img)], [score_details], [score])
100
+ return fig
101
 
102
 
103
  if __name__ == "__main__":
biomap/inference.py CHANGED
@@ -1,6 +1,7 @@
1
  import torch.multiprocessing
2
  import torchvision.transforms as T
3
  from utils import transform_to_pil
 
4
 
5
  preprocess = T.Compose(
6
  [
@@ -13,6 +14,7 @@ preprocess = T.Compose(
13
  )
14
 
15
  def inference(images, model):
 
16
  x = torch.stack([preprocess(image) for image in images]).cpu()
17
 
18
  with torch.no_grad():
@@ -47,7 +49,7 @@ if __name__ == "__main__":
47
  cfg = hydra.compose(config_name="my_train_config.yml")
48
 
49
  # Load the model
50
- model_path = "biomap/checkpoint/model/model.pt"
51
  saved_state_dict = torch.load(model_path, map_location=torch.device("cpu"))
52
 
53
  nbclasses = cfg.dir_dataset_n_classes
 
1
  import torch.multiprocessing
2
  import torchvision.transforms as T
3
  from utils import transform_to_pil
4
+ import logging
5
 
6
  preprocess = T.Compose(
7
  [
 
14
  )
15
 
16
  def inference(images, model):
17
+ logging.info("Inference on Images")
18
  x = torch.stack([preprocess(image) for image in images]).cpu()
19
 
20
  with torch.no_grad():
 
49
  cfg = hydra.compose(config_name="my_train_config.yml")
50
 
51
  # Load the model
52
+ model_path = "checkpoint/model/model.pt"
53
  saved_state_dict = torch.load(model_path, map_location=torch.device("cpu"))
54
 
55
  nbclasses = cfg.dir_dataset_n_classes
biomap/model.py CHANGED
@@ -10,6 +10,7 @@ import unet
10
  class LitUnsupervisedSegmenter(pl.LightningModule):
11
  def __init__(self, n_classes, cfg):
12
  super().__init__()
 
13
  self.cfg = cfg
14
  self.n_classes = n_classes
15
 
 
10
  class LitUnsupervisedSegmenter(pl.LightningModule):
11
  def __init__(self, n_classes, cfg):
12
  super().__init__()
13
+ self.name = "LitUnsupervisedSegmenter"
14
  self.cfg = cfg
15
  self.n_classes = n_classes
16
 
biomap/streamlit_app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit_folium import st_folium
3
+ import folium
4
+ import logging
5
+ import sys
6
+ import hydra
7
+ from plot_functions import *
8
+ import hydra
9
+
10
+ import torch
11
+ from model import LitUnsupervisedSegmenter
12
+ from helper import inference_on_location_and_month, inference_on_location
13
+
14
+ DEFAULT_LATITUDE = 48.81
15
+ DEFAULT_LONGITUDE = 2.98
16
+ DEFAULT_ZOOM = 5
17
+
18
+ MIN_YEAR = 2018
19
+ MAX_YEAR = 2024
20
+
21
+ FOLIUM_WIDTH = 925
22
+ FOLIUM_HEIGHT = 600
23
+
24
+
25
+ st.set_page_config(layout="wide")
26
+ @st.cache_resource
27
+ def init_cfg(cfg_name):
28
+ hydra.initialize(config_path="configs", job_name="corine")
29
+ return hydra.compose(config_name=cfg_name)
30
+
31
+ @st.cache_resource
32
+ def init_app(cfg_name) -> LitUnsupervisedSegmenter:
33
+ file_handler = logging.FileHandler(filename='biomap.log')
34
+ stdout_handler = logging.StreamHandler(stream=sys.stdout)
35
+ handlers = [file_handler, stdout_handler]
36
+
37
+ logging.basicConfig(handlers=handlers, encoding='utf-8', level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
38
+ # # Initialize hydra with configs
39
+ # GlobalHydra.instance().clear()
40
+
41
+ cfg = init_cfg(cfg_name)
42
+ logging.info(f"config : {cfg}")
43
+ nbclasses = cfg.dir_dataset_n_classes
44
+ model = LitUnsupervisedSegmenter(nbclasses, cfg)
45
+ model = model.cpu()
46
+ logging.info(f"Model Initialiazed")
47
+
48
+ model_path = "biomap/checkpoint/model/model.pt"
49
+ saved_state_dict = torch.load(model_path, map_location=torch.device("cpu"))
50
+ logging.info(f"Model weights Loaded")
51
+ model.load_state_dict(saved_state_dict)
52
+ return model
53
+
54
+ def app(model):
55
+ if "infered" not in st.session_state:
56
+ st.session_state["infered"] = False
57
+
58
+ st.markdown("<h1 style='text-align: center;'>🐢 Biomap by Ekimetrics 🐢</h1>", unsafe_allow_html=True)
59
+ st.markdown("<h2 style='text-align: center;'>Estimate Biodiversity score in the world with the help of land use.</h2>", unsafe_allow_html=True)
60
+ st.markdown("<p style='text-align: center;'>The segmentation is an association of UNet and DinoV1 trained on the dataset CORINE.</p>", unsafe_allow_html=True)
61
+ st.markdown("<p style='text-align: center;'>Land use is divided into 6 differents classes :Each class is assigned a GBS score from 0 to 1</p>", unsafe_allow_html=True)
62
+ st.markdown("<p style='text-align: center;'>Buildings : 0.1 | Infrastructure : 0.1 | Cultivation : 0.4 | Wetland : 0.9 | Water : 0.9 | Natural green : 1 </p>", unsafe_allow_html=True)
63
+ st.markdown("<p style='text-align: center;'>The score is then average on the full image.</p>", unsafe_allow_html=True)
64
+
65
+
66
+
67
+ col_1, col_2 = st.columns([0.5,0.5])
68
+ with col_1:
69
+ m = folium.Map(location=[DEFAULT_LATITUDE, DEFAULT_LONGITUDE], zoom_start=DEFAULT_ZOOM)
70
+
71
+ # The code below will be responsible for displaying
72
+ # the popup with the latitude and longitude shown
73
+ m.add_child(folium.LatLngPopup())
74
+ f_map = st_folium(m, width=FOLIUM_WIDTH, height=FOLIUM_HEIGHT)
75
+
76
+
77
+ selected_latitude = DEFAULT_LATITUDE
78
+ selected_longitude = DEFAULT_LONGITUDE
79
+
80
+ if f_map.get("last_clicked"):
81
+ selected_latitude = f_map["last_clicked"]["lat"]
82
+ selected_longitude = f_map["last_clicked"]["lng"]
83
+
84
+ with col_2:
85
+ tabs1, tabs2 = st.tabs(["TimeLapse", "Single Image"])
86
+ with tabs1:
87
+ lat = st.text_input("lattitude", value=selected_latitude)
88
+ long = st.text_input("longitude", value=selected_longitude)
89
+
90
+
91
+ years = list(range(MIN_YEAR, MAX_YEAR, 1))
92
+ start_date = st.selectbox("Start date", years)
93
+
94
+ end_years = [year for year in years if year > start_date]
95
+ end_date = st.selectbox("End date", end_years)
96
+
97
+ segment_interval = st.radio("Interval of time between two segmentation", options=['month','2months', 'year'],horizontal=True)
98
+ submit = st.button("Predict TimeLapse", use_container_width=True)
99
+ with tabs2:
100
+ lat = st.text_input("lat.", value=selected_latitude)
101
+ long = st.text_input("long.", value=selected_longitude)
102
+
103
+ date = st.text_input("date", "2021-01-01", placeholder="2021-01-01")
104
+
105
+ submit2 = st.button("Predict Single Image", use_container_width=True)
106
+
107
+
108
+ if submit:
109
+ fig = inference_on_location(model, lat, long, start_date, end_date, segment_interval)
110
+ st.session_state["infered"] = True
111
+ st.session_state["previous_fig"] = fig
112
+
113
+ if submit2:
114
+ fig = inference_on_location_and_month(model, lat, long, date)
115
+ st.session_state["infered"] = True
116
+ st.session_state["previous_fig"] = fig
117
+
118
+ if st.session_state["infered"]:
119
+ st.plotly_chart(st.session_state["previous_fig"], use_container_width=True)
120
+
121
+
122
+
123
+ if __name__ == "__main__":
124
+ model = init_app("my_train_config.yml")
125
+ app(model)
biomap/utils copy.py ADDED
@@ -0,0 +1,675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import os
3
+ from os.path import join
4
+ import io
5
+
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import torch.multiprocessing
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import wget
12
+
13
+ import datetime
14
+
15
+ from dateutil.relativedelta import relativedelta
16
+ from PIL import Image
17
+ from scipy.optimize import linear_sum_assignment
18
+ from torch._six import string_classes
19
+ from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format
20
+ from torchmetrics import Metric
21
+ from torchvision import models
22
+ from torchvision import transforms as T
23
+ from torch.utils.tensorboard.summary import hparams
24
+ import matplotlib as mpl
25
+ from PIL import Image
26
+
27
+ import matplotlib as mpl
28
+
29
+ import torch.multiprocessing
30
+ import torchvision.transforms as T
31
+
32
+ import plotly.graph_objects as go
33
+ import plotly.express as px
34
+ import numpy as np
35
+ from plotly.subplots import make_subplots
36
+
37
+ import os
38
+ os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
39
+
40
+ colors = ("red", "palegreen", "green", "steelblue", "blue", "yellow", "lightgrey")
41
+ class_names = ('Buildings', 'Cultivation', 'Natural green', 'Wetland', 'Water', 'Infrastructure', 'Background')
42
+ mapping_class = {
43
+ "Buildings": 1,
44
+ "Cultivation": 2,
45
+ "Natural green": 3,
46
+ "Wetland": 4,
47
+ "Water": 5,
48
+ "Infrastructure": 6,
49
+ "Background": 0,
50
+ }
51
+
52
+ score_attribution = {
53
+ "Buildings" : 0.,
54
+ "Cultivation": 0.3,
55
+ "Natural green": 1.,
56
+ "Wetland": 0.9,
57
+ "Water": 0.9,
58
+ "Infrastructure": 0.,
59
+ "Background": 0.
60
+ }
61
+ bounds = list(np.arange(len(mapping_class.keys()) + 1) + 1)
62
+ cmap = mpl.colors.ListedColormap(colors)
63
+ norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
64
+
65
+ def compute_biodiv_score(class_image):
66
+ """Compute the biodiversity score of an image
67
+
68
+ Args:
69
+ image (_type_): _description_
70
+
71
+ Returns:
72
+ biodiversity_score: the biodiversity score associated to the landscape of the image
73
+ """
74
+ score_matrice = class_image.copy().astype(int)
75
+ for key in mapping_class.keys():
76
+ score_matrice = np.where(score_matrice==mapping_class[key], score_attribution[key], score_matrice)
77
+ number_of_pixel = np.prod(list(score_matrice.shape))
78
+ score = np.sum(score_matrice)/number_of_pixel
79
+ score_details = {
80
+ key: np.sum(np.where(class_image == mapping_class[key], 1, 0))
81
+ for key in mapping_class.keys()
82
+ if key not in ["background"]
83
+ }
84
+ return score, score_details
85
+
86
+ def plot_imgs_labels(months, imgs, imgs_label, nb_values, scores) :
87
+ scores = [0.89, 0.70, 0.3, 0.2]
88
+
89
+ # fig2 = px.imshow(np.array(imgs), animation_frame=0, binary_string=True)
90
+ # fig3 = px.imshow(np.array(imgs_label), animation_frame=0, binary_string=True)
91
+
92
+ # # Scores
93
+ # scatters = [go.Scatter(
94
+ # x=months[:i+1],
95
+ # y=scores[:i+1],
96
+ # mode="lines+markers+text",
97
+ # marker_color="black",
98
+ # text = [f"{score:.4f}" for score in scores[:i+1]],
99
+ # textposition="top center",
100
+
101
+ # ) for i in range(len(scores))]
102
+
103
+
104
+
105
+ # fig.add_trace(fig2["frames"][0]["data"][0], row=1, col=1)
106
+ # fig.add_trace(fig3["frames"][0]["data"][0], row=1, col=2)
107
+
108
+ # fig.add_trace(go.Pie(labels = class_names,
109
+ # values = [nb_values[0][key] for key in mapping_class.keys()],
110
+ # marker_colors = colors,
111
+ # name="Segment repartition",
112
+ # textposition='inside',
113
+ # texttemplate = "%{percent:.0%}",
114
+ # textfont_size=14
115
+ # ),
116
+ # row=1, col=3)
117
+
118
+
119
+ # fig.add_trace(scatters[0], row=1, col=4)
120
+ # # fig.update_traces(selector=dict(type='scatter'))
121
+
122
+ # number_frames = len(imgs)
123
+ # frames = [dict(
124
+ # name = k,
125
+ # data = [ fig2["frames"][k]["data"][0],
126
+ # fig3["frames"][k]["data"][0],
127
+ # go.Pie(labels = class_names,
128
+ # values = [nb_values[k][key] for key in mapping_class.keys()],
129
+ # marker_colors = colors,
130
+ # name="Segment repartition",
131
+ # textposition='inside',
132
+ # texttemplate = "%{percent:.0%}",
133
+ # textfont_size=14
134
+ # ),
135
+ # scatters[k]
136
+ # ],
137
+ # traces=[0, 1, 2, 3]
138
+ # ) for k in range(number_frames)]
139
+
140
+ # updatemenus = [dict(type='buttons',
141
+ # buttons=[dict(
142
+ # label='Play',
143
+ # method='animate',
144
+ # args=[
145
+ # [f'{k}' for k in range(number_frames)],
146
+ # dict(
147
+ # frame=dict(duration=500, redraw=False),
148
+ # transition=dict(duration=0),
149
+ # # easing='linear',
150
+ # # fromcurrent=True,
151
+ # # mode='immediate'
152
+ # )
153
+ # ])
154
+ # ],
155
+ # direction= 'left',
156
+ # pad=dict(r= 10, t=85),
157
+ # showactive=True, x= 0.1, y= 0.1, xanchor= 'right', yanchor= 'bottom')
158
+ # ]
159
+
160
+ # sliders = [{'yanchor': 'top',
161
+ # 'xanchor': 'left',
162
+ # 'currentvalue': {'font': {'size': 16}, 'prefix': 'Frame: ', 'visible': False, 'xanchor': 'right'},
163
+ # 'transition': {'duration': 500.0, 'easing': 'linear'},
164
+ # 'pad': {'b': 10, 't': 50},
165
+ # 'len': 0.9, 'x': 0.1, 'y': 0,
166
+ # 'steps': [{'args': [[k], {'frame': {'duration': 500.0, 'easing': 'linear', 'redraw': False},
167
+ # 'transition': {'duration': 0, 'easing': 'linear'}}],
168
+ # 'label': months[k], 'method': 'animate'} for k in range(number_frames)
169
+ # ]}]
170
+
171
+
172
+ # fig.update(frames=frames,
173
+ # layout={
174
+ # "xaxis1": {
175
+ # "autorange":True,
176
+ # 'showgrid': False,
177
+ # 'zeroline': False, # thick line at x=0
178
+ # 'visible': False, # numbers below
179
+ # },
180
+
181
+ # "yaxis1": {
182
+ # "autorange":True,
183
+ # 'showgrid': False,
184
+ # 'zeroline': False,
185
+ # 'visible': False,},
186
+
187
+ # "xaxis2": {
188
+ # "autorange":True,
189
+ # 'showgrid': False,
190
+ # 'zeroline': False,
191
+ # 'visible': False,
192
+ # },
193
+
194
+ # "yaxis2": {
195
+ # "autorange":True,
196
+ # 'showgrid': False,
197
+ # 'zeroline': False,
198
+ # 'visible': False,},
199
+
200
+
201
+ # "xaxis4": {
202
+ # "ticktext": months,
203
+ # "tickvals": months,
204
+ # "tickangle": 90,
205
+ # },
206
+ # "yaxis4": {
207
+ # 'range': [min(scores)*0.9, max(scores)* 1.1],
208
+ # 'showgrid': False,
209
+ # 'zeroline': False,
210
+ # 'visible': True
211
+ # },
212
+ # })
213
+ # fig.update_layout(
214
+ # updatemenus=updatemenus,
215
+ # sliders=sliders,
216
+ # # legend=dict(
217
+ # # yanchor= 'bottom',
218
+ # # xanchor= 'center',
219
+ # # orientation="h"),
220
+
221
+ # )
222
+ # Scores
223
+ fig = make_subplots(
224
+ rows=1, cols=4,
225
+ specs=[[{"type": "image"},{"type": "image"}, {"type": "pie"}, {"type": "scatter"}]],
226
+ subplot_titles=("Localisation visualization", "Labeled visualisation", "Segments repartition", "Biodiversity scores")
227
+ )
228
+
229
+ fig2 = px.imshow(np.array(imgs), animation_frame=0, binary_string=True)
230
+ fig3 = px.imshow(np.array(imgs_label), animation_frame=0, binary_string=True)
231
+ pie_charts = [go.Pie(labels = class_names,
232
+ values = [nb_values[k][key] for key in mapping_class.keys()],
233
+ marker_colors = colors,
234
+ name="Segment repartition",
235
+ textposition='inside',
236
+ texttemplate = "%{percent:.0%}",
237
+ textfont_size=14,
238
+ )
239
+ for k in range(len(scores))]
240
+ scatters = [go.Scatter(
241
+ x=months[:i+1],
242
+ y=scores[:i+1],
243
+ mode="lines+markers+text",
244
+ marker_color="black",
245
+ text = [f"{score:.4f}" for score in scores[:i+1]],
246
+ textposition="top center",
247
+ ) for i in range(len(scores))]
248
+
249
+ fig.add_trace(fig2["frames"][0]["data"][0], row=1, col=1)
250
+ fig.add_trace(fig3["frames"][0]["data"][0], row=1, col=2)
251
+ fig.add_trace(pie_charts[0], row=1, col=3)
252
+ fig.add_trace(scatters[0], row=1, col=4)
253
+
254
+ start_date = datetime.datetime.strptime(months[0], "%Y-%m-%d") - relativedelta(months=1)
255
+ end_date = datetime.datetime.strptime(months[-1], "%Y-%m-%d") + relativedelta(months=1)
256
+ interval = [start_date.strftime("%Y-%m-%d"),end_date.strftime("%Y-%m-%d")]
257
+ fig.update_layout({
258
+ "xaxis": {
259
+ "autorange":True,
260
+ 'showgrid': False,
261
+ 'zeroline': False, # thick line at x=0
262
+ 'visible': False, # numbers below
263
+ },
264
+
265
+ "yaxis": {
266
+ "autorange":True,
267
+ 'showgrid': False,
268
+ 'zeroline': False,
269
+ 'visible': False,},
270
+
271
+ "xaxis1": {
272
+ "range":[0,imgs[0].shape[1]],
273
+ 'showgrid': False,
274
+ 'zeroline': False,
275
+ 'visible': False,
276
+ },
277
+
278
+ "yaxis1": {
279
+ "range":[imgs[0].shape[0],0],
280
+ 'showgrid': False,
281
+ 'zeroline': False,
282
+ 'visible': False,},
283
+
284
+
285
+ "xaxis3": {
286
+ "dtick":"M3",
287
+ "range":interval
288
+ },
289
+ "yaxis3": {
290
+ 'range': [min(scores)*0.9, max(scores)* 1.1],
291
+ 'showgrid': False,
292
+ 'zeroline': False,
293
+ 'visible': True
294
+ }}
295
+ )
296
+
297
+ frames = [dict(
298
+ name = k,
299
+ data = [ fig2["frames"][k]["data"][0],
300
+ fig3["frames"][k]["data"][0],
301
+ pie_charts[k],
302
+ scatters[k]
303
+ ],
304
+
305
+ traces=[0,1,2,3]
306
+ ) for k in range(len(scores))]
307
+
308
+
309
+ updatemenus = [dict(type='buttons',
310
+ buttons=[dict(label='Play',
311
+ method='animate',
312
+ args=[
313
+ [f'{k}' for k in range(len(scores))],
314
+ dict(
315
+ frame=dict(duration=500, redraw=False),
316
+ transition=dict(duration=0),
317
+ # easing='linear',
318
+ # fromcurrent=True,
319
+ # mode='immediate'
320
+ )
321
+ ]
322
+
323
+ )],
324
+ direction= 'left',
325
+ pad=dict(r= 10, t=85),
326
+ showactive =True, x= 0.1, y= 0, xanchor= 'right', yanchor= 'top')
327
+ ]
328
+
329
+ sliders = [{'yanchor': 'top',
330
+ 'xanchor': 'left',
331
+ 'currentvalue': {
332
+ 'font': {'size': 16},
333
+ 'visible': True,
334
+ 'xanchor': 'right'},
335
+ 'transition': {
336
+ 'duration': 500.0,
337
+ 'easing': 'linear'},
338
+ 'pad': {'b': 10, 't': 50},
339
+ 'len': 0.9, 'x': 0.1, 'y': 0,
340
+ 'steps': [{'args': [None, {'frame': {'duration': 500.0,'redraw': False},
341
+ 'transition': {'duration': 0}}],
342
+ 'label': k, 'method': 'animate'} for k in range(len(scores))
343
+ ]
344
+ }]
345
+
346
+ fig.update_layout(updatemenus=updatemenus,
347
+ sliders=sliders,
348
+ )
349
+ fig.update(frames=frames)
350
+ return fig
351
+
352
+
353
+ def transform_to_pil(output, alpha=0.3):
354
+ # Transform img with torch
355
+ img = torch.moveaxis(prep_for_plot(output['img']),-1,0)
356
+ img=T.ToPILImage()(img)
357
+
358
+ cmaplist = np.array([np.array(cmap(i)) for i in range(cmap.N)])
359
+ labels = np.array(output['linear_preds'])-1
360
+ label = T.ToPILImage()((cmaplist[labels]*255).astype(np.uint8))
361
+
362
+ # Overlay labels with img wit alpha
363
+ background = img.convert("RGBA")
364
+ overlay = label.convert("RGBA")
365
+
366
+ labeled_img = Image.blend(background, overlay, alpha)
367
+
368
+ return img, label, labeled_img
369
+
370
+
371
+ def prep_for_plot(img, rescale=True, resize=None):
372
+ if resize is not None:
373
+ img = F.interpolate(img.unsqueeze(0), resize, mode="bilinear")
374
+ else:
375
+ img = img.unsqueeze(0)
376
+
377
+ plot_img = unnorm(img).squeeze(0).cpu().permute(1, 2, 0)
378
+ if rescale:
379
+ plot_img = (plot_img - plot_img.min()) / (plot_img.max() - plot_img.min())
380
+ return plot_img
381
+
382
+
383
+ def add_plot(writer, name, step):
384
+ buf = io.BytesIO()
385
+ plt.savefig(buf, format='jpeg', dpi=100)
386
+ buf.seek(0)
387
+ image = Image.open(buf)
388
+ image = T.ToTensor()(image)
389
+ writer.add_image(name, image, step)
390
+ plt.clf()
391
+ plt.close()
392
+
393
+
394
+ @torch.jit.script
395
+ def shuffle(x):
396
+ return x[torch.randperm(x.shape[0])]
397
+
398
+
399
+ def add_hparams_fixed(writer, hparam_dict, metric_dict, global_step):
400
+ exp, ssi, sei = hparams(hparam_dict, metric_dict)
401
+ writer.file_writer.add_summary(exp)
402
+ writer.file_writer.add_summary(ssi)
403
+ writer.file_writer.add_summary(sei)
404
+ for k, v in metric_dict.items():
405
+ writer.add_scalar(k, v, global_step)
406
+
407
+
408
+ @torch.jit.script
409
+ def resize(classes: torch.Tensor, size: int):
410
+ return F.interpolate(classes, (size, size), mode="bilinear", align_corners=False)
411
+
412
+
413
+ def one_hot_feats(labels, n_classes):
414
+ return F.one_hot(labels, n_classes).permute(0, 3, 1, 2).to(torch.float32)
415
+
416
+
417
+ def load_model(model_type, data_dir):
418
+ if model_type == "robust_resnet50":
419
+ model = models.resnet50(pretrained=False)
420
+ model_file = join(data_dir, 'imagenet_l2_3_0.pt')
421
+ if not os.path.exists(model_file):
422
+ wget.download("http://6.869.csail.mit.edu/fa19/psets19/pset6/imagenet_l2_3_0.pt",
423
+ model_file)
424
+ model_weights = torch.load(model_file)
425
+ model_weights_modified = {name.split('model.')[1]: value for name, value in model_weights['model'].items() if
426
+ 'model' in name}
427
+ model.load_state_dict(model_weights_modified)
428
+ model = nn.Sequential(*list(model.children())[:-1])
429
+ elif model_type == "densecl":
430
+ model = models.resnet50(pretrained=False)
431
+ model_file = join(data_dir, 'densecl_r50_coco_1600ep.pth')
432
+ if not os.path.exists(model_file):
433
+ wget.download("https://cloudstor.aarnet.edu.au/plus/s/3GapXiWuVAzdKwJ/download",
434
+ model_file)
435
+ model_weights = torch.load(model_file)
436
+ # model_weights_modified = {name.split('model.')[1]: value for name, value in model_weights['model'].items() if
437
+ # 'model' in name}
438
+ model.load_state_dict(model_weights['state_dict'], strict=False)
439
+ model = nn.Sequential(*list(model.children())[:-1])
440
+ elif model_type == "resnet50":
441
+ model = models.resnet50(pretrained=True)
442
+ model = nn.Sequential(*list(model.children())[:-1])
443
+ elif model_type == "mocov2":
444
+ model = models.resnet50(pretrained=False)
445
+ model_file = join(data_dir, 'moco_v2_800ep_pretrain.pth.tar')
446
+ if not os.path.exists(model_file):
447
+ wget.download("https://dl.fbaipublicfiles.com/moco/moco_checkpoints/"
448
+ "moco_v2_800ep/moco_v2_800ep_pretrain.pth.tar", model_file)
449
+ checkpoint = torch.load(model_file)
450
+ # rename moco pre-trained keys
451
+ state_dict = checkpoint['state_dict']
452
+ for k in list(state_dict.keys()):
453
+ # retain only encoder_q up to before the embedding layer
454
+ if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
455
+ # remove prefix
456
+ state_dict[k[len("module.encoder_q."):]] = state_dict[k]
457
+ # delete renamed or unused k
458
+ del state_dict[k]
459
+ msg = model.load_state_dict(state_dict, strict=False)
460
+ assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
461
+ model = nn.Sequential(*list(model.children())[:-1])
462
+ elif model_type == "densenet121":
463
+ model = models.densenet121(pretrained=True)
464
+ model = nn.Sequential(*list(model.children())[:-1] + [nn.AdaptiveAvgPool2d((1, 1))])
465
+ elif model_type == "vgg11":
466
+ model = models.vgg11(pretrained=True)
467
+ model = nn.Sequential(*list(model.children())[:-1] + [nn.AdaptiveAvgPool2d((1, 1))])
468
+ else:
469
+ raise ValueError("No model: {} found".format(model_type))
470
+
471
+ model.eval()
472
+ model.cuda()
473
+ return model
474
+
475
+
476
+ class UnNormalize(object):
477
+ def __init__(self, mean, std):
478
+ self.mean = mean
479
+ self.std = std
480
+
481
+ def __call__(self, image):
482
+ image2 = torch.clone(image)
483
+ for t, m, s in zip(image2, self.mean, self.std):
484
+ t.mul_(s).add_(m)
485
+ return image2
486
+
487
+
488
+ normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
489
+ unnorm = UnNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
490
+
491
+
492
+ class ToTargetTensor(object):
493
+ def __call__(self, target):
494
+ return torch.as_tensor(np.array(target), dtype=torch.int64).unsqueeze(0)
495
+
496
+
497
+ def prep_args():
498
+ import sys
499
+
500
+ old_args = sys.argv
501
+ new_args = [old_args.pop(0)]
502
+ while len(old_args) > 0:
503
+ arg = old_args.pop(0)
504
+ if len(arg.split("=")) == 2:
505
+ new_args.append(arg)
506
+ elif arg.startswith("--"):
507
+ new_args.append(arg[2:] + "=" + old_args.pop(0))
508
+ else:
509
+ raise ValueError("Unexpected arg style {}".format(arg))
510
+ sys.argv = new_args
511
+
512
+
513
+ def get_transform(res, is_label, crop_type):
514
+ if crop_type == "center":
515
+ cropper = T.CenterCrop(res)
516
+ elif crop_type == "random":
517
+ cropper = T.RandomCrop(res)
518
+ elif crop_type is None:
519
+ cropper = T.Lambda(lambda x: x)
520
+ res = (res, res)
521
+ else:
522
+ raise ValueError("Unknown Cropper {}".format(crop_type))
523
+ if is_label:
524
+ return T.Compose([T.Resize(res, Image.NEAREST),
525
+ cropper,
526
+ ToTargetTensor()])
527
+ else:
528
+ return T.Compose([T.Resize(res, Image.NEAREST),
529
+ cropper,
530
+ T.ToTensor(),
531
+ normalize])
532
+
533
+
534
+ def _remove_axes(ax):
535
+ ax.xaxis.set_major_formatter(plt.NullFormatter())
536
+ ax.yaxis.set_major_formatter(plt.NullFormatter())
537
+ ax.set_xticks([])
538
+ ax.set_yticks([])
539
+
540
+
541
+ def remove_axes(axes):
542
+ if len(axes.shape) == 2:
543
+ for ax1 in axes:
544
+ for ax in ax1:
545
+ _remove_axes(ax)
546
+ else:
547
+ for ax in axes:
548
+ _remove_axes(ax)
549
+
550
+
551
+ class UnsupervisedMetrics(Metric):
552
+ def __init__(self, prefix: str, n_classes: int, extra_clusters: int, compute_hungarian: bool,
553
+ dist_sync_on_step=True):
554
+ # call `self.add_state`for every internal state that is needed for the metrics computations
555
+ # dist_reduce_fx indicates the function that should be used to reduce
556
+ # state from multiple processes
557
+ super().__init__(dist_sync_on_step=dist_sync_on_step)
558
+
559
+ self.n_classes = n_classes
560
+ self.extra_clusters = extra_clusters
561
+ self.compute_hungarian = compute_hungarian
562
+ self.prefix = prefix
563
+ self.add_state("stats",
564
+ default=torch.zeros(n_classes + self.extra_clusters, n_classes, dtype=torch.int64),
565
+ dist_reduce_fx="sum")
566
+
567
+ def update(self, preds: torch.Tensor, target: torch.Tensor):
568
+ with torch.no_grad():
569
+ actual = target.reshape(-1)
570
+ preds = preds.reshape(-1)
571
+ mask = (actual >= 0) & (actual < self.n_classes) & (preds >= 0) & (preds < self.n_classes)
572
+ actual = actual[mask]
573
+ preds = preds[mask]
574
+ self.stats += torch.bincount(
575
+ (self.n_classes + self.extra_clusters) * actual + preds,
576
+ minlength=self.n_classes * (self.n_classes + self.extra_clusters)) \
577
+ .reshape(self.n_classes, self.n_classes + self.extra_clusters).t().to(self.stats.device)
578
+
579
+ def map_clusters(self, clusters):
580
+ if self.extra_clusters == 0:
581
+ return torch.tensor(self.assignments[1])[clusters]
582
+ else:
583
+ missing = sorted(list(set(range(self.n_classes + self.extra_clusters)) - set(self.assignments[0])))
584
+ cluster_to_class = self.assignments[1]
585
+ for missing_entry in missing:
586
+ if missing_entry == cluster_to_class.shape[0]:
587
+ cluster_to_class = np.append(cluster_to_class, -1)
588
+ else:
589
+ cluster_to_class = np.insert(cluster_to_class, missing_entry + 1, -1)
590
+ cluster_to_class = torch.tensor(cluster_to_class)
591
+ return cluster_to_class[clusters]
592
+
593
+ def compute(self):
594
+ if self.compute_hungarian:
595
+ self.assignments = linear_sum_assignment(self.stats.detach().cpu(), maximize=True)
596
+ # print(self.assignments)
597
+ if self.extra_clusters == 0:
598
+ self.histogram = self.stats[np.argsort(self.assignments[1]), :]
599
+ if self.extra_clusters > 0:
600
+ self.assignments_t = linear_sum_assignment(self.stats.detach().cpu().t(), maximize=True)
601
+ histogram = self.stats[self.assignments_t[1], :]
602
+ missing = list(set(range(self.n_classes + self.extra_clusters)) - set(self.assignments[0]))
603
+ new_row = self.stats[missing, :].sum(0, keepdim=True)
604
+ histogram = torch.cat([histogram, new_row], axis=0)
605
+ new_col = torch.zeros(self.n_classes + 1, 1, device=histogram.device)
606
+ self.histogram = torch.cat([histogram, new_col], axis=1)
607
+ else:
608
+ self.assignments = (torch.arange(self.n_classes).unsqueeze(1),
609
+ torch.arange(self.n_classes).unsqueeze(1))
610
+ self.histogram = self.stats
611
+
612
+ tp = torch.diag(self.histogram)
613
+ fp = torch.sum(self.histogram, dim=0) - tp
614
+ fn = torch.sum(self.histogram, dim=1) - tp
615
+
616
+ iou = tp / (tp + fp + fn)
617
+ prc = tp / (tp + fn)
618
+ opc = torch.sum(tp) / torch.sum(self.histogram)
619
+
620
+ metric_dict = {self.prefix + "mIoU": iou[~torch.isnan(iou)].mean().item(),
621
+ self.prefix + "Accuracy": opc.item()}
622
+ return {k: 100 * v for k, v in metric_dict.items()}
623
+
624
+
625
+ def flexible_collate(batch):
626
+ r"""Puts each data field into a tensor with outer dimension batch size"""
627
+
628
+ elem = batch[0]
629
+ elem_type = type(elem)
630
+ if isinstance(elem, torch.Tensor):
631
+ out = None
632
+ if torch.utils.data.get_worker_info() is not None:
633
+ # If we're in a background process, concatenate directly into a
634
+ # shared memory tensor to avoid an extra copy
635
+ numel = sum([x.numel() for x in batch])
636
+ storage = elem.storage()._new_shared(numel)
637
+ out = elem.new(storage)
638
+ try:
639
+ return torch.stack(batch, 0, out=out)
640
+ except RuntimeError:
641
+ return batch
642
+ elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
643
+ and elem_type.__name__ != 'string_':
644
+ if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
645
+ # array of string classes and object
646
+ if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
647
+ raise TypeError(default_collate_err_msg_format.format(elem.dtype))
648
+
649
+ return flexible_collate([torch.as_tensor(b) for b in batch])
650
+ elif elem.shape == (): # scalars
651
+ return torch.as_tensor(batch)
652
+ elif isinstance(elem, float):
653
+ return torch.tensor(batch, dtype=torch.float64)
654
+ elif isinstance(elem, int):
655
+ return torch.tensor(batch)
656
+ elif isinstance(elem, string_classes):
657
+ return batch
658
+ elif isinstance(elem, collections.abc.Mapping):
659
+ return {key: flexible_collate([d[key] for d in batch]) for key in elem}
660
+ elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
661
+ return elem_type(*(flexible_collate(samples) for samples in zip(*batch)))
662
+ elif isinstance(elem, collections.abc.Sequence):
663
+ # check to make sure that the elements in batch have consistent size
664
+ it = iter(batch)
665
+ elem_size = len(next(it))
666
+ if not all(len(elem) == elem_size for elem in it):
667
+ raise RuntimeError('each element in list of batch should be of equal size')
668
+ transposed = zip(*batch)
669
+ return [flexible_collate(samples) for samples in transposed]
670
+
671
+ raise TypeError(default_collate_err_msg_format.format(elem_type))
672
+
673
+
674
+ if __name__ == "__main__":
675
+ fig = plot_imgs_labels(months, imgs, imgs_label, nb_values, scores)
biomap/utils.py CHANGED
@@ -3,6 +3,9 @@ import os
3
  from os.path import join
4
  import io
5
 
 
 
 
6
  import matplotlib.pyplot as plt
7
  import numpy as np
8
  import torch.multiprocessing
@@ -79,8 +82,73 @@ def compute_biodiv_score(class_image):
79
  }
80
  return score, score_details
81
 
82
- def plot_imgs_labels(months, imgs, imgs_label, nb_values, scores) :
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  fig2 = px.imshow(np.array(imgs), animation_frame=0, binary_string=True)
85
  fig3 = px.imshow(np.array(imgs_label), animation_frame=0, binary_string=True)
86
 
@@ -91,12 +159,10 @@ def plot_imgs_labels(months, imgs, imgs_label, nb_values, scores) :
91
  y=scores[:i+1],
92
  mode="lines+markers+text",
93
  marker_color="black",
94
- text = [f"{score:.4f}" for score in scores[:i+1]],
95
  textposition="top center"
96
  ) for i in range(len(scores))
97
  ]
98
- # scatters = [go.Scatter(y=scores[:i], mode="lines+markers+text", marker_color="black", text = scores[:i], textposition="top center") for i in range(len(scores))]
99
-
100
 
101
  # Scores
102
  fig = make_subplots(
@@ -152,7 +218,7 @@ def plot_imgs_labels(months, imgs, imgs_label, nb_values, scores) :
152
  mode='immediate'
153
  )])],
154
  direction= 'left',
155
- pad=dict(r= 10, t=85),
156
  showactive =True, x= 0.1, y= 0.13, xanchor= 'right', yanchor= 'top')
157
  ]
158
 
@@ -174,17 +240,34 @@ def plot_imgs_labels(months, imgs, imgs_label, nb_values, scores) :
174
  fr.update(
175
  layout={
176
  "xaxis": {
177
- "range": [0,imgs[0].shape[1]+i/100000]
 
 
 
178
  },
179
  "yaxis": {
180
- "range": [imgs[0].shape[0]+i/100000,0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  },
182
  })
183
-
184
- fr.update(layout_title_text= months[i])
185
 
186
-
187
- fig.update(layout_title_text= months[0])
 
188
  fig.update(
189
  layout={
190
  "xaxis": {
@@ -215,20 +298,14 @@ def plot_imgs_labels(months, imgs, imgs_label, nb_values, scores) :
215
 
216
 
217
  "xaxis3": {
218
- "tickmode": "array",
219
- "ticktext": months,
220
- "tickvals": months,
221
- "range": [0,len(months)]
222
- # 'showgrid': False, # thin lines in the background
223
- # 'zeroline': False, # thick line at y=0
224
- # 'visible': True,
225
  },
226
  "yaxis3": {
227
- "range": [min(scores) * 0.9,max(scores) * 1.1],
228
- 'autorange': False,
229
- 'showgrid': False, # thin lines in the background
230
- 'zeroline': False, # thick line at y=0
231
- 'visible': True # thin lines in the background
232
  }
233
  }
234
  )
@@ -237,13 +314,14 @@ def plot_imgs_labels(months, imgs, imgs_label, nb_values, scores) :
237
  fig.update_layout(updatemenus=updatemenus,
238
  sliders=sliders,
239
  legend=dict(
240
- yanchor= 'top',
241
- xanchor= 'left',
242
- orientation="h")
 
 
243
  )
244
 
245
 
246
-
247
  fig.update_layout(margin=dict(b=0, r=0))
248
  return fig
249
 
 
3
  from os.path import join
4
  import io
5
 
6
+ import datetime
7
+
8
+ from dateutil.relativedelta import relativedelta
9
  import matplotlib.pyplot as plt
10
  import numpy as np
11
  import torch.multiprocessing
 
82
  }
83
  return score, score_details
84
 
85
+ def plot_image(months, imgs, imgs_label, nb_values, scores):
86
+ fig2 = px.imshow(np.array(imgs), animation_frame=0, binary_string=True)
87
+ fig3 = px.imshow(np.array(imgs_label), animation_frame=0, binary_string=True)
88
+
89
+ # Scores
90
+ fig = make_subplots(
91
+ rows=1, cols=4,
92
+ specs=[[{"type": "image"},{"type": "image"}, {"type": "pie"}, {"type": "indicator"}]],
93
+ subplot_titles=("Localisation visualization", "Labeled visualisation", "Segments repartition", "Biodiversity scores")
94
+ )
95
+
96
+ fig.add_trace(fig2["frames"][0]["data"][0], row=1, col=1)
97
+ fig.add_trace(fig3["frames"][0]["data"][0], row=1, col=2)
98
+
99
+ fig.add_trace(go.Pie(labels = class_names,
100
+ values = [nb_values[0][key] for key in mapping_class.keys()],
101
+ marker_colors = colors,
102
+ name="Segment repartition",
103
+ textposition='inside',
104
+ texttemplate = "%{percent:.0%}",
105
+ textfont_size=14
106
+ ),
107
+ row=1, col=3)
108
+
109
+
110
+ fig.add_trace(go.Indicator(value=scores[0]), row=1, col=4)
111
+ fig.update_layout(
112
+ legend=dict(
113
+ xanchor = "center",
114
+ yanchor="top",
115
+ y=-0.1,
116
+ x = 0.5,
117
+ orientation="h")
118
+ )
119
+ fig.update(
120
+ layout={
121
+ "xaxis": {
122
+ "range": [0,imgs[0].shape[1]+1/100000],
123
+ 'showgrid': False, # thin lines in the background
124
+ 'zeroline': False, # thick line at x=0
125
+ 'visible': False, # numbers below
126
+ },
127
 
128
+ "yaxis": {
129
+ "range": [imgs[0].shape[0]+1/100000,0],
130
+ 'showgrid': False, # thin lines in the background
131
+ 'zeroline': False, # thick line at y=0
132
+ 'visible': False,},
133
+ "xaxis1": {
134
+ "range": [0,imgs[0].shape[1]+1/100000],
135
+ 'showgrid': False, # thin lines in the background
136
+ 'zeroline': False, # thick line at x=0
137
+ 'visible': False, # numbers below
138
+ },
139
+
140
+ "yaxis1": {
141
+ "range": [imgs[0].shape[0]+1/100000,0],
142
+ 'showgrid': False, # thin lines in the background
143
+ 'zeroline': False, # thick line at y=0
144
+ 'visible': False,}
145
+
146
+ },)
147
+ fig.update_xaxes(row=1, col=2, visible=False)
148
+ fig.update_yaxes(row=1, col=2, visible=False)
149
+ return fig
150
+
151
+ def plot_imgs_labels(months, imgs, imgs_label, nb_values, scores) :
152
  fig2 = px.imshow(np.array(imgs), animation_frame=0, binary_string=True)
153
  fig3 = px.imshow(np.array(imgs_label), animation_frame=0, binary_string=True)
154
 
 
159
  y=scores[:i+1],
160
  mode="lines+markers+text",
161
  marker_color="black",
162
+ text = [f"{score:.2f}" for score in scores[:i+1]],
163
  textposition="top center"
164
  ) for i in range(len(scores))
165
  ]
 
 
166
 
167
  # Scores
168
  fig = make_subplots(
 
218
  mode='immediate'
219
  )])],
220
  direction= 'left',
221
+ pad=dict(t=85),
222
  showactive =True, x= 0.1, y= 0.13, xanchor= 'right', yanchor= 'top')
223
  ]
224
 
 
240
  fr.update(
241
  layout={
242
  "xaxis": {
243
+ "range": [0,imgs[0].shape[1]+i/100000],
244
+ 'showgrid': False, # thin lines in the background
245
+ 'zeroline': False, # thick line at x=0
246
+ 'visible': False, # numbers below
247
  },
248
  "yaxis": {
249
+ "range": [imgs[0].shape[0]+i/100000,0],
250
+ 'showgrid': False, # thin lines in the background
251
+ 'zeroline': False, # thick line at x=0
252
+ 'visible': False, # numbers below
253
+ },
254
+ "xaxis1": {
255
+ "range": [0,imgs[0].shape[1]+i/100000],
256
+ 'showgrid': False, # thin lines in the background
257
+ 'zeroline': False, # thick line at x=0
258
+ 'visible': False, # numbers below
259
+ },
260
+ "yaxis1": {
261
+ "range": [imgs[0].shape[0]+i/100000,0],
262
+ 'showgrid': False, # thin lines in the background
263
+ 'zeroline': False, # thick line at x=0
264
+ 'visible': False, # numbers below
265
  },
266
  })
 
 
267
 
268
+ start_date = datetime.datetime.strptime(months[0], "%Y-%m-%d") - relativedelta(months=1)
269
+ end_date = datetime.datetime.strptime(months[-1], "%Y-%m-%d") + relativedelta(months=1)
270
+ interval = [start_date.strftime("%Y-%m-%d"),end_date.strftime("%Y-%m-%d")]
271
  fig.update(
272
  layout={
273
  "xaxis": {
 
298
 
299
 
300
  "xaxis3": {
301
+ "dtick":"M3",
302
+ "range":interval
 
 
 
 
 
303
  },
304
  "yaxis3": {
305
+ 'range': [min(scores)*0.9, max(scores)* 1.1],
306
+ 'showgrid': False,
307
+ 'zeroline': False,
308
+ 'visible': True
 
309
  }
310
  }
311
  )
 
314
  fig.update_layout(updatemenus=updatemenus,
315
  sliders=sliders,
316
  legend=dict(
317
+ xanchor = "center",
318
+ yanchor="top",
319
+ y=-0.1,
320
+ x = 0.5,
321
+ orientation="h")
322
  )
323
 
324
 
 
325
  fig.update_layout(margin=dict(b=0, r=0))
326
  return fig
327
 
biomap/utils_gee.py CHANGED
@@ -12,9 +12,10 @@ service_account = '[email protected]'
12
  credentials = ee.ServiceAccountCredentials(service_account, os.path.join(os.path.dirname(__file__), '.private-key.json'))
13
  ee.Initialize(credentials)
14
 
15
- def get_image(location, d1, d2):
16
  logging.info(f"getting image for {d1} to {d2} at location {location}")
17
  img = extract_img(location, d1, d2)
 
18
  img_test = transform_ee_img(
19
  img, max=0.3
20
  )
@@ -125,7 +126,6 @@ def extract_np_from_url(url):
125
  temp1.append(temp2)
126
 
127
  data = np.array(temp1)
128
-
129
  return data
130
 
131
  #Fonction globale
@@ -145,7 +145,9 @@ def extract_img(location,start_date,end_date, width = 0.01 , len = 0.01,scale=5)
145
  """
146
  ee_img, geometry = extract_ee_img(location, width,start_date,end_date , len)
147
  url = get_url(ee_img, geometry, scale)
 
148
  img = extract_np_from_url(url)
 
149
 
150
  return img
151
 
 
12
  credentials = ee.ServiceAccountCredentials(service_account, os.path.join(os.path.dirname(__file__), '.private-key.json'))
13
  ee.Initialize(credentials)
14
 
15
+ def get_url(location, d1, d2):
16
  logging.info(f"getting image for {d1} to {d2} at location {location}")
17
  img = extract_img(location, d1, d2)
18
+
19
  img_test = transform_ee_img(
20
  img, max=0.3
21
  )
 
126
  temp1.append(temp2)
127
 
128
  data = np.array(temp1)
 
129
  return data
130
 
131
  #Fonction globale
 
145
  """
146
  ee_img, geometry = extract_ee_img(location, width,start_date,end_date , len)
147
  url = get_url(ee_img, geometry, scale)
148
+ logging.info(f"got url image for {start_date} to {end_date}")
149
  img = extract_np_from_url(url)
150
+ logging.info(f"Downloaded image for {start_date} to {end_date}")
151
 
152
  return img
153