mirror of
https://github.com/csd4ni3l/fleet-commander.git
synced 2026-01-01 04:23:47 +01:00
Add model training with graphs and current stats, improve model with better rewarding system
This commit is contained in:
@@ -55,8 +55,8 @@ class Main(arcade.gui.UIView):
|
||||
self.play_button = self.box.add(arcade.gui.UITextureButton(text="Play", texture=button_texture, texture_hovered=button_hovered_texture, width=self.window.width / 2, height=150, style=big_button_style))
|
||||
self.play_button.on_click = lambda event: self.play()
|
||||
|
||||
self.train_button = self.box.add(arcade.gui.UITextureButton(text="Train", texture=button_texture, texture_hovered=button_hovered_texture, width=self.window.width / 2, height=150, style=big_button_style))
|
||||
self.train_button.on_click = lambda event: self.train()
|
||||
self.train_model_button = self.box.add(arcade.gui.UITextureButton(text="Train Model", texture=button_texture, texture_hovered=button_hovered_texture, width=self.window.width / 2, height=150, style=big_button_style))
|
||||
self.train_model_button.on_click = lambda event: self.train_model()
|
||||
|
||||
self.settings_button = self.box.add(arcade.gui.UITextureButton(text="Settings", texture=button_texture, texture_hovered=button_hovered_texture, width=self.window.width / 2, height=150, style=big_button_style))
|
||||
self.settings_button.on_click = lambda event: self.settings()
|
||||
@@ -68,3 +68,7 @@ class Main(arcade.gui.UIView):
|
||||
def settings(self):
|
||||
from menus.settings import Settings
|
||||
self.window.show_view(Settings(self.pypresence_client))
|
||||
|
||||
def train_model(self):
|
||||
from menus.train_model import TrainModel
|
||||
self.window.show_view(TrainModel(self.pypresence_client))
|
||||
@@ -1,60 +1,177 @@
|
||||
import arcade, arcade.gui
|
||||
import arcade, arcade.gui, threading, io, os, time
|
||||
|
||||
from utils.constants import button_style, MODEL_SETTINGS
|
||||
from utils.constants import button_style, MODEL_SETTINGS, monitor_log_dir
|
||||
from utils.preload import button_texture, button_hovered_texture
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
from utils.ml import SpaceInvadersEnv
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.logger import configure
|
||||
|
||||
from utils.rl import SpaceInvadersEnv
|
||||
|
||||
class TrainModel(arcade.gui.UIView):
|
||||
def __init__(self, pypresence_client):
|
||||
super().__init__()
|
||||
|
||||
self.pypresence_client = pypresence_client
|
||||
self.pypresence_client.update(state="Model Training")
|
||||
|
||||
self.current_state = "settings"
|
||||
|
||||
self.anchor = self.add_widget(arcade.gui.UIAnchorLayout(size_hint=(1, 1)))
|
||||
self.box = self.anchor.add(arcade.gui.UIBoxLayout(space_between=10))
|
||||
|
||||
self.settings = MODEL_SETTINGS.copy()
|
||||
self.settings = {
|
||||
setting: data[0] # default value
|
||||
for setting, data in MODEL_SETTINGS.items()
|
||||
}
|
||||
self.labels = {}
|
||||
|
||||
self.training = False
|
||||
self.training_text = ""
|
||||
|
||||
self.last_progress_update = time.perf_counter()
|
||||
|
||||
def on_show_view(self):
|
||||
super().on_show_view()
|
||||
|
||||
self.show_menu(self.current_state)
|
||||
self.show_menu()
|
||||
|
||||
def show_menu(self, state):
|
||||
if state == "settings":
|
||||
self.box.add(arcade.gui.UILabel("Settings", font_size=48))
|
||||
def main_exit(self):
|
||||
from menus.main import Main
|
||||
self.window.show_view(Main(self.pypresence_client))
|
||||
|
||||
for setting, data in MODEL_SETTINGS:
|
||||
default, min_value, max_value, step = data
|
||||
self.box.add(arcade.gui.UILabel(text=f"{setting.replace('_', ' ').capitalize()}: {default}"))
|
||||
|
||||
slider = self.box.add(arcade.gui.UISlider(value=default, min_value=min_value, max_value=max_value, step=step))
|
||||
slider._render_steps = lambda surface: None
|
||||
slider.on_change = lambda e, key=setting: self.change_value(key, e.new_value)
|
||||
def show_menu(self):
|
||||
self.back_button = self.anchor.add(arcade.gui.UITextureButton(texture=button_texture, texture_hovered=button_hovered_texture, text='<--', style=button_style, width=100, height=50), anchor_x="left", anchor_y="top", align_x=5, align_y=-5)
|
||||
self.back_button.on_click = lambda event: self.main_exit()
|
||||
|
||||
train_button = self.box.add(arcade.gui.UITextureButton(width=self.window.width / 2, height=self.window.height / 10, text="Train", style=button_style, texture=button_texture, texture_hovered=button_hovered_texture))
|
||||
train_button.on_click = lambda e: self.train()
|
||||
self.box.add(arcade.gui.UILabel("Settings", font_size=36))
|
||||
|
||||
for setting, data in MODEL_SETTINGS.items():
|
||||
default, min_value, max_value, step = data
|
||||
label = self.box.add(arcade.gui.UILabel(text=f"{setting.replace('_', ' ').capitalize()}: {default}", font_size=18))
|
||||
|
||||
slider = self.box.add(arcade.gui.UISlider(value=default, min_value=min_value, max_value=max_value, step=step, width=self.window.width / 2, height=self.window.height / 25))
|
||||
slider._render_steps = lambda surface: None
|
||||
slider.on_change = lambda e, key=setting: self.change_value(key, e.new_value)
|
||||
|
||||
self.labels[setting] = label
|
||||
|
||||
train_button = self.box.add(arcade.gui.UITextureButton(width=self.window.width / 2, height=self.window.height / 10, text="Train", style=button_style, texture=button_texture, texture_hovered=button_hovered_texture))
|
||||
train_button.on_click = lambda e: self.start_training()
|
||||
|
||||
def change_value(self, key, value):
|
||||
...
|
||||
|
||||
self.labels[key].text = f"{key.replace('_', ' ').capitalize()}: {self.round_near_int(value)}"
|
||||
self.settings[key] = self.round_near_int(value)
|
||||
|
||||
def start_training(self):
|
||||
self.box.clear()
|
||||
|
||||
self.training = True
|
||||
|
||||
self.training_label = self.box.add(arcade.gui.UILabel("No Output yet.", font_size=16, multiline=True, width=self.window.width / 2, height=self.window.height / 2))
|
||||
|
||||
self.plot_image_widget = self.box.add(arcade.gui.UIImage(texture=arcade.Texture.create_empty("empty", (1, 1))))
|
||||
self.plot_image_widget.visible = False
|
||||
|
||||
threading.Thread(target=self.train, daemon=True).start()
|
||||
|
||||
def on_update(self, delta_time):
|
||||
if self.training and os.path.exists(os.path.join("training_logs", "progress.csv")) and time.perf_counter() - self.last_progress_update >= 0.5:
|
||||
self.last_progress_update = time.perf_counter()
|
||||
|
||||
try:
|
||||
progress_df = pd.read_csv(os.path.join("training_logs", "progress.csv"))
|
||||
except pd.errors.EmptyDataError:
|
||||
return
|
||||
|
||||
progress_text = ""
|
||||
|
||||
for key, value in progress_df.items():
|
||||
progress_text += f"{key}: {round(value.iloc[-1], 6)}\n"
|
||||
|
||||
self.training_text = progress_text
|
||||
|
||||
if hasattr(self, "training_label"):
|
||||
self.training_label.text = self.training_text
|
||||
|
||||
def round_near_int(self, x, tol=1e-4):
|
||||
nearest = round(x)
|
||||
if abs(x - nearest) < tol:
|
||||
return nearest
|
||||
return x
|
||||
|
||||
def train(self):
|
||||
env = SpaceInvadersEnv()
|
||||
os.makedirs(monitor_log_dir, exist_ok=True)
|
||||
env = Monitor(SpaceInvadersEnv(), filename=os.path.join(monitor_log_dir, "monitor.csv"))
|
||||
|
||||
model = PPO(
|
||||
"MlpPolicy",
|
||||
env,
|
||||
n_steps=2048,
|
||||
batch_size=64,
|
||||
n_epochs=10,
|
||||
learning_rate=3e-4,
|
||||
verbose=1,
|
||||
device="cpu",
|
||||
gamma=0.99,
|
||||
ent_coef=0.01,
|
||||
clip_range=0.2
|
||||
n_steps=self.settings["n_steps"],
|
||||
batch_size=self.settings["batch_size"],
|
||||
n_epochs=self.settings["n_epochs"],
|
||||
learning_rate=self.settings["learning_rate"],
|
||||
verbose=1,
|
||||
device="cpu",
|
||||
gamma=self.settings["gamma"],
|
||||
ent_coef=self.settings["ent_coef"],
|
||||
clip_range=self.settings["clip_range"],
|
||||
)
|
||||
model.learn(1_000_000)
|
||||
model.save("invader_agent")
|
||||
|
||||
new_logger = configure(
|
||||
folder=monitor_log_dir, format_strings=["csv"]
|
||||
)
|
||||
|
||||
model.set_logger(new_logger)
|
||||
|
||||
model.learn(self.settings["learning_steps"])
|
||||
model.save("invader_agent")
|
||||
|
||||
self.training = False
|
||||
|
||||
self.plot_results(os.path.join(monitor_log_dir, "monitor.csv"), os.path.join(monitor_log_dir, "progress.csv"))
|
||||
|
||||
def plot_results(self, log_path, loss_log_path):
|
||||
df = pd.read_csv(log_path, skiprows=1)
|
||||
|
||||
fig, axes = plt.subplots(2, 1, figsize=(6, 8), dpi=100)
|
||||
|
||||
loss_df = pd.read_csv(loss_log_path)
|
||||
|
||||
axes[0].plot(np.cumsum(df['l']), df['r'], label='Reward')
|
||||
axes[0].set_title('PPO Training: Episodic Reward')
|
||||
axes[0].set_xlabel('Total Timesteps')
|
||||
axes[0].set_ylabel('Reward')
|
||||
axes[0].grid(True)
|
||||
|
||||
axes[1].plot(loss_df['time/total_timesteps'], loss_df['train/policy_gradient_loss'], label='Policy Gradient Loss')
|
||||
axes[1].plot(loss_df['time/total_timesteps'], loss_df['train/value_loss'], label='Value Loss')
|
||||
axes[1].plot(loss_df['time/total_timesteps'], loss_df['train/explained_variance'], label='Explained Variance')
|
||||
axes[1].set_title('PPO Training: Loss Functions')
|
||||
axes[1].set_xlabel('Total Timesteps')
|
||||
axes[1].set_ylabel('Loss Value')
|
||||
axes[1].legend()
|
||||
axes[1].grid(True)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
buffer = BytesIO()
|
||||
plt.savefig(buffer, format='png')
|
||||
buffer.seek(0)
|
||||
plt.close(fig)
|
||||
|
||||
plot_texture = arcade.Texture(Image.open(buffer))
|
||||
|
||||
self.plot_image_widget.texture = plot_texture
|
||||
self.plot_image_widget.size_hint = (None, None)
|
||||
self.plot_image_widget.width = plot_texture.width
|
||||
self.plot_image_widget.height = plot_texture.height
|
||||
|
||||
self.plot_image_widget.visible = True
|
||||
self.training_text = "Training finished. Plot displayed."
|
||||
Reference in New Issue
Block a user