Add model training with graphs and current stats, improve model with better rewarding system

This commit is contained in:
csd4ni3l
2025-11-15 18:54:37 +01:00
parent 32477def6a
commit 05f568a457
10 changed files with 204 additions and 92 deletions

1
.gitignore vendored
View File

@@ -180,3 +180,4 @@ test*.py
logs/ logs/
logs logs
settings.json settings.json
training_logs

View File

@@ -1 +1,5 @@
Fleet Commander is like Space Invaders but you are the enemy instead of the player. Fleet Commander is like Space Invaders but you are the enemy instead of the player.
It uses AI (Reinforcement Learning) for the Player, and you, the Enemy has to defeat it.
You can train yourself, or use the default model which comes with the game.

View File

@@ -34,7 +34,7 @@ class Game(arcade.gui.UIView):
def on_show_view(self): def on_show_view(self):
super().on_show_view() super().on_show_view()
self.back_button = self.anchor.add(arcade.gui.UITextureButton(texture=button_texture, texture_hovered=button_hovered_texture, text='<--', style=button_style, width=100, height=50), anchor_x="left", anchor_y="top") self.back_button = self.anchor.add(arcade.gui.UITextureButton(texture=button_texture, texture_hovered=button_hovered_texture, text='<--', style=button_style, width=100, height=50), anchor_x="left", anchor_y="top", align_x=5, align_y=-5)
self.back_button.on_click = lambda event: self.main_exit() self.back_button.on_click = lambda event: self.main_exit()
def main_exit(self): def main_exit(self):

View File

@@ -33,7 +33,7 @@ class Player(arcade.Sprite): # Not actually the player
def update(self, model: PPO, enemies, bullets, width, height): def update(self, model: PPO, enemies, bullets, width, height):
if enemies: if enemies:
nearest_enemy = min(enemies, key=lambda e: abs(e.center_y - self.center_y) + abs(e.center_x - self.center_x)) nearest_enemy = min(enemies, key=lambda e: abs(e.center_x - self.center_x))
enemy_x = (nearest_enemy.center_x - self.center_x) / width enemy_x = (nearest_enemy.center_x - self.center_x) / width
enemy_y = (nearest_enemy.center_y - self.center_y) / height enemy_y = (nearest_enemy.center_y - self.center_y) / height
else: else:

View File

@@ -55,8 +55,8 @@ class Main(arcade.gui.UIView):
self.play_button = self.box.add(arcade.gui.UITextureButton(text="Play", texture=button_texture, texture_hovered=button_hovered_texture, width=self.window.width / 2, height=150, style=big_button_style)) self.play_button = self.box.add(arcade.gui.UITextureButton(text="Play", texture=button_texture, texture_hovered=button_hovered_texture, width=self.window.width / 2, height=150, style=big_button_style))
self.play_button.on_click = lambda event: self.play() self.play_button.on_click = lambda event: self.play()
self.train_button = self.box.add(arcade.gui.UITextureButton(text="Train", texture=button_texture, texture_hovered=button_hovered_texture, width=self.window.width / 2, height=150, style=big_button_style)) self.train_model_button = self.box.add(arcade.gui.UITextureButton(text="Train Model", texture=button_texture, texture_hovered=button_hovered_texture, width=self.window.width / 2, height=150, style=big_button_style))
self.train_button.on_click = lambda event: self.train() self.train_model_button.on_click = lambda event: self.train_model()
self.settings_button = self.box.add(arcade.gui.UITextureButton(text="Settings", texture=button_texture, texture_hovered=button_hovered_texture, width=self.window.width / 2, height=150, style=big_button_style)) self.settings_button = self.box.add(arcade.gui.UITextureButton(text="Settings", texture=button_texture, texture_hovered=button_hovered_texture, width=self.window.width / 2, height=150, style=big_button_style))
self.settings_button.on_click = lambda event: self.settings() self.settings_button.on_click = lambda event: self.settings()
@@ -68,3 +68,7 @@ class Main(arcade.gui.UIView):
def settings(self): def settings(self):
from menus.settings import Settings from menus.settings import Settings
self.window.show_view(Settings(self.pypresence_client)) self.window.show_view(Settings(self.pypresence_client))
def train_model(self):
from menus.train_model import TrainModel
self.window.show_view(TrainModel(self.pypresence_client))

View File

