diff --git a/.gitignore b/.gitignore index 81c71f9..5c8f4f6 100644 --- a/.gitignore +++ b/.gitignore @@ -180,3 +180,4 @@ test*.py logs/ logs settings.json +training_logs \ No newline at end of file diff --git a/README.md b/README.md index 9cbb7c4..6177693 100644 --- a/README.md +++ b/README.md @@ -1 +1,5 @@ -Fleet Commander is like Space Invaders but you are the enemy instead of the player. \ No newline at end of file +Fleet Commander is like Space Invaders but you are the enemy instead of the player. + +It uses AI (Reinforcement Learning) for the Player, and you, the Enemy has to defeat it. + +You can train yourself, or use the default model which comes with the game. \ No newline at end of file diff --git a/game/play.py b/game/play.py index a60dec4..8601b32 100644 --- a/game/play.py +++ b/game/play.py @@ -34,7 +34,7 @@ class Game(arcade.gui.UIView): def on_show_view(self): super().on_show_view() - self.back_button = self.anchor.add(arcade.gui.UITextureButton(texture=button_texture, texture_hovered=button_hovered_texture, text='<--', style=button_style, width=100, height=50), anchor_x="left", anchor_y="top") + self.back_button = self.anchor.add(arcade.gui.UITextureButton(texture=button_texture, texture_hovered=button_hovered_texture, text='<--', style=button_style, width=100, height=50), anchor_x="left", anchor_y="top", align_x=5, align_y=-5) self.back_button.on_click = lambda event: self.main_exit() def main_exit(self): diff --git a/game/sprites.py b/game/sprites.py index 7f6520c..2de03f9 100644 --- a/game/sprites.py +++ b/game/sprites.py @@ -33,7 +33,7 @@ class Player(arcade.Sprite): # Not actually the player def update(self, model: PPO, enemies, bullets, width, height): if enemies: - nearest_enemy = min(enemies, key=lambda e: abs(e.center_y - self.center_y) + abs(e.center_x - self.center_x)) + nearest_enemy = min(enemies, key=lambda e: abs(e.center_x - self.center_x)) enemy_x = (nearest_enemy.center_x - self.center_x) / width enemy_y = (nearest_enemy.center_y - self.center_y) / height else: diff --git a/menus/main.py b/menus/main.py index 25aded9..4342026 100644 --- a/menus/main.py +++ b/menus/main.py @@ -55,8 +55,8 @@ class Main(arcade.gui.UIView): self.play_button = self.box.add(arcade.gui.UITextureButton(text="Play", texture=button_texture, texture_hovered=button_hovered_texture, width=self.window.width / 2, height=150, style=big_button_style)) self.play_button.on_click = lambda event: self.play() - self.train_button = self.box.add(arcade.gui.UITextureButton(text="Train", texture=button_texture, texture_hovered=button_hovered_texture, width=self.window.width / 2, height=150, style=big_button_style)) - self.train_button.on_click = lambda event: self.train() + self.train_model_button = self.box.add(arcade.gui.UITextureButton(text="Train Model", texture=button_texture, texture_hovered=button_hovered_texture, width=self.window.width / 2, height=150, style=big_button_style)) + self.train_model_button.on_click = lambda event: self.train_model() self.settings_button = self.box.add(arcade.gui.UITextureButton(text="Settings", texture=button_texture, texture_hovered=button_hovered_texture, width=self.window.width / 2, height=150, style=big_button_style)) self.settings_button.on_click = lambda event: self.settings() @@ -68,3 +68,7 @@ class Main(arcade.gui.UIView): def settings(self): from menus.settings import Settings self.window.show_view(Settings(self.pypresence_client)) + + def train_model(self): + from menus.train_model import TrainModel + self.window.show_view(TrainModel(self.pypresence_client)) \ No newline at end of file diff --git a/menus/train_model.py b/menus/train_model.py index 7989a27..f8921bd 100644 --- a/menus/train_model.py +++ b/menus/train_model.py @@ -1,60 +1,177 @@ -import arcade, arcade.gui +import arcade, arcade.gui, threading, io, os, time -from utils.constants import button_style, MODEL_SETTINGS +from utils.constants import button_style, MODEL_SETTINGS, monitor_log_dir from utils.preload import button_texture, button_hovered_texture +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +from io import BytesIO +from PIL import Image + from stable_baselines3 import PPO -from utils.ml import SpaceInvadersEnv +from stable_baselines3.common.monitor import Monitor +from stable_baselines3.common.logger import configure + +from utils.rl import SpaceInvadersEnv class TrainModel(arcade.gui.UIView): def __init__(self, pypresence_client): + super().__init__() + self.pypresence_client = pypresence_client self.pypresence_client.update(state="Model Training") - self.current_state = "settings" - self.anchor = self.add_widget(arcade.gui.UIAnchorLayout(size_hint=(1, 1))) self.box = self.anchor.add(arcade.gui.UIBoxLayout(space_between=10)) - self.settings = MODEL_SETTINGS.copy() + self.settings = { + setting: data[0] # default value + for setting, data in MODEL_SETTINGS.items() + } + self.labels = {} + + self.training = False + self.training_text = "" + + self.last_progress_update = time.perf_counter() def on_show_view(self): super().on_show_view() - self.show_menu(self.current_state) + self.show_menu() - def show_menu(self, state): - if state == "settings": - self.box.add(arcade.gui.UILabel("Settings", font_size=48)) + def main_exit(self): + from menus.main import Main + self.window.show_view(Main(self.pypresence_client)) - for setting, data in MODEL_SETTINGS: - default, min_value, max_value, step = data - self.box.add(arcade.gui.UILabel(text=f"{setting.replace('_', ' ').capitalize()}: {default}")) - - slider = self.box.add(arcade.gui.UISlider(value=default, min_value=min_value, max_value=max_value, step=step)) - slider._render_steps = lambda surface: None - slider.on_change = lambda e, key=setting: self.change_value(key, e.new_value) + def show_menu(self): + self.back_button = self.anchor.add(arcade.gui.UITextureButton(texture=button_texture, texture_hovered=button_hovered_texture, text='<--', style=button_style, width=100, height=50), anchor_x="left", anchor_y="top", align_x=5, align_y=-5) + self.back_button.on_click = lambda event: self.main_exit() - train_button = self.box.add(arcade.gui.UITextureButton(width=self.window.width / 2, height=self.window.height / 10, text="Train", style=button_style, texture=button_texture, texture_hovered=button_hovered_texture)) - train_button.on_click = lambda e: self.train() + self.box.add(arcade.gui.UILabel("Settings", font_size=36)) + + for setting, data in MODEL_SETTINGS.items(): + default, min_value, max_value, step = data + label = self.box.add(arcade.gui.UILabel(text=f"{setting.replace('_', ' ').capitalize()}: {default}", font_size=18)) + + slider = self.box.add(arcade.gui.UISlider(value=default, min_value=min_value, max_value=max_value, step=step, width=self.window.width / 2, height=self.window.height / 25)) + slider._render_steps = lambda surface: None + slider.on_change = lambda e, key=setting: self.change_value(key, e.new_value) + + self.labels[setting] = label + + train_button = self.box.add(arcade.gui.UITextureButton(width=self.window.width / 2, height=self.window.height / 10, text="Train", style=button_style, texture=button_texture, texture_hovered=button_hovered_texture)) + train_button.on_click = lambda e: self.start_training() def change_value(self, key, value): - ... - + self.labels[key].text = f"{key.replace('_', ' ').capitalize()}: {self.round_near_int(value)}" + self.settings[key] = self.round_near_int(value) + + def start_training(self): + self.box.clear() + + self.training = True + + self.training_label = self.box.add(arcade.gui.UILabel("No Output yet.", font_size=16, multiline=True, width=self.window.width / 2, height=self.window.height / 2)) + + self.plot_image_widget = self.box.add(arcade.gui.UIImage(texture=arcade.Texture.create_empty("empty", (1, 1)))) + self.plot_image_widget.visible = False + + threading.Thread(target=self.train, daemon=True).start() + + def on_update(self, delta_time): + if self.training and os.path.exists(os.path.join("training_logs", "progress.csv")) and time.perf_counter() - self.last_progress_update >= 0.5: + self.last_progress_update = time.perf_counter() + + try: + progress_df = pd.read_csv(os.path.join("training_logs", "progress.csv")) + except pd.errors.EmptyDataError: + return + + progress_text = "" + + for key, value in progress_df.items(): + progress_text += f"{key}: {round(value.iloc[-1], 6)}\n" + + self.training_text = progress_text + + if hasattr(self, "training_label"): + self.training_label.text = self.training_text + + def round_near_int(self, x, tol=1e-4): + nearest = round(x) + if abs(x - nearest) < tol: + return nearest + return x + def train(self): - env = SpaceInvadersEnv() + os.makedirs(monitor_log_dir, exist_ok=True) + env = Monitor(SpaceInvadersEnv(), filename=os.path.join(monitor_log_dir, "monitor.csv")) + model = PPO( "MlpPolicy", env, - n_steps=2048, - batch_size=64, - n_epochs=10, - learning_rate=3e-4, - verbose=1, - device="cpu", - gamma=0.99, - ent_coef=0.01, - clip_range=0.2 + n_steps=self.settings["n_steps"], + batch_size=self.settings["batch_size"], + n_epochs=self.settings["n_epochs"], + learning_rate=self.settings["learning_rate"], + verbose=1, + device="cpu", + gamma=self.settings["gamma"], + ent_coef=self.settings["ent_coef"], + clip_range=self.settings["clip_range"], ) - model.learn(1_000_000) - model.save("invader_agent") \ No newline at end of file + + new_logger = configure( + folder=monitor_log_dir, format_strings=["csv"] + ) + + model.set_logger(new_logger) + + model.learn(self.settings["learning_steps"]) + model.save("invader_agent") + + self.training = False + + self.plot_results(os.path.join(monitor_log_dir, "monitor.csv"), os.path.join(monitor_log_dir, "progress.csv")) + + def plot_results(self, log_path, loss_log_path): + df = pd.read_csv(log_path, skiprows=1) + + fig, axes = plt.subplots(2, 1, figsize=(6, 8), dpi=100) + + loss_df = pd.read_csv(loss_log_path) + + axes[0].plot(np.cumsum(df['l']), df['r'], label='Reward') + axes[0].set_title('PPO Training: Episodic Reward') + axes[0].set_xlabel('Total Timesteps') + axes[0].set_ylabel('Reward') + axes[0].grid(True) + + axes[1].plot(loss_df['time/total_timesteps'], loss_df['train/policy_gradient_loss'], label='Policy Gradient Loss') + axes[1].plot(loss_df['time/total_timesteps'], loss_df['train/value_loss'], label='Value Loss') + axes[1].plot(loss_df['time/total_timesteps'], loss_df['train/explained_variance'], label='Explained Variance') + axes[1].set_title('PPO Training: Loss Functions') + axes[1].set_xlabel('Total Timesteps') + axes[1].set_ylabel('Loss Value') + axes[1].legend() + axes[1].grid(True) + + plt.tight_layout() + + buffer = BytesIO() + plt.savefig(buffer, format='png') + buffer.seek(0) + plt.close(fig) + + plot_texture = arcade.Texture(Image.open(buffer)) + + self.plot_image_widget.texture = plot_texture + self.plot_image_widget.size_hint = (None, None) + self.plot_image_widget.width = plot_texture.width + self.plot_image_widget.height = plot_texture.height + + self.plot_image_widget.visible = True + self.training_text = "Training finished. Plot displayed." \ No newline at end of file diff --git a/run.py b/run.py index 88a9d59..4e4d4ff 100644 --- a/run.py +++ b/run.py @@ -10,7 +10,6 @@ script_dir = os.path.dirname(os.path.abspath(__file__)) pyglet.resource.path.append(script_dir) pyglet.font.add_directory(os.path.join(script_dir, 'assets', 'fonts')) - from utils.utils import get_closest_resolution, print_debug_info, on_exception from utils.constants import log_dir, menu_background_color from menus.main import Main @@ -18,8 +17,6 @@ from arcade.experimental.controller_window import ControllerWindow sys.excepthook = on_exception -__builtins__.print = lambda *args, **kwargs: logging.debug(" ".join(map(str, args))) - if not log_dir in os.listdir(): os.makedirs(log_dir) diff --git a/train.py b/train.py deleted file mode 100644 index 26d7293..0000000 --- a/train.py +++ /dev/null @@ -1,20 +0,0 @@ -from stable_baselines3 import PPO -from utils.ml import SpaceInvadersEnv - -env = SpaceInvadersEnv() -model = PPO( - "MlpPolicy", - env, - n_steps=2048, - batch_size=64, - n_epochs=10, - learning_rate=3e-4, - verbose=1, - device="cpu", - gamma=0.99, - ent_coef=0.02, - clip_range=0.2, - gae_lambda=0.95 -) -model.learn(1_000_000) -model.save("invader_agent") \ No newline at end of file diff --git a/utils/constants.py b/utils/constants.py index 8d4b304..68beb95 100644 --- a/utils/constants.py +++ b/utils/constants.py @@ -12,18 +12,21 @@ PLAYER_ATTACK_SPEED = 0.75 BULLET_SPEED = 3 BULLET_RADIUS = 10 +# default, min, max, step MODEL_SETTINGS = { "n_steps": [2048, 256, 8192, 256], - "batch_size": 64, - "n_epochs": 10, - "learning_rate": 3e-4, - "gamma": 0.99, - "ent_coef": 0.01, - "clip_range": 0.2 + "batch_size": [64, 16, 512, 16], + "n_epochs": [10, 1, 50, 1], + "learning_rate": [3e-4, 1e-5, 1e-2, 1e-5], + "gamma": [0.99, 0.8, 0.9999, 0.001], + "ent_coef": [0.01, 0.0, 0.1, 0.001], + "clip_range": [0.2, 0.1, 0.4, 0.01], + "learning_steps": [500_000, 50_000, 25_000_000, 50_000] } menu_background_color = (30, 30, 47) log_dir = 'logs' +monitor_log_dir = "training_logs" discord_presence_id = 1438214877343907881 button_style = {'normal': UITextureButtonStyle(font_name="Roboto", font_color=arcade.color.BLACK), 'hover': UITextureButtonStyle(font_name="Roboto", font_color=arcade.color.BLACK), diff --git a/utils/ml.py b/utils/rl.py similarity index 85% rename from utils/ml.py rename to utils/rl.py index ea983e5..d282ebf 100644 --- a/utils/ml.py +++ b/utils/rl.py @@ -13,7 +13,7 @@ class SpaceInvadersEnv(gym.Env): self.height = height self.action_space = gym.spaces.Discrete(3) - self.observation_space = gym.spaces.Box(low=-10.0, high=10.0, shape=(9,), dtype=np.float32) + self.observation_space = gym.spaces.Box(low=-2.0, high=2.0, shape=(9,), dtype=np.float32) self.enemies = [] self.bullets = [] @@ -25,8 +25,14 @@ class SpaceInvadersEnv(gym.Env): self.prev_bx = 2.0 self.steps_since_direction_change = 0 self.last_direction = 0 + self.max_steps = 1000 + self.current_step = 0 def reset(self, seed=None, options=None): + if seed is not None: + np.random.seed(seed) + random.seed(seed) + self.enemies = [] self.bullets = [] self.dir_history = [] @@ -36,6 +42,7 @@ class SpaceInvadersEnv(gym.Env): self.prev_bx = 2.0 self.steps_since_direction_change = 0 self.last_direction = 0 + self.current_step = 0 start_x = self.width * 0.15 start_y = self.height * 0.9 @@ -51,7 +58,7 @@ class SpaceInvadersEnv(gym.Env): def _nearest_enemy(self): if not self.enemies: return None - return min(self.enemies, key=lambda e: abs(e.center_y - self.player.center_y) + abs(e.center_x - self.player.center_x)) + return min(self.enemies, key=lambda e: abs(e.center_x - self.player.center_x)) def _lowest_enemy(self): if not self.enemies: @@ -104,18 +111,16 @@ class SpaceInvadersEnv(gym.Env): terminated = False truncated = False + self.current_step += 1 + if self.current_step >= self.max_steps: + truncated = True + nearest = self._nearest_enemy() if nearest is not None: enemy_x = (nearest.center_x - self.player.center_x) / float(self.width) else: enemy_x = 2.0 - prev_bullet = self._nearest_enemy_bullet() - if prev_bullet is not None: - prev_bx = (prev_bullet.center_x - self.player.center_x) / float(self.width) - else: - prev_bx = 2.0 - prev_x = self.player.center_x current_action_dir = 0 @@ -129,12 +134,15 @@ class SpaceInvadersEnv(gym.Env): t = time.perf_counter() if t - self.last_shot >= PLAYER_ATTACK_SPEED: self.last_shot = t + b = Bullet(self.player.center_x, self.player.center_y, 1) + self.bullets.append(b) + if enemy_x != 2.0 and abs(enemy_x) < 0.04: - reward += 8.0 + reward += 0.3 elif enemy_x != 2.0 and abs(enemy_x) < 0.1: - reward += 3.0 + reward += 0.1 if self.player.center_x > self.width: self.player.center_x = self.width @@ -145,8 +153,9 @@ class SpaceInvadersEnv(gym.Env): if current_action_dir != 0: if self.last_direction != 0 and current_action_dir != self.last_direction: - if self.steps_since_direction_change < 8: - reward -= 3.0 + if self.steps_since_direction_change < 3: + reward -= 0.1 + self.steps_since_direction_change = 0 else: self.steps_since_direction_change += 1 @@ -154,9 +163,9 @@ class SpaceInvadersEnv(gym.Env): if enemy_x != 2.0: if abs(enemy_x) < 0.03: - reward += 3.0 + reward += 0.1 elif abs(enemy_x) < 0.08: - reward += 1.0 + reward += 0.05 for b in list(self.bullets): b.center_y += b.direction_y * BULLET_SPEED @@ -178,7 +187,7 @@ class SpaceInvadersEnv(gym.Env): self.bullets.remove(b) except ValueError: pass - reward += 25.0 + reward += 1.0 break for b in list(self.bullets): @@ -188,14 +197,14 @@ class SpaceInvadersEnv(gym.Env): self.bullets.remove(b) except ValueError: pass - reward -= 100.0 + reward -= 5.0 terminated = True if not self.enemies: - reward += 200.0 + reward += 10.0 terminated = True - if self.enemies and random.random() < 0.02: + if self.enemies and random.random() < 0.05: e = random.choice(self.enemies) b = Bullet(e.center_x, e.center_y, -1) self.bullets.append(b) @@ -203,17 +212,14 @@ class SpaceInvadersEnv(gym.Env): curr_bullet = self._nearest_enemy_bullet() if curr_bullet is not None: curr_bx = (curr_bullet.center_x - self.player.center_x) / float(self.width) - curr_by = (curr_bullet.center_y - self.player.center_y) / float(self.height) else: curr_bx = 2.0 - curr_by = 2.0 - if prev_bx != 2.0 and curr_bx != 2.0: - if abs(curr_bx) > abs(prev_bx): - reward += 0.3 + if self.prev_bx != 2.0 and curr_bx != 2.0: + if abs(curr_bx) > abs(self.prev_bx): + reward += 0.02 - if curr_bx != 2.0 and abs(curr_bx) < 0.08 and curr_by < 0.5: - reward -= 0.3 + reward -= 0.01 obs = self._obs() self.prev_bx = curr_bx