|
import math |
|
|
|
import pandas as pd |
|
|
|
import gradio as gr |
|
import datetime |
|
import numpy as np |
|
|
|
from dgl.data import YelpDataset |
|
|
|
import dgl |
|
import torch as th |
|
|
|
from dgl.dataloading import LaborSampler, NeighborSampler |
|
|
|
data = YelpDataset() |
|
|
|
|
|
device = 'cpu' |
|
|
|
g = data[0].to(device) |
|
|
|
num_layers = 3 |
|
|
|
fanouts = [10] * num_layers |
|
|
|
samplers = [LaborSampler(fanouts, importance_sampling=1), LaborSampler(fanouts, importance_sampling=0), NeighborSampler(fanouts)] |
|
|
|
names = ['LABOR-1', 'LABOR-0', 'NS'] |
|
|
|
indices = th.arange(g.num_nodes()).to(device) |
|
|
|
batch_size=1024 |
|
|
|
loaders = [dgl.dataloading.DataLoader(g, indices, sampler, batch_size=batch_size, shuffle=True, drop_last=True) for sampler in samplers] |
|
|
|
def get_time(): |
|
return datetime.datetime.now() |
|
|
|
plot_end = 2 * math.pi |
|
|
|
|
|
def get_plot2(period=1): |
|
global plot_end |
|
x = np.arange(plot_end - 2 * math.pi, plot_end, 0.02) |
|
y = np.sin(2 * math.pi * period * x) |
|
update = gr.LinePlot.update( |
|
value=pd.DataFrame({"x": x, "y": y}), |
|
x="x", |
|
y="y", |
|
title="Plot (updates every second)", |
|
width=600, |
|
height=350, |
|
) |
|
plot_end += 2 * math.pi |
|
if plot_end > 1000: |
|
plot_end = 2 * math.pi |
|
return update |
|
|
|
results = [] |
|
|
|
def get_plot(batch_size=1024): |
|
for sampled in zip(*loaders): |
|
results.append([s[0].shape for s in sampled]) |
|
break |
|
|
|
t = th.tensor(results) |
|
|
|
x = "sampler" |
|
y = "# vertices" |
|
|
|
d = {x: [], y: []} |
|
|
|
for i, name in enumerate(names): |
|
yy = t[:, i] |
|
d[y] += yy.tolist() |
|
d[x] += [name] * yy.shape[0] |
|
|
|
update = gr.BarPlot.update( |
|
value=pd.DataFrame(d), |
|
x=x, |
|
y=y, |
|
title="Number of sampled vertices", |
|
width=600, |
|
height=350 |
|
) |
|
|
|
return update |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
c_time2 = gr.Textbox(label="Current Time refreshed every second") |
|
gr.Textbox( |
|
"Change the value of the slider to automatically update the plot", |
|
label="", |
|
) |
|
batch_size = gr.Number( |
|
label="batch size", value=1024, show_label=True |
|
) |
|
plot = gr.BarPlot(show_label=False) |
|
with gr.Column(): |
|
name = gr.Textbox(label="Enter your name") |
|
greeting = gr.Textbox(label="Greeting") |
|
button = gr.Button(value="Greet") |
|
button.click(lambda s: f"Hello {s}", name, greeting) |
|
|
|
demo.load(lambda: datetime.datetime.now(), None, c_time2, every=10) |
|
dep = demo.load(get_plot, None, plot, every=10) |
|
batch_size.submit(get_plot, batch_size, plot, every=10, cancels=[dep]) |
|
|
|
if __name__ == "__main__": |
|
demo.queue().launch() |