Spaces:
Sleeping
Sleeping
import gradio as gr | |
from io import StringIO | |
from model import DecoderTransformer, Tokenizer | |
from huggingface_hub import hf_hub_download | |
import torch | |
import chess | |
import chess.svg | |
import chess.pgn | |
from svglib.svglib import svg2rlg | |
from reportlab.graphics import renderPM | |
from PIL import Image | |
import os | |
from uuid import uuid4 | |
vocab_size = 36 | |
n_embed = 384 | |
context_size = 256 | |
n_layer = 6 | |
n_head = 6 | |
dropout = 0.2 | |
device = 'cpu' | |
model_id = "philipp-zettl/chessPT" | |
model_path = hf_hub_download(repo_id=model_id, filename="chessPT-v0.5.pth") | |
tokenizer_path = hf_hub_download(repo_id=model_id, filename="tokenizer-v0.5.json") | |
model = DecoderTransformer(vocab_size, n_embed, context_size, n_layer, n_head, dropout) | |
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
model.to(device) | |
tokenizer = Tokenizer.from_pretrained(tokenizer_path) | |
invalid_move_plot = Image.open('./invalid_move.png') | |
def gen_image_from_svg(img, filename): | |
with open(filename + '.svg', 'w') as f: | |
f.write(img) | |
drawing = svg2rlg(filename + '.svg') | |
renderPM.drawToFile(drawing, f"{filename}.png", fmt="PNG") | |
plot = Image.open(f'{filename}.png') | |
os.remove(f'{filename}.png') | |
os.remove(f'{filename}.svg') | |
return plot | |
def get_board(pgn): | |
pgn_str = StringIO(pgn) | |
try: | |
game = chess.pgn.read_game(pgn_str) | |
board = game.board() | |
for move in game.mainline_moves(): | |
board.push(move) | |
except Exception as e: | |
if 'illegal san' in str(e): | |
return None | |
return board | |
def gen_board_image(pgn): | |
board = get_board(pgn) | |
return chess.svg.board(board) | |
def gen_move(pgn): | |
model_input = torch.tensor(tokenizer.encode(pgn), dtype=torch.long, device=device).view((1, len(pgn))) | |
is_invalid = True | |
board = get_board(pgn) | |
while is_invalid: | |
new_pgn = tokenizer.decode(model.generate(model_input, max_new_tokens=4, context_size=context_size, temperature=0.2)[0].tolist()) | |
try: | |
print(f'checking {new_pgn}') | |
mv = new_pgn[len(pgn):].split(' ')[0] | |
new_pgn = pgn.rstrip() + f' {mv}' | |
is_invalid = get_board(new_pgn) is None | |
except Exception: | |
is_invalid = True | |
print(f'For {pgn} invalid "{new_pgn[len(pgn):].split(" ")[0]}" {new_pgn}') | |
#print(mov in board.legal_moves) | |
return new_pgn | |
def generate(prompt): | |
model_input = torch.tensor(tokenizer.encode(prompt), dtype=torch.long, device=device).view((1, len(prompt))) | |
pgn = tokenizer.decode(model.generate(model_input, max_new_tokens=4, context_size=context_size, temperature=0.2)[0].tolist()) | |
img = gen_board_image(pgn) | |
filename = f'./moves-{str(uuid4())}' | |
plot = gen_image_from_svg(img, filename) | |
return pgn, plot | |
with gr.Blocks() as demo: | |
gr.Markdown(""" | |
# ChessPT | |
Welcome to ChessPT. | |
The **C**hess-**P**re-trained-**T**ransformer. | |
The rules are simple: | |
- "Interactive": Play a game against the model | |
- "Next turn prediction": provide a PGN string of your current game, the model will predict the next token | |
""") | |
def manual(): | |
with gr.Tab("Next turn prediction"): | |
prompt = gr.Text(label="PGN") | |
output = gr.Text(label="Next turn", interactive=False) | |
img = gr.Image() | |
submit = gr.Button("Submit") | |
submit.click(generate, [prompt], [output, img]) | |
gr.Examples( | |
[ | |
["1. e4", ], | |
["1. e4 g6 2."], | |
], | |
inputs=[prompt], | |
outputs=[output, img], | |
fn=generate | |
) | |
def interactive(): | |
with gr.Tab("Interactive"): | |
color = gr.Dropdown(["white", "black"], value='white', label="Chose a color") | |
start_button = gr.Button("Start Game") | |
def start_game(c): | |
pgn = '1. ' | |
if c == 'black': | |
pgn = gen_move(pgn) | |
img = gen_board_image(pgn) | |
fn = 'foo' | |
return gen_image_from_svg(img, fn), pgn, 1 | |
state = gr.Text(label='PGN', value='', interactive=False) | |
game = gr.Image() | |
move_counter = gr.State(value=1) | |
start_button.click( | |
start_game, | |
inputs=[color], | |
outputs=[game, state, move_counter] | |
) | |
next_move = gr.Text(label='Next move') | |
gen_next_move_button = gr.Button("Submit") | |
def gen_next_move(pgn, new_move, move_ctr, c): | |
pgn += (' ' if c == 'black' else '') + new_move.strip() + ' ' | |
if c == 'black': | |
move_ctr += 1 | |
pgn = f'{pgn.rstrip()} {move_ctr}. ' | |
print(f'gen for {pgn}') | |
pgn = gen_move(pgn) | |
print(f'got {pgn}') | |
img = gen_board_image(pgn) | |
if c == 'white': | |
move_ctr += 1 | |
pgn = f'{pgn.rstrip()} {move_ctr}. ' | |
return gen_image_from_svg(img, 'foo-bar'), pgn, move_ctr | |
gen_next_move_button.click( | |
gen_next_move, | |
inputs=[state, next_move, move_counter, color], | |
outputs=[game, state, move_counter] | |
) | |
interactive() | |
manual() | |
demo.launch() | |