ShaderCoder / app.py
Vipitis's picture
fix byte string, (dropdown is broken)
280eea5
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import datasets
import numpy as np
# import torch
from utils.tree_utils import parse_functions, get_docstrings, grab_before_comments, line_chr2char, node_str_idx, replace_function
from utils.html_utils import make_iframe, construct_embed
from utils.generation import combine_generation_kwargs, stream_generation, construct_model_context
PIPE = None
intro_text = """
# Welcome to the interactive shadercoding demo.
This gives you access to a filtered version of the [Shadertoys](https://huggingface.co/datasets/Vipitis/Shadertoys) dataset, only shaders that consist of a single pass are available.
And then lets you use code generation models to make alterations to part of the shadercode.
## How To Use:
1. Load any Model for [`text-generation`](https://huggingface.co/models?pipeline_tag=text-generation) and hit ENTER.
2. Use the slider to sample a shader from the dataset.
- The original shader will be embedding on the left, click on title to get to the source.
- The shadercode will be displayed on the right, this is interactive.
- A preview of the currently displayed shadercode will be displayed on the lower left. (hover to advance time)
3. use the dropdown to select a function to modify.
4. press either button to make modifications to that function
5. you can also edit the code manually.
"""
outro_text ="""
## Models to try (look at [ShaderEval](https://huggingface.co/spaces/Vipitis/ShaderEval) for an indication of how helpful they will be):
- [gpt2](https://huggingface.co/gpt2) baseline for language models, really struggles with shadercode.
- [bigscience/bloom-1b1](https://huggingface.co/bigscience/bloom-1b1) a newer and larger freely available model. Does understand a big of code.
- [codeparrot/codeparrot-small](https://huggingface.co/codeparrot/codeparrot-small) a model trained on code, but not on shadercode. Manages to graps the patterns.
- [salesforce/codegen-2B-multi](https://huggingface.co/salesforce/codegen-2B-multi) a larger model that indicates some potential.
- [bigcode/santacoder](https://huggingface.co/bigcode/santacoder) a model trained on subset of [TheStack](https://huggingface.co/datasets/bigcode/the-stack), struggles with shadercode.
- [Vipitis/santacoder-finetuned-the-stack-glsl](https://huggingface.co/Vipitis/santacoder-finetuned-the-stack-glsl) fine-tuned by me on the glsl subset of [TheStack](https://huggingface.co/datasets/bigcode/the-stack), is an improvement.
- [Vipitis/santacoder-finetuned-Shadertoys](https://huggingface.co/Vipitis/santacoder-finetuned-Shadertoys) fine-tuned by me on whole shaders from [Shadertoys](https://huggingface.co/datasets/Vipitis/Shadertoys). Does overfit quite a bit with greedy decoding.
- [Vipitis/santacoder-finetuned-Shadertoys-fine](https://huggingface.co/Vipitis/santacoder-finetuned-Shadertoys-fine) fine-tuned by me just functions from [Shadertoys-fine](https://huggingface.co/datasets/Vipitis/Shadertoys-fine). Memorizes the exact function about half the time.
- [bigcode/starcoder](https://huggingface.co/bigcode/starcoder) a very large model which I haven't tried yet.
- **any other model you want to**
## TODO (feel free to contribute with a [Pull-Request](https://huggingface.co/Vipitis/santacoder-finetuned-the-stack-glsl/discussions?status=open&type=pull_request)):
- [x] use embedded Shadertoy for reference/attribution (done, but some errors)
- [~] working render implementation on CPU only space (as webgl via webglfundamentals, ccs needs fixing for iframe (or hijack Shadertoy iframe))
- [~] generate variations of return statements [ShaderEval task1](https://huggingface.co/spaces/Vipitis/ShaderEval) (needs to be reworked using the other parts)
- [x] generate whole functions (seems to work quite well)
- [] dropdown for model selection (from curated list or all supported models?)
- [] generation history stating which function and orig/generated returns. (use State ??). do it as comments in the code?
- [~] display errros/issues to the user (raise gr.Error could be one idea, but highlighting in the code would be awesome) currently adds a comment to the code.
- [~] generate whole shaders (via prompts guidance, recursive from errors) - prompt context is in progress.
- [x] accordion with generation parameters (as pipeline_kwargs?) look up starcoder playround and take "inspiration" from there (implemented for both buttons, untested)
- [] support FIM task for better model context
- [x] include some context for prompt (title, comments before a functions) - now takes all comments directly before a function as well as all comments at the beginning inside a function. (misses comments between argument list and body)
- [] gradio examples
- [x] use GPU if available, respect memory restrictions (implemented via accelerate.Accelerator.device in utils.generation.py), tested with A750 successfully!
- [x] stream model generation (maybe in a new window?) - janky solution and only sometimes hangs up
- [] 2nd iFrame needs a lot of fixing (I am not a web developer, need help) BUG:background is white, so colors are wrong. Shadertoy uses black background (or we ignore alpha).
- [] (optional) filtering the dataset by license?
### Notes:
- this is meant as a resource to show code generation for a "creative" task.
- the goal is not to not replace shader artists, but aims to be an assistant instead.
- the space still lacks quite a lot of features, but will continue to evolve.
- this demo can be useful to sannity check evaluation results, where the academic numbers are made.
- If you create a remix with these tools, please attribute the original creator of your starting point when sharing the results. (And perhaps share in the [discussion tab](https://huggingface.co/Vipitis/santacoder-finetuned-the-stack-glsl/discussions?status=open&type=discussion) too)
"""
new_shadertoy_code = """void mainImage( out vec4 fragColor, in vec2 fragCoord )
{
// touch the slider to load a shader from the dataset or start coding from here.
vec2 uv = fragCoord/iResolution.xy;
vec3 col = 0.5 + 0.5*cos(iTime+uv.xyx+vec3(0,2,4));
fragColor = vec4(col,1.0);
}"""
def grab_sample(sample_idx):
sample_pass = all_single_passes[sample_idx]
sample_code = sample_pass["code"]
sample_source = sample_pass["source"]
sample_title = sample_pass["title"]
sample_auhtor = sample_pass["author"]
source_iframe = construct_embed(sample_source)
print(f"{source_iframe=}")
# sample_funcs = _parse_functions(sample_code)
# funcs = _parse_functions(sample_code)
# func_identifiers = [f"{idx:2d}: {n.child_by_field_name('declarator').text.decode()}" for idx, n in enumerate(funcs)]
# print(f"updating drop down to:{func_identifiers}")
return sample_pass, sample_code, sample_title, source_iframe#, gr.Dropdown.update(choices=func_identifiers) #, sample_title, sample_auhtor
def _make_pipeline(model_cp = "Vipitis/santacoder-finetuned-Shadertoys-fine"): #bad default model for testing
# if torch.cuda.is_available():
# device = "cuda"
# else:
# device = "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_cp, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_cp, trust_remote_code=True)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, trust_remote_code=True) #, device=device)
PIPE = pipe # set the global?
print(f"loaded model {model_cp} as a pipline")
return pipe
def process_retn(retn):
return retn.split(";")[0].strip()
def get_full_replacement(orig_code, retn_start_idx, retn_end_idx, prediction) -> str:
"""
Batches the generated return statement into the code and returns the full altered code.
"""
print(f"{orig_code[retn_start_idx:retn_end_idx]=}")
generated = process_retn(prediction)
print(f"{generated=}")
variation = orig_code[:retn_start_idx] + generated + orig_code[retn_end_idx:]
return variation
def alter_return(orig_code, func_idx, temperature, max_new_tokens, top_p, repetition_penalty, pipeline=PIPE): #default pipeline can't be passed as gloabl?
"""
Replaces the return statement of a function with a generated one.
Args:
orig_code (str): The original code.
func_idx (int): The index of the function to replace the return statement of.
temperature (float): The temperature to use for generation.
max_new_tokens (int): The maximum number of tokens to generate.
top_p (float): The top_p to use for generation.
repetition_penalty (float): The repetition_penalty to use for generation.
pipeline (Pipeline): The pipeline to use for generation.
Returns:
str: The altered code.
"""
if pipeline is None:
print("no pipeline found, loading default one")
pipeline = _make_pipeline()
if isinstance(func_idx, str):
print(f"{func_idx=}")
func_idx = int(func_idx.split(":")[0].strip())
elif isinstance(func_idx, int):
pass
else:
raise gr.Error(f"func_idx must be int or str, not {type(func_idx)}")
generation_kwargs = combine_generation_kwargs(temperature, max_new_tokens, top_p, repetition_penalty)
retrns = []
retrn_start_idx = orig_code.find("return")
while retrn_start_idx != -1:
retrn_end_idx = orig_code.find(";", retrn_start_idx)
retrns.append((retrn_start_idx, retrn_end_idx))
retrn_start_idx = orig_code.find("return", retrn_end_idx)
num_returns = len(retrns)
if num_returns == 0:
print("no return statement found, returning original code")
return orig_code
func_idx = int(max(0, min(func_idx, num_returns - 1))) #clamp to valid range, cast to int as a bodge.
retrn_start_idx, retrn_end_idx = retrns[func_idx]
model_context = orig_code[:retrn_start_idx] #TODO: maximal context?
model_inp = model_context + "return"
pipe_generation = pipeline(model_inp, return_full_text=False, **generation_kwargs)[0]["generated_text"] #pipeline kwargs are missing?!
altered_code = get_full_replacement(orig_code, retrn_start_idx+7, retrn_end_idx, pipe_generation)
return altered_code
def alter_body(old_code, func_id, funcs_list: list, prompt="", temperature=0.2, max_new_tokens=512, top_p=.95, repetition_penalty=1.2, pipeline=PIPE):
"""
Replaces the body of a function with a generated one.
Args:
old_code (str): The original code.
func_node (Node): The node of the function to replace the body of.
funcs_list (list): The list of all functions in the code.
prompt (str): The prompt(title) to use for generation. defaults to "".
temperature (float): The temperature to use for generation. defaults to 0.2.
max_new_tokens (int): The maximum number of tokens to generate. defaults to 512.
top_p (float): The top_p to use for generation. defaults to 0.95.
repetition_penalty (float): The repetition_penalty to use for generation. defaults to 1.2.
pipeline (Pipeline): The pipeline to use for generation.
Returns:
str: The altered code.
"""
if isinstance(func_id, str):
print(f"{func_id=}")
func_id = int(func_id.split(":")[0].strip()) #undo their string casting?
elif isinstance(func_id, int):
pass
else:
raise gr.Error(f"func_id must be int or str, not {type(func_id)}")
func_node = funcs_list[func_id]
print(f"using for generation: {func_node=}")
generation_kwargs = combine_generation_kwargs(temperature, max_new_tokens, top_p, repetition_penalty)
model_context = construct_model_context(func_node, prompt=prompt)[0]
print(f"{model_context=}")
body_node = func_node.child_by_field_name("body")
body_start_idx, body_end_idx = node_str_idx(body_node)
# generation = pipeline(model_context, return_full_text=False, **generation_kwargs)[0]["generated_text"]
generation = stream_generation(model_context, pipeline, generation_kwargs)
for i in generation:
# print(f"{i=}")
yield model_context + i #fix in between, do all the stuff in the end?
generation = i[:] #seems to work
print(f"{generation=}")
ctx_with_generation = model_context + generation
try:
#strip the body
first_gened_func = parse_functions(ctx_with_generation)[0] # truncate generation to a single function?
except IndexError:
print("generation wasn't a full function.")
altered_code = old_code[:body_start_idx] + generation + "//the generation didn't complete the function!\n" + old_code[body_end_idx:] #needs a newline to break out of the comment.
return altered_code
altered_code = replace_function(func_node, first_gened_func)
yield altered_code #yield once so it updates? -> works... gg but doesn't seem to do it for the dropdown
return altered_code #never gets used by the code block? maybe I need to yield it first? but works in the ov_notebook
def list_dropdown_options(in_code): #only used for auto update, not on sample pick?
funcs = parse_functions(in_code)
func_identifiers = [f"{idx:2d}: {n.child_by_field_name('declarator').text.decode()}" for idx, n in enumerate(funcs)]
# funcs = [n for n in funcs] #wrapped as set to avoid json issues?
print(f"updating drop down to:{func_identifiers}")
return funcs, gr.Dropdown(choices=func_identifiers)
if __name__ == "__main__": #works on huggingface?
passes_dataset = datasets.load_dataset("Vipitis/Shadertoys")
single_passes = passes_dataset.filter(lambda x: not x["has_inputs"] and x["num_passes"] == 1) #could also include shaders with no extra functions.
# single_passes = single_passes.filter(lambda x: x["license"] not in "copyright") #to avoid any "do not display this" license?
all_single_passes = datasets.concatenate_datasets([single_passes["train"], single_passes["test"]])
num_samples = len(all_single_passes)
with gr.Blocks() as demo:
top_md = gr.Markdown(intro_text)
model_cp = gr.Textbox(value="Vipitis/santacoder-finetuned-Shadertoys-fine", label="Model Checkpoint (Enter to load!)", interactive=True)
sample_idx = gr.Slider(minimum=0, maximum=10513, value=3211, label="pick sample from dataset", step=1.0)
func_dropdown = gr.Dropdown(choices=["0: edit the Code (or load a shader) to update this dropdown"], label="chose a function to modify") #breaks if I add a string in before that? #TODO: use type="index" to get int - always gives None?
prompt_text = gr.Textbox(value="the title used by the model has generation hint", label="prompt text", info="leave blank to skip", interactive=True)
with gr.Accordion("Advanced settings", open=False): # from: https://huggingface.co/spaces/bigcode/bigcode-playground/blob/main/app.py
with gr.Row():
column_1, column_2 = gr.Column(), gr.Column()
with column_1:
temperature = gr.Slider(
label="Temperature",
value=0.2, #start out at 0 to do greedy? or will there be an error?
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Higher values produce more diverse outputs",
)
max_new_tokens = gr.Slider(
label="Max new tokens",
value=265,
minimum=0,
maximum=2048, #this could be inferred from the model?
step=32,
interactive=True,
info="The maximum numbers of new tokens",
)
with column_2:
top_p = gr.Slider(
label="Top-p (nucleus sampling)",
value=0.90,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
)
repetition_penalty = gr.Slider(
label="Repetition penalty",
value=1.2,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
)
with gr.Row():
gen_return_button = gr.Button("generate a alternate return statement", scale=0)
gen_func_button = gr.Button("generate an alternate function body", scale=1)
with gr.Row():
with gr.Column():
source_embed = gr.HTML('<iframe width="640" height="360" frameborder="0" src="" allowfullscreen></iframe>', label="How this shader originally renders")
our_embed = gr.HTML(label="glsl render of the current code")
sample_code = gr.Code(new_shadertoy_code, label="Current Code (will update changes you generate)", language=None)
bot_md = gr.Markdown(outro_text)
sample_pass = gr.State(value={})
funcs = gr.State(value=[])
pipe = gr.State(value=PIPE)
pipe.value=_make_pipeline("Vipitis/santacoder-finetuned-Shadertoys-fine") # set a default like this?
model_cp.submit(fn=_make_pipeline, inputs=[model_cp], outputs=[pipe]) # how can we trigger this on load?
sample_idx.release(fn=grab_sample, inputs=[sample_idx], outputs=[sample_pass, sample_code, prompt_text, source_embed]) #funcs here?
gen_return_button.click(fn=alter_return, inputs=[sample_code, func_dropdown, temperature, max_new_tokens, top_p, repetition_penalty, pipe], outputs=[sample_code])
gen_func_button.click(fn=alter_body, inputs=[sample_code, func_dropdown, funcs, prompt_text, temperature, max_new_tokens, top_p, repetition_penalty, pipe], outputs=[sample_code]).then(
fn=list_dropdown_options, inputs=[sample_code], outputs=[funcs, func_dropdown]
)
sample_code.change(fn=list_dropdown_options, inputs=[sample_code], outputs=[funcs, func_dropdown]).then(
fn=make_iframe, inputs=[sample_code], outputs=[our_embed])
demo.queue()
demo.launch()