mirror of
https://github.com/csd4ni3l/fleet-commander.git
synced 2026-01-01 04:23:47 +01:00
Add 10 million timestep model, improve README, add diffculty/mode selector, score, make model training have instant graphs and also multiple envs for faster training, better plotting, improve RL model by including multiple players, better reard system, use EnemyFormation instead of single Enemy-es
This commit is contained in:
@@ -1,20 +1,26 @@
|
||||
import arcade, arcade.gui, threading, io, os, time
|
||||
import arcade, arcade.gui, threading, os, queue, time, shutil
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
|
||||
from utils.constants import button_style, MODEL_SETTINGS, monitor_log_dir
|
||||
from utils.preload import button_texture, button_hovered_texture
|
||||
from utils.rl import SpaceInvadersEnv
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.logger import configure
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||
|
||||
from utils.rl import SpaceInvadersEnv
|
||||
def make_env(rank: int, seed: int = 0):
|
||||
def _init():
|
||||
env = SpaceInvadersEnv()
|
||||
env = Monitor(env, filename=os.path.join(monitor_log_dir, f"monitor_{rank}.csv"))
|
||||
return env
|
||||
return _init
|
||||
|
||||
class TrainModel(arcade.gui.UIView):
|
||||
def __init__(self, pypresence_client):
|
||||
@@ -24,22 +30,25 @@ class TrainModel(arcade.gui.UIView):
|
||||
self.pypresence_client.update(state="Model Training")
|
||||
|
||||
self.anchor = self.add_widget(arcade.gui.UIAnchorLayout(size_hint=(1, 1)))
|
||||
self.box = self.anchor.add(arcade.gui.UIBoxLayout(space_between=10))
|
||||
self.box = self.anchor.add(arcade.gui.UIBoxLayout(space_between=5))
|
||||
|
||||
self.settings = {
|
||||
setting: data[0] # default value
|
||||
setting: data[0]
|
||||
for setting, data in MODEL_SETTINGS.items()
|
||||
}
|
||||
|
||||
self.labels = {}
|
||||
|
||||
self.training = False
|
||||
self.training_text = ""
|
||||
self.training_text = "Starting training..."
|
||||
|
||||
self.result_queue = queue.Queue()
|
||||
self.training_thread = None
|
||||
|
||||
self.last_progress_update = time.perf_counter()
|
||||
|
||||
def on_show_view(self):
|
||||
super().on_show_view()
|
||||
|
||||
self.show_menu()
|
||||
|
||||
def main_exit(self):
|
||||
@@ -50,128 +59,188 @@ class TrainModel(arcade.gui.UIView):
|
||||
self.back_button = self.anchor.add(arcade.gui.UITextureButton(texture=button_texture, texture_hovered=button_hovered_texture, text='<--', style=button_style, width=100, height=50), anchor_x="left", anchor_y="top", align_x=5, align_y=-5)
|
||||
self.back_button.on_click = lambda event: self.main_exit()
|
||||
|
||||
self.box.add(arcade.gui.UILabel("Settings", font_size=36))
|
||||
self.box.add(arcade.gui.UILabel("Settings", font_size=32))
|
||||
|
||||
for setting, data in MODEL_SETTINGS.items():
|
||||
default, min_value, max_value, step = data
|
||||
label = self.box.add(arcade.gui.UILabel(text=f"{setting.replace('_', ' ').capitalize()}: {default}", font_size=18))
|
||||
is_int = setting == "n_envs" or (abs(step - 1) < 1e-6 and abs(min_value - round(min_value)) < 1e-6)
|
||||
|
||||
val_text = str(int(default)) if is_int else str(default)
|
||||
label = self.box.add(arcade.gui.UILabel(text=f"{setting.replace('_', ' ').capitalize()}: {val_text}", font_size=14))
|
||||
|
||||
slider = self.box.add(arcade.gui.UISlider(value=default, min_value=min_value, max_value=max_value, step=step, width=self.window.width / 2, height=self.window.height / 25))
|
||||
slider._render_steps = lambda surface: None
|
||||
slider.on_change = lambda e, key=setting: self.change_value(key, e.new_value)
|
||||
slider.on_change = lambda e, key=setting, is_int_slider=is_int: self.change_value(key, e.new_value, is_int_slider)
|
||||
|
||||
self.labels[setting] = label
|
||||
|
||||
train_button = self.box.add(arcade.gui.UITextureButton(width=self.window.width / 2, height=self.window.height / 10, text="Train", style=button_style, texture=button_texture, texture_hovered=button_hovered_texture))
|
||||
train_button.on_click = lambda e: self.start_training()
|
||||
|
||||
def change_value(self, key, value):
|
||||
self.labels[key].text = f"{key.replace('_', ' ').capitalize()}: {self.round_near_int(value)}"
|
||||
self.settings[key] = self.round_near_int(value)
|
||||
def change_value(self, key, value, is_int=False):
|
||||
if is_int:
|
||||
val = int(round(value))
|
||||
self.settings[key] = val
|
||||
self.labels[key].text = f"{key.replace('_', ' ').capitalize()}: {val}"
|
||||
else:
|
||||
val = self.round_near_int(value)
|
||||
self.settings[key] = val
|
||||
self.labels[key].text = f"{key.replace('_', ' ').capitalize()}: {val}"
|
||||
|
||||
def start_training(self):
|
||||
self.box.clear()
|
||||
|
||||
self.training = True
|
||||
|
||||
self.training_label = self.box.add(arcade.gui.UILabel("No Output yet.", font_size=16, multiline=True, width=self.window.width / 2, height=self.window.height / 2))
|
||||
self.training_text = "Starting training..."
|
||||
self.training_label = self.box.add(arcade.gui.UILabel("Starting training...", font_size=16, multiline=True, width=self.window.width / 2, height=self.window.height / 2))
|
||||
|
||||
self.plot_image_widget = self.box.add(arcade.gui.UIImage(texture=arcade.Texture.create_empty("empty", (1, 1))))
|
||||
self.plot_image_widget.visible = False
|
||||
|
||||
threading.Thread(target=self.train, daemon=True).start()
|
||||
self.training_thread = threading.Thread(target=self.train, daemon=True)
|
||||
self.training_thread.start()
|
||||
|
||||
def on_update(self, delta_time):
|
||||
if self.training and os.path.exists(os.path.join("training_logs", "progress.csv")) and time.perf_counter() - self.last_progress_update >= 0.5:
|
||||
self.last_progress_update = time.perf_counter()
|
||||
try:
|
||||
result = self.result_queue.get_nowait()
|
||||
|
||||
try:
|
||||
progress_df = pd.read_csv(os.path.join("training_logs", "progress.csv"))
|
||||
except pd.errors.EmptyDataError:
|
||||
return
|
||||
if result["type"] == "text":
|
||||
self.training_text = result["message"]
|
||||
|
||||
progress_text = ""
|
||||
elif result["type"] == "plot":
|
||||
self.plot_image_widget.texture = result["image"]
|
||||
self.plot_image_widget.width = result["image"].width
|
||||
self.plot_image_widget.height = result["image"].height
|
||||
self.plot_image_widget.trigger_render()
|
||||
self.plot_image_widget.visible = True
|
||||
|
||||
for key, value in progress_df.items():
|
||||
progress_text += f"{key}: {round(value.iloc[-1], 6)}\n"
|
||||
elif result["type"] == "finished":
|
||||
self.training = False
|
||||
self.training_text = "Training finished."
|
||||
|
||||
self.training_text = progress_text
|
||||
except queue.Empty:
|
||||
if self.training and os.path.exists(os.path.join("training_logs", "progress.csv")) and all([os.path.exists(os.path.join(monitor_log_dir, f"monitor_{i}.csv.monitor.csv")) for i in range(int(self.settings["n_envs"]))]) and time.perf_counter() - self.last_progress_update >= 1:
|
||||
self.last_progress_update = time.perf_counter()
|
||||
self.plot_results()
|
||||
|
||||
if hasattr(self, "training_label"):
|
||||
self.training_label.text = self.training_text
|
||||
|
||||
def round_near_int(self, x, tol=1e-4):
|
||||
nearest = round(x)
|
||||
|
||||
if abs(x - nearest) < tol:
|
||||
return nearest
|
||||
|
||||
return x
|
||||
|
||||
def train(self):
|
||||
os.makedirs(monitor_log_dir, exist_ok=True)
|
||||
env = Monitor(SpaceInvadersEnv(), filename=os.path.join(monitor_log_dir, "monitor.csv"))
|
||||
if os.path.exists(monitor_log_dir):
|
||||
shutil.rmtree(monitor_log_dir)
|
||||
os.makedirs(monitor_log_dir)
|
||||
|
||||
n_envs = int(self.settings["n_envs"])
|
||||
env = DummyVecEnv([make_env(i) for i in range(n_envs)])
|
||||
|
||||
n_steps = int(self.settings["n_steps"])
|
||||
batch_size = int(self.settings["batch_size"])
|
||||
|
||||
total_steps_per_rollout = n_steps * max(1, n_envs)
|
||||
if total_steps_per_rollout % batch_size != 0:
|
||||
batch_size = max(64, total_steps_per_rollout // max(1, total_steps_per_rollout // batch_size))
|
||||
print(f"Warning: Adjusting batch size to {batch_size} for {n_envs} envs.")
|
||||
|
||||
model = PPO(
|
||||
"MlpPolicy",
|
||||
env,
|
||||
n_steps=self.settings["n_steps"],
|
||||
batch_size=self.settings["batch_size"],
|
||||
n_epochs=self.settings["n_epochs"],
|
||||
learning_rate=self.settings["learning_rate"],
|
||||
"MlpPolicy",
|
||||
env,
|
||||
n_steps=n_steps,
|
||||
batch_size=batch_size,
|
||||
n_epochs=int(self.settings["n_epochs"]),
|
||||
learning_rate=float(self.settings["learning_rate"]),
|
||||
verbose=1,
|
||||
device="cpu",
|
||||
gamma=self.settings["gamma"],
|
||||
ent_coef=self.settings["ent_coef"],
|
||||
clip_range=self.settings["clip_range"],
|
||||
gamma=float(self.settings["gamma"]),
|
||||
ent_coef=float(self.settings["ent_coef"]),
|
||||
clip_range=float(self.settings["clip_range"]),
|
||||
)
|
||||
|
||||
new_logger = configure(
|
||||
folder=monitor_log_dir, format_strings=["csv"]
|
||||
)
|
||||
|
||||
new_logger = configure(folder=monitor_log_dir, format_strings=["csv"])
|
||||
model.set_logger(new_logger)
|
||||
|
||||
model.learn(self.settings["learning_steps"])
|
||||
model.save("invader_agent")
|
||||
try:
|
||||
self.training = True
|
||||
model.learn(int(self.settings["learning_steps"]))
|
||||
model.save("invader_agent")
|
||||
except Exception as e:
|
||||
print(f"Error during training: {e}")
|
||||
self.result_queue.put({"type": "text", "message": f"Error:\n{e}"})
|
||||
finally:
|
||||
try:
|
||||
env.close()
|
||||
except Exception:
|
||||
pass
|
||||
self.result_queue.put({"type": "finished"})
|
||||
|
||||
self.training = False
|
||||
def plot_results(self):
|
||||
try:
|
||||
reward_df = pd.read_csv(os.path.join(monitor_log_dir, "progress.csv"))
|
||||
except pd.errors.EmptyDataError:
|
||||
return
|
||||
|
||||
self.plot_results(os.path.join(monitor_log_dir, "monitor.csv"), os.path.join(monitor_log_dir, "progress.csv"))
|
||||
all_monitor_files = [os.path.join(monitor_log_dir, f) for f in os.listdir(monitor_log_dir) if f.startswith("monitor_") and f.endswith(".csv")]
|
||||
try:
|
||||
df_list = [pd.read_csv(f, skiprows=1) for f in all_monitor_files]
|
||||
except pd.errors.EmptyDataError:
|
||||
return
|
||||
|
||||
def plot_results(self, log_path, loss_log_path):
|
||||
df = pd.read_csv(log_path, skiprows=1)
|
||||
monitor_df = pd.concat(df_list).sort_values(by='t')
|
||||
monitor_df['total_timesteps'] = monitor_df['l'].cumsum()
|
||||
|
||||
loss_log_path = os.path.join(monitor_log_dir, "progress.csv")
|
||||
loss_df = None
|
||||
if os.path.exists(loss_log_path):
|
||||
try:
|
||||
loss_df = pd.read_csv(loss_log_path)
|
||||
except Exception:
|
||||
loss_df = None
|
||||
|
||||
fig, axes = plt.subplots(2, 1, figsize=(6, 8), dpi=100)
|
||||
|
||||
loss_df = pd.read_csv(loss_log_path)
|
||||
|
||||
axes[0].plot(np.cumsum(df['l']), df['r'], label='Reward')
|
||||
if monitor_df is not None and 'total_timesteps' in monitor_df.columns and 'r' in monitor_df.columns:
|
||||
axes[0].plot(monitor_df['total_timesteps'], monitor_df['r'].rolling(window=10).mean(), label='Episodic Reward (Rolling 10)')
|
||||
elif reward_df is not None and 'time/total_timesteps' in reward_df.columns and 'rollout/ep_rew_mean' in reward_df.columns:
|
||||
axes[0].plot(reward_df['time/total_timesteps'], reward_df['rollout/ep_rew_mean'], label='Ep reward mean')
|
||||
else:
|
||||
axes[0].text(0.5, 0.5, "No reward data available", horizontalalignment='center', verticalalignment='center')
|
||||
|
||||
axes[0].set_title('PPO Training: Episodic Reward')
|
||||
axes[0].set_xlabel('Total Timesteps')
|
||||
axes[0].set_ylabel('Reward')
|
||||
axes[0].grid(True)
|
||||
|
||||
axes[1].plot(loss_df['time/total_timesteps'], loss_df['train/policy_gradient_loss'], label='Policy Gradient Loss')
|
||||
axes[1].plot(loss_df['time/total_timesteps'], loss_df['train/value_loss'], label='Value Loss')
|
||||
axes[1].plot(loss_df['time/total_timesteps'], loss_df['train/explained_variance'], label='Explained Variance')
|
||||
axes[1].set_title('PPO Training: Loss Functions')
|
||||
axes[1].set_title('PPO Training: Loss & Variance')
|
||||
axes[1].set_xlabel('Total Timesteps')
|
||||
axes[1].set_ylabel('Loss Value')
|
||||
axes[1].legend()
|
||||
axes[1].set_ylabel('Value')
|
||||
axes[1].grid(True)
|
||||
|
||||
|
||||
if loss_df is not None and 'time/total_timesteps' in loss_df.columns and 'train/policy_gradient_loss' in loss_df.columns and 'train/value_loss' in loss_df.columns and 'train/explained_variance' in loss_df.columns:
|
||||
tcol = 'time/total_timesteps'
|
||||
axes[1].plot(loss_df[tcol], loss_df['train/policy_gradient_loss'], label='Policy Loss')
|
||||
axes[1].plot(loss_df[tcol], loss_df['train/value_loss'], label='Value Loss')
|
||||
axes[1].plot(loss_df[tcol], loss_df['train/explained_variance'], label='Explained Variance')
|
||||
|
||||
axes[1].legend()
|
||||
else:
|
||||
axes[1].text(0.5, 0.5, "No loss/variance data available", horizontalalignment='center', verticalalignment='center')
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
buffer = BytesIO()
|
||||
plt.savefig(buffer, format='png')
|
||||
plt.savefig(buffer, format='png', bbox_inches='tight')
|
||||
buffer.seek(0)
|
||||
plt.close(fig)
|
||||
|
||||
plot_texture = arcade.Texture(Image.open(buffer))
|
||||
pil_img = Image.open(buffer).convert("RGBA")
|
||||
|
||||
self.plot_image_widget.texture = plot_texture
|
||||
self.plot_image_widget.size_hint = (None, None)
|
||||
self.plot_image_widget.width = plot_texture.width
|
||||
self.plot_image_widget.height = plot_texture.height
|
||||
plot_texture = arcade.Texture(pil_img)
|
||||
|
||||
self.plot_image_widget.visible = True
|
||||
self.training_text = "Training finished. Plot displayed."
|
||||
self.result_queue.put({"type": "plot", "image": plot_texture})
|
||||
|
||||
Reference in New Issue
Block a user