File size: 5,684 Bytes
5129aaa
 
196b164
 
 
 
5129aaa
 
 
 
196b164
 
 
 
805f17c
196b164
 
 
 
 
 
 
 
d31681b
 
 
 
196b164
 
5129aaa
196b164
 
5129aaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67fe5dd
 
 
 
5129aaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d91013d
5129aaa
d91013d
5129aaa
d91013d
5129aaa
 
 
 
 
 
 
5511c20
5129aaa
 
 
 
196b164
 
 
 
 
 
 
 
 
 
 
5129aaa
 
 
 
 
 
 
 
 
 
 
 
196b164
 
 
 
5129aaa
 
196b164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5129aaa
d31681b
67fe5dd
d31681b
 
67fe5dd
 
 
 
 
d31681b
67fe5dd
 
5129aaa
196b164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5129aaa
5511c20
 
 
 
 
 
 
5129aaa
 
196b164
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
from datasets import load_dataset as _load_dataset
from os import environ
from PIL import Image
import numpy as np
import json

from pyarrow.parquet import ParquetFile
from pyarrow import Table as pa_Table
from datasets import Dataset

DATASET = "satellogic/EarthView"

sets = {
    "satellogic": {
        "shards" : 7863,
    },
    "sentinel_1": {
        "shards" : 1763,
    },
    "neon": {
        "config" : "default",
        "shards" : 607,
        "path"   : "data",
    },
    "sentinel_2": {
        "shards" : 19997,
    },
}

def get_subsets():
    return sets.keys()

def get_nshards(subset):
    return sets[subset]["shards"]

def get_path(subset):
    return sets[subset].get("path", subset)

def get_config(subset):
    return sets[subset].get("config", subset)

def load_dataset(subset, dataset="satellogic/EarthView", split="train", shards = None, streaming=True, **kwargs):
    config = get_config(subset)
    nshards = get_nshards(subset)
    path   = get_path(subset)
    if shards is None:
        data_files = None
    else:
        if subset == "sentinel_2":
            data_files = [f"{path}/sentinel_2-{shard//10}/{split}-{shard % 10:05d}-of-00010.parquet" for shard in shards]
        else:
            data_files = [f"{path}/{split}-{shard:05d}-of-{nshards:05d}.parquet" for shard in shards]
        data_files = {split: data_files}

    ds = _load_dataset(
        path=dataset,
        name=config,
        save_infos=True,
        split=split,
        data_files=data_files,
        streaming=streaming,
        token=environ.get("HF_TOKEN", None),
        **kwargs)

    return ds    

def load_parquet(subset_or_filename, batch_size=100):
    if subset_or_filename in get_subsets():
        filename = f"dataset/{subset_or_filename}/sample.parquet"
    else:
        filename = subset_or_filename

    pqfile = ParquetFile(filename)
    batch  = pqfile.iter_batches(batch_size=batch_size)
    return Dataset(pa_Table.from_batches(batch))

def item_to_images(subset, item):
    """
    Converts the images within an item (arrays), as retrieved from the dataset to proper PIL.Image

    subset: The name of the Subset, one of "satellogic", "neon", "sentinel-1"
    item: The item as retrieved from the subset

    returns the item, with arrays converted to PIL.Image
    """
    metadata = item["metadata"]
    if type(metadata) == str:
        metadata = json.loads(metadata)

    item = {
        k: np.asarray(v).astype("uint8")
            for k,v in item.items()
                if k != "metadata"
    }
    item["metadata"] = metadata
    
    if subset == "satellogic":
        # item["rgb"] = [
        #     Image.fromarray(np.average(image.transpose(1,2,0), 2).astype("uint8"))
        #         for image in item["rgb"]
        # ]
        rgbs = []
        for rgb in item["rgb"]:
            rgbs.append(Image.fromarray(rgb.transpose(1,2,0)))
            # rgbs.append(Image.fromarray(rgb[0,:,:]))      # Red
            # rgbs.append(Image.fromarray(rgb[1,:,:]))      # Green
            # rgbs.append(Image.fromarray(rgb[2,:,:]))      # Blue
        item["rgb"] = rgbs
        item["1m"] = [
            Image.fromarray(image[0,:,:])
                for image in item["1m"]
        ]
        count = len(item["1m"])
    elif subset == "sentinel_1":
        # Mapping of V and H to RGB. May not be correct
        # https://gis.stackexchange.com/questions/400726/creating-composite-rgb-images-from-sentinel-1-channels
        i10m = item["10m"]
        i10m = np.concatenate(
            (   i10m,
                np.expand_dims(
                    i10m[:,0,:,:]/(i10m[:,1,:,:]+0.01)*256,
                    1
                ).astype("uint8")
            ),
            1
        )
        item["10m"] = [
            Image.fromarray(image.transpose(1,2,0))
                for image in i10m
        ]
        count = len(item["10m"])
    elif subset == "sentinel_2":
        for channel in ['10m', '20m', 'rgb', 'scl']: #, '40m']:
            data = item[channel]
            count = len(data)
            data = np.asarray(data).astype("uint8").transpose(0,2,3,1)
            if channel == "20m":
                data = data[:,:,:,[0,2,4]]
            mode = "L" if channel in ["10m", "scl"] else "RGB"
            images = [Image.fromarray(data[i].squeeze(), mode=mode) for i in range(count)]
            item[channel] = images
        for field in ["solarAngles", "tileGeometry", "viewIncidenceAngles"]:
            item["metadata"][field] = [json.loads(s) for s in item["metadata"][field]]
    elif subset == "neon":
        item["rgb"] = [
            Image.fromarray(image.transpose(1,2,0))
                for image in item["rgb"]
        ]
        item["chm"] = [
            Image.fromarray(image[0])
                for image in item["chm"]
        ]

        # The next is a very arbitrary conversion from the 369 hyperspectral data to RGB
        # It just averages each 1/3 of the bads and assigns it to a channel
        item["1m"] = [
            Image.fromarray(
                np.concatenate((
                    np.expand_dims(np.average(image[:124],0),2),
                    np.expand_dims(np.average(image[124:247],0),2),
                    np.expand_dims(np.average(image[247:],0),2))
                ,2).astype("uint8"))
                    for image in item["1m"]
        ]
        count = len(item["rgb"])
        bounds = item["metadata"]["bounds"]

        # swap pairs
        item["metadata"]["bounds"] = [bounds[i+1-l] for i in range(0, len(bounds), 2) for l in range(2)]

        # fix CRS
        item["metadata"]["epsg"] = "EPSG:4326"
    
    item["metadata"]["count"] = count
    return item