mirror of
https://github.com/csd4ni3l/fleet-commander.git
synced 2026-01-01 04:23:47 +01:00
Add 10 million timestep model, improve README, add diffculty/mode selector, score, make model training have instant graphs and also multiple envs for faster training, better plotting, improve RL model by including multiple players, better reard system, use EnemyFormation instead of single Enemy-es
This commit is contained in:
@@ -62,8 +62,8 @@ class Main(arcade.gui.UIView):
|
||||
self.settings_button.on_click = lambda event: self.settings()
|
||||
|
||||
def play(self):
|
||||
from game.play import Game
|
||||
self.window.show_view(Game(self.pypresence_client))
|
||||
from menus.mode_selector import ModeSelector
|
||||
self.window.show_view(ModeSelector(self.pypresence_client))
|
||||
|
||||
def settings(self):
|
||||
from menus.settings import Settings
|
||||
|
||||
72
menus/mode_selector.py
Normal file
72
menus/mode_selector.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import arcade, arcade.gui
|
||||
|
||||
from utils.preload import button_texture, button_hovered_texture
|
||||
from utils.constants import dropdown_style, button_style, DIFFICULTY_LEVELS, DIFFICULTY_SETTINGS
|
||||
|
||||
class ModeSelector(arcade.gui.UIView):
|
||||
def __init__(self, pypresence_client):
|
||||
super().__init__()
|
||||
|
||||
self.pypresence_client = pypresence_client
|
||||
self.pypresence_client.update(state="Selecting Mode")
|
||||
|
||||
self.anchor = self.add_widget(arcade.gui.UIAnchorLayout(size_hint=(1, 1)))
|
||||
self.box = self.anchor.add(arcade.gui.UIBoxLayout(size_hint=(0.75, 0.75), space_between=10), anchor_x="center", anchor_y="center")
|
||||
|
||||
self.settings = DIFFICULTY_LEVELS["Easy"]
|
||||
self.setting_sliders = {}
|
||||
self.setting_labels = {}
|
||||
|
||||
def on_show_view(self):
|
||||
super().on_show_view()
|
||||
|
||||
self.back_button = self.anchor.add(arcade.gui.UITextureButton(texture=button_texture, texture_hovered=button_hovered_texture, text='<--', style=button_style, width=100, height=50), anchor_x="left", anchor_y="top", align_x=5, align_y=-5)
|
||||
self.back_button.on_click = lambda event: self.main_exit()
|
||||
|
||||
self.box.add(arcade.gui.UILabel("Settings", font_size=32))
|
||||
|
||||
self.box.add(arcade.gui.UISpace(height=self.window.height / 80))
|
||||
|
||||
self.difficulty_selector = self.box.add(arcade.gui.UIDropdown(default="Easy", options=list(DIFFICULTY_LEVELS.keys()), active_style=dropdown_style, primary_style=dropdown_style, dropdown_style=dropdown_style, width=self.window.width / 2, height=self.window.height / 20))
|
||||
self.difficulty_selector.on_change = lambda event: self.set_difficulty_values(event.new_value)
|
||||
|
||||
self.box.add(arcade.gui.UISpace(height=self.window.height / 80))
|
||||
|
||||
for key, data in DIFFICULTY_SETTINGS.items():
|
||||
default, name, min_value, max_value = DIFFICULTY_LEVELS["Easy"][key], *data
|
||||
|
||||
label = self.box.add(arcade.gui.UILabel(text=f"{name}: {default}", font_size=14))
|
||||
|
||||
slider = self.box.add(arcade.gui.UISlider(value=default, min_value=min_value, max_value=max_value, step=1, width=self.window.width / 2, height=self.window.height / 25))
|
||||
slider._render_steps = lambda surface: None
|
||||
slider.on_event = lambda event: None # disable slider for difficulties
|
||||
slider.on_click = lambda event: None # disable slider for difficulties
|
||||
slider.on_change = lambda e, key=key: self.change_value(key, e.new_value)
|
||||
|
||||
self.setting_sliders[key] = slider
|
||||
self.setting_labels[key] = label
|
||||
|
||||
self.play_button = self.box.add(arcade.gui.UITextureButton(text="Play", width=self.window.width / 2, height=self.window.height / 15, texture=button_texture, texture_hovered=button_hovered_texture, style=button_style))
|
||||
self.play_button.on_click = lambda event: self.start_game()
|
||||
|
||||
def set_difficulty_values(self, difficulty):
|
||||
for key, value in DIFFICULTY_LEVELS[difficulty].items():
|
||||
self.settings[key] = value
|
||||
self.setting_labels[key].text = f"{DIFFICULTY_SETTINGS[key][0]}: {value}"
|
||||
self.setting_sliders[key].value = value
|
||||
|
||||
for slider in self.setting_sliders.values():
|
||||
if difficulty != "Custom":
|
||||
slider.on_event = lambda event: None
|
||||
slider.on_click = lambda event: None
|
||||
else:
|
||||
slider.on_event = lambda event, slider=slider: arcade.gui.UISlider.on_event(slider, event)
|
||||
slider.on_click = lambda event, slider=slider: arcade.gui.UISlider.on_click(slider, event)
|
||||
|
||||
def change_value(self, key, value):
|
||||
self.settings[key] = int(value)
|
||||
self.setting_labels[key].text = f"{DIFFICULTY_SETTINGS[key][0]}: {int(value)}"
|
||||
|
||||
def start_game(self):
|
||||
from game.play import Game
|
||||
self.window.show_view(Game(self.pypresence_client, self.settings))
|
||||
@@ -1,20 +1,26 @@
|
||||
import arcade, arcade.gui, threading, io, os, time
|
||||
import arcade, arcade.gui, threading, os, queue, time, shutil
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
|
||||
from utils.constants import button_style, MODEL_SETTINGS, monitor_log_dir
|
||||
from utils.preload import button_texture, button_hovered_texture
|
||||
from utils.rl import SpaceInvadersEnv
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.logger import configure
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||
|
||||
from utils.rl import SpaceInvadersEnv
|
||||
def make_env(rank: int, seed: int = 0):
|
||||
def _init():
|
||||
env = SpaceInvadersEnv()
|
||||
env = Monitor(env, filename=os.path.join(monitor_log_dir, f"monitor_{rank}.csv"))
|
||||
return env
|
||||
return _init
|
||||
|
||||
class TrainModel(arcade.gui.UIView):
|
||||
def __init__(self, pypresence_client):
|
||||
@@ -24,22 +30,25 @@ class TrainModel(arcade.gui.UIView):
|
||||
self.pypresence_client.update(state="Model Training")
|
||||
|
||||
self.anchor = self.add_widget(arcade.gui.UIAnchorLayout(size_hint=(1, 1)))
|
||||
self.box = self.anchor.add(arcade.gui.UIBoxLayout(space_between=10))
|
||||
self.box = self.anchor.add(arcade.gui.UIBoxLayout(space_between=5))
|
||||
|
||||
self.settings = {
|
||||
setting: data[0] # default value
|
||||
setting: data[0]
|
||||
for setting, data in MODEL_SETTINGS.items()
|
||||
}
|
||||
|
||||
self.labels = {}
|
||||
|
||||
self.training = False
|
||||
self.training_text = ""
|
||||
self.training_text = "Starting training..."
|
||||
|
||||
self.result_queue = queue.Queue()
|
||||
self.training_thread = None
|
||||
|
||||
self.last_progress_update = time.perf_counter()
|
||||
|
||||
def on_show_view(self):
|
||||
super().on_show_view()
|
||||
|
||||
self.show_menu()
|
||||
|
||||
def main_exit(self):
|
||||
@@ -50,128 +59,188 @@ class TrainModel(arcade.gui.UIView):
|
||||
self.back_button = self.anchor.add(arcade.gui.UITextureButton(texture=button_texture, texture_hovered=button_hovered_texture, text='<--', style=button_style, width=100, height=50), anchor_x="left", anchor_y="top", align_x=5, align_y=-5)
|
||||
self.back_button.on_click = lambda event: self.main_exit()
|
||||
|
||||
self.box.add(arcade.gui.UILabel("Settings", font_size=36))
|
||||
self.box.add(arcade.gui.UILabel("Settings", font_size=32))
|
||||
|
||||
for setting, data in MODEL_SETTINGS.items():
|
||||
default, min_value, max_value, step = data
|
||||
label = self.box.add(arcade.gui.UILabel(text=f"{setting.replace('_', ' ').capitalize()}: {default}", font_size=18))
|
||||
is_int = setting == "n_envs" or (abs(step - 1) < 1e-6 and abs(min_value - round(min_value)) < 1e-6)
|
||||
|
||||
val_text = str(int(default)) if is_int else str(default)
|
||||
label = self.box.add(arcade.gui.UILabel(text=f"{setting.replace('_', ' ').capitalize()}: {val_text}", font_size=14))
|
||||
|
||||
slider = self.box.add(arcade.gui.UISlider(value=default, min_value=min_value, max_value=max_value, step=step, width=self.window.width / 2, height=self.window.height / 25))
|
||||
slider._render_steps = lambda surface: None
|
||||
slider.on_change = lambda e, key=setting: self.change_value(key, e.new_value)
|
||||
slider.on_change = lambda e, key=setting, is_int_slider=is_int: self.change_value(key, e.new_value, is_int_slider)
|
||||
|
||||
self.labels[setting] = label
|
||||
|
||||
train_button = self.box.add(arcade.gui.UITextureButton(width=self.window.width / 2, height=self.window.height / 10, text="Train", style=button_style, texture=button_texture, texture_hovered=button_hovered_texture))
|
||||
train_button.on_click = lambda e: self.start_training()
|
||||
|
||||
def change_value(self, key, value):
|
||||
self.labels[key].text = f"{key.replace('_', ' ').capitalize()}: {self.round_near_int(value)}"
|
||||
self.settings[key] = self.round_near_int(value)
|
||||
def change_value(self, key, value, is_int=False):
|
||||
if is_int:
|
||||
val = int(round(value))
|
||||
self.settings[key] = val
|
||||
self.labels[key].text = f"{key.replace('_', ' ').capitalize()}: {val}"
|
||||
else:
|
||||
val = self.round_near_int(value)
|
||||
self.settings[key] = val
|
||||
self.labels[key].text = f"{key.replace('_', ' ').capitalize()}: {val}"
|
||||
|
||||
def start_training(self):
|
||||
self.box.clear()
|
||||
|
||||
self.training = True
|
||||
|
||||
self.training_label = self.box.add(arcade.gui.UILabel("No Output yet.", font_size=16, multiline=True, width=self.window.width / 2, height=self.window.height / 2))
|
||||
self.training_text = "Starting training..."
|
||||
self.training_label = self.box.add(arcade.gui.UILabel("Starting training...", font_size=16, multiline=True, width=self.window.width / 2, height=self.window.height / 2))
|
||||
|
||||
self.plot_image_widget = self.box.add(arcade.gui.UIImage(texture=arcade.Texture.create_empty("empty", (1, 1))))
|
||||
self.plot_image_widget.visible = False
|
||||
|
||||
threading.Thread(target=self.train, daemon=True).start()
|
||||
self.training_thread = threading.Thread(target=self.train, daemon=True)
|
||||
self.training_thread.start()
|
||||
|
||||
def on_update(self, delta_time):
|
||||
if self.training and os.path.exists(os.path.join("training_logs", "progress.csv")) and time.perf_counter() - self.last_progress_update >= 0.5:
|
||||
self.last_progress_update = time.perf_counter()
|
||||
try:
|
||||
result = self.result_queue.get_nowait()
|
||||
|
||||
try:
|
||||
progress_df = pd.read_csv(os.path.join("training_logs", "progress.csv"))
|
||||
except pd.errors.EmptyDataError:
|
||||
return
|
||||
if result["type"] == "text":
|
||||
self.training_text = result["message"]
|
||||
|
||||
progress_text = ""
|
||||
elif result["type"] == "plot":
|
||||
self.plot_image_widget.texture = result["image"]
|
||||
self.plot_image_widget.width = result["image"].width
|
||||
self.plot_image_widget.height = result["image"].height
|
||||
self.plot_image_widget.trigger_render()
|
||||
self.plot_image_widget.visible = True
|
||||
|
||||
for key, value in progress_df.items():
|
||||
progress_text += f"{key}: {round(value.iloc[-1], 6)}\n"
|
||||
elif result["type"] == "finished":
|
||||
self.training = False
|
||||
self.training_text = "Training finished."
|
||||
|
||||
self.training_text = progress_text
|
||||
except queue.Empty:
|
||||
if self.training and os.path.exists(os.path.join("training_logs", "progress.csv")) and all([os.path.exists(os.path.join(monitor_log_dir, f"monitor_{i}.csv.monitor.csv")) for i in range(int(self.settings["n_envs"]))]) and time.perf_counter() - self.last_progress_update >= 1:
|
||||
self.last_progress_update = time.perf_counter()
|
||||
self.plot_results()
|
||||
|
||||
if hasattr(self, "training_label"):
|
||||
self.training_label.text = self.training_text
|
||||
|
||||
def round_near_int(self, x, tol=1e-4):
|
||||
nearest = round(x)
|
||||
|
||||
if abs(x - nearest) < tol:
|
||||
return nearest
|
||||
|
||||
return x
|
||||
|
||||
def train(self):
|
||||
os.makedirs(monitor_log_dir, exist_ok=True)
|
||||
env = Monitor(SpaceInvadersEnv(), filename=os.path.join(monitor_log_dir, "monitor.csv"))
|
||||
if os.path.exists(monitor_log_dir):
|
||||
shutil.rmtree(monitor_log_dir)
|
||||
os.makedirs(monitor_log_dir)
|
||||
|
||||
n_envs = int(self.settings["n_envs"])
|
||||
env = DummyVecEnv([make_env(i) for i in range(n_envs)])
|
||||
|
||||
n_steps = int(self.settings["n_steps"])
|
||||
batch_size = int(self.settings["batch_size"])
|
||||
|
||||
total_steps_per_rollout = n_steps * max(1, n_envs)
|
||||
if total_steps_per_rollout % batch_size != 0:
|
||||
batch_size = max(64, total_steps_per_rollout // max(1, total_steps_per_rollout // batch_size))
|
||||
print(f"Warning: Adjusting batch size to {batch_size} for {n_envs} envs.")
|
||||
|
||||
model = PPO(
|
||||
"MlpPolicy",
|
||||
env,
|
||||
n_steps=self.settings["n_steps"],
|
||||
batch_size=self.settings["batch_size"],
|
||||
n_epochs=self.settings["n_epochs"],
|
||||
learning_rate=self.settings["learning_rate"],
|
||||
"MlpPolicy",
|
||||
env,
|
||||
n_steps=n_steps,
|
||||
batch_size=batch_size,
|
||||
n_epochs=int(self.settings["n_epochs"]),
|
||||
learning_rate=float(self.settings["learning_rate"]),
|
||||
verbose=1,
|
||||
device="cpu",
|
||||
gamma=self.settings["gamma"],
|
||||
ent_coef=self.settings["ent_coef"],
|
||||
clip_range=self.settings["clip_range"],
|
||||
gamma=float(self.settings["gamma"]),
|
||||
ent_coef=float(self.settings["ent_coef"]),
|
||||
clip_range=float(self.settings["clip_range"]),
|
||||
)
|
||||
|
||||
new_logger = configure(
|
||||
folder=monitor_log_dir, format_strings=["csv"]
|
||||
)
|
||||
|
||||
new_logger = configure(folder=monitor_log_dir, format_strings=["csv"])
|
||||
model.set_logger(new_logger)
|
||||
|
||||
model.learn(self.settings["learning_steps"])
|
||||
model.save("invader_agent")
|
||||
try:
|
||||
self.training = True
|
||||
model.learn(int(self.settings["learning_steps"]))
|
||||
model.save("invader_agent")
|
||||
except Exception as e:
|
||||
print(f"Error during training: {e}")
|
||||
self.result_queue.put({"type": "text", "message": f"Error:\n{e}"})
|
||||
finally:
|
||||
try:
|
||||
env.close()
|
||||
except Exception:
|
||||
pass
|
||||
self.result_queue.put({"type": "finished"})
|
||||
|
||||
self.training = False
|
||||
def plot_results(self):
|
||||
try:
|
||||
reward_df = pd.read_csv(os.path.join(monitor_log_dir, "progress.csv"))
|
||||
except pd.errors.EmptyDataError:
|
||||
return
|
||||
|
||||
self.plot_results(os.path.join(monitor_log_dir, "monitor.csv"), os.path.join(monitor_log_dir, "progress.csv"))
|
||||
all_monitor_files = [os.path.join(monitor_log_dir, f) for f in os.listdir(monitor_log_dir) if f.startswith("monitor_") and f.endswith(".csv")]
|
||||
try:
|
||||
df_list = [pd.read_csv(f, skiprows=1) for f in all_monitor_files]
|
||||
except pd.errors.EmptyDataError:
|
||||
return
|
||||
|
||||
def plot_results(self, log_path, loss_log_path):
|
||||
df = pd.read_csv(log_path, skiprows=1)
|
||||
monitor_df = pd.concat(df_list).sort_values(by='t')
|
||||
monitor_df['total_timesteps'] = monitor_df['l'].cumsum()
|
||||
|
||||
loss_log_path = os.path.join(monitor_log_dir, "progress.csv")
|
||||
loss_df = None
|
||||
if os.path.exists(loss_log_path):
|
||||
try:
|
||||
loss_df = pd.read_csv(loss_log_path)
|
||||
except Exception:
|
||||
loss_df = None
|
||||
|
||||
fig, axes = plt.subplots(2, 1, figsize=(6, 8), dpi=100)
|
||||
|
||||
loss_df = pd.read_csv(loss_log_path)
|
||||
|
||||
axes[0].plot(np.cumsum(df['l']), df['r'], label='Reward')
|
||||
if monitor_df is not None and 'total_timesteps' in monitor_df.columns and 'r' in monitor_df.columns:
|
||||
axes[0].plot(monitor_df['total_timesteps'], monitor_df['r'].rolling(window=10).mean(), label='Episodic Reward (Rolling 10)')
|
||||
elif reward_df is not None and 'time/total_timesteps' in reward_df.columns and 'rollout/ep_rew_mean' in reward_df.columns:
|
||||
axes[0].plot(reward_df['time/total_timesteps'], reward_df['rollout/ep_rew_mean'], label='Ep reward mean')
|
||||
else:
|
||||
axes[0].text(0.5, 0.5, "No reward data available", horizontalalignment='center', verticalalignment='center')
|
||||
|
||||
axes[0].set_title('PPO Training: Episodic Reward')
|
||||
axes[0].set_xlabel('Total Timesteps')
|
||||
axes[0].set_ylabel('Reward')
|
||||
axes[0].grid(True)
|
||||
|
||||
axes[1].plot(loss_df['time/total_timesteps'], loss_df['train/policy_gradient_loss'], label='Policy Gradient Loss')
|
||||
axes[1].plot(loss_df['time/total_timesteps'], loss_df['train/value_loss'], label='Value Loss')
|
||||
axes[1].plot(loss_df['time/total_timesteps'], loss_df['train/explained_variance'], label='Explained Variance')
|
||||
axes[1].set_title('PPO Training: Loss Functions')
|
||||
axes[1].set_title('PPO Training: Loss & Variance')
|
||||
axes[1].set_xlabel('Total Timesteps')
|
||||
axes[1].set_ylabel('Loss Value')
|
||||
axes[1].legend()
|
||||
axes[1].set_ylabel('Value')
|
||||
axes[1].grid(True)
|
||||
|
||||
|
||||
if loss_df is not None and 'time/total_timesteps' in loss_df.columns and 'train/policy_gradient_loss' in loss_df.columns and 'train/value_loss' in loss_df.columns and 'train/explained_variance' in loss_df.columns:
|
||||
tcol = 'time/total_timesteps'
|
||||
axes[1].plot(loss_df[tcol], loss_df['train/policy_gradient_loss'], label='Policy Loss')
|
||||
axes[1].plot(loss_df[tcol], loss_df['train/value_loss'], label='Value Loss')
|
||||
axes[1].plot(loss_df[tcol], loss_df['train/explained_variance'], label='Explained Variance')
|
||||
|
||||
axes[1].legend()
|
||||
else:
|
||||
axes[1].text(0.5, 0.5, "No loss/variance data available", horizontalalignment='center', verticalalignment='center')
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
buffer = BytesIO()
|
||||
plt.savefig(buffer, format='png')
|
||||
plt.savefig(buffer, format='png', bbox_inches='tight')
|
||||
buffer.seek(0)
|
||||
plt.close(fig)
|
||||
|
||||
plot_texture = arcade.Texture(Image.open(buffer))
|
||||
pil_img = Image.open(buffer).convert("RGBA")
|
||||
|
||||
self.plot_image_widget.texture = plot_texture
|
||||
self.plot_image_widget.size_hint = (None, None)
|
||||
self.plot_image_widget.width = plot_texture.width
|
||||
self.plot_image_widget.height = plot_texture.height
|
||||
plot_texture = arcade.Texture(pil_img)
|
||||
|
||||
self.plot_image_widget.visible = True
|
||||
self.training_text = "Training finished. Plot displayed."
|
||||
self.result_queue.put({"type": "plot", "image": plot_texture})
|
||||
|
||||
Reference in New Issue
Block a user