|
import streamlit as st |
|
import streamlit.components.v1 as components |
|
from streamlit.components.v1 import html |
|
import chess |
|
import streamlit_scrollable_textbox as stx |
|
from st_bridge import bridge |
|
from modules.chess import Chess |
|
from modules.utility import set_page |
|
from modules.states import init_states |
|
import datetime as dt |
|
from gradio_client import Client |
|
import random |
|
|
|
set_page(title='Chess vs LLaMA 3.1 405B', page_icon="♟️") |
|
init_states() |
|
st.session_state.board_width = 400 |
|
|
|
|
|
llama_client = Client("xianbao/SambaNova-fast") |
|
|
|
|
|
if 'player_color' not in st.session_state: |
|
st.session_state.player_color = 'white' |
|
if 'current_turn' not in st.session_state: |
|
st.session_state.current_turn = 'white' |
|
if 'game_started' not in st.session_state: |
|
st.session_state.game_started = False |
|
if 'curfen' not in st.session_state: |
|
st.session_state.curfen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1" |
|
if 'lastfen' not in st.session_state: |
|
st.session_state.lastfen = None |
|
if 'moves' not in st.session_state: |
|
st.session_state.moves = {} |
|
if 'curside' not in st.session_state: |
|
st.session_state.curside = 'white' |
|
|
|
def get_ai_move(fen): |
|
board = chess.Board(fen) |
|
legal_moves = list(board.legal_moves) |
|
|
|
if not legal_moves: |
|
return None |
|
|
|
prompt = f"You are a chess engine. Given the following chess position in FEN notation: {fen}, suggest a good move. Respond with only the move in UCI notation (e.g., e2e4)." |
|
|
|
for _ in range(3): |
|
try: |
|
response = llama_client.predict( |
|
message=prompt, |
|
system_message="You are a chess engine assistant.", |
|
max_tokens=10, |
|
temperature=0.7, |
|
top_p=0.9, |
|
top_k=50, |
|
api_name="/chat" |
|
) |
|
move = chess.Move.from_uci(response.strip()) |
|
if move in legal_moves: |
|
return move.uci() |
|
except ValueError: |
|
pass |
|
|
|
|
|
return random.choice(legal_moves).uci() |
|
|
|
def reset_game(player_color): |
|
st.session_state.player_color = player_color |
|
st.session_state.curfen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1" |
|
st.session_state.moves = {} |
|
st.session_state.current_turn = 'white' |
|
st.session_state.game_started = True |
|
st.session_state.lastfen = None |
|
st.session_state.game_over = False |
|
|
|
|
|
if st.session_state.player_color == 'black': |
|
ai_move = get_ai_move(st.session_state.curfen) |
|
board = chess.Board(st.session_state.curfen) |
|
if ai_move: |
|
move = chess.Move.from_uci(ai_move) |
|
board.push(move) |
|
st.session_state.curfen = board.fen() |
|
st.session_state.moves.update( |
|
{ |
|
st.session_state.curfen : { |
|
'side': 'white', |
|
'curfen': st.session_state.curfen, |
|
'last_fen': "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1", |
|
'last_move': ai_move, |
|
'data': None, |
|
'timestamp': str(dt.datetime.now()) |
|
} |
|
} |
|
) |
|
st.session_state.current_turn = 'black' |
|
|
|
def check_game_end(board): |
|
outcome = board.outcome() |
|
if outcome: |
|
st.session_state.game_over = True |
|
if outcome.winner is None: |
|
return "Draw" |
|
return "White" if outcome.winner else "Black" |
|
return None |
|
|
|
st.title("Chess vs LLaMA 3.1 405B") |
|
|
|
|
|
col1, col2, col3 = st.columns([1,1,1]) |
|
with col1: |
|
player_color = st.selectbox("Choose your color", ['white', 'black'], key='color_select') |
|
with col2: |
|
if st.button('Start New Game', key='start_game'): |
|
reset_game(player_color) |
|
st.rerun() |
|
with col3: |
|
st.write(f"Current turn: {st.session_state.current_turn}") |
|
st.write(f"Your color: {st.session_state.player_color}") |
|
|
|
|
|
data = bridge("my-bridge") |
|
if data is not None and st.session_state.game_started and not st.session_state.game_over: |
|
st.session_state.lastfen = st.session_state.curfen |
|
st.session_state.curfen = data['fen'] |
|
st.session_state.curside = data['move']['color'].replace('w','white').replace('b','black') |
|
st.session_state.moves.update( |
|
{ |
|
st.session_state.curfen : { |
|
'side':st.session_state.curside, |
|
'curfen':st.session_state.curfen, |
|
'last_fen':st.session_state.lastfen, |
|
'last_move':data['pgn'], |
|
'data': None, |
|
'timestamp': str(dt.datetime.now()) |
|
} |
|
} |
|
) |
|
st.session_state.current_turn = 'white' if st.session_state.curside == 'black' else 'black' |
|
|
|
board = chess.Board(st.session_state.curfen) |
|
game_result = check_game_end(board) |
|
if game_result: |
|
st.success(f"Game Over! Winner: {game_result}") |
|
elif st.session_state.current_turn != st.session_state.player_color: |
|
|
|
ai_move = get_ai_move(st.session_state.curfen) |
|
if ai_move: |
|
move = chess.Move.from_uci(ai_move) |
|
board.push(move) |
|
st.session_state.curfen = board.fen() |
|
st.session_state.moves.update( |
|
{ |
|
st.session_state.curfen : { |
|
'side': st.session_state.current_turn, |
|
'curfen': st.session_state.curfen, |
|
'last_fen': st.session_state.lastfen, |
|
'last_move': ai_move, |
|
'data': None, |
|
'timestamp': str(dt.datetime.now()) |
|
} |
|
} |
|
) |
|
st.session_state.current_turn = st.session_state.player_color |
|
game_result = check_game_end(board) |
|
if game_result: |
|
st.success(f"Game Over! Winner: {game_result}") |
|
else: |
|
st.error("The AI couldn't make a move. The game may be over.") |
|
|
|
|
|
cols = st.columns([3, 2]) |
|
with cols[0]: |
|
if st.session_state.game_started: |
|
puzzle = Chess(st.session_state.board_width, st.session_state.curfen) |
|
components.html( |
|
puzzle.puzzle_board(), |
|
height=st.session_state.board_width + 75, |
|
scrolling=False |
|
) |
|
board = chess.Board(st.session_state.curfen) |
|
|
|
|
|
status_col1, status_col2 = st.columns(2) |
|
with status_col1: |
|
st.write("Game Status:") |
|
st.write(f"Check: {'Yes' if board.is_check() else 'No'}") |
|
st.write(f"Checkmate: {'Yes' if board.is_checkmate() else 'No'}") |
|
with status_col2: |
|
st.write("\u200B") |
|
st.write(f"Stalemate: {'Yes' if board.is_stalemate() else 'No'}") |
|
st.write(f"Insufficient material: {'Yes' if board.is_insufficient_material() else 'No'}") |
|
|
|
if st.session_state.game_over: |
|
st.success(f"Game Over! Winner: {check_game_end(board)}") |
|
else: |
|
st.info("Welcome to Chess vs LLaMA 3.1 405B!") |
|
st.write("To start a new game:") |
|
st.write("1. Choose your color (white or black)") |
|
st.write("2. Click 'Start New Game'") |
|
st.write("3. Make your moves on the chess board") |
|
st.write("Enjoy playing against the AI!") |
|
|
|
with cols[1]: |
|
if st.session_state.game_started: |
|
st.subheader("Move History") |
|
records = [ |
|
f"##### {value['timestamp'].split('.')[0]} \n {value['side']} - {value.get('last_move','')}" |
|
for key, value in st.session_state['moves'].items() |
|
] |
|
stx.scrollableTextbox('\n\n'.join(records), height = 400, border=True) |
|
else: |
|
st.image("https://upload.wikimedia.org/wikipedia/commons/6/6f/ChessSet.jpg", caption="Chess pieces", use_column_width=True) |