import model import tetris import sys import representation import random from pathlib import Path script_dir = Path(__file__).parent.resolve() checkpoints_dir = script_dir / "checkpoints" checkpoints_dir.mkdir(exist_ok=True) log_file_path = checkpoints_dir / "log.txt" # if you want to start from a checkpoint, fill this in with the path to the .pth file. If wanting to start from a new NN, leave blank! model_save_path = r"" # training settings gamma:float = 0.5 epsilon:float = 0.2 # training config batch_size:int = 100 # the number of experiences that will be collected and trained on save_model_every_experiences:int = 5000 ################ # construct/load model tmodel:model.TetrisAI = None if model_save_path != None and model_save_path != "": print("Loading model checkpoint at '" + model_save_path + "'...") tmodel = model.TetrisAI(model_save_path) print("Model loaded!") else: print("Constructing new model...") tmodel = model.TetrisAI() # variables to track experiences_trained:int = 0 # the number of experiences the model has been trained on model_last_saved_at_experiences_trained:int = 0 # the last number of experiences that the model was trained on on_checkpoint:int = 0 def log(path:str, content:str) -> None: if path != None and path != "": f = open(path, "a") f.write(content + "\n") f.close() # training loop while True: # collect X number of experiences gs:tetris.GameState = tetris.GameState() experiences:list[model.Experience] = [] for ei in range(0, batch_size): # print! sys.stdout.write("\r" + "Collecting experience " + str(ei+1) + " / " + str(batch_size) + "... ") sys.stdout.flush() # get board representation state_board:list[int] = representation.BoardState(gs) # select move to play move:int if random.random() < epsilon: # if by chance we should select a random move move = random.randint(0, 3) # choose move at random else: predictions:list[float] = tmodel.predict(state_board) # predict Q-Values move = predictions.index(max(predictions)) # select the move (index) with the highest Q-Value # play the move IllegalMovePlayed:bool = False MoveReward:float try: MoveReward = gs.drop(move) except tetris.InvalidDropException as ex: # the model (or at random) tried to play an illegal move IllegalMovePlayed = True MoveReward = -3.0 # small penalty for illegal moves except Exception as ex: print("Unhandled exception in move execution: " + str(ex)) input("Press enter key to continue, if you want to.") # store this experience exp:model.Experience = model.Experience() exp.state = state_board exp.action = move exp.reward = MoveReward exp.next_state = representation.BoardState(gs) # the state we find ourselves in now. exp.done = gs.over() or IllegalMovePlayed # it is over if the game is completed OR an illegal move was played experiences.append(exp) # if game is over or they played an illegal move, reset the game! if gs.over() or IllegalMovePlayed: gs = tetris.GameState() print() # print avg rewards rewards:float = 0.0 for exp in experiences: rewards = rewards + exp.reward status:str = "Average reward over those " + str(len(experiences)) + " experiences on model w/ " + str(experiences_trained) + " trained experiences: " + str(round(rewards / len(experiences), 2)) log(log_file_path, status) print(status) # train! for ei in range(0, len(experiences)): exp = experiences[ei] # print training number sys.stdout.write("\r" + "Training on experience " + str(ei+1) + " / " + str(len(experiences)) + "... ") sys.stdout.flush() # determine new target based on the game ending or not (maybe we should factor in future rewards, maybe we shouldnt) new_target:float if exp.done: new_target = exp.reward else: max_q_of_next_state:float = max(tmodel.predict(exp.next_state)) new_target = exp.reward + (gamma * max_q_of_next_state) # blend immediate vs. future rewards # ask the model to predict again for this experiences state qvalues:list[float] = tmodel.predict(exp.state) # plug in the new target where it belongs qvalues[exp.action] = new_target # now train on the updated qvalues (with 1 changed) tmodel.train(exp.state, qvalues) experiences_trained = experiences_trained + 1 print("Training complete!") # save model! if (experiences_trained - model_last_saved_at_experiences_trained) >= save_model_every_experiences: print("Time to save model!") path = checkpoints_dir / f"checkpoint{on_checkpoint}.pth" tmodel.save(path) print("Checkpoint # " + str(on_checkpoint) + " saved to " + str(path) + "!") on_checkpoint = on_checkpoint + 1 model_last_saved_at_experiences_trained = experiences_trained print("Model saved to " + str(path) + "!")