add RL training which doesnt work that wall yet, and start to make UI for model training

This commit is contained in:
csd4ni3l
2025-11-15 15:56:56 +01:00
parent 032f38f4ce
commit 32477def6a
9 changed files with 524 additions and 17 deletions

View File

@@ -55,6 +55,9 @@ class Main(arcade.gui.UIView):
self.play_button = self.box.add(arcade.gui.UITextureButton(text="Play", texture=button_texture, texture_hovered=button_hovered_texture, width=self.window.width / 2, height=150, style=big_button_style))
self.play_button.on_click = lambda event: self.play()
self.train_button = self.box.add(arcade.gui.UITextureButton(text="Train", texture=button_texture, texture_hovered=button_hovered_texture, width=self.window.width / 2, height=150, style=big_button_style))
self.train_button.on_click = lambda event: self.train()
self.settings_button = self.box.add(arcade.gui.UITextureButton(text="Settings", texture=button_texture, texture_hovered=button_hovered_texture, width=self.window.width / 2, height=150, style=big_button_style))
self.settings_button.on_click = lambda event: self.settings()

60
menus/train_model.py Normal file
View File

@@ -0,0 +1,60 @@
import arcade, arcade.gui
from utils.constants import button_style, MODEL_SETTINGS
from utils.preload import button_texture, button_hovered_texture
from stable_baselines3 import PPO
from utils.ml import SpaceInvadersEnv
class TrainModel(arcade.gui.UIView):
def __init__(self, pypresence_client):
self.pypresence_client = pypresence_client
self.pypresence_client.update(state="Model Training")
self.current_state = "settings"
self.anchor = self.add_widget(arcade.gui.UIAnchorLayout(size_hint=(1, 1)))
self.box = self.anchor.add(arcade.gui.UIBoxLayout(space_between=10))
self.settings = MODEL_SETTINGS.copy()
def on_show_view(self):
super().on_show_view()
self.show_menu(self.current_state)
def show_menu(self, state):
if state == "settings":
self.box.add(arcade.gui.UILabel("Settings", font_size=48))
for setting, data in MODEL_SETTINGS:
default, min_value, max_value, step = data
self.box.add(arcade.gui.UILabel(text=f"{setting.replace('_', ' ').capitalize()}: {default}"))
slider = self.box.add(arcade.gui.UISlider(value=default, min_value=min_value, max_value=max_value, step=step))
slider._render_steps = lambda surface: None
slider.on_change = lambda e, key=setting: self.change_value(key, e.new_value)
train_button = self.box.add(arcade.gui.UITextureButton(width=self.window.width / 2, height=self.window.height / 10, text="Train", style=button_style, texture=button_texture, texture_hovered=button_hovered_texture))
train_button.on_click = lambda e: self.train()
def change_value(self, key, value):
...
def train(self):
env = SpaceInvadersEnv()
model = PPO(
"MlpPolicy",
env,
n_steps=2048,
batch_size=64,
n_epochs=10,
learning_rate=3e-4,
verbose=1,
device="cpu",
gamma=0.99,
ent_coef=0.01,
clip_range=0.2
)
model.learn(1_000_000)
model.save("invader_agent")