Add 10 million timestep model, improve README, add diffculty/mode selector, score, make model training have instant graphs and also multiple envs for faster training, better plotting, improve RL model by including multiple players, better reard system, use EnemyFormation instead of single Enemy-es

This commit is contained in:
csd4ni3l
2025-11-16 21:27:21 +01:00
parent fb7b45b6df
commit dce64e5d3f
10 changed files with 644 additions and 273 deletions

View File

@@ -2,7 +2,9 @@ Fleet Commander is like Space Invaders but you are the enemy instead of the play
It uses AI (Reinforcement Learning) for the Player, and you, the Enemy has to defeat it.
You can train yourself, or use the default model which comes with the game.
I know the game is too easy and is too simple, but please understand that doing RL isnt the easiest thing ever. I also did this so late so yeah.
You can train yourself, or use the default model(10 million timesteps) which comes with the game.
# Install steps:
@@ -18,4 +20,7 @@ You can train yourself, or use the default model which comes with the game.
- `pip3 install -r requirements.txt`
- `pip3 install torch --index-url https://download.pytorch.org/whl/cpu`
- `pip3 install stable_baselines3`
- `python3 run.py`
- `python3 run.py`
# Disclaimer
AI assistance was used in this project, since i never did any RL work before. But every instance of AI code was heavily modified by me.

View File

@@ -1,16 +1,17 @@
import arcade, arcade.gui, random, time
from utils.constants import button_style, ENEMY_ROWS, ENEMY_COLS, PLAYER_ATTACK_SPEED
from utils.constants import button_style, ENEMY_ATTACK_SPEED, ENEMY_SPEED
from utils.preload import button_texture, button_hovered_texture
from stable_baselines3 import PPO
from game.sprites import Enemy, Player, Bullet
from game.sprites import EnemyFormation, Player, Bullet
class Game(arcade.gui.UIView):
def __init__(self, pypresence_client):
def __init__(self, pypresence_client, settings):
super().__init__()
self.settings = settings
self.pypresence_client = pypresence_client
self.pypresence_client.update(state="Invading Space")
@@ -18,42 +19,53 @@ class Game(arcade.gui.UIView):
self.spritelist = arcade.SpriteList()
self.player = Player(self.window.width / 2 + random.randint(int(-self.window.width / 3), int(self.window.width / 3)), 100) # not actually player
self.spritelist.append(self.player)
self.last_player_shoot = time.perf_counter() # not actually player
self.players = []
for _ in range(settings["player_count"]):
self.players.append(Player(self.window.width / 2 + random.randint(int(-self.window.width / 3), int(self.window.width / 3)), 100)) # not actually player
self.spritelist.append(self.players[-1])
self.model = PPO.load("invader_agent.zip")
self.enemies: list[Enemy] = []
self.enemy_formation = EnemyFormation(self.window.width / 2 + random.randint(int(-self.window.width / 3), int(self.window.width / 3)), self.window.height * 0.9, self.spritelist, settings["enemy_rows"], settings["enemy_cols"])
self.player_bullets: list[Bullet] = []
self.enemy_bullets: list[Bullet] = []
self.summon_enemies()
self.player_respawns = settings["player_respawns"]
self.enemy_respawns = settings["enemy_respawns"]
self.score = 0
self.last_enemy_shoot = time.perf_counter()
self.game_over = False
def on_show_view(self):
super().on_show_view()
self.back_button = self.anchor.add(arcade.gui.UITextureButton(texture=button_texture, texture_hovered=button_hovered_texture, text='<--', style=button_style, width=100, height=50), anchor_x="left", anchor_y="top", align_x=5, align_y=-5)
self.back_button.on_click = lambda event: self.main_exit()
self.score_label = self.anchor.add(arcade.gui.UILabel("Score: 0", font_size=24), anchor_x="center", anchor_y="top")
def main_exit(self):
from menus.main import Main
self.window.show_view(Main(self.pypresence_client))
def summon_enemies(self):
enemy_start_x = self.window.width * 0.15
enemy_start_y = self.window.height * 0.9
for row in range(ENEMY_ROWS):
for col in range(ENEMY_COLS):
enemy_sprite = Enemy(enemy_start_x + col * 100, enemy_start_y - row * 100)
self.spritelist.append(enemy_sprite)
self.enemies.append(enemy_sprite)
def on_update(self, delta_time):
for enemy in self.enemies:
enemy.update()
if self.game_over:
return
if self.window.keyboard[arcade.key.LEFT] or self.window.keyboard[arcade.key.A]:
self.enemy_formation.move(self.window.width, self.window.height, "x", -ENEMY_SPEED)
if self.window.keyboard[arcade.key.RIGHT] or self.window.keyboard[arcade.key.D]:
self.enemy_formation.move(self.window.width, self.window.height, "x", ENEMY_SPEED)
if self.window.keyboard[arcade.key.DOWN] or self.window.keyboard[arcade.key.S]:
self.enemy_formation.move(self.window.width, self.window.height, "y", -ENEMY_SPEED)
if self.window.keyboard[arcade.key.UP] or self.window.keyboard[arcade.key.W]:
self.enemy_formation.move(self.window.width, self.window.height, "y", ENEMY_SPEED)
if self.enemy_formation.enemies and self.window.keyboard[arcade.key.SPACE] and time.perf_counter() - self.last_enemy_shoot >= ENEMY_ATTACK_SPEED:
self.last_enemy_shoot = time.perf_counter()
enemy = self.enemy_formation.get_lowest_enemy()
self.shoot(enemy.center_x, enemy.center_y, -1)
bullets_to_remove = []
@@ -62,18 +74,21 @@ class Game(arcade.gui.UIView):
bullet_hit = False
if bullet.direction_y == 1:
for enemy in self.enemies:
for enemy in self.enemy_formation.enemies:
if bullet.rect.intersection(enemy.rect):
self.spritelist.remove(enemy)
self.enemies.remove(enemy)
self.player.target = None
self.enemy_formation.remove_enemy(enemy)
bullets_to_remove.append(bullet)
bullet_hit = True
break
else:
if bullet.rect.intersection(self.player.rect):
bullets_to_remove.append(bullet)
bullet_hit = True
for player in self.players:
if bullet.rect.intersection(player.rect):
self.spritelist.remove(player)
self.players.remove(player)
bullets_to_remove.append(bullet)
bullet_hit = True
self.score += 75
break
if not bullet_hit and bullet.center_y > self.window.height or bullet.center_y < 0:
bullets_to_remove.append(bullet)
@@ -86,20 +101,40 @@ class Game(arcade.gui.UIView):
elif bullet_to_remove in self.player_bullets:
self.player_bullets.remove(bullet_to_remove)
self.player.update(self.model, self.enemies, self.enemy_bullets, self.window.width, self.window.height) # not actually player
for player in self.players:
player.update(self.model, self.enemy_formation, self.enemy_bullets, self.window.width, self.window.height, self.player_respawns / self.settings["player_respawns"], self.enemy_respawns / self.settings["enemy_respawns"]) # not actually player
if self.player.center_x > self.window.width:
self.player.center_x = self.window.width
elif self.player.center_x < 0:
self.player.center_x = 0
if player.center_x > self.window.width:
player.center_x = self.window.width
elif player.center_x < 0:
player.center_x = 0
if self.player.shoot:
self.player.shoot = False
self.shoot(self.player.center_x, self.player.center_y, 1)
if player.shoot:
player.shoot = False
self.shoot(player.center_x, player.center_y, 1)
if time.perf_counter() - self.last_player_shoot >= PLAYER_ATTACK_SPEED:
self.last_player_shoot = time.perf_counter()
self.shoot(self.player.center_x, self.player.center_y, 1)
if len(self.players) == 0:
if self.player_respawns > 0:
self.player_respawns -= 1
for _ in range(self.settings["player_count"]):
self.players.append(Player(self.window.width / 2 + random.randint(int(-self.window.width / 3), int(self.window.width / 3)), 100)) # not actually player
self.spritelist.append(self.players[-1])
self.score += 300
else:
self.game_over = True
self.game_over_label = self.anchor.add(arcade.gui.UILabel("You (The enemies) won!", font_size=48), anchor_x="center", anchor_y="center")
elif len(self.enemy_formation.enemies) == 0:
if self.enemy_respawns > 0:
self.enemy_respawns -= 1
self.enemy_formation.create_formation(self.window.width / 2 + random.randint(int(-self.window.width / 3), int(self.window.width / 3)), self.window.height * 0.9)
else:
self.game_over = True
self.game_over_label = self.anchor.add(arcade.gui.UILabel("You lost! The Players win!", font_size=48), anchor_x="center", anchor_y="center")
self.score += 5 * delta_time
self.score_label.text = f"Score: {int(self.score)}"
def shoot(self, x, y, direction_y):
bullet = Bullet(x, y, direction_y)
@@ -112,11 +147,6 @@ class Game(arcade.gui.UIView):
bullets.append(bullet)
def on_key_press(self, symbol, modifiers):
if symbol == arcade.key.SPACE:
enemy = random.choice(self.enemies)
self.shoot(enemy.center_x, enemy.center_y, -1)
def on_draw(self):
super().on_draw()