@@ -1,60 +1,177 @@
import arcade, arcade.gui import arcade, arcade.gui, threading, io, os, time
from utils.constants import button_style, MODEL_SETTINGS from utils.constants import button_style, MODEL_SETTINGS, monitor_log_dir
from utils.preload import button_texture, button_hovered_texture from utils.preload import button_texture, button_hovered_texture
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from io import BytesIO
from PIL import Image
from stable_baselines3 import PPO from stable_baselines3 import PPO
from utils.ml import SpaceInvadersEnv from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.logger import configure
from utils.rl import SpaceInvadersEnv
class TrainModel(arcade.gui.UIView): class TrainModel(arcade.gui.UIView):
def __init__(self, pypresence_client): def __init__(self, pypresence_client):
super().__init__()
self.pypresence_client = pypresence_client self.pypresence_client = pypresence_client
self.pypresence_client.update(state="Model Training") self.pypresence_client.update(state="Model Training")
self.current_state = "settings"
self.anchor = self.add_widget(arcade.gui.UIAnchorLayout(size_hint=(1, 1))) self.anchor = self.add_widget(arcade.gui.UIAnchorLayout(size_hint=(1, 1)))
self.box = self.anchor.add(arcade.gui.UIBoxLayout(space_between=10)) self.box = self.anchor.add(arcade.gui.UIBoxLayout(space_between=10))
self.settings = MODEL_SETTINGS.copy() self.settings = {
setting: data[0] # default value
for setting, data in MODEL_SETTINGS.items()
}
self.labels = {}
self.training = False
self.training_text = ""
self.last_progress_update = time.perf_counter()
def on_show_view(self): def on_show_view(self):
super().on_show_view() super().on_show_view()
self.show_menu(self.current_state) self.show_menu()
def show_menu(self, state): def main_exit(self):
if state == "settings": from menus.main import Main
self.box.add(arcade.gui.UILabel("Settings", font_size=48)) self.window.show_view(Main(self.pypresence_client))
for setting, data in MODEL_SETTINGS: def show_menu(self):
self.back_button = self.anchor.add(arcade.gui.UITextureButton(texture=button_texture, texture_hovered=button_hovered_texture, text='<--', style=button_style, width=100, height=50), anchor_x="left", anchor_y="top", align_x=5, align_y=-5)
self.back_button.on_click = lambda event: self.main_exit()
self.box.add(arcade.gui.UILabel("Settings", font_size=36))
for setting, data in MODEL_SETTINGS.items():
default, min_value, max_value, step = data default, min_value, max_value, step = data
self.box.add(arcade.gui.UILabel(text=f"{setting.replace('_', ' ').capitalize()}: {default}")) label = self.box.add(arcade.gui.UILabel(text=f"{setting.replace('_', ' ').capitalize()}: {default}", font_size=18))
slider = self.box.add(arcade.gui.UISlider(value=default, min_value=min_value, max_value=max_value, step=step)) slider = self.box.add(arcade.gui.UISlider(value=default, min_value=min_value, max_value=max_value, step=step, width=self.window.width / 2, height=self.window.height / 25))
slider._render_steps = lambda surface: None slider._render_steps = lambda surface: None
slider.on_change = lambda e, key=setting: self.change_value(key, e.new_value) slider.on_change = lambda e, key=setting: self.change_value(key, e.new_value)
self.labels[setting] = label
train_button = self.box.add(arcade.gui.UITextureButton(width=self.window.width / 2, height=self.window.height / 10, text="Train", style=button_style, texture=button_texture, texture_hovered=button_hovered_texture)) train_button = self.box.add(arcade.gui.UITextureButton(width=self.window.width / 2, height=self.window.height / 10, text="Train", style=button_style, texture=button_texture, texture_hovered=button_hovered_texture))
train_button.on_click = lambda e: self.train() train_button.on_click = lambda e: self.start_training()
def change_value(self, key, value): def change_value(self, key, value):
... self.labels[key].text = f"{key.replace('_', ' ').capitalize()}: {self.round_near_int(value)}"
self.settings[key] = self.round_near_int(value)
def start_training(self):
self.box.clear()
self.training = True
self.training_label = self.box.add(arcade.gui.UILabel("No Output yet.", font_size=16, multiline=True, width=self.window.width / 2, height=self.window.height / 2))
self.plot_image_widget = self.box.add(arcade.gui.UIImage(texture=arcade.Texture.create_empty("empty", (1, 1))))
self.plot_image_widget.visible = False
threading.Thread(target=self.train, daemon=True).start()
def on_update(self, delta_time):
if self.training and os.path.exists(os.path.join("training_logs", "progress.csv")) and time.perf_counter() - self.last_progress_update >= 0.5:
self.last_progress_update = time.perf_counter()
try:
progress_df = pd.read_csv(os.path.join("training_logs", "progress.csv"))
except pd.errors.EmptyDataError:
return
progress_text = ""
for key, value in progress_df.items():
progress_text += f"{key}: {round(value.iloc[-1], 6)}\n"
self.training_text = progress_text
if hasattr(self, "training_label"):
self.training_label.text = self.training_text
def round_near_int(self, x, tol=1e-4):
nearest = round(x)
if abs(x - nearest) < tol:
return nearest
return x
def train(self): def train(self):
env = SpaceInvadersEnv() os.makedirs(monitor_log_dir, exist_ok=True)
env = Monitor(SpaceInvadersEnv(), filename=os.path.join(monitor_log_dir, "monitor.csv"))
model = PPO( model = PPO(
"MlpPolicy", "MlpPolicy",
env, env,
n_steps=2048, n_steps=self.settings["n_steps"],
batch_size=64, batch_size=self.settings["batch_size"],
n_epochs=10, n_epochs=self.settings["n_epochs"],
learning_rate=3e-4, learning_rate=self.settings["learning_rate"],
verbose=1, verbose=1,
device="cpu", device="cpu",
gamma=0.99, gamma=self.settings["gamma"],
ent_coef=0.01, ent_coef=self.settings["ent_coef"],
clip_range=0.2 clip_range=self.settings["clip_range"],
) )
model.learn(1_000_000)
new_logger = configure(
folder=monitor_log_dir, format_strings=["csv"]
)
model.set_logger(new_logger)
model.learn(self.settings["learning_steps"])
model.save("invader_agent") model.save("invader_agent")
self.training = False
self.plot_results(os.path.join(monitor_log_dir, "monitor.csv"), os.path.join(monitor_log_dir, "progress.csv"))
def plot_results(self, log_path, loss_log_path):
df = pd.read_csv(log_path, skiprows=1)
fig, axes = plt.subplots(2, 1, figsize=(6, 8), dpi=100)
loss_df = pd.read_csv(loss_log_path)
axes[0].plot(np.cumsum(df['l']), df['r'], label='Reward')
axes[0].set_title('PPO Training: Episodic Reward')
axes[0].set_xlabel('Total Timesteps')
axes[0].set_ylabel('Reward')
axes[0].grid(True)
axes[1].plot(loss_df['time/total_timesteps'], loss_df['train/policy_gradient_loss'], label='Policy Gradient Loss')
axes[1].plot(loss_df['time/total_timesteps'], loss_df['train/value_loss'], label='Value Loss')
axes[1].plot(loss_df['time/total_timesteps'], loss_df['train/explained_variance'], label='Explained Variance')
axes[1].set_title('PPO Training: Loss Functions')
axes[1].set_xlabel('Total Timesteps')
axes[1].set_ylabel('Loss Value')
axes[1].legend()
axes[1].grid(True)
plt.tight_layout()
buffer = BytesIO()
plt.savefig(buffer, format='png')
buffer.seek(0)
plt.close(fig)
plot_texture = arcade.Texture(Image.open(buffer))
self.plot_image_widget.texture = plot_texture
self.plot_image_widget.size_hint = (None, None)
self.plot_image_widget.width = plot_texture.width
self.plot_image_widget.height = plot_texture.height
self.plot_image_widget.visible = True
self.training_text = "Training finished. Plot displayed."

