Spaces:
Runtime error
Runtime error
improved docstring extraction
Browse files- utils/__init__.py +2 -2
- utils/generation.py +2 -6
- utils/tree_utils.py +27 -5
utils/__init__.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
-
from .tree_utils import (parse_functions, get_docstrings, grab_before_comments, line_chr2char, replace_function, get_root, node_str_idx, give_tree)
|
2 |
from .html_utils import (make_iframe, make_script, construct_embed)
|
3 |
from .generation import (combine_generation_kwargs, stream_generation, construct_model_context)
|
4 |
|
5 |
-
tree_funcs = ["parse_functions", "get_docstrings", "grab_before_comments", "line_chr2char", "replace_function", "get_root", "node_str_idx", "give_tree"]
|
6 |
html_funcs = ["make_iframe", "make_script", "construct_embed"]
|
7 |
gen_funcs = ["combine_generation_kwargs", "stream_generation", "construct_model_context"]
|
8 |
|
|
|
1 |
+
from .tree_utils import (parse_functions, get_docstrings, grab_before_comments, line_chr2char, replace_function, get_root, node_str_idx, give_tree, full_func_head, has_docstrings)
|
2 |
from .html_utils import (make_iframe, make_script, construct_embed)
|
3 |
from .generation import (combine_generation_kwargs, stream_generation, construct_model_context)
|
4 |
|
5 |
+
tree_funcs = ["parse_functions", "get_docstrings", "grab_before_comments", "line_chr2char", "replace_function", "get_root", "node_str_idx", "give_tree", "full_func_head", "has_docstrings"]
|
6 |
html_funcs = ["make_iframe", "make_script", "construct_embed"]
|
7 |
gen_funcs = ["combine_generation_kwargs", "stream_generation", "construct_model_context"]
|
8 |
|
utils/generation.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from transformers import TextIteratorStreamer
|
2 |
from threading import Thread
|
3 |
-
from .tree_utils import
|
4 |
|
5 |
def combine_generation_kwargs(temperature=2.0, max_new_tokens=512, top_p=0.95, repetition_penalty=1.2):
|
6 |
"""
|
@@ -47,11 +47,7 @@ def construct_model_context(func_node, prompt="") -> str:
|
|
47 |
"""
|
48 |
Constructs the model context from a function node.
|
49 |
"""
|
50 |
-
model_context = func_node
|
51 |
-
docstring = get_docstrings(func_node) #might be empty?
|
52 |
-
if docstring:
|
53 |
-
model_context = model_context + "\n" + docstring
|
54 |
-
model_context = grab_before_comments(func_node) + model_context #prepend comments
|
55 |
if prompt != "":
|
56 |
model_context = "//Title: " + prompt + "\n" + model_context #prepend user prompt/title
|
57 |
model_context = "//Language: Shadertoy GLSL fragment shader\n" + model_context #prepend system prompt, language hint
|
|
|
1 |
from transformers import TextIteratorStreamer
|
2 |
from threading import Thread
|
3 |
+
from .tree_utils import full_func_head, grab_before_comments
|
4 |
|
5 |
def combine_generation_kwargs(temperature=2.0, max_new_tokens=512, top_p=0.95, repetition_penalty=1.2):
|
6 |
"""
|
|
|
47 |
"""
|
48 |
Constructs the model context from a function node.
|
49 |
"""
|
50 |
+
model_context = grab_before_comments(func_node) + full_func_head(func_node) # (identifier + docstrings)
|
|
|
|
|
|
|
|
|
51 |
if prompt != "":
|
52 |
model_context = "//Title: " + prompt + "\n" + model_context #prepend user prompt/title
|
53 |
model_context = "//Language: Shadertoy GLSL fragment shader\n" + model_context #prepend system prompt, language hint
|
utils/tree_utils.py
CHANGED
@@ -56,13 +56,28 @@ def get_docstrings(func_node):
|
|
56 |
returns the docstring of a function node
|
57 |
"""
|
58 |
docstring = ""
|
59 |
-
for node in func_node.
|
60 |
-
if node.type == "comment"
|
61 |
-
docstring += node.text.decode()
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
64 |
return docstring
|
65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
def grab_before_comments(func_node):
|
68 |
"""
|
@@ -80,6 +95,13 @@ def grab_before_comments(func_node):
|
|
80 |
return precomment
|
81 |
return precomment
|
82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
def line_chr2char(text, line_idx, chr_idx):
|
84 |
"""
|
85 |
returns the character index at the given line and character index.
|
|
|
56 |
returns the docstring of a function node
|
57 |
"""
|
58 |
docstring = ""
|
59 |
+
for node in func_node.children:
|
60 |
+
if node.type == "comment": #comment in like the declarator
|
61 |
+
docstring += node.text.decode()
|
62 |
+
elif node.type == "compound_statement": #body below here
|
63 |
+
for body_node in node.children:
|
64 |
+
if body_node.type == "comment" or body_node.type == "{":
|
65 |
+
docstring += " " * body_node.start_point[1] #add in indentation
|
66 |
+
docstring += body_node.text.decode() + "\n"
|
67 |
+
else:
|
68 |
+
return docstring
|
69 |
return docstring
|
70 |
|
71 |
+
def full_func_head(func_node):
|
72 |
+
"""
|
73 |
+
returns function head including docstrings before any real body code
|
74 |
+
"""
|
75 |
+
cursor = func_node.child_by_field_name("body").walk()
|
76 |
+
cursor.goto_first_child()
|
77 |
+
while cursor.node.type == "comment" or cursor.node.type == "{":
|
78 |
+
cursor.goto_next_sibling()
|
79 |
+
end = cursor.node.start_point
|
80 |
+
return "\n".join(func_node.text.decode().split("\n")[:(end[0]-func_node.start_point[0])])
|
81 |
|
82 |
def grab_before_comments(func_node):
|
83 |
"""
|
|
|
95 |
return precomment
|
96 |
return precomment
|
97 |
|
98 |
+
def has_docstrings(func_node):
|
99 |
+
"""
|
100 |
+
returns whether a function node has a docstring
|
101 |
+
"""
|
102 |
+
return get_docstrings(func_node).strip() != "{" or grab_before_comments(func_node) != ""
|
103 |
+
|
104 |
+
|
105 |
def line_chr2char(text, line_idx, chr_idx):
|
106 |
"""
|
107 |
returns the character index at the given line and character index.
|