from datasets import load_dataset from functools import partial from pandas import DataFrame import earthview as ev import gradio as gr import tqdm import os DEBUG = False # False, "random", "samples" if DEBUG == "random": import numpy as np def open_dataset(dataset, subset, split, batch_size, shard, only_rgb, state): nshards = ev.get_nshards(subset) if shard == -1: shards = None else: shards = [shard] if DEBUG == "random": ds = range(batch_size) elif DEBUG == "samples": ds = ev.load_parquet(subset, batch_size=batch_size) elif not DEBUG: ds = ev.load_dataset(subset, dataset=dataset, split=split, shards=shards, cache_dir="dataset") dsi = iter(ds) state["subset"] = subset state["dsi"] = dsi return ( gr.update(label=f"Shard (max {nshards})", value=shard, maximum=nshards), *get_images(batch_size, only_rgb, state), state ) def get_images(batch_size, only_rgb, state): try: subset = state["subset"] except KeyError: raise gr.Error("You need to load a Dataset first") images = [] metadatas = [] for i in tqdm.trange(batch_size, desc=f"Getting images"): if DEBUG == "random": images.append(np.random.randint(0,255,(384,384,3))) if not only_rgb: images.append(np.random.randint(0,255,(100,100,3))) metadatas.append({"bounds":[[1,1,4,4]], }) else: try: item = next(state["dsi"]) except StopIteration: break metadata = item["metadata"] item = ev.item_to_images(subset, item) if subset == "satellogic": images.extend(item["rgb"]) if not only_rgb: images.extend(item["1m"]) if subset == "sentinel_1": images.extend(item["10m"]) if subset == "neon": images.extend(item["rgb"]) if not only_rgb: images.extend(item["chm"]) images.extend(item["1m"]) metadatas.append(item["metadata"]) return images, DataFrame(metadatas) def update_shape(columns): return gr.update(columns=columns) def new_state(): return gr.State({}) if __name__ == "__main__": with gr.Blocks(title="EarthView Viewer", fill_height = True) as demo: state = new_state() gr.Markdown(f"# Viewer for [{ev.DATASET}](https://huggingface.co/datasets/satellogic/EarthView) Dataset") batch_size = gr.Number(10, label = "Batch Size", render=False) shard = gr.Slider(label="Shard", minimum=0, maximum=10000, step=1, render=False) table = gr.DataFrame(render = False) # headers=["Index","TimeStamp","Bounds","CRS"], gallery = gr.Gallery( label=ev.DATASET, interactive=False, object_fit="scale-down", columns=5, render=False) with gr.Row(): dataset = gr.Textbox(label="Dataset", value=ev.DATASET, interactive=False) subset = gr.Dropdown(choices=ev.get_subsets(), label="Subset", value="satellogic", ) split = gr.Textbox(label="Split", value="train") initial_shard = gr.Number(label = "Initial shard", value=10, info="-1 for whole dataset") only_rgb = gr.Checkbox(label="Only RGB", value=True) gr.Button("Load (minutes)").click( open_dataset, inputs=[dataset, subset, split, batch_size, initial_shard, only_rgb, state], outputs=[shard, gallery, table, state]) gallery.render() with gr.Row(): batch_size.render() columns = gr.Number(5, label="Columns") columns.change(update_shape, [columns], [gallery]) with gr.Row(): shard.render() shard.release( open_dataset, inputs=[dataset, subset, split, batch_size, shard, only_rgb, state], outputs=[shard, gallery, table, state]) btn = gr.Button("Next Batch (same shard)", scale=0) btn.click(get_images, [batch_size, only_rgb, state], [gallery, table]) btn.click() table.render() demo.launch(show_api=False)