import gradio as gr from io import StringIO from model import DecoderTransformer, Tokenizer from huggingface_hub import hf_hub_download import torch import chess import chess.svg import chess.pgn from svglib.svglib import svg2rlg from reportlab.graphics import renderPM from PIL import Image import os from uuid import uuid4 vocab_size = 36 n_embed = 384 context_size = 256 n_layer = 6 n_head = 6 dropout = 0.2 device = 'cpu' model_id = "philipp-zettl/chessPT" model_path = hf_hub_download(repo_id=model_id, filename="chessPT-v0.5.pth") tokenizer_path = hf_hub_download(repo_id=model_id, filename="tokenizer-v0.5.json") model = DecoderTransformer(vocab_size, n_embed, context_size, n_layer, n_head, dropout) model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) model.to(device) tokenizer = Tokenizer.from_pretrained(tokenizer_path) invalid_move_plot = Image.open('./invalid_move.png') def gen_image_from_svg(img, filename): with open(filename + '.svg', 'w') as f: f.write(img) drawing = svg2rlg(filename + '.svg') renderPM.drawToFile(drawing, f"{filename}.png", fmt="PNG") plot = Image.open(f'{filename}.png') os.remove(f'{filename}.png') os.remove(f'{filename}.svg') return plot def get_board(pgn): pgn_str = StringIO(pgn) try: game = chess.pgn.read_game(pgn_str) board = game.board() for move in game.mainline_moves(): board.push(move) except Exception as e: if 'illegal san' in str(e): return None return board def gen_board_image(pgn): board = get_board(pgn) return chess.svg.board(board) def gen_move(pgn): model_input = torch.tensor(tokenizer.encode(pgn), dtype=torch.long, device=device).view((1, len(pgn))) is_invalid = True board = get_board(pgn) while is_invalid: new_pgn = tokenizer.decode(model.generate(model_input, max_new_tokens=4, context_size=context_size, temperature=0.2)[0].tolist()) try: print(f'checking {new_pgn}') mv = new_pgn[len(pgn):].split(' ')[0] new_pgn = pgn.rstrip() + f' {mv}' is_invalid = get_board(new_pgn) is None except Exception: is_invalid = True print(f'For {pgn} invalid "{new_pgn[len(pgn):].split(" ")[0]}" {new_pgn}') #print(mov in board.legal_moves) return new_pgn def generate(prompt): model_input = torch.tensor(tokenizer.encode(prompt), dtype=torch.long, device=device).view((1, len(prompt))) pgn = tokenizer.decode(model.generate(model_input, max_new_tokens=4, context_size=context_size, temperature=0.2)[0].tolist()) img = gen_board_image(pgn) filename = f'./moves-{str(uuid4())}' plot = gen_image_from_svg(img, filename) return pgn, plot with gr.Blocks() as demo: gr.Markdown(""" # ChessPT Welcome to ChessPT. The **C**hess-**P**re-trained-**T**ransformer. The rules are simple: - "Interactive": Play a game against the model - "Next turn prediction": provide a PGN string of your current game, the model will predict the next token """) def manual(): with gr.Tab("Next turn prediction"): prompt = gr.Text(label="PGN") output = gr.Text(label="Next turn", interactive=False) img = gr.Image() submit = gr.Button("Submit") submit.click(generate, [prompt], [output, img]) gr.Examples( [ ["1. e4", ], ["1. e4 g6 2."], ], inputs=[prompt], outputs=[output, img], fn=generate ) def interactive(): with gr.Tab("Interactive"): color = gr.Dropdown(["white", "black"], value='white', label="Chose a color") start_button = gr.Button("Start Game") def start_game(c): pgn = '1. ' if c == 'black': pgn = gen_move(pgn) img = gen_board_image(pgn) fn = 'foo' return gen_image_from_svg(img, fn), pgn, 1 state = gr.Text(label='PGN', value='', interactive=False) game = gr.Image() move_counter = gr.State(value=1) start_button.click( start_game, inputs=[color], outputs=[game, state, move_counter] ) next_move = gr.Text(label='Next move') gen_next_move_button = gr.Button("Submit") def gen_next_move(pgn, new_move, move_ctr, c): pgn += (' ' if c == 'black' else '') + new_move.strip() + ' ' if c == 'black': move_ctr += 1 pgn = f'{pgn.rstrip()} {move_ctr}. ' print(f'gen for {pgn}') pgn = gen_move(pgn) print(f'got {pgn}') img = gen_board_image(pgn) if c == 'white': move_ctr += 1 pgn = f'{pgn.rstrip()} {move_ctr}. ' return gen_image_from_svg(img, 'foo-bar'), pgn, move_ctr gen_next_move_button.click( gen_next_move, inputs=[state, next_move, move_counter, color], outputs=[game, state, move_counter] ) interactive() manual() demo.launch()