View File

@@ -1,10 +1,8 @@
import arcade, time
import arcade, time, random, numpy as np
from stable_baselines3 import PPO
import numpy as np
from utils.constants import PLAYER_SPEED, BULLET_SPEED, BULLET_RADIUS, PLAYER_ATTACK_SPEED, ENEMY_COLS, ENEMY_ROWS
from utils.constants import PLAYER_SPEED, BULLET_SPEED, BULLET_RADIUS, PLAYER_ATTACK_SPEED
from utils.preload import player_texture, enemy_texture
class Bullet(arcade.Sprite):
@@ -16,9 +14,100 @@ class Bullet(arcade.Sprite):
def update(self):
self.center_y += self.direction_y * BULLET_SPEED
class Enemy(arcade.Sprite):
def __init__(self, x, y):
super().__init__(enemy_texture, center_x=x, center_y=y)
class EnemyFormation():
def __init__(self, start_x, start_y, spritelist: arcade.SpriteList | None, rows, cols):
self.grid = [[] for _ in range(rows)]
self.start_x = start_x
self.start_y = start_y
self.rows = rows
self.cols = cols
self.spritelist = spritelist
self.create_formation()
def create_formation(self, start_x=None, start_y=None):
if start_x:
self.start_x = start_x
if start_y:
self.start_y = start_y
del self.grid
self.grid = [[] for _ in range(self.rows)]
for row in range(self.rows):
for col in range(self.cols):
enemy_sprite = arcade.Sprite(enemy_texture, center_x=self.start_x + col * 100, center_y=self.start_y - row * 100)
if self.spritelist:
self.spritelist.append(enemy_sprite)
self.grid[row].append(enemy_sprite)
def remove_enemy(self, enemy):
if self.spritelist and enemy not in self.spritelist:
return
for row in range(self.rows):
for col in range(self.cols):
if self.grid[row][col] == enemy:
self.grid[row][col] = None
if self.spritelist:
self.spritelist.remove(enemy)
return
def get_lowest_enemy(self):
valid_cols = []
for col in range(self.cols):
row = self.rows - 1
while row >= 0 and self.grid[row][col] is None:
row -= 1
if row >= 0:
valid_cols.append((col, row))
if not valid_cols:
return None
col, row = random.choice(valid_cols)
return self.grid[row][col]
def move(self, width, height, direction_type, value):
if direction_type == "x":
wall_hit = False
for enemy in self.enemies:
self.start_x += value
enemy.center_x += value
if enemy.center_x + enemy.width / 2 > width or enemy.center_x < enemy.width / 2:
wall_hit = True
if wall_hit:
for enemy in self.enemies:
self.start_x -= value
enemy.center_x -= value
else:
wall_hit = False
for enemy in self.enemies:
self.start_x += value
enemy.center_y += value
if enemy.center_y + enemy.height / 2 > height or enemy.center_y < enemy.height / 2:
wall_hit = True
if wall_hit:
for enemy in self.enemies:
self.start_y -= value
enemy.center_y -= value
@property
def center_x(self):
return self.start_x + (self.cols / 2) * 100
@property
def enemies(self):
return [col for row in self.grid for col in row if not col == None]
class Player(arcade.Sprite): # Not actually the player
def __init__(self, x, y):
@@ -26,21 +115,19 @@ class Player(arcade.Sprite): # Not actually the player
self.last_target_change = time.perf_counter()
self.last_shoot = time.perf_counter()
self.target = None
self.shoot = False
self.player_speed = 0
def update(self, model: PPO, enemies, bullets, width, height):
if enemies:
nearest_enemy = min(enemies, key=lambda e: abs(e.center_x - self.center_x))
def update(self, model: PPO, enemy_formation, bullets, width, height, player_respawns_norm, enemy_respawns_norm):
if enemy_formation.enemies:
nearest_enemy = min(enemy_formation.enemies, key=lambda e: abs(e.center_x - self.center_x))
enemy_x = (nearest_enemy.center_x - self.center_x) / width
enemy_y = (nearest_enemy.center_y - self.center_y) / height
else:
enemy_x = 2
enemy_y = 2
enemy_count = len(enemies) / float(max(1, ENEMY_ROWS * ENEMY_COLS))
enemy_count = len(enemy_formation.enemies) / float(max(1, enemy_formation.rows * enemy_formation.cols))
player_x_norm = self.center_x / width
curr_bullet = min(bullets, key=lambda b: abs(b.center_x - self.center_x) + abs(b.center_y - self.center_y)) if bullets else None
@@ -51,18 +138,31 @@ class Player(arcade.Sprite): # Not actually the player
curr_bx = 2.0
curr_by = 2.0
lowest = max(enemies, key=lambda e: e.center_y) if enemies else None
lowest = max(enemy_formation.enemies, key=lambda e: e.center_y) if enemy_formation.enemies else None
if lowest is not None:
lowest_dy = (lowest.center_y - self.center_y) / float(height)
else:
lowest_dy = 2.0
enemy_dispersion = 0.0
if enemies:
xs = np.array([e.center_x for e in enemies], dtype=np.float32)
if enemy_formation.enemies:
xs = np.array([e.center_x for e in enemy_formation.enemies], dtype=np.float32)
enemy_dispersion = float(xs.std()) / float(width)
obs = np.array([player_x_norm, enemy_x, enemy_y, lowest_dy, curr_bx, curr_by, self.player_speed, enemy_count, enemy_dispersion], dtype=np.float32)
obs = np.array([
player_x_norm,
enemy_x,
enemy_y,
lowest_dy,
curr_bx,
curr_by,
self.player_speed,
enemy_count,
enemy_dispersion,
time.perf_counter() - self.last_shoot >= PLAYER_ATTACK_SPEED,
player_respawns_norm,
enemy_respawns_norm
], dtype=np.float32)
action, _ = model.predict(obs, deterministic=True)
self.prev_x = self.center_x
@@ -71,6 +171,8 @@ class Player(arcade.Sprite): # Not actually the player
elif action == 1:
self.center_x += PLAYER_SPEED
elif action == 2:
pass
elif action == 3:
t = time.perf_counter()
if t - self.last_shoot >= PLAYER_ATTACK_SPEED:
self.last_shoot = t