3
run.py
View File

@@ -10,7 +10,6 @@ script_dir = os.path.dirname(os.path.abspath(__file__))
pyglet.resource.path.append(script_dir) pyglet.resource.path.append(script_dir)
pyglet.font.add_directory(os.path.join(script_dir, 'assets', 'fonts')) pyglet.font.add_directory(os.path.join(script_dir, 'assets', 'fonts'))
from utils.utils import get_closest_resolution, print_debug_info, on_exception from utils.utils import get_closest_resolution, print_debug_info, on_exception
from utils.constants import log_dir, menu_background_color from utils.constants import log_dir, menu_background_color
from menus.main import Main from menus.main import Main
@@ -18,8 +17,6 @@ from arcade.experimental.controller_window import ControllerWindow
sys.excepthook = on_exception sys.excepthook = on_exception
__builtins__.print = lambda *args, **kwargs: logging.debug(" ".join(map(str, args)))
if not log_dir in os.listdir(): if not log_dir in os.listdir():
os.makedirs(log_dir) os.makedirs(log_dir)

View File

@@ -1,20 +0,0 @@
from stable_baselines3 import PPO
from utils.ml import SpaceInvadersEnv
env = SpaceInvadersEnv()
model = PPO(
"MlpPolicy",
env,
n_steps=2048,
batch_size=64,
n_epochs=10,
learning_rate=3e-4,
verbose=1,
device="cpu",
gamma=0.99,
ent_coef=0.02,
clip_range=0.2,
gae_lambda=0.95
)
model.learn(1_000_000)
model.save("invader_agent")

View File

@@ -12,18 +12,21 @@ PLAYER_ATTACK_SPEED = 0.75
BULLET_SPEED = 3 BULLET_SPEED = 3
BULLET_RADIUS = 10 BULLET_RADIUS = 10
# default, min, max, step
MODEL_SETTINGS = { MODEL_SETTINGS = {
"n_steps": [2048, 256, 8192, 256], "n_steps": [2048, 256, 8192, 256],
"batch_size": 64, "batch_size": [64, 16, 512, 16],
"n_epochs": 10, "n_epochs": [10, 1, 50, 1],
"learning_rate": 3e-4, "learning_rate": [3e-4, 1e-5, 1e-2, 1e-5],
"gamma": 0.99, "gamma": [0.99, 0.8, 0.9999, 0.001],
"ent_coef": 0.01, "ent_coef": [0.01, 0.0, 0.1, 0.001],
"clip_range": 0.2 "clip_range": [0.2, 0.1, 0.4, 0.01],
"learning_steps": [500_000, 50_000, 25_000_000, 50_000]
} }
menu_background_color = (30, 30, 47) menu_background_color = (30, 30, 47)
log_dir = 'logs' log_dir = 'logs'
monitor_log_dir = "training_logs"
discord_presence_id = 1438214877343907881 discord_presence_id = 1438214877343907881
button_style = {'normal': UITextureButtonStyle(font_name="Roboto", font_color=arcade.color.BLACK), 'hover': UITextureButtonStyle(font_name="Roboto", font_color=arcade.color.BLACK), button_style = {'normal': UITextureButtonStyle(font_name="Roboto", font_color=arcade.color.BLACK), 'hover': UITextureButtonStyle(font_name="Roboto", font_color=arcade.color.BLACK),

View File

@@ -13,7 +13,7 @@ class SpaceInvadersEnv(gym.Env):
self.height = height self.height = height
self.action_space = gym.spaces.Discrete(3) self.action_space = gym.spaces.Discrete(3)
self.observation_space = gym.spaces.Box(low=-10.0, high=10.0, shape=(9,), dtype=np.float32) self.observation_space = gym.spaces.Box(low=-2.0, high=2.0, shape=(9,), dtype=np.float32)
self.enemies = [] self.enemies = []
self.bullets = [] self.bullets = []
@@ -25,8 +25,14 @@ class SpaceInvadersEnv(gym.Env):
self.prev_bx = 2.0 self.prev_bx = 2.0
self.steps_since_direction_change = 0 self.steps_since_direction_change = 0
self.last_direction = 0 self.last_direction = 0
self.max_steps = 1000
self.current_step = 0
def reset(self, seed=None, options=None): def reset(self, seed=None, options=None):
if seed is not None:
np.random.seed(seed)
random.seed(seed)
self.enemies = [] self.enemies = []
self.bullets = [] self.bullets = []
self.dir_history = [] self.dir_history = []
@@ -36,6 +42,7 @@ class SpaceInvadersEnv(gym.Env):
self.prev_bx = 2.0 self.prev_bx = 2.0
self.steps_since_direction_change = 0 self.steps_since_direction_change = 0
self.last_direction = 0 self.last_direction = 0
self.current_step = 0
start_x = self.width * 0.15 start_x = self.width * 0.15
start_y = self.height * 0.9 start_y = self.height * 0.9
@@ -51,7 +58,7 @@ class SpaceInvadersEnv(gym.Env):
def _nearest_enemy(self): def _nearest_enemy(self):
if not self.enemies: if not self.enemies:
return None return None
return min(self.enemies, key=lambda e: abs(e.center_y - self.player.center_y) + abs(e.center_x - self.player.center_x)) return min(self.enemies, key=lambda e: abs(e.center_x - self.player.center_x))
def _lowest_enemy(self): def _lowest_enemy(self):
if not self.enemies: if not self.enemies:
@@ -104,18 +111,16 @@ class SpaceInvadersEnv(gym.Env):
terminated = False terminated = False
truncated = False truncated = False
self.current_step += 1
if self.current_step >= self.max_steps:
truncated = True
nearest = self._nearest_enemy() nearest = self._nearest_enemy()
if nearest is not None: if nearest is not None:
enemy_x = (nearest.center_x - self.player.center_x) / float(self.width) enemy_x = (nearest.center_x - self.player.center_x) / float(self.width)
else: else:
enemy_x = 2.0 enemy_x = 2.0
prev_bullet = self._nearest_enemy_bullet()
if prev_bullet is not None:
prev_bx = (prev_bullet.center_x - self.player.center_x) / float(self.width)
else:
prev_bx = 2.0
prev_x = self.player.center_x prev_x = self.player.center_x
current_action_dir = 0 current_action_dir = 0
@@ -129,12 +134,15 @@ class SpaceInvadersEnv(gym.Env):
t = time.perf_counter() t = time.perf_counter()
if t - self.last_shot >= PLAYER_ATTACK_SPEED: if t - self.last_shot >= PLAYER_ATTACK_SPEED:
self.last_shot = t self.last_shot = t
b = Bullet(self.player.center_x, self.player.center_y, 1) b = Bullet(self.player.center_x, self.player.center_y, 1)
self.bullets.append(b) self.bullets.append(b)
if enemy_x != 2.0 and abs(enemy_x) < 0.04: if enemy_x != 2.0 and abs(enemy_x) < 0.04:
reward += 8.0 reward += 0.3
elif enemy_x != 2.0 and abs(enemy_x) < 0.1: elif enemy_x != 2.0 and abs(enemy_x) < 0.1:
reward += 3.0 reward += 0.1
if self.player.center_x > self.width: if self.player.center_x > self.width:
self.player.center_x = self.width self.player.center_x = self.width
@@ -145,8 +153,9 @@ class SpaceInvadersEnv(gym.Env):
if current_action_dir != 0: if current_action_dir != 0:
if self.last_direction != 0 and current_action_dir != self.last_direction: if self.last_direction != 0 and current_action_dir != self.last_direction:
if self.steps_since_direction_change < 8: if self.steps_since_direction_change < 3:
reward -= 3.0 reward -= 0.1
self.steps_since_direction_change = 0 self.steps_since_direction_change = 0
else: else:
self.steps_since_direction_change += 1 self.steps_since_direction_change += 1
@@ -154,9 +163,9 @@ class SpaceInvadersEnv(gym.Env):
if enemy_x != 2.0: if enemy_x != 2.0:
if abs(enemy_x) < 0.03: if abs(enemy_x) < 0.03:
reward += 3.0 reward += 0.1
elif abs(enemy_x) < 0.08: elif abs(enemy_x) < 0.08:
reward += 1.0 reward += 0.05
for b in list(self.bullets): for b in list(self.bullets):
b.center_y += b.direction_y * BULLET_SPEED b.center_y += b.direction_y * BULLET_SPEED
@@ -178,7 +187,7 @@ class SpaceInvadersEnv(gym.Env):
self.bullets.remove(b) self.bullets.remove(b)
except ValueError: except ValueError:
pass pass
reward += 25.0 reward += 1.0
break break
for b in list(self.bullets): for b in list(self.bullets):
@@ -188,14 +197,14 @@ class SpaceInvadersEnv(gym.Env):
self.bullets.remove(b) self.bullets.remove(b)
except ValueError: except ValueError:
pass pass
reward -= 100.0 reward -= 5.0
terminated = True terminated = True
if not self.enemies: if not self.enemies:
reward += 200.0 reward += 10.0
terminated = True terminated = True
if self.enemies and random.random() < 0.02: if self.enemies and random.random() < 0.05:
e = random.choice(self.enemies) e = random.choice(self.enemies)
b = Bullet(e.center_x, e.center_y, -1) b = Bullet(e.center_x, e.center_y, -1)
self.bullets.append(b) self.bullets.append(b)
@@ -203,17 +212,14 @@ class SpaceInvadersEnv(gym.Env):
curr_bullet = self._nearest_enemy_bullet() curr_bullet = self._nearest_enemy_bullet()
if curr_bullet is not None: if curr_bullet is not None:
curr_bx = (curr_bullet.center_x - self.player.center_x) / float(self.width) curr_bx = (curr_bullet.center_x - self.player.center_x) / float(self.width)
curr_by = (curr_bullet.center_y - self.player.center_y) / float(self.height)
else: else:
curr_bx = 2.0 curr_bx = 2.0
curr_by = 2.0
if prev_bx != 2.0 and curr_bx != 2.0: if self.prev_bx != 2.0 and curr_bx != 2.0:
if abs(curr_bx) > abs(prev_bx): if abs(curr_bx) > abs(self.prev_bx):
reward += 0.3 reward += 0.02
if curr_bx != 2.0 and abs(curr_bx) < 0.08 and curr_by < 0.5: reward -= 0.01
reward -= 0.3
obs = self._obs() obs = self._obs()
self.prev_bx = curr_bx self.prev_bx = curr_bx