BIN
invader_agent.zip Normal file

Binary file not shown.

View File

@@ -62,8 +62,8 @@ class Main(arcade.gui.UIView):
self.settings_button.on_click = lambda event: self.settings()
def play(self):
from game.play import Game
self.window.show_view(Game(self.pypresence_client))
from menus.mode_selector import ModeSelector
self.window.show_view(ModeSelector(self.pypresence_client))
def settings(self):
from menus.settings import Settings

72
menus/mode_selector.py Normal file
View File

@@ -0,0 +1,72 @@
import arcade, arcade.gui
from utils.preload import button_texture, button_hovered_texture
from utils.constants import dropdown_style, button_style, DIFFICULTY_LEVELS, DIFFICULTY_SETTINGS
class ModeSelector(arcade.gui.UIView):
def __init__(self, pypresence_client):
super().__init__()
self.pypresence_client = pypresence_client
self.pypresence_client.update(state="Selecting Mode")
self.anchor = self.add_widget(arcade.gui.UIAnchorLayout(size_hint=(1, 1)))
self.box = self.anchor.add(arcade.gui.UIBoxLayout(size_hint=(0.75, 0.75), space_between=10), anchor_x="center", anchor_y="center")
self.settings = DIFFICULTY_LEVELS["Easy"]
self.setting_sliders = {}
self.setting_labels = {}
def on_show_view(self):
super().on_show_view()
self.back_button = self.anchor.add(arcade.gui.UITextureButton(texture=button_texture, texture_hovered=button_hovered_texture, text='<--', style=button_style, width=100, height=50), anchor_x="left", anchor_y="top", align_x=5, align_y=-5)
self.back_button.on_click = lambda event: self.main_exit()
self.box.add(arcade.gui.UILabel("Settings", font_size=32))
self.box.add(arcade.gui.UISpace(height=self.window.height / 80))
self.difficulty_selector = self.box.add(arcade.gui.UIDropdown(default="Easy", options=list(DIFFICULTY_LEVELS.keys()), active_style=dropdown_style, primary_style=dropdown_style, dropdown_style=dropdown_style, width=self.window.width / 2, height=self.window.height / 20))
self.difficulty_selector.on_change = lambda event: self.set_difficulty_values(event.new_value)
self.box.add(arcade.gui.UISpace(height=self.window.height / 80))
for key, data in DIFFICULTY_SETTINGS.items():
default, name, min_value, max_value = DIFFICULTY_LEVELS["Easy"][key], *data
label = self.box.add(arcade.gui.UILabel(text=f"{name}: {default}", font_size=14))
slider = self.box.add(arcade.gui.UISlider(value=default, min_value=min_value, max_value=max_value, step=1, width=self.window.width / 2, height=self.window.height / 25))
slider._render_steps = lambda surface: None
slider.on_event = lambda event: None # disable slider for difficulties
slider.on_click = lambda event: None # disable slider for difficulties
slider.on_change = lambda e, key=key: self.change_value(key, e.new_value)
self.setting_sliders[key] = slider
self.setting_labels[key] = label
self.play_button = self.box.add(arcade.gui.UITextureButton(text="Play", width=self.window.width / 2, height=self.window.height / 15, texture=button_texture, texture_hovered=button_hovered_texture, style=button_style))
self.play_button.on_click = lambda event: self.start_game()
def set_difficulty_values(self, difficulty):
for key, value in DIFFICULTY_LEVELS[difficulty].items():
self.settings[key] = value
self.setting_labels[key].text = f"{DIFFICULTY_SETTINGS[key][0]}: {value}"
self.setting_sliders[key].value = value
for slider in self.setting_sliders.values():
if difficulty != "Custom":
slider.on_event = lambda event: None
slider.on_click = lambda event: None
else:
slider.on_event = lambda event, slider=slider: arcade.gui.UISlider.on_event(slider, event)
slider.on_click = lambda event, slider=slider: arcade.gui.UISlider.on_click(slider, event)
def change_value(self, key, value):
self.settings[key] = int(value)
self.setting_labels[key].text = f"{DIFFICULTY_SETTINGS[key][0]}: {int(value)}"
def start_game(self):
from game.play import Game
self.window.show_view(Game(self.pypresence_client, self.settings))

View File

@@ -1,20 +1,26 @@
import arcade, arcade.gui, threading, io, os, time
import arcade, arcade.gui, threading, os, queue, time, shutil
import matplotlib.pyplot as plt
import pandas as pd
from utils.constants import button_style, MODEL_SETTINGS, monitor_log_dir
from utils.preload import button_texture, button_hovered_texture
from utils.rl import SpaceInvadersEnv
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from io import BytesIO
from PIL import Image
from io import BytesIO
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.logger import configure
from stable_baselines3.common.vec_env import DummyVecEnv
from utils.rl import SpaceInvadersEnv
def make_env(rank: int, seed: int = 0):
def _init():
env = SpaceInvadersEnv()
env = Monitor(env, filename=os.path.join(monitor_log_dir, f"monitor_{rank}.csv"))
return env
return _init
class TrainModel(arcade.gui.UIView):
def __init__(self, pypresence_client):
@@ -24,22 +30,25 @@ class TrainModel(arcade.gui.UIView):
self.pypresence_client.update(state="Model Training")
self.anchor = self.add_widget(arcade.gui.UIAnchorLayout(size_hint=(1, 1)))
self.box = self.anchor.add(arcade.gui.UIBoxLayout(space_between=10))
self.box = self.anchor.add(arcade.gui.UIBoxLayout(space_between=5))
self.settings = {
setting: data[0] # default value
setting: data[0]
for setting, data in MODEL_SETTINGS.items()
}
self.labels = {}
self.training = False
self.training_text = ""
self.training_text = "Starting training..."
self.result_queue = queue.Queue()
self.training_thread = None
self.last_progress_update = time.perf_counter()
def on_show_view(self):
super().on_show_view()
self.show_menu()
def main_exit(self):
@@ -50,128 +59,188 @@ class TrainModel(arcade.gui.UIView):
self.back_button = self.anchor.add(arcade.gui.UITextureButton(texture=button_texture, texture_hovered=button_hovered_texture, text='<--', style=button_style, width=100, height=50), anchor_x="left", anchor_y="top", align_x=5, align_y=-5)
self.back_button.on_click = lambda event: self.main_exit()
self.box.add(arcade.gui.UILabel("Settings", font_size=36))
self.box.add(arcade.gui.UILabel("Settings", font_size=32))
for setting, data in MODEL_SETTINGS.items():
default, min_value, max_value, step = data
label = self.box.add(arcade.gui.UILabel(text=f"{setting.replace('_', ' ').capitalize()}: {default}", font_size=18))
is_int = setting == "n_envs" or (abs(step - 1) < 1e-6 and abs(min_value - round(min_value)) < 1e-6)
val_text = str(int(default)) if is_int else str(default)
label = self.box.add(arcade.gui.UILabel(text=f"{setting.replace('_', ' ').capitalize()}: {val_text}", font_size=14))
slider = self.box.add(arcade.gui.UISlider(value=default, min_value=min_value, max_value=max_value, step=step, width=self.window.width / 2, height=self.window.height / 25))
slider._render_steps = lambda surface: None
slider.on_change = lambda e, key=setting: self.change_value(key, e.new_value)
slider.on_change = lambda e, key=setting, is_int_slider=is_int: self.change_value(key, e.new_value, is_int_slider)
self.labels[setting] = label
train_button = self.box.add(arcade.gui.UITextureButton(width=self.window.width / 2, height=self.window.height / 10, text="Train", style=button_style, texture=button_texture, texture_hovered=button_hovered_texture))
train_button.on_click = lambda e: self.start_training()
def change_value(self, key, value):
self.labels[key].text = f"{key.replace('_', ' ').capitalize()}: {self.round_near_int(value)}"
self.settings[key] = self.round_near_int(value)
def change_value(self, key, value, is_int=False):
if is_int:
val = int(round(value))
self.settings[key] = val
self.labels[key].text = f"{key.replace('_', ' ').capitalize()}: {val}"
else:
val = self.round_near_int(value)
self.settings[key] = val
self.labels[key].text = f"{key.replace('_', ' ').capitalize()}: {val}"
def start_training(self):
self.box.clear()
self.training = True
self.training_label = self.box.add(arcade.gui.UILabel("No Output yet.", font_size=16, multiline=True, width=self.window.width / 2, height=self.window.height / 2))
self.training_text = "Starting training..."
self.training_label = self.box.add(arcade.gui.UILabel("Starting training...", font_size=16, multiline=True, width=self.window.width / 2, height=self.window.height / 2))
self.plot_image_widget = self.box.add(arcade.gui.UIImage(texture=arcade.Texture.create_empty("empty", (1, 1))))
self.plot_image_widget.visible = False
threading.Thread(target=self.train, daemon=True).start()
self.training_thread = threading.Thread(target=self.train, daemon=True)
self.training_thread.start()
def on_update(self, delta_time):
if self.training and os.path.exists(os.path.join("training_logs", "progress.csv")) and time.perf_counter() - self.last_progress_update >= 0.5:
self.last_progress_update = time.perf_counter()
try:
result = self.result_queue.get_nowait()
try:
progress_df = pd.read_csv(os.path.join("training_logs", "progress.csv"))
except pd.errors.EmptyDataError:
return
if result["type"] == "text":
self.training_text = result["message"]
progress_text = ""
elif result["type"] == "plot":
self.plot_image_widget.texture = result["image"]
self.plot_image_widget.width = result["image"].width
self.plot_image_widget.height = result["image"].height
self.plot_image_widget.trigger_render()
self.plot_image_widget.visible = True
for key, value in progress_df.items():
progress_text += f"{key}: {round(value.iloc[-1], 6)}\n"
elif result["type"] == "finished":
self.training = False
self.training_text = "Training finished."
self.training_text = progress_text
except queue.Empty:
if self.training and os.path.exists(os.path.join("training_logs", "progress.csv")) and all([os.path.exists(os.path.join(monitor_log_dir, f"monitor_{i}.csv.monitor.csv")) for i in range(int(self.settings["n_envs"]))]) and time.perf_counter() - self.last_progress_update >= 1:
self.last_progress_update = time.perf_counter()
self.plot_results()
if hasattr(self, "training_label"):
self.training_label.text = self.training_text
def round_near_int(self, x, tol=1e-4):
nearest = round(x)
if abs(x - nearest) < tol:
return nearest
return x
def train(self):
os.makedirs(monitor_log_dir, exist_ok=True)
env = Monitor(SpaceInvadersEnv(), filename=os.path.join(monitor_log_dir, "monitor.csv"))
if os.path.exists(monitor_log_dir):
shutil.rmtree(monitor_log_dir)
os.makedirs(monitor_log_dir)
n_envs = int(self.settings["n_envs"])
env = DummyVecEnv([make_env(i) for i in range(n_envs)])
n_steps = int(self.settings["n_steps"])
batch_size = int(self.settings["batch_size"])
total_steps_per_rollout = n_steps * max(1, n_envs)
if total_steps_per_rollout % batch_size != 0:
batch_size = max(64, total_steps_per_rollout // max(1, total_steps_per_rollout // batch_size))
print(f"Warning: Adjusting batch size to {batch_size} for {n_envs} envs.")
model = PPO(
"MlpPolicy",
env,
n_steps=self.settings["n_steps"],
batch_size=self.settings["batch_size"],
n_epochs=self.settings["n_epochs"],
learning_rate=self.settings["learning_rate"],
"MlpPolicy",
env,
n_steps=n_steps,
batch_size=batch_size,
n_epochs=int(self.settings["n_epochs"]),
learning_rate=float(self.settings["learning_rate"]),
verbose=1,
device="cpu",
gamma=self.settings["gamma"],
ent_coef=self.settings["ent_coef"],
clip_range=self.settings["clip_range"],
gamma=float(self.settings["gamma"]),
ent_coef=float(self.settings["ent_coef"]),
clip_range=float(self.settings["clip_range"]),
)
new_logger = configure(
folder=monitor_log_dir, format_strings=["csv"]
)
new_logger = configure(folder=monitor_log_dir, format_strings=["csv"])
model.set_logger(new_logger)
model.learn(self.settings["learning_steps"])
model.save("invader_agent")
try:
self.training = True
model.learn(int(self.settings["learning_steps"]))
model.save("invader_agent")
except Exception as e:
print(f"Error during training: {e}")
self.result_queue.put({"type": "text", "message": f"Error:\n{e}"})
finally:
try:
env.close()
except Exception:
pass
self.result_queue.put({"type": "finished"})
self.training = False
def plot_results(self):
try:
reward_df = pd.read_csv(os.path.join(monitor_log_dir, "progress.csv"))
except pd.errors.EmptyDataError:
return
self.plot_results(os.path.join(monitor_log_dir, "monitor.csv"), os.path.join(monitor_log_dir, "progress.csv"))
all_monitor_files = [os.path.join(monitor_log_dir, f) for f in os.listdir(monitor_log_dir) if f.startswith("monitor_") and f.endswith(".csv")]
try:
df_list = [pd.read_csv(f, skiprows=1) for f in all_monitor_files]
except pd.errors.EmptyDataError:
return
def plot_results(self, log_path, loss_log_path):
df = pd.read_csv(log_path, skiprows=1)
monitor_df = pd.concat(df_list).sort_values(by='t')
monitor_df['total_timesteps'] = monitor_df['l'].cumsum()
loss_log_path = os.path.join(monitor_log_dir, "progress.csv")
loss_df = None
if os.path.exists(loss_log_path):
try:
loss_df = pd.read_csv(loss_log_path)
except Exception:
loss_df = None
fig, axes = plt.subplots(2, 1, figsize=(6, 8), dpi=100)
loss_df = pd.read_csv(loss_log_path)
axes[0].plot(np.cumsum(df['l']), df['r'], label='Reward')
if monitor_df is not None and 'total_timesteps' in monitor_df.columns and 'r' in monitor_df.columns:
axes[0].plot(monitor_df['total_timesteps'], monitor_df['r'].rolling(window=10).mean(), label='Episodic Reward (Rolling 10)')
elif reward_df is not None and 'time/total_timesteps' in reward_df.columns and 'rollout/ep_rew_mean' in reward_df.columns:
axes[0].plot(reward_df['time/total_timesteps'], reward_df['rollout/ep_rew_mean'], label='Ep reward mean')
else:
axes[0].text(0.5, 0.5, "No reward data available", horizontalalignment='center', verticalalignment='center')
axes[0].set_title('PPO Training: Episodic Reward')
axes[0].set_xlabel('Total Timesteps')
axes[0].set_ylabel('Reward')
axes[0].grid(True)
axes[1].plot(loss_df['time/total_timesteps'], loss_df['train/policy_gradient_loss'], label='Policy Gradient Loss')
axes[1].plot(loss_df['time/total_timesteps'], loss_df['train/value_loss'], label='Value Loss')
axes[1].plot(loss_df['time/total_timesteps'], loss_df['train/explained_variance'], label='Explained Variance')
axes[1].set_title('PPO Training: Loss Functions')
axes[1].set_title('PPO Training: Loss & Variance')
axes[1].set_xlabel('Total Timesteps')
axes[1].set_ylabel('Loss Value')
axes[1].legend()
axes[1].set_ylabel('Value')
axes[1].grid(True)
if loss_df is not None and 'time/total_timesteps' in loss_df.columns and 'train/policy_gradient_loss' in loss_df.columns and 'train/value_loss' in loss_df.columns and 'train/explained_variance' in loss_df.columns:
tcol = 'time/total_timesteps'
axes[1].plot(loss_df[tcol], loss_df['train/policy_gradient_loss'], label='Policy Loss')
axes[1].plot(loss_df[tcol], loss_df['train/value_loss'], label='Value Loss')
axes[1].plot(loss_df[tcol], loss_df['train/explained_variance'], label='Explained Variance')
axes[1].legend()
else:
axes[1].text(0.5, 0.5, "No loss/variance data available", horizontalalignment='center', verticalalignment='center')
plt.tight_layout()
buffer = BytesIO()
plt.savefig(buffer, format='png')
plt.savefig(buffer, format='png', bbox_inches='tight')
buffer.seek(0)
plt.close(fig)
plot_texture = arcade.Texture(Image.open(buffer))
pil_img = Image.open(buffer).convert("RGBA")
self.plot_image_widget.texture = plot_texture
self.plot_image_widget.size_hint = (None, None)
self.plot_image_widget.width = plot_texture.width
self.plot_image_widget.height = plot_texture.height
plot_texture = arcade.Texture(pil_img)
self.plot_image_widget.visible = True
self.training_text = "Training finished. Plot displayed."
self.result_queue.put({"type": "plot", "image": plot_texture})

2
run.py
View File

@@ -29,7 +29,7 @@ timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
log_filename = f"debug_{timestamp}.log"
logging.basicConfig(filename=f'{os.path.join(log_dir, log_filename)}', format='%(asctime)s %(name)s %(levelname)s: %(message)s', level=logging.DEBUG)
for logger_name_to_disable in ['arcade']:
for logger_name_to_disable in ['arcade', "matplotlib", "matplotlib.fontmanager", "PIL"]:
logging.getLogger(logger_name_to_disable).propagate = False
logging.getLogger(logger_name_to_disable).disabled = True

View File

@@ -3,25 +3,68 @@ from arcade.types import Color
from arcade.gui.widgets.buttons import UITextureButtonStyle, UIFlatButtonStyle
from arcade.gui.widgets.slider import UISliderStyle
ENEMY_ROWS = 3
ENEMY_COLS = 13
ENEMY_SPEED = 5
ENEMY_ATTACK_SPEED = 0.75
PLAYER_SPEED = 5 # not actually player
PLAYER_ATTACK_SPEED = 0.75
BULLET_SPEED = 3
BULLET_RADIUS = 10
BULLET_SPEED = 5
BULLET_RADIUS = 15
# default, min, max, step
MODEL_SETTINGS = {
"n_steps": [2048, 256, 8192, 256],
"batch_size": [64, 16, 512, 16],
"n_steps": [1024, 256, 8192, 256],
"batch_size": [128, 16, 512, 16],
"n_epochs": [10, 1, 50, 1],
"learning_rate": [3e-4, 1e-5, 1e-2, 1e-5],
"gamma": [0.99, 0.8, 0.9999, 0.001],
"ent_coef": [0.01, 0.0, 0.1, 0.001],
"ent_coef": [0.015, 0.0, 0.1, 0.001],
"clip_range": [0.2, 0.1, 0.4, 0.01],
"learning_steps": [500_000, 50_000, 25_000_000, 50_000]
"learning_steps": [1_000_000, 50_000, 25_000_000, 50_000],
"n_envs": (12, 1, 128, 1)
}
DIFFICULTY_SETTINGS = {
"enemy_rows": ["Enemy Rows", 1, 6],
"enemy_cols": ["Enemy Columns", 1, 7],
"enemy_respawns": ["Enemy Respawns", 1, 5],
"player_count": ["Player Count", 1, 10],
"player_respawns": ["Player Respawns", 1, 5]
}
DIFFICULTY_LEVELS = {
"Easy": {
"enemy_rows": 3,
"enemy_cols": 4,
"enemy_respawns": 5,
"player_count": 2,
"player_respawns": 2
},
"Medium": {
"enemy_rows": 3,
"enemy_cols": 5,
"enemy_respawns": 4,
"player_count": 4,
"player_respawns": 3
},
"Hard": {
"enemy_rows": 4,
"enemy_cols": 6,
"enemy_respawns": 3,
"player_count": 6,
"player_respawns": 4
},
"Extra Hard": {
"enemy_rows": 6,
"enemy_cols": 7,
"enemy_respawns": 2,
"player_count": 8,
"player_respawns": 5
},
"Custom": {
}
}
menu_background_color = (30, 30, 47)

View File

@@ -1,78 +1,87 @@
import gymnasium as gym
import numpy as np
import arcade
import time
import random
from game.sprites import Enemy, Player, Bullet
from utils.constants import PLAYER_SPEED, BULLET_SPEED, PLAYER_ATTACK_SPEED, ENEMY_ROWS, ENEMY_COLS
from game.sprites import EnemyFormation, Player, Bullet
from utils.constants import PLAYER_SPEED, BULLET_SPEED, ENEMY_SPEED, DIFFICULTY_LEVELS
class SpaceInvadersEnv(gym.Env):
def __init__(self, width=800, height=600):
def __init__(self, width=800, height=600, difficulty="Hard"):
self.width = width
self.height = height
self.action_space = gym.spaces.Discrete(3)
self.observation_space = gym.spaces.Box(low=-2.0, high=2.0, shape=(9,), dtype=np.float32)
self.action_space = gym.spaces.Discrete(4)
self.observation_space = gym.spaces.Box(low=-2.0, high=2.0, shape=(12,), dtype=np.float32)
if difficulty not in DIFFICULTY_LEVELS:
raise ValueError(f"Unknown difficulty: {difficulty}. Available: {list(DIFFICULTY_LEVELS.keys())}")
self.difficulty_settings = DIFFICULTY_LEVELS[difficulty]
self.enemies = []
self.bullets = []
self.dir_history = []
self.last_shot = 0.0
self.player = None
self.prev_x = 0.0
self.enemy_formation = None
self.player_speed = 0.0
self.prev_bx = 2.0
self.steps_since_direction_change = 0
self.last_direction = 0
self.max_steps = 1000
self.max_steps = 2000
self.current_step = 0
self.enemies_killed = 0
self.enemy_move_speed = ENEMY_SPEED
self.player_respawns = self.difficulty_settings["player_respawns"]
self.enemy_respawns = self.difficulty_settings["enemy_respawns"]
self.player_respawns_remaining = 0
self.enemy_respawns_remaining = 0
self.player_alive = True
self.player_attack_cooldown_steps = 5
self.current_cooldown = 0
def reset(self, seed=None, options=None):
if seed is not None:
np.random.seed(seed)
random.seed(seed)
self.enemies = []
self.bullets = []
self.dir_history = []
self.player = Player(self.width / 2 + random.randint(int(-self.width / 3), int(self.width / 3)), 100)
self.prev_x = self.player.center_x
self.player_speed = 0.0
self.prev_bx = 2.0
self.steps_since_direction_change = 0
self.last_direction = 0
self.current_step = 0
self.enemies_killed = 0
self.player_respawns_remaining = self.player_respawns
self.enemy_respawns_remaining = self.enemy_respawns
self.player_alive = True
self.current_cooldown = 0
start_x = self.width * 0.15
start_y = self.height * 0.9
for r in range(ENEMY_ROWS):
for c in range(ENEMY_COLS):
e = Enemy(start_x + c * 100, start_y - r * 100)
self.enemies.append(e)
self.enemy_formation = EnemyFormation(start_x, start_y, None,
self.difficulty_settings["enemy_rows"],
self.difficulty_settings["enemy_cols"])
self.last_shot = time.perf_counter()
return self._obs(), {}
def _nearest_enemy(self):
if not self.enemies:
if not self.enemy_formation.enemies:
return None
return min(self.enemies, key=lambda e: abs(e.center_x - self.player.center_x))
return min(self.enemy_formation.enemies, key=lambda e: abs(e.center_x - self.player.center_x))
def _lowest_enemy(self):
if not self.enemies:
if not self.enemy_formation.enemies:
return None
return max(self.enemies, key=lambda e: e.center_y)
return min(self.enemy_formation.enemies, key=lambda e: e.center_y)
def _nearest_enemy_bullet(self):
enemy_bullets = [b for b in self.bullets if b.direction_y == -1]
if not enemy_bullets:
return None
return min(enemy_bullets, key=lambda b: abs(b.center_x - self.player.center_x) + abs(b.center_y - self.player.center_y))
def _obs(self):
if self.enemies:
if self.enemy_formation.enemies and self.player_alive:
nearest = self._nearest_enemy()
enemy_x = (nearest.center_x - self.player.center_x) / float(self.width)
enemy_y = (nearest.center_y - self.player.center_y) / float(self.height)
@@ -81,31 +90,60 @@ class SpaceInvadersEnv(gym.Env):
enemy_y = 2.0
lowest = self._lowest_enemy()
if lowest is not None:
if lowest is not None and self.player_alive:
lowest_dy = (lowest.center_y - self.player.center_y) / float(self.height)
else:
lowest_dy = 2.0
nb = self._nearest_enemy_bullet()
if nb is not None:
if nb is not None and self.player_alive:
bx = (nb.center_x - self.player.center_x) / float(self.width)
by = (nb.center_y - self.player.center_y) / float(self.height)
else:
bx = 2.0
by = 2.0
enemy_count = len(self.enemies) / float(max(1, ENEMY_ROWS * ENEMY_COLS))
player_x_norm = self.player.center_x / float(self.width)
enemy_dispersion = 0.0
enemy_count = len(self.enemy_formation.enemies) / float(max(1, self.difficulty_settings["enemy_rows"] * self.difficulty_settings["enemy_cols"]))
player_x_norm = self.player.center_x / float(self.width) if self.player_alive else 0.5
if self.enemies:
xs = np.array([e.center_x for e in self.enemies], dtype=np.float32)
enemy_dispersion = 0.0
if self.enemy_formation.enemies:
xs = np.array([e.center_x for e in self.enemy_formation.enemies], dtype=np.float32)
enemy_dispersion = float(xs.std()) / float(self.width)
obs = np.array([player_x_norm, enemy_x, enemy_y, lowest_dy, bx, by, self.player_speed, enemy_count, enemy_dispersion], dtype=np.float32)
can_shoot = 1.0 if (self.player_alive and self.current_cooldown <= 0) else 0.0
player_respawns_norm = self.player_respawns_remaining / float(max(1, self.player_respawns))
enemy_respawns_norm = self.enemy_respawns_remaining / float(max(1, self.enemy_respawns))
obs = np.array([
player_x_norm,
enemy_x,
enemy_y,
lowest_dy,
bx,
by,
self.player_speed,
enemy_count,
enemy_dispersion,
can_shoot,
player_respawns_norm,
enemy_respawns_norm
], dtype=np.float32)
return obs
def _respawn_player(self):
self.player = Player(self.width / 2 + random.randint(int(-self.width / 3), int(self.width / 3)), 100)
self.player_alive = True
self.bullets = [b for b in self.bullets if b.direction_y == 1]
self.current_cooldown = 0
def _respawn_enemies(self):
self.enemy_formation.start_x = self.width * 0.15
self.enemy_formation.start_y = self.height * 0.9
self.enemy_formation.create_formation()
def step(self, action):
reward = 0.0
terminated = False
@@ -115,113 +153,125 @@ class SpaceInvadersEnv(gym.Env):
if self.current_step >= self.max_steps:
truncated = True
nearest = self._nearest_enemy()
if nearest is not None:
enemy_x = (nearest.center_x - self.player.center_x) / float(self.width)
else:
enemy_x = 2.0
if self.current_cooldown > 0:
self.current_cooldown -= 1
prev_x = self.player.center_x
current_action_dir = 0
if self.player_alive:
prev_x = self.player.center_x
if action == 0:
self.player.center_x -= PLAYER_SPEED
current_action_dir = -1
elif action == 1:
self.player.center_x += PLAYER_SPEED
current_action_dir = 1
elif action == 2:
t = time.perf_counter()
if t - self.last_shot >= PLAYER_ATTACK_SPEED:
self.last_shot = t
b = Bullet(self.player.center_x, self.player.center_y, 1)
if action == 0:
self.player.center_x -= PLAYER_SPEED
elif action == 1:
self.player.center_x += PLAYER_SPEED
elif action == 2:
pass
elif action == 3:
if self.current_cooldown <= 0:
self.current_cooldown = self.player_attack_cooldown_steps
reward += 0.01
b = Bullet(self.player.center_x, self.player.center_y, 1)
self.bullets.append(b)
else:
reward -= 0.05
self.bullets.append(b)
if enemy_x != 2.0 and abs(enemy_x) < 0.04:
if self.enemy_formation.enemies:
nearest = self._nearest_enemy()
alignment = abs(nearest.center_x - self.player.center_x) / self.width
if alignment < 0.025:
reward += 0.3
elif enemy_x != 2.0 and abs(enemy_x) < 0.1:
reward += 0.1
if self.player.center_x > self.width:
self.player.center_x = self.width
elif self.player.center_x < 0:
self.player.center_x = 0
self.player.center_x = np.clip(self.player.center_x, 0, self.width)
self.player_speed = (self.player.center_x - prev_x) / max(1e-6, PLAYER_SPEED)
self.player_speed = (self.player.center_x - prev_x) / max(1e-6, PLAYER_SPEED)
if self.enemy_formation.enemies and self.player_alive:
if self.enemy_formation.center_x < self.player.center_x:
self.enemy_formation.move(self.width, self.height, "x", self.enemy_move_speed)
elif self.enemy_formation.center_x > self.player.center_x:
self.enemy_formation.move(self.width, self.height, "x", -self.enemy_move_speed)
if random.random() < 0.02:
if random.random() < 0.5:
self.enemy_formation.move(self.width, self.height, "y", -self.enemy_move_speed)
else:
self.enemy_formation.move(self.width, self.height, "y", self.enemy_move_speed)
if current_action_dir != 0:
if self.last_direction != 0 and current_action_dir != self.last_direction:
if self.steps_since_direction_change < 3:
reward -= 0.1
bullets_to_remove = []
self.steps_since_direction_change = 0
else:
self.steps_since_direction_change += 1
self.last_direction = current_action_dir
if enemy_x != 2.0:
if abs(enemy_x) < 0.03:
reward += 0.1
elif abs(enemy_x) < 0.08:
reward += 0.05
for b in list(self.bullets):
for b in self.bullets:
b.center_y += b.direction_y * BULLET_SPEED
if b.center_y > self.height or b.center_y < 0:
try:
self.bullets.remove(b)
except ValueError:
pass
for b in list(self.bullets):
bullets_to_remove.append(b)
continue
if b.direction_y == 1:
for e in list(self.enemies):
for e in self.enemy_formation.enemies:
if arcade.check_for_collision(b, e):
try:
self.enemies.remove(e)
except ValueError:
pass
try:
self.bullets.remove(b)
except ValueError:
pass
reward += 1.0
self.enemy_formation.remove_enemy(e)
bullets_to_remove.append(b)
reward += 10.0
self.enemies_killed += 1
break
for b in list(self.bullets):
if b.direction_y == -1:
elif b.direction_y == -1 and self.player_alive:
if arcade.check_for_collision(b, self.player):
try:
self.bullets.remove(b)
except ValueError:
pass
reward -= 5.0
bullets_to_remove.append(b)
reward -= 10.0
self.player_alive = False
if self.player_respawns_remaining > 0:
self.player_respawns_remaining -= 1
self._respawn_player()
reward += 2.0
else:
terminated = True
for b in bullets_to_remove:
if b in self.bullets:
self.bullets.remove(b)
if self.player_alive:
lowest_enemy = self._lowest_enemy()
if lowest_enemy and lowest_enemy.center_y <= self.player.center_y:
reward -= 10.0
self.player_alive = False
if self.player_respawns_remaining > 0:
self.player_respawns_remaining -= 1
self._respawn_player()
else:
terminated = True
if not self.enemies:
reward += 10.0
terminated = True
if not self.enemy_formation.enemies:
reward += 50.0
if self.enemy_respawns_remaining > 0:
self.enemy_respawns_remaining -= 1
self._respawn_enemies()
reward += 20.0
else:
reward += 100.0
terminated = True
if self.enemies and random.random() < 0.05:
e = random.choice(self.enemies)
b = Bullet(e.center_x, e.center_y, -1)
self.bullets.append(b)
shooting_prob = 0.05 + (0.05 * (1.0 - len(self.enemy_formation.enemies) / (self.difficulty_settings["enemy_rows"] * self.difficulty_settings["enemy_cols"])))
if self.enemy_formation.enemies and random.random() < shooting_prob:
enemy = self.enemy_formation.get_lowest_enemy()
if enemy:
b = Bullet(enemy.center_x, enemy.center_y, -1)
self.bullets.append(b)
curr_bullet = self._nearest_enemy_bullet()
if curr_bullet is not None:
curr_bx = (curr_bullet.center_x - self.player.center_x) / float(self.width)
else:
curr_bx = 2.0
if self.player_alive:
edge_threshold = self.width * 0.15
if self.player.center_x < edge_threshold or self.player.center_x > self.width - edge_threshold:
reward -= 0.03
if self.prev_bx != 2.0 and curr_bx != 2.0:
if abs(curr_bx) > abs(self.prev_bx):
reward += 0.02
reward -= 0.01
reward -= 0.005
obs = self._obs()
self.prev_bx = curr_bx
return obs, float(reward), bool(terminated), bool(truncated), {}
return obs, float(reward), bool(terminated), bool(truncated), {
"enemies_killed": self.enemies_killed,
"step": self.current_step,
"player_respawns_remaining": self.player_respawns_remaining,
"enemy_respawns_remaining": self.enemy_respawns_remaining